Logo

dev-resources.site

for different kinds of informations.

Training a small Language model using Knowledge distillation on Amazon SageMaker

Published at
10/15/2022
Categories
nlp
aws
Author
Gokul S
Categories
2 categories in total
nlp
open
aws
open
Training a small Language model using Knowledge distillation on Amazon SageMaker

In this article, we will learn about the language model and knowledge distillation. We will perform task-specific knowledge distillation to our student model with the help of Amazon Sagemaker compute resources.

The complete code can be found in this github page: https://github.com/gokulsg/BERT-models-complete-code/blob/main/knowledge_distillation.py

In recent years, pre-trained transformer-based models have produced state-of-the-art performance in several NLP tasks. The self-attention is a key component in transformer-based models which enable parallel processing and support training on multiple GPUs. These transformer models can be broadly categorized into three broad classes:

  • Encoder-based models - the models that use only the encoder block of the transformer network e.g: BERT, RoBERTa
  • Decoder-based models - the models that use only the decoder block of the transformer network e.g: GPT
  • Encoder-Decoder models - Have both encoder and decoder modules e.g: T5, BART

In this article, we will be focusing mainly on the encoder-based model, BERT. These encoder-based models are usually trained in two stages. The first stage is the pre-training stage, where the model is trained using the Masked language modeling (MLM) objective. MLM forces the model to learn the meaning in both directions to predict the masked word. The second stage of training is fine-tuning, where the model is trained on a specific downstream task like sentiment classification.

Knowledge distillation is one of the popular model compression approaches. It involves transferring knowledge from a huge teacher model to a tiny student model. There are two different types of knowledge distillation:

  1. Task-specific knowledge distillation: Knowledge distillation happens at the fine-tuning stage only for a specific task.

  2. Task-Agnostic knowledge distillation: The knowledge distillation mainly happens in the pre-training phase in a task-agnostic way. These models after distillation can be fine-tuned for any task.

We will use the hugging face library for implementing task-specific knowledge distillation.

student_id = "distilbert-base-uncased"
teacher_id = "textattack/bert-base-uncased-SST-2"

we are using the distilbERT model as a student and the BERT base model as a teacher. The distilBERT has only 6 encoder layers whereas the original BERT base model has 12 encoder layers. So, the student model will be having half the number of encoder layers when compared with the teacher.

from transformers import AutoTokenizer

# init tokenizer
teacher_tokenizer = AutoTokenizer.from_pretrained(teacher_id)
student_tokenizer = AutoTokenizer.from_pretrained(student_id)

# sample input
sample = "Testing tokenizers."

# assert results
assert teacher_tokenizer(sample) == student_tokenizer(sample), "Tokenizers produced different output"

We need to confirm that the tokenizers from both teacher and student models generated similar tokenization results. If not, the performance of the student model will be affected.

We will be using the Stanford Sentiment Treebank (sst-2) dataset (2-class sentiment classification dataset) for the experiment.

from datasets import load_dataset

dataset = load_dataset('glue','sst2')

we have initialized the dataset for the experiment now we need to do the tokenization.

from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained(teacher_id)

def process(examples):
    tokenized_inputs = tokenizer(
        examples["sentence"], truncation=True, max_length=512
    )
    return tokenized_inputs

tokenized_datasets = dataset.map(process, batched=True)
tokenized_datasets = tokenized_datasets.rename_column("label","labels")

Now, we write a class for performing knowledge distillation.

from transformers import TrainingArguments, Trainer
import torch
import torch.nn as nn
import torch.nn.functional as F

class DistillationTrainingArguments(TrainingArguments):
    def __init__(self, *args, alpha=0.5, temperature=2.0, **kwargs):
        super().__init__(*args, **kwargs)

        self.alpha = alpha
        self.temperature = temperature

class DistillationTrainer(Trainer):
    def __init__(self, *args, teacher_model=None, **kwargs):
        super().__init__(*args, **kwargs)
        self.teacher = teacher_model
        # place teacher on same device as student
        self._move_model_to_device(self.teacher,self.model.device)
        self.teacher.eval()

    def compute_loss(self, model, inputs, return_outputs=False):

        # compute student output
        outputs_student = model(**inputs)
        student_loss=outputs_student.loss
        # compute teacher output
        with torch.no_grad():
          outputs_teacher = self.teacher(**inputs)

        # assert size
        assert outputs_student.logits.size() == outputs_teacher.logits.size()

        # Soften probabilities and compute distillation loss
        loss_function = nn.KLDivLoss(reduction="batchmean")
        loss_logits = (loss_function(
            F.log_softmax(outputs_student.logits / self.args.temperature, dim=-1),
            F.softmax(outputs_teacher.logits / self.args.temperature, dim=-1)) * (self.args.temperature ** 2))
        # Return weighted student loss
        loss = self.args.alpha * student_loss + (1. - self.args.alpha) * loss_logits
        return (loss, outputs_student) if return_outputs else loss

Now, we can specify the hyperparameters and start the training.

from transformers import AutoModelForSequenceClassification, DataCollatorWithPadding
from huggingface_hub import HfFolder

# create label2id, id2label dicts for nice outputs for the model
labels = tokenized_datasets["train"].features["labels"].names
num_labels = len(labels)
label2id, id2label = dict(), dict()
for i, label in enumerate(labels):
    label2id[label] = str(i)
    id2label[str(i)] = label

# define training args
training_args = DistillationTrainingArguments(
    num_train_epochs=3,
    per_device_train_batch_size=64,
    per_device_eval_batch_size=64,
    learning_rate=5e-5,
    metric_for_best_model="accuracy",
    alpha=0.5,
    temperature=3.0
    )

# define data_collator
data_collator = DataCollatorWithPadding(tokenizer=tokenizer)

# define model
teacher_model = AutoModelForSequenceClassification.from_pretrained(
    teacher_id,
    num_labels=num_labels,
    id2label=id2label,
    label2id=label2id,
)

# define student model
student_model = AutoModelForSequenceClassification.from_pretrained(
    student_id,
    num_labels=num_labels,
    id2label=id2label,
    label2id=label2id,
)

trainer = DistillationTrainer(
    student_model,
    training_args,
    teacher_model=teacher_model,
    train_dataset=tokenized_datasets["train"],
    eval_dataset=tokenized_datasets["validation"],
    data_collator=data_collator,
    tokenizer=tokenizer,
    compute_metrics=compute_metrics,
)

trainer.train()
trainer.evaluate()

Here, we have used accuracy as a performance evaluation metric. For training using AWS sagemaker few modifications have to be made to the existing code.

from sagemaker.huggingface import HuggingFace

# hyperparameters, which are passed into the training job #
hyperparameters={
    'teacher_id':'textattack/bert-base-uncased-SST-2',
    'student_id':'distilbert-base-uncased',
    'dataset_id':'glue',
    'dataset_config':'sst2',
    # distillation parameter
    'alpha': 0.5,
    'temparature': 3,
}

# create the Estimator #
huggingface_estimator = HuggingFace(..., hyperparameters=hyperparameters)

# start knowledge distillation training #
huggingface_estimator.fit()

By this way, we can create a compact student model that would be several times smaller and faster than the teacher model, enabling them to be easily deployed on mobile/ low-resourced devices. Also, these student models can perform as good as the teacher models.

In the next article, we will learn about another model compression technique Quantization.

Reference:

https://www.philschmid.de/knowledge-distillation-bert-transformers

Featured ones: