Skip to main content

Generative AI: Faster fine-tuning of a Mistral Model with less memory

 

Another blog post starts with you beautiful people💥. I hope you have started your Generative AI learning from my last post and if you have not, then I recommend you read that one before proceeding to this blog.  In this post, we will delve into the transformative power of generative AI, specifically exploring the cutting-edge Mistral 7B model fine-tuning🚀. This revolutionary technology has not only redefined the boundaries of artificial intelligence but also sparked a paradigm shift in how we approach data generation and creativity. Join me on a journey through the fascinating intersection of machine learning and creativity, where Mistral 7B stands as a beacon of innovation, pushing the boundaries of what's possible in the realm of generative AI👈.

Fine-tuning a large language model(LLM) refers to the process of training the model on a specific, smaller dataset to adapt it to a particular task or domain. Large language models, like ChatGPT, are pre-trained on massive datasets to learn general language patterns and context. However, fine-tuning allows users to tailor the model for more specific applications, such as sentiment analysis, question-answering, or domain-specific content generation👌.

But Fine-tuning an LLM has its own challenges👿 if you want to do it in your local system due to the following reasons-

1. Memory Requirements: Large language models often have a considerable number of parameters, resulting in high memory requirements. Fine-tuning involves loading both the pre-trained model and the dataset into memory, and low RAM can limit the size of the model and the dataset that can be efficiently processed.

2. Processing Power: Fine-tuning is a computationally intensive task. Low RAM computers may lack the necessary processing power to handle the optimization process efficiently. This can lead to slow training times and potential issues with model convergence.

3. Batch Size Constraints: During fine-tuning, models are typically trained in batches to improve computational efficiency. Low RAM can restrict the batch size, impacting the model's ability to generalize well and slowing down the convergence of the training process.

4. Data Loading Constraints: Loading and processing large datasets during fine-tuning can strain the available memory. Insufficient RAM may lead to data loading errors or inefficient use of available resources.

To address these challenges, we are going to use a cutting-edge fine-tuning approach:QLORA but with 2.2X faster & 62% less memory💁. Can you believe that? Yes, we are so let's start this in our colab notebook using T4 GPU only. To achieve this faster and less RAM-consuming fine-tuning approach we are going to use a library called unsloth👮. This fantastic library can be used to fine-tune other LLMs like Llama 2 7b as well so don't forget to give a star in github to the library👍. For our learning, we will stick to my favorite Mistral 7b. This library has a built-in 4-bit pre-quantized Mistral 7b model which means neither we need to download this model from elsewhere nor do we have to convert the original model to 4-bit👌. Another benefit of using this library is that it supports Huggingface's TRL, Trainer, Seq2SeqTrainer, or even Pytorch code!

First, we will install the unsloth library based on our GPU type since it is sensitive to GPU type, and then the Transformers library from the Hugging Face which provides APIs and tools to easily download and train state-of-the-art pre-trained models💪 -


Next, we download the pre-trained 4-bit Mistral 7B model with its tokenizer like the below-


See, how easily and with very few lines of code we can download the model💫. You will see the following output in the console while running the above code-

Before running the next line of code, we need to know that large-scale unsupervised language models (LMs) learn broad world knowledge and some reasoning skills, achieving precise control of their behavior is difficult😑 due to the completely unsupervised nature of their training. Existing methods for gaining such steerability collect human labels of the relative quality of model generations and fine-tune the unsupervised LM to align with these preferences, often with reinforcement learning from human feedback (RLHF). 

However, RLHF is a complex and often unstable procedure😓, first fitting a reward model that reflects the human preferences, and then fine-tuning the large unsupervised LM using reinforcement learning to maximize this estimated reward without drifting too far from the original model. To address this big issue, we are going to use another cutting-edge algorithm called DPO 💪which can fine-tune LMs to align with human preferences as well as or better than existing methods. Also, fine-tuning with DPO exceeds PPO-based RLHF in the ability to control the sentiment of generations and matches or improves response quality in summarization and single-turn dialogue while being substantially simpler to implement and train. 

Don't worry using the DPO algorithm is also very easy due to unsloth library💓. We will add LoRA adapters to update 1 to 10% of all parameters for patching the model like below-


Here, r is the dimension of the low-rank matrices, lora_alpha is the scaling factor for the low-rank matrices and lora_dropout is the dropout probability of the LoRA layers.

Next, we need a dataset that we can fine-tune. For this purpose, we will use cleaned 52k samples of alpaca-cleaned dataset from the original Alpaca Dataset released by Stanford. The Alpaca dataset is designed for instruction training pre-trained language models💨. But you can replace this with your own dataset. It looks like in the below format-


Since from our last blog post we know that the Mistral 7B model expects a specific format, we will convert this dataset to the specific format after loading it like below-


The downloading will look like below-


Next, we can start training the model on this formatted dataset using the Huggingface TRL's SFTTrainer. TRL is a full stack library to train transformer language models with Reinforcement Learning, from the Supervised Fine-tuning step (SFT), and Reward Modeling step (RM) to the Proximal Policy Optimization (PPO) step. The library is integrated with transformers-


Here, per_device_train_batch_size is the batch size per GPU/TPU core/CPU for training, gradient_accumulation_steps is the number of update steps to accumulate the gradients before performing a backward/update pass, fp16 is whether to use 16-bit (mixed) precision training (through NVIDIA apex) instead of 32-bit training. 

Before starting the training, we can check the memory of the notebook as below-


Now, we will start the training using the following command-


For 1 epoch, it took approx 15:36 mins to complete the training and after the training completion when I checked the memory I was quite surprised to see the stats as below-


Now, we can test our fine-tuned model on a given prompt🙋. For example here I am asking my model to complete the Fibonacci series as below-

Here, in the prompt template, I have given one sample sequence of the Fibonacci series and then I asked it to complete the next sequence with the max token size of 128. You can increase this size to see more completion in the final output. I got the following output from my model-

Looks promising right!😇 We successfully fine-tuned a 4-bit quantized Mistral 7B model to a dataset. Next, to save this fine-tuned model, we can use Huggingface's push_to_hub for an online save or save_pretrained for a local save like the below-

Next, to load this saved model, we can use the PeftModel class from the Hugging Face as below-

That's it guys. In this blog post, we learned about the loading of a 4-bit quantized Mistral 7b model & fine-tuning of that model on a dataset. This process taught us some cutting-edge algorithms like QLORA & DPO. We learned about using a very powerful large language model on a simple T4 GPU-enabled colab notebook. Phew that was indeed quite a lot of learning today👷. But don't need to wait, do your own hands-on by copying my shared notebook in your colab environment and play with the parameters of the APIs I used in the notebook. In the next post, we will learn about the retrieval augmented generation in Gen AI, till then 👉 Go chase your dreams, have an awesome day, make every second count, and see you later in my next post😇




Comments

  1. It encompasses a range of techniques and theories from mathematics, statistics, computer science, and domain-specific knowledge. Here’s a comprehensive guide to understanding data science, its components, and the opportunities it presents, including information about Data Science courses in Delhi.

    Data Science courses in Delhi

    ReplyDelete

Post a Comment

Popular posts from this blog

How to deploy your ML model as Fast API?

Another post starts with you beautiful people! Thank you all for showing so much interests in my last posts about object detection and recognition using YOLOv4. I was very happy to see many aspiring data scientists have learnt from my past three posts about using YOLOv4. Today I am going to share you all a new skill to learn. Most of you have seen my post about  deploying and consuming ML models as Flask API   where we have learnt to deploy and consume a keras model with Flask API  . In this post you are going to learn a new framework-  FastAPI to deploy your model as Rest API. After completing this post you will have a new industry standard skill. What is FastAPI? FastAPI is a modern, fast (high-performance), web framework for building APIs with Python 3.6+ based on standard Python type hints. It is easy to learn, fast to code and ready for production . Yes, you heard it right! Flask is not meant to be used in production but with FastAPI you can use you...

Learn the fastest way to build data apps

Another post starts with you beautiful people! I hope you have enjoyed and learned something new from my previous three posts about machine learning model deployment. In one post we have learned  How to deploy a model as FastAPI?  I n the second post, we have learned  How to deploy a deep learning model as RestAPI ? and in the third post, we have also learned  How to scale your deep learning model API?   If you are following my blog posts, you have seen how easily you have transit yourselves from aspiring to a mature data scientist. In this new post, I am going to share a new framework-  Streamlit which will help you to easily create a beautiful app with Python only. I will show here how had I used the Streamlit framework to create an app for my YOLOv3 custom model. What is Streamlit? Streamlit’s open-source app framework is the easiest way for data scientists and machine learning engineers to create beautiful, performant apps in only a few hours!...

How can I make a simple ChatBot?

Another post starts with you beautiful people! It has been a long time of posting a new post. But my friends in this period I was not sitting  where I got a chance to work with chatbot and classification related machine learning problem. So in this post I am going to share all about chatbot- from where I have learned? What I have learned? And how can you build your first bot? Quite interesting right! Chatbot is a program that can conduct an intelligent conversation based on user's input. Since chatbot is a new thing to me also, I first searched- is there any Python library available to start with this? And like always Python has helped me this time also. There is a Python library available with name as  ChatterBot   which is nothing but a machine learning conversational dialog engine. And yes that is all I want to start my learning because I always prefer inbuilt Python library to start my learning journey and once I learn this then only I move ahead for another...