Blog

Fine-tuning the LLaMA 2 model on RTX 4090

- Team Vast

July 24, 2023-GPU

Introduction

Meta has recently launched Llama 2, the latest edition of the widely recognized Llama models, which has been trained with a 40% increase in data. The announcement of this new model is quite thrilling, considering the meteoric rise in popularity of open-source large language models (LLMs) like Llama. The Llama LLM bares some similarities to other large language models, but it’s commercially available for free, which levels the playing field. Llama 2 was pre-trained on an enormous dataset of publicly available online text and code. The fine-tuned model, Llama-2-chat, was trained on this dataset as well as over 1 million human annotations. These models have given rise to popular offspring such as Vicuna and Falcon.

Concurrently, the open-source community has launched an abundance of utilities designed to fine-tune and deploy these language models. Tools such as Peft, Bitsandbytes, and TRL allow for fine-tuning of LLMs on machines that don't have the capacity to hold the full precision model in their GPU RAM.

The aim of this blog post is to guide you on how to fine-tune Llama 2 models on the Vast platform. We have benchmarked this on an RTX 3090, RTX 4090, and A100 SMX4 80GB. See the latest pricing on Vast for up to the minute on-demand rental prices.

Benchmarks

The RTX 4090 demonstrates an impressive 1.5 8-bit samples/sec with a batch size of 8. This is almost twice as fast as running this on an A100 when accounting for batch size! Considering that the RTX 4090 is $0.50/hr, the price for performance is about 6X when compared to an A100 for $1.50/hr.

ModelGPUBatch Size4bit samples/sec[1]8bit samples/sec[1]
7BRTX 4090811.5
7BA100 SMX4 80GB320.50.8
7BA100 RTX 309080.50.9

[1] Here samples/sec is calculated by multiplying batch size by the inverse of s/iter that the sft_trainer script reports. All training runs had gradient accumulation steps equal to 1.

How to run this yourself on Vast

Follow the rest of this post for a guide on how to fine-tune LLaMA 2 on Vast using TRL.

Sign up with Meta and Hugging Face for access:

1- Request access from Meta here: https://ai.meta.com/resources/models-and-libraries/llama-downloads/

2- Request access from Hugging Face on any of the model pages: https://huggingface.co/meta-llama/Llama-2-7b

3 - Set up an auth token with Hugging Face here: https://huggingface.co/settings/tokens

You will need to do both 1 and 2 in order to get access to LLaMA 2.

Rent a powerful GPU on Vast.ai

Vast has RTX 3090s, RTX 4090s and A100s for on-demand rentals. Our pricing is typically the best you can find online.

To run LLaMA 2 fine tuning, you will want to use a Pytorch image on the machine you pick. To do that, click on this Vast console link, which will select our recommended Pytorch template with SSH enabled and other settings enabled. If you don’t have an account set up, no problem! You will need to first quickly set up a Vast account by registering your email, verifying your email, and then purchasing credits.

To rent a machine, pick an RTX 3090 or RTX 4090 by selecting that filter in the upper right. Move the storage slider over on the left of the interface to ~30GB so your instance will have enough storage to download the model weights.

Hit the rent button to start the instance. Once it is done loading, hit the blue <_ button to get the SSH details and SSH into the instance by copying and pasting that SSH command to your command line. You will now be logged into your Vast instance via SSH. If you have trouble setting up SSH, read our SSH docs.

Set up environment

Setup the packages you will need:

pip install transformers peft trl bitsandbytes scipy

Clone the TRL repo for the training script

git clone https://github.com/lvwerra/trl

Log into HuggingFace on CLI

huggingface-cli login

Copy the auth token you created earlier (from https://huggingface.co/settings/tokens) and paste it into the prompt when asked. You can decline adding the token to your git credentials.

Fine tune!

python trl/examples/scripts/sft_trainer.py --model_name meta-llama/Llama-2-7b-hf --dataset_name timdettmers/openassistant-guanaco --load_in_8bit --use_peft --batch_size 8 --gradient_accumulation_steps 1

This will download the model weights automatically, so the first time you run, it will take a bit to actually start training. The dataset is specified there and it is using the openassistant-guanaco set. If you want to try loading in the 4bit version or change the batch size, modify that command.

You should end up seeing output like this:

{'loss': 1.6493, 'learning_rate': 1.4096181965881397e-05, 'epoch': 0.0}
{'loss': 1.3571, 'learning_rate': 1.4092363931762796e-05, 'epoch': 0.0}
{'loss': 1.5853, 'learning_rate': 1.4088545897644193e-05, 'epoch': 0.0}
{'loss': 1.4237, 'learning_rate': 1.408472786352559e-05, 'epoch': 0.0}
{'loss': 1.7098, 'learning_rate': 1.4080909829406987e-05, 'epoch': 0.0}
{'loss': 1.4348, 'learning_rate': 1.4077091795288384e-05, 'epoch': 0.0}
{'loss': 1.6022, 'learning_rate': 1.407327376116978e-05, 'epoch': 0.01}
{'loss': 1.3352, 'learning_rate': 1.4069455727051177e-05, 'epoch': 0.01}

Summary

The price for performance on Vast when using an RTX 4090 is quite impressive when compared to an A100. In this post we have shown to easy it is to spin up a very low cost GPU ($0.20 per hour) and fine-tune the LLaMA 2 models.

Share on
  • Contact
  • Get in Touch