{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# E2. 使用 continuous prompt 完成 SST2 分类" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n",
"\n"
],
"text/plain": [
"\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"4.18.0\n"
]
}
],
"source": [
"import torch\n",
"from torch.optim import AdamW\n",
"from torch.utils.data import DataLoader, Dataset\n",
"\n",
"import torch.nn as nn\n",
"\n",
"import transformers\n",
"from transformers import AutoTokenizer\n",
"from transformers import AutoModelForSequenceClassification\n",
"\n",
"import sys\n",
"sys.path.append('..')\n",
"\n",
"import fastNLP\n",
"from fastNLP import Trainer\n",
"from fastNLP.core.metrics import Accuracy\n",
"\n",
"print(transformers.__version__)"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"GLUE_TASKS = [\"cola\", \"mnli\", \"mnli-mm\", \"mrpc\", \"qnli\", \"qqp\", \"rte\", \"sst2\", \"stsb\", \"wnli\"]\n",
"\n",
"task = \"sst2\"\n",
"model_checkpoint = \"distilbert-base-uncased\""
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"class ClassModel(nn.Module):\n",
" def __init__(self, model_checkpoint, num_labels, pre_seq_len):\n",
" nn.Module.__init__(self)\n",
" self.num_labels = num_labels\n",
" self.back_bone = AutoModelForSequenceClassification.from_pretrained(model_checkpoint, \n",
" num_labels=num_labels)\n",
" self.embeddings = self.back_bone.get_input_embeddings()\n",
"\n",
" for param in self.back_bone.parameters():\n",
" param.requires_grad = False\n",
" \n",
" self.pre_seq_len = pre_seq_len\n",
" self.prefix_tokens = torch.arange(self.pre_seq_len).long()\n",
" self.prefix_encoder = nn.Embedding(self.pre_seq_len, self.embeddings.embedding_dim)\n",
" \n",
" def get_prompt(self, batch_size):\n",
" prefix_tokens = self.prefix_tokens.unsqueeze(0).expand(batch_size, -1).to(self.back_bone.device)\n",
" prompts = self.prefix_encoder(prefix_tokens)\n",
" return prompts\n",
"\n",
" def forward(self, input_ids, attention_mask, labels):\n",
" \n",
" batch_size = input_ids.shape[0]\n",
" raw_embedding = self.embeddings(input_ids)\n",
" \n",
" prompts = self.get_prompt(batch_size=batch_size)\n",
" inputs_embeds = torch.cat((prompts, raw_embedding), dim=1)\n",
" prefix_attention_mask = torch.ones(batch_size, self.pre_seq_len).to(self.back_bone.device)\n",
" attention_mask = torch.cat((prefix_attention_mask, attention_mask), dim=1)\n",
"\n",
" outputs = self.back_bone(inputs_embeds=inputs_embeds, \n",
" attention_mask=attention_mask, labels=labels)\n",
" return outputs\n",
"\n",
" def train_step(self, input_ids, attention_mask, labels):\n",
" return {\"loss\": self(input_ids, attention_mask, labels).loss}\n",
"\n",
" def evaluate_step(self, input_ids, attention_mask, labels):\n",
" pred = self(input_ids, attention_mask, labels).logits\n",
" pred = torch.max(pred, dim=-1)[1]\n",
" return {\"pred\": pred, \"target\": labels}"
]
},
{
"cell_type": "code",
"execution_count": 17,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Some weights of the model checkpoint at distilbert-base-uncased were not used when initializing DistilBertForSequenceClassification: ['vocab_projector.weight', 'vocab_transform.bias', 'vocab_projector.bias', 'vocab_layer_norm.bias', 'vocab_transform.weight', 'vocab_layer_norm.weight']\n",
"- This IS expected if you are initializing DistilBertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
"- This IS NOT expected if you are initializing DistilBertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n",
"Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['classifier.weight', 'pre_classifier.bias', 'pre_classifier.weight', 'classifier.bias']\n",
"You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
]
}
],
"source": [
"num_labels = 3 if task.startswith(\"mnli\") else 1 if task == \"stsb\" else 2\n",
"\n",
"model = ClassModel(num_labels=num_labels, model_checkpoint=model_checkpoint, pre_seq_len=16)\n",
"\n",
"# Generally, simple classification tasks prefer shorter prompts (less than 20)\n",
"\n",
"optimizers = AdamW(params=model.parameters(), lr=5e-3)"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {
"scrolled": false
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Using the latest cached version of the module from /remote-home/xrliu/.cache/huggingface/modules/datasets_modules/datasets/glue/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad (last modified on Thu May 26 15:30:15 2022) since it couldn't be found locally at glue., or remotely on the Hugging Face Hub.\n",
"Reusing dataset glue (/remote-home/xrliu/.cache/huggingface/datasets/glue/sst2/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad)\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "1b73650d43f245ac8a5501dc91c6fe8c",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/3 [00:00, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"from datasets import load_dataset, load_metric\n",
"\n",
"dataset = load_dataset(\"glue\", \"mnli\" if task == \"mnli-mm\" else task)\n",
"\n",
"tokenizer = AutoTokenizer.from_pretrained(model_checkpoint, use_fast=True)"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Loading cached processed dataset at /remote-home/xrliu/.cache/huggingface/datasets/glue/sst2/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad/cache-294e481a713c5754.arrow\n",
"Loading cached processed dataset at /remote-home/xrliu/.cache/huggingface/datasets/glue/sst2/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad/cache-ed9d9258aaf0fb54.arrow\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "0be84915c90f460896b8e67299e09df4",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/2 [00:00, ?ba/s]"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"def preprocess_function(examples):\n",
" return tokenizer(examples['sentence'], truncation=True)\n",
"\n",
"encoded_dataset = dataset.map(preprocess_function, batched=True)"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
"class TestDistilBertDataset(Dataset):\n",
" def __init__(self, dataset):\n",
" super(TestDistilBertDataset, self).__init__()\n",
" self.dataset = dataset\n",
"\n",
" def __len__(self):\n",
" return len(self.dataset)\n",
"\n",
" def __getitem__(self, item):\n",
" item = self.dataset[item]\n",
" return item[\"input_ids\"], item[\"attention_mask\"], [item[\"label\"]] "
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [],
"source": [
"def test_bert_collate_fn(batch):\n",
" input_ids, atten_mask, labels = [], [], []\n",
" max_length = [0] * 3\n",
" for each_item in batch:\n",
" input_ids.append(each_item[0])\n",
" max_length[0] = max(max_length[0], len(each_item[0]))\n",
" atten_mask.append(each_item[1])\n",
" max_length[1] = max(max_length[1], len(each_item[1]))\n",
" labels.append(each_item[2])\n",
" max_length[2] = max(max_length[2], len(each_item[2]))\n",
"\n",
" for i in range(3):\n",
" each = (input_ids, atten_mask, labels)[i]\n",
" for item in each:\n",
" item.extend([0] * (max_length[i] - len(item)))\n",
" return {\"input_ids\": torch.cat([torch.tensor([item]) for item in input_ids], dim=0),\n",
" \"attention_mask\": torch.cat([torch.tensor([item]) for item in atten_mask], dim=0),\n",
" \"labels\": torch.cat([torch.tensor(item) for item in labels], dim=0)}"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [],
"source": [
"dataset_train = TestDistilBertDataset(encoded_dataset[\"train\"])\n",
"dataloader_train = DataLoader(dataset=dataset_train, \n",
" batch_size=32, shuffle=True, collate_fn=test_bert_collate_fn)\n",
"dataset_valid = TestDistilBertDataset(encoded_dataset[\"validation\"])\n",
"dataloader_valid = DataLoader(dataset=dataset_valid, \n",
" batch_size=32, shuffle=False, collate_fn=test_bert_collate_fn)"
]
},
{
"cell_type": "code",
"execution_count": 18,
"metadata": {},
"outputs": [],
"source": [
"trainer = Trainer(\n",
" model=model,\n",
" driver='torch',\n",
" device='cuda',\n",
" n_epochs=10,\n",
" optimizers=optimizers,\n",
" train_dataloader=dataloader_train,\n",
" evaluate_dataloaders=dataloader_valid,\n",
" metrics={'acc': Accuracy()}\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 19,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"\n"
],
"text/plain": []
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"\n"
],
"text/plain": []
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"\n",
"\n"
],
"text/plain": [
"\n"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"trainer.run(num_eval_batch_per_dl=10)"
]
},
{
"cell_type": "code",
"execution_count": 20,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"\n"
],
"text/plain": []
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"\n"
],
"text/plain": []
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"{'acc#acc': 0.644495, 'total#acc': 872.0, 'correct#acc': 562.0}"
]
},
"execution_count": 20,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"trainer.evaluator.run()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.13"
},
"pycharm": {
"stem_cell": {
"cell_type": "raw",
"metadata": {
"collapsed": false
},
"source": []
}
}
},
"nbformat": 4,
"nbformat_minor": 1
}