Behind the scenes of IMAGINE 7B training
Edited on
Introduction
Interaction Model for Advanced Graphics Inference and Exploration
In this blog post, I will take you through the journey of creating IMAGINE, an LLM (Language Learning Model) trained to craft top-quality prompts for Midjourney.
Over the course of several days, I encountered challenges, made improvements, and ultimately arrived at a satisfying result.
Day #1: A Magical Beginning
I started by selecting a base model to train IMAGINE on.
The first choice was LLaMa2 13B because of its performances at various benchmarks. According to Meta AI it is out-performing or equaling MPT and Falcon, both in 30B and 40B respectively.
But, running a 13B model locally requires a lot of resource. Since I want IMAGINE to be available to almost everyone, I had to make compromise and turn to a 7B model.
After few hours of looking at various benchmark results, one model caught my attention: Mistral 7B.
Trained by Mistral AI (a french company 🐓) it allegedly out-performs LLaMa2 7B and 13B ! — More infos
With the base model now finally selected, a small dataset of 1,000 prompts, an A100 GPU and almost 0 knowledge on how to fine-tune an LLM, I began the journey!
This was my first dive into training AI from scratch, and the experience felt nothing short of magical.
After this first day the plan ahead was clear:
- Enhance the dataset
- Fix bugs
- Deploy IMAGINE on HuggingFace for everyone to use
Day #2: Tackling Overfitting
Following the successful first day of training with a small dataset, I decided to retry the experiment with a much larger dataset: 6,000 prompts
But, during the training process, I encountered an issue that I read about on day 1: Overfitting
Overfitting occurs when a machine learning model is excessively trained on a specific dataset to the point where it starts to memorize the patterns and noise in that particular dataset, rather than learning the general underlying patterns.
As a result, the model becomes highly specialized and performs extremely well on the training data, but fails to generalize its knowledge to unseen data.
Additionally, I faced another issue with the dataset formatting — But nothing major, mostly aesthetic elements in the response from the model.
Day #3: Fixing Previous Sessions Issues
To mitigate the overfitting issue encountered in the previous session, I reduced the dataset to around 2,000 prompts and completely reworked the training workflow.
First, I trained on batches of 500 prompts and cross-validated the trainer reports. But since the dataset was not big enough in the first place, it was a complete waste of time.
After a couple of hours of research, I decided to increase the frequency at which the trainer reported the validation_loss
and train_loss
to pin point exactly where the model started to overfit, so I could stop the training. And it worked!
Overfitting was no longer an issue and in the meantime I update the way the dataset was formatted to fix the issue that I had on the previous session.
This day marked significant progress, leaving me excited for the next steps.
Day #4: Fine-tuning for Better Instruction Understanding
Day #4 was mostly about refining the dataset changing the way each prompt was described.
Training took over 20 minutes on an A100 (80GB), during which I encountered a new issue. Despite detailed instructions, prompts sometimes failed to transcribe elements accurately due to a lack of diverse data.
To address this, there was only two options:
- Generate more images with Midjourney featuring diverse prompts.
- Merge the dataset with an open-source one.
Although I preferred the first option, it requires significant time and financial investment to work.
Day #5: Refining the Model
Few days went by between day #4 and #5. During that time I generated thousands of images with Midjourney to feed the dataset.
By significantly increasing the dataset size to include longer and shorter prompts, I saw the repetition issue, observed earlier with a smaller dataset, vanish.
Yet again, like day #3, I rewrote the training workflow to adopt a faster and more versatile approach to training future LLM with it.
The final dataset contains a little under 10,000 labeled prompts, and, as predicated, required many hours of dedicated prompting in Midjourney
To train v0.2, I used a different GPU cluster than previously:
- 4x Nvidia Tesla V100S (32GiB) GPUs
- 160 GiB RAM
- 52 vCores CPU.
Training lasted 6 hours and resulted in approximately 0.72kg of CO² emissions.
Conclusion
Creating IMAGINE, has been an exciting and challenging journey.
From the first steps of learning how to fine-tune an AI, to training, refining the model and addressing glitches, each day was filled with suspense and uncertainty.
As we move closer to the official release, I am thrilled with the results and eager to share IMAGINE with the world.