# E2. 使用 continuous prompt 完成 SST2 分类

In [1]:
import torch
from torch.optim import AdamW
from torch.utils.data import DataLoader, Dataset

import torch.nn as nn

import transformers
from transformers import AutoTokenizer
from transformers import AutoModelForSequenceClassification

import sys
sys.path.append('..')

import fastNLP
from fastNLP import Trainer
from fastNLP.core.metrics import Accuracy

print(transformers.__version__)

4.18.0


In [2]:
GLUE_TASKS = ["cola", "mnli", "mnli-mm", "mrpc", "qnli", "qqp", "rte", "sst2", "stsb", "wnli"]

task = "sst2"
model_checkpoint = "distilbert-base-uncased"

In [3]:
class ClassModel(nn.Module):
    def __init__(self, model_checkpoint, num_labels, pre_seq_len):
        nn.Module.__init__(self)
        self.num_labels = num_labels
        self.back_bone = AutoModelForSequenceClassification.from_pretrained(model_checkpoint, 
                                                                            num_labels=num_labels)
        self.embeddings = self.back_bone.get_input_embeddings()

        for param in self.back_bone.parameters():
            param.requires_grad = False
        
        self.pre_seq_len = pre_seq_len
        self.prefix_tokens = torch.arange(self.pre_seq_len).long()
        self.prefix_encoder = nn.Embedding(self.pre_seq_len, self.embeddings.embedding_dim)
    
    def get_prompt(self, batch_size):
        prefix_tokens = self.prefix_tokens.unsqueeze(0).expand(batch_size, -1).to(self.back_bone.device)
        prompts = self.prefix_encoder(prefix_tokens)
        return prompts

    def forward(self, input_ids, attention_mask, labels):
        
        batch_size = input_ids.shape[0]
        raw_embedding = self.embeddings(input_ids)
        
        prompts = self.get_prompt(batch_size=batch_size)
        inputs_embeds = torch.cat((prompts, raw_embedding), dim=1)
        prefix_attention_mask = torch.ones(batch_size, self.pre_seq_len).to(self.back_bone.device)
        attention_mask = torch.cat((prefix_attention_mask, attention_mask), dim=1)

        outputs = self.back_bone(inputs_embeds=inputs_embeds, 
                                 attention_mask=attention_mask, labels=labels)
        return outputs

    def train_step(self, input_ids, attention_mask, labels):
        return {"loss": self(input_ids, attention_mask, labels).loss}

    def evaluate_step(self, input_ids, attention_mask, labels):
        pred = self(input_ids, attention_mask, labels).logits
        pred = torch.max(pred, dim=-1)[1]
        return {"pred": pred, "target": labels}

In [17]:
num_labels = 3 if task.startswith("mnli") else 1 if task == "stsb" else 2

model = ClassModel(num_labels=num_labels, model_checkpoint=model_checkpoint, pre_seq_len=16)

# Generally, simple classification tasks prefer shorter prompts (less than 20)

optimizers = AdamW(params=model.parameters(), lr=5e-3)

Some weights of the model checkpoint at distilbert-base-uncased were not used when initializing DistilBertForSequenceClassification: ['vocab_projector.weight', 'vocab_transform.bias', 'vocab_projector.bias', 'vocab_layer_norm.bias', 'vocab_transform.weight', 'vocab_layer_norm.weight']
- This IS expected if you are initializing DistilBertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing DistilBertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['classifier.weight', 'pre_classifier.bias', 'pre_classi

In [5]:
from datasets import load_dataset, load_metric

dataset = load_dataset("glue", "mnli" if task == "mnli-mm" else task)

tokenizer = AutoTokenizer.from_pretrained(model_checkpoint, use_fast=True)

Using the latest cached version of the module from /remote-home/xrliu/.cache/huggingface/modules/datasets_modules/datasets/glue/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad (last modified on Thu May 26 15:30:15 2022) since it couldn't be found locally at glue., or remotely on the Hugging Face Hub.
Reusing dataset glue (/remote-home/xrliu/.cache/huggingface/datasets/glue/sst2/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad)


  0%|          | 0/3 [00:00<?, ?it/s]

In [6]:
def preprocess_function(examples):
    return tokenizer(examples['sentence'], truncation=True)

encoded_dataset = dataset.map(preprocess_function, batched=True)

Loading cached processed dataset at /remote-home/xrliu/.cache/huggingface/datasets/glue/sst2/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad/cache-294e481a713c5754.arrow
Loading cached processed dataset at /remote-home/xrliu/.cache/huggingface/datasets/glue/sst2/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad/cache-ed9d9258aaf0fb54.arrow


  0%|          | 0/2 [00:00<?, ?ba/s]

In [7]:
class TestDistilBertDataset(Dataset):
    def __init__(self, dataset):
        super(TestDistilBertDataset, self).__init__()
        self.dataset = dataset

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, item):
        item = self.dataset[item]
        return item["input_ids"], item["attention_mask"], [item["label"]] 

In [8]:
def test_bert_collate_fn(batch):
    input_ids, atten_mask, labels = [], [], []
    max_length = [0] * 3
    for each_item in batch:
        input_ids.append(each_item[0])
        max_length[0] = max(max_length[0], len(each_item[0]))
        atten_mask.append(each_item[1])
        max_length[1] = max(max_length[1], len(each_item[1]))
        labels.append(each_item[2])
        max_length[2] = max(max_length[2], len(each_item[2]))

    for i in range(3):
        each = (input_ids, atten_mask, labels)[i]
        for item in each:
            item.extend([0] * (max_length[i] - len(item)))
    return {"input_ids": torch.cat([torch.tensor([item]) for item in input_ids], dim=0),
            "attention_mask": torch.cat([torch.tensor([item]) for item in atten_mask], dim=0),
            "labels": torch.cat([torch.tensor(item) for item in labels], dim=0)}

In [9]:
dataset_train = TestDistilBertDataset(encoded_dataset["train"])
dataloader_train = DataLoader(dataset=dataset_train, 
                              batch_size=32, shuffle=True, collate_fn=test_bert_collate_fn)
dataset_valid = TestDistilBertDataset(encoded_dataset["validation"])
dataloader_valid = DataLoader(dataset=dataset_valid, 
                              batch_size=32, shuffle=False, collate_fn=test_bert_collate_fn)

In [18]:
trainer = Trainer(
    model=model,
    driver='torch',
    device='cuda',
    n_epochs=10,
    optimizers=optimizers,
    train_dataloader=dataloader_train,
    evaluate_dataloaders=dataloader_valid,
    metrics={'acc': Accuracy()}
)

In [19]:
trainer.run(num_eval_batch_per_dl=10)

In [20]:
trainer.evaluator.run()

{'acc#acc': 0.644495, 'total#acc': 872.0, 'correct#acc': 562.0}