Hugo Ventura

Behind the scenes of IMAGINE 7B training

Edited on


Introduction

Interaction Model for Advanced Graphics Inference and Exploration

In this blog post, I will take you through the journey of creating IMAGINE, an LLM (Language Learning Model) trained to craft top-quality prompts for Midjourney.

Over the course of several days, I encountered challenges, made improvements, and ultimately arrived at a satisfying result.

Day #1: A Magical Beginning

I started by selecting a base model to train IMAGINE on.

The first choice was LLaMa2 13B because of its performances at various benchmarks. According to Meta AI it is out-performing or equaling MPT and Falcon, both in 30B and 40B respectively.

But, running a 13B model locally requires a lot of resource. Since I want IMAGINE to be available to almost everyone, I had to make compromise and turn to a 7B model.

After few hours of looking at various benchmark results, one model caught my attention: Mistral 7B.
Trained by Mistral AI (a french company 🐓) it allegedly out-performs LLaMa2 7B and 13B ! — More infos

With the base model now finally selected, a small dataset of 1,000 prompts, an A100 GPU and almost 0 knowledge on how to fine-tune an LLM, I began the journey!

This was my first dive into training AI from scratch, and the experience felt nothing short of magical.

After this first day the plan ahead was clear:

  • Enhance the dataset
  • Fix bugs
  • Deploy IMAGINE on HuggingFace for everyone to use

Day #2: Tackling Overfitting

Following the successful first day of training with a small dataset, I decided to retry the experiment with a much larger dataset: 6,000 prompts

But, during the training process, I encountered an issue that I read about on day 1: Overfitting

Overfitting occurs when a machine learning model is excessively trained on a specific dataset to the point where it starts to memorize the patterns and noise in that particular dataset, rather than learning the general underlying patterns.
As a result, the model becomes highly specialized and performs extremely well on the training data, but fails to generalize its knowledge to unseen data.

Additionally, I faced another issue with the dataset formatting — But nothing major, mostly aesthetic elements in the response from the model.

Day #3: Fixing Previous Sessions Issues

To mitigate the overfitting issue encountered in the previous session, I reduced the dataset to around 2,000 prompts and completely reworked the training workflow.

First, I trained on batches of 500 prompts and cross-validated the trainer reports. But since the dataset was not big enough in the first place, it was a complete waste of time.

After a couple of hours of research, I decided to increase the frequency at which the trainer reported the validation_loss and train_loss to pin point exactly where the model started to overfit, so I could stop the training. And it worked!

Overfitting was no longer an issue and in the meantime I update the way the dataset was formatted to fix the issue that I had on the previous session.

This day marked significant progress, leaving me excited for the next steps.

Day #4: Fine-tuning for Better Instruction Understanding

Day #4 was mostly about refining the dataset changing the way each prompt was described.

Training took over 20 minutes on an A100 (80GB), during which I encountered a new issue. Despite detailed instructions, prompts sometimes failed to transcribe elements accurately due to a lack of diverse data.

To address this, there was only two options:

  1. Generate more images with Midjourney featuring diverse prompts.
  2. Merge the dataset with an open-source one.

Although I preferred the first option, it requires significant time and financial investment to work.

Day #5: Refining the Model

Few days went by between day #4 and #5. During that time I generated thousands of images with Midjourney to feed the dataset.

By significantly increasing the dataset size to include longer and shorter prompts, I saw the repetition issue, observed earlier with a smaller dataset, vanish.

Yet again, like day #3, I rewrote the training workflow to adopt a faster and more versatile approach to training future LLM with it.

The final dataset contains a little under 10,000 labeled prompts, and, as predicated, required many hours of dedicated prompting in Midjourney

To train v0.2, I used a different GPU cluster than previously:

  • 4x Nvidia Tesla V100S (32GiB) GPUs
  • 160 GiB RAM
  • 52 vCores CPU.

Training lasted 6 hours and resulted in approximately 0.72kg of CO² emissions.

Conclusion

Creating IMAGINE, has been an exciting and challenging journey.

From the first steps of learning how to fine-tune an AI, to training, refining the model and addressing glitches, each day was filled with suspense and uncertainty.

As we move closer to the official release, I am thrilled with the results and eager to share IMAGINE with the world.