How To Train Your Own GenAI Model

How To Train Your Own GenAI Model

Train lightweight LLMs like GPT2 for specialized tasks without having to use ChatGPT

While GenAI is dominating the news and earnings reports across the world, a lot of emphasis is focused on massive, powerful, and expensive large language models (LLMs). Every business wants to slam their credit card on the table and integrate AI in their business. You can accomplish a lot by hooking up an input on your platform to OpenAI’s GPT service and getting good results. It’s relatively affordable, results are usually best in class, and sometimes it’s the easiest solution.

The benefits of using GPT3+ are clear, but if your use case is simple enough, it actually makes more sense to use GPT2. First, GPT2 is open source, so you can use your business’s private data without exposing it externally. Secondly, GPT2 is far more lightweight (distilgpt2 weighs less than 400MB). In fact, GPT2 can be implemented on mobile phones (including CoreML support on iOS). This means that you can have a lightweight model, running on private data, with the durability to operate even when devices are offline.

If I was to summarize the goal of this article, it’s that we’re going to learn to light a campfire with a lighter (GPT2) and not a flamethrower (GPT3.5).

I need a drink

We’re going to build a GenAI model that takes the name of a cocktail (real or made up) and generates the ingredients for that cocktail. To visualize how a transformer will learn, let’s take a look at some of the data we’re going to train.

Name Ingredients
Moscow Mule [vodka, ginger beer]
Mexican Mule [tequila, ginger beer]
Classic Negroni [campari, gin]
Oaxacan Negroni [mezcal, campari]

A model will learn the associations with each recipe in the same way we do. Cocktails that have mule in their name almost always contain ginger beer in them. Comparatively, Oaxacan usually means the recipe calls for mezcal, not tequila. What if someone suggests a Oaxacan Mule, and it’s not in our data set? This tutorial will teach us how to train a model to ✨generate✨ a recipe.

Getting started

Before we get started, it’s important to note that there is a wide range of open-source models you can use to train for these Seq2Seq models. A Seq2Seq model learns how the relationships from one sequence tie to another. For example, a phrase in English has a corresponding phrase in Italian. Training a Seq2Seq on English phrases and their Italian counterparts effectively builds us a translator.

Similarly, a Seq2Seq trained on Q&A will know how to translate an incoming question into the proper answer. You can train a model on abstracts and papers to teach it to summarize or expand on certain topics. Unlike multiclass models, the output is continuous, like linear regression. In a lot of ways, Seq2Seq is kind of like regression for text. A more in-depth reading can be found here.

We’ll focus mostly on GPT2 since it is the most popular, but I also encourage you to experiment with other transformers like T5. Whereas GPT works like an auto-complete, where your input will also be in the output, T5 just gives you the output.

GPT T5
GPT(input) = input + output T5(input) = output

For this example it won’t really matter, but if you’re going to use this on larger bodies of text, you might want to use T5 so your model doesn’t need to have so many inputs.

Lastly, I must emphasize that you will need either an NVIDIA GPU or Apple Silicon for this exercise. If you do not have a laptop with Apple Silicon (M-Series chip), you can access free GPU resources on Google Colab.

I don’t have all day

The real challenge of training these models is actually engineering a training process to run efficiently. Here are the core concepts you need to consider before you start training. Small changes can change your iterative training time from hours to minutes. The pluses and minuses below will tell you whether increasing something will speed up training (+) or (-) slow it.

  • GPU VRAM (+): The GPU you use will be the most impactful tool when it comes to training time. GPU RAM (VRAM) will determine how much you can train concurrently, but it will not necessarily train your data faster. If you are training with larger models like gpt2-large, or have larger batch sizes or lengths, you might want a GPU with a much higher GPU RAM. This is helpful if you’re experimenting with multiple models in parallel in order to find the model that gives you the best performance. In 2023 a commercially available GPU compute gets up to 48GB of VRAM.
  • GPU Compute (+): Some GPUs are faster than others. There are newer GPUs that optimize for speed (like the RTX 4090) and will rip through training much faster than other models. You don’t need to use a 4090 for this exercise, but if you need to speed your training time, use a faster GPU.
  • Max Length (-): Simply put, what is the longest input (or output) we expect to see in our model? This will determine how many inputs and outputs the model learns to align on. For faster training time, minimize max length. For example, if 95% of your training data has a length of under 64 tokens (think characters), but 5% has a length of up to 512 tokens, consider clipping out that 5%. It will save you tons of time and likely will give you similar results.
  • Batch Size (-): This won’t change the training time too much, but depending on your VRAM, it will determine how much data can be trained at once. There might be some actual different model performances based on batch size. A high batch size is like drinking out of a fire hose, so keep that in mind as you experiment with different training parameters.
  • Number of Records (-): More records, longer train time. Depending on the model you use, you might need more or less data. Smaller transformers need more data than larger ones. If you have too much data, try reducing your data set. We want diversity, not size here. 10K great records are better than 1M mediocre ones. It might not change your performance that much, but reducing the number of records will speed up your training time. This open-source data set is very small, with less than 10K records.

Tutorial

You’ll need the following packages installed to run all of the code for this tutorial.

pip install torch torchtext transformers sentencepiece pandas tqdm

First, download the data from Huggingface and convert it to a pandas DataFrame. The data set also includes units and quantities for each ingredient, but we’re going to remove those. Quantities are an exceptionally difficult problem for these models. While it is possible, you would probably need to train other transformers to parse and standardize the units and quantities. It’s all fun and games until the model recommends 13 cups of vodka for every 1 cup of orange juice because it doesn’t know how to read fractions.

from data sets import load_dataset, DatasetDict, Dataset
import pandas as pd
import ast
import data sets
from tqdm import tqdm
import time

# Load data set from huggingface
data set = load_dataset("erwanlc/cocktails_recipe_no_brand")

# Convert to a pandas dataframe
data = [{'title': item['title'], 'raw_ingredients': item['raw_ingredients']} for item in data set['train']]
df = pd.DataFrame(data)

# Just extract the ingredient names, nothing else
df.raw_ingredients = df.raw_ingredients.apply(lambda x: ', '.join([y[1] for y in ast.literal_eval(x)]))
display(df.head())
title raw_ingredients
Abacaxi Ricaço pineapple, white rum, lime juice, white sugar
Abbey gin, lillet blanc, orange juice, angostura
A.B.C. Cocktail mint leaves, tawny port, cognac, maraschino, sugar syrup
Absinthe Cocktail absinthe, chilled water, sugar syrup
Absinthe Frappé absinthe, anisette liqueur, chilled water, sugar syrup

Lastly, let’s import the transformers and establish the GPU (that’s device). The tokenizer always accompanies the model, since it encodes (and decodes) the input and output text for both us and the models. Each model series has its own tokenizer, so be sure not to mix them up! The GPT tokenizer works for all GPT models, but it won’t work on a T5, which has its own tokenizer. We’re going to use distilgpt2 since it’s very close to the performance of gpt2 while being much lighter and faster. If you want to use gpt2 or gpt2-large, all the code should still work, but you will likely have to modify one of the parameters we mentioned above since they will be larger models.

from transformers import GPT2Tokenizer, GPT2LMHeadModel
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, random_split

# Models need to be attached to hardware
# If you have an NVIDIA GPU attached, use 'cuda'
if torch.cuda.is_available():
    self.device = torch.device('cuda')
else:
    # If Apple Silicon, set to 'mps' - otherwise 'cpu' (not advised)
    try:
        self.device = torch.device('mps')
    except Exception:
        self.device = torch.device('cpu')

# The tokenizer turns texts to numbers (and vice-versa)
tokenizer = GPT2Tokenizer.from_pretrained('distilgpt2')

# The transformer
model = GPT2LMHeadModel.from_pretrained('distilgpt2').to(device)

# Model params
BATCH_SIZE = 8

Clip data to reduce max length

Let’s take a look at the distribution of lengths for the inputs and outputs for this model. We want to find the long-tail and remove it from our data set. Your cutoff should be somewhere after the long-tail starts. Make a visual guess from the histogram or use the 75th percentile as a loose gauge. Do not use the actual max from your data unless you have to.

df.describe()

# count    6956.000000
# mean       98.147930
# std        32.458764
# min        10.000000
# 25%        75.000000
# 50%        95.000000
# 75%       117.000000 <- Tip: Cut off at next largest power of 2 (128)
# max       292.000000
# Name: raw_ingredients, dtype: float

Lock and loaders

Unlike sklearn, transformers will need a little bit of handling to prepare the data to be read and loaded into the training loop. We’re going to use a DataLoader and Dataset extension to do that. From PyTorch:

PyTorch provides two data primitives: torch.utils.data.DataLoader and torch.utils.data.Dataset that allow you to use preloaded data sets as well as your own data. Dataset stores the samples and their corresponding labels, and DataLoader wraps an iterable around the Dataset to enable easy access to the samples.

# Dataset Prep
class LanguageDataset(Dataset):
    """
    An extension of the Dataset object to:
      - Make training loop cleaner
      - Make ingestion easier from pandas df's
    """
    def __init__(self, df, tokenizer):
        self.labels = df.columns
        self.data = df.to_dict(orient='records')
        self.tokenizer = tokenizer
        x = fittest_max_length(df)
        self.max_length = x

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        x = self.data[idx][self.labels[0]]
        y = self.data[idx][self.labels[1]]
        text = f"{x} | {y}"
        tokens = self.tokenizer.encode_plus(text, return_tensors='pt', max_length=128, padding='max_length', truncation=True)
        return tokens

    def fittest_max_length(df):
      """
      Smallest power of two larger than the longest term in the data set.
      Important to set up max length to speed training time.
      """
      max_length = max(len(max(df[df.columns[0]], key=len)), len(max(df[df.columns[1]], key=len)))
      x = 2
      while x < max_length: x = x * 2
      return x

# Cast the Huggingface data set as a LanguageDataset we defined above
data set = LanguageDataset(df, tokenizer)

# Create train, valid
train_size = int(0.8 * len(dataset)
valid_size = len(dataset) - train_size
train_data, valid_data = random_split(dataset, [train_size, valid_size])

# Make the iterators
train_loader = DataLoader(train_data, batch_size=BATCH_SIZE, shuffle=True)
valid_loader = DataLoader(valid_data, batch_size=BATCH_SIZE)

Training

We finally made it! This is what the training loop for GPT2 will look like. Note that this isn’t a universal training loop. Other models will have slightly different loops. Since these jobs take time, we’re adding tqdm to integrate progress bars with our training. This will help you assess how long an epoch should train and give us a live view into the performance while the model is training. I highly recommend it since it will let you know quickly whether you need to adjust a training parameter to speed things up.

# Set the number of epochs
num_epochs = 3

# Training parameters
batch_size = BATCH_SIZE
model_name = 'distilgpt2'
gpu = 0

# Set the learning rate and loss function
## CrossEntropyLoss measures how close answers to the truth.
## More punishing for high confidence wrong answers
criterion = nn.CrossEntropyLoss(ignore_index = tokenizer.pad_token_id)
optimizer = optim.Adam(model.parameters(), lr=5e-4)
tokenizer.pad_token = tokenizer.eos_token

# Init a results dataframe
results = pd.DataFrame(columns=['epoch', 'transformer', 'batch_size', 'gpu',
                                'training_loss', 'validation_loss', 'epoch_duration_sec'])
# The training loop
for epoch in range(num_epochs):
    start_time = time.time()  # Start the timer for the epoch

    # Training
    ## This line tells the model we're in 'learning mode'
    model.train()
    epoch_training_loss = 0
    train_iterator = tqdm(train_loader, desc=f"Training Epoch {epoch+1}/{num_epochs} Batch Size: {batch_size}, Transformer: {model_name}")
    for batch in train_iterator:
        optimizer.zero_grad()
        inputs = batch['input_ids'].squeeze(1).to(device)
        targets = inputs.clone()
        outputs = model(input_ids=inputs, labels=targets)
        loss = outputs.loss
        loss.backward()
        optimizer.step()
        train_iterator.set_postfix({'Training Loss': loss.item()})
        epoch_training_loss += loss.item()
    avg_epoch_training_loss = epoch_training_loss / len(train_iterator)

    # Validation
    ## This line below tells the model to 'stop learning'
    model.eval()
    epoch_validation_loss = 0
    total_loss = 0
    valid_iterator = tqdm(valid_loader, desc=f"Validation Epoch {epoch+1}/{num_epochs}")
    with torch.no_grad():
        for batch in valid_iterator:
            inputs = batch['input_ids'].squeeze(1).to(device)
            targets = inputs.clone()
            outputs = model(input_ids=inputs, labels=targets)
            loss = outputs.loss
            total_loss += loss
            valid_iterator.set_postfix({'Validation Loss': loss.item()})
            epoch_validation_loss += loss.item()

    avg_epoch_validation_loss = epoch_validation_loss / len(valid_loader)

    end_time = time.time()  # End the timer for the epoch
    epoch_duration_sec = end_time - start_time  # Calculate the duration in seconds

    new_row = {'transformer': model_name,
               'batch_size': batch_size,
               'gpu': gpu,
               'epoch': epoch+1,
               'training_loss': avg_epoch_training_loss,
               'validation_loss': avg_epoch_validation_loss,
               'epoch_duration_sec': epoch_duration_sec}  # Add epoch_duration to the dataframe

    results.loc[len(results)] = new_row
    print(f"Epoch: {epoch+1}, Validation Loss: {total_loss/len(valid_loader)}")

Once the training is done, let’s take a look at the results. The metric we want to pay attention to here is the validation_loss. We want to minimize this metric until it hits a minimum. It seems we hit our minimum after two epochs. You can run this on more epochs, but given the fact that we only have around 7K records, we probably won’t see any improvements after epoch two. I’ve also run this on six epochs, and I’d like to save you some time by confirming that it does not, in fact, get better.

Training Epoch 1/3 Batch Size: 8, Transformer: distilgpt2: 100%|██████████| 696/696 [02:30<00:00,  4.61it/s, Training Loss=0.316]
Validation Epoch 1/3: 100%|██████████| 174/174 [00:10<00:00, 16.34it/s, Validation Loss=0.342]
Epoch: 1, Validation Loss: 0.36245641112327576

Training Epoch 2/3 Batch Size: 8, Transformer: distilgpt2: 100%|██████████| 696/696 [02:27<00:00,  4.72it/s, Training Loss=0.29]
Validation Epoch 2/3: 100%|██████████| 174/174 [00:10<00:00, 16.47it/s, Validation Loss=0.32]
Epoch: 2, Validation Loss: 0.34115156531333923

Training Epoch 3/3 Batch Size: 8, Transformer: distilgpt2: 100%|██████████| 696/696 [02:27<00:00,  4.73it/s, Training Loss=0.223]
Validation Epoch 3/3: 100%|██████████| 174/174 [00:10<00:00, 16.42it/s, Validation Loss=0.342]
Epoch: 3, Validation Loss: 0.34880194067955017

Results

Now that we have a GenAI model trained on the recipes, we have a cocktail generator and can predict on recipe names! Let’s look at an example below. As you can see, GPT2 generation requires a couple of parameters accompanying the input string. The values I have placed in these parameters have already been tuned for cocktail generation. After you are done training your data, make sure you then tune the output generation to validate model output with desired output.

input_str = "espresso martini"
input_ids = tokenizer.encode(input_str, return_tensors='pt').to(device)

output = model.generate(
    input_ids,
    max_length=16,
    num_return_sequences=1,
    do_sample=True,
    top_k=8,
    top_p=0.95,
    temperature=0.5,
    repetition_penalty=1.2
)

decoded_output = tokenizer.decode(output[0], skip_special_tokens=True)
print(decoded_output)

So when we pass the example for espresso martini, then we get the following:

espresso martini | vodka, espresso coffee, sugar syrup, bob's chocolate

This looks correct! But what happens when we pass the hypothetical Oaxacan Mule we talked about before?

Oaxacan Mule | mezcal, lime juice, sugar syrup,

So we got the mezcal and the lime juice we would expect from a Mule, but we are missing ginger beer. If we go into the data we trained on, it looks like ginger beer is actually missing from a lot of Mule drinks. However, virtually all of them include lime juice. So the model is learning correctly, but with incomplete data. As always, make sure you check the underlying training data to validate your model qualitatively:

title raw_ingredients
Dead Man's Mule absinthe, cinnamon schnapps & goldwasser liqueurs, giffard orgeat syrup, lime juice, gin
French Mule cognac, lime juice, sugar syrup, angostura, gin
Gin Gin Mule gin, gin, lime juice, sugar syrup, mint leaves, gin
Jamaican Mule spiced rum, lime juice, sugar syrup, gin
Limey Mule vodka, lime juice, sugar syrup, gin

Conclusion

Now that you have trained your model, you can save it. Now it’s up to you to deploy it. The distilgpt2 model in particular is small and easy to deploy.

torch.save(model, 'drinkGenGPT2.pt')

Now that you are more comfortable training a GPT2 model, there are a couple of things you can do to continue your learning.

  • Apply to Your Data: While this is a small exercise with open-source, noncommercially licensed data, the most powerful application of this technique is going to be within your own org. You can now leverage micro LLMs on your own data (or your org’s data) for new applications for your stakeholders and customers.
  • Bigger Can Be Better: This is a very small model trained on a very small data set. I would highly recommend trying to run this technique with more data, larger models, and more GPU compute. Experiment with different model sizes on the same data. For simple applications like this one, you might find that, while larger models perform slightly better, they weigh (and cost) orders of magnitude more than the lighter models.
  • Push the Edge: For small enough models like the one we used in this article, you can actually deploy these models on edge devices like mobile phones. GPT2 models like distilgpt and gpt2 can even be deployed on iOS applications. They will weigh ~500MB, which is a tremendous size for an iOS application. However, if you expect your model to be a core feature, it might actually justify the weight. Consider the advantage of having a local GPT model on an edge device that can truly act autonomously without connecting to the internet.

I’m excited to see what you build!


This post is intended for educational purposes only. You are solely responsible for your use of any AI products and services mentioned herein. The content in this article does not in any way constitute legal advice relating to, or permission to use for any purpose, any of the referenced products or services. Your use of any AI product, service, or feature is subject to its own terms of use and privacy policies.

Table Of Contents