| @@ -39,7 +39,6 @@ | |||||
| "from torch.utils.data import DataLoader, Dataset\n", | "from torch.utils.data import DataLoader, Dataset\n", | ||||
| "\n", | "\n", | ||||
| "import torch.nn as nn\n", | "import torch.nn as nn\n", | ||||
| "from torch.nn.utils.rnn import pad_sequence\n", | |||||
| "\n", | "\n", | ||||
| "import transformers\n", | "import transformers\n", | ||||
| "from transformers import AutoTokenizer\n", | "from transformers import AutoTokenizer\n", | ||||
| @@ -50,7 +49,6 @@ | |||||
| "\n", | "\n", | ||||
| "import fastNLP\n", | "import fastNLP\n", | ||||
| "from fastNLP import Trainer\n", | "from fastNLP import Trainer\n", | ||||
| "from fastNLP.core.utils.utils import dataclass_to_dict\n", | |||||
| "from fastNLP.core.metrics import Accuracy\n", | "from fastNLP.core.metrics import Accuracy\n", | ||||
| "\n", | "\n", | ||||
| "print(transformers.__version__)" | "print(transformers.__version__)" | ||||
| @@ -73,134 +71,80 @@ | |||||
| "execution_count": 3, | "execution_count": 3, | ||||
| "metadata": {}, | "metadata": {}, | ||||
| "outputs": [], | "outputs": [], | ||||
| "source": [ | |||||
| "class PromptEncoder(nn.Module):\n", | |||||
| " def __init__(self, template, hidden_size):\n", | |||||
| " nn.Module.__init__(self)\n", | |||||
| " self.template = template\n", | |||||
| " self.hidden_size = hidden_size\n", | |||||
| " self.cloze_mask = [[1] * self.template[0] + [1] * self.template[1]]\n", | |||||
| " self.cloze_mask = torch.LongTensor(self.cloze_mask).bool()\n", | |||||
| "\n", | |||||
| " self.seq_indices = torch.LongTensor(list(range(len(self.cloze_mask[0]))))\n", | |||||
| " # embed\n", | |||||
| " self.embedding = torch.nn.Embedding(len(self.cloze_mask[0]), hidden_size)\n", | |||||
| " # LSTM\n", | |||||
| " self.lstm_head = torch.nn.LSTM(input_size=hidden_size,\n", | |||||
| " hidden_size=hidden_size // 2,\n", | |||||
| " num_layers=2, dropout=0.0,\n", | |||||
| " bidirectional=True, batch_first=True)\n", | |||||
| " # MLP\n", | |||||
| " self.mlp_head = nn.Sequential(nn.Linear(hidden_size, hidden_size),\n", | |||||
| " nn.ReLU(),\n", | |||||
| " nn.Linear(hidden_size, hidden_size))\n", | |||||
| " print(\"init prompt encoder...\")\n", | |||||
| "\n", | |||||
| " def forward(self, device):\n", | |||||
| " input_embeds = self.embedding(self.seq_indices.to(device)).unsqueeze(0)\n", | |||||
| " output_embeds = self.mlp_head(self.lstm_head(input_embeds)[0]).squeeze()\n", | |||||
| " return output_embeds" | |||||
| ] | |||||
| }, | |||||
| { | |||||
| "cell_type": "code", | |||||
| "execution_count": 4, | |||||
| "metadata": {}, | |||||
| "outputs": [], | |||||
| "source": [ | "source": [ | ||||
| "class ClassModel(nn.Module):\n", | "class ClassModel(nn.Module):\n", | ||||
| " def __init__(self, num_labels, model_checkpoint, pseudo_token='[PROMPT]', template=(3, 3)):\n", | |||||
| " def __init__(self, model_checkpoint, num_labels, pre_seq_len):\n", | |||||
| " nn.Module.__init__(self)\n", | " nn.Module.__init__(self)\n", | ||||
| " self.template = template\n", | |||||
| " self.num_labels = num_labels\n", | " self.num_labels = num_labels\n", | ||||
| " self.spell_length = sum(template)\n", | |||||
| " self.tokenizer = AutoTokenizer.from_pretrained(model_checkpoint, use_fast=True)\n", | |||||
| " self.back_bone = AutoModelForSequenceClassification.from_pretrained(model_checkpoint, \n", | " self.back_bone = AutoModelForSequenceClassification.from_pretrained(model_checkpoint, \n", | ||||
| " num_labels=num_labels)\n", | " num_labels=num_labels)\n", | ||||
| " self.embeddings = self.back_bone.get_input_embeddings()\n", | |||||
| "\n", | |||||
| " for param in self.back_bone.parameters():\n", | " for param in self.back_bone.parameters():\n", | ||||
| " param.requires_grad = False\n", | " param.requires_grad = False\n", | ||||
| " self.embeddings = self.back_bone.get_input_embeddings()\n", | |||||
| " \n", | |||||
| " self.hidden_size = self.embeddings.embedding_dim\n", | |||||
| " self.tokenizer.add_special_tokens({'additional_special_tokens': [pseudo_token]})\n", | |||||
| " self.pseudo_token_id = self.tokenizer.get_vocab()[pseudo_token]\n", | |||||
| " self.pad_token_id = self.tokenizer.pad_token_id\n", | |||||
| " \n", | " \n", | ||||
| " self.prompt_encoder = PromptEncoder(self.template, self.hidden_size)\n", | |||||
| "\n", | |||||
| " self.loss_fn = nn.CrossEntropyLoss()\n", | |||||
| "\n", | |||||
| " def get_query(self, query):\n", | |||||
| " device = query.device\n", | |||||
| " return torch.cat([torch.tensor([self.tokenizer.cls_token_id]).to(device), # [CLS]\n", | |||||
| " torch.tensor([self.pseudo_token_id] * self.template[0]).to(device), # [PROMPT]\n", | |||||
| " torch.tensor([self.tokenizer.mask_token_id]).to(device), # [MASK] \n", | |||||
| " torch.tensor([self.pseudo_token_id] * self.template[1]).to(device), # [PROMPT]\n", | |||||
| " query, \n", | |||||
| " torch.tensor([self.tokenizer.sep_token_id]).to(device)], dim=0) # [SEP]\n", | |||||
| " self.pre_seq_len = pre_seq_len\n", | |||||
| " self.prefix_tokens = torch.arange(self.pre_seq_len).long()\n", | |||||
| " self.prefix_encoder = nn.Embedding(self.pre_seq_len, self.embeddings.embedding_dim)\n", | |||||
| " \n", | |||||
| " def get_prompt(self, batch_size):\n", | |||||
| " prefix_tokens = self.prefix_tokens.unsqueeze(0).expand(batch_size, -1).to(self.back_bone.device)\n", | |||||
| " prompts = self.prefix_encoder(prefix_tokens)\n", | |||||
| " return prompts\n", | |||||
| "\n", | "\n", | ||||
| " def forward(self, input_ids):\n", | |||||
| " input_ids = torch.stack([self.get_query(input_ids[i]) for i in range(len(input_ids))])\n", | |||||
| " attention_mask = input_ids != self.pad_token_id\n", | |||||
| " def forward(self, input_ids, attention_mask, labels):\n", | |||||
| " \n", | " \n", | ||||
| " bz = input_ids.shape[0]\n", | |||||
| " inputs_embeds = input_ids.clone()\n", | |||||
| " inputs_embeds[(input_ids == self.pseudo_token_id)] = self.tokenizer.unk_token_id\n", | |||||
| " inputs_embeds = self.embeddings(inputs_embeds)\n", | |||||
| "\n", | |||||
| " blocked_indices = (input_ids == self.pseudo_token_id).nonzero().reshape((bz, self.spell_length, 2))[:, :, 1] # bz\n", | |||||
| " replace_embeds = self.prompt_encoder(input_ids.device)\n", | |||||
| " for bidx in range(bz):\n", | |||||
| " for i in range(self.spell_length):\n", | |||||
| " inputs_embeds[bidx, blocked_indices[bidx, i], :] = replace_embeds[i, :]\n", | |||||
| " batch_size = input_ids.shape[0]\n", | |||||
| " raw_embedding = self.embeddings(input_ids)\n", | |||||
| " \n", | " \n", | ||||
| " return self.back_bone(inputs_embeds=inputs_embeds, attention_mask=attention_mask)\n", | |||||
| " prompts = self.get_prompt(batch_size=batch_size)\n", | |||||
| " inputs_embeds = torch.cat((prompts, raw_embedding), dim=1)\n", | |||||
| " prefix_attention_mask = torch.ones(batch_size, self.pre_seq_len).to(self.back_bone.device)\n", | |||||
| " attention_mask = torch.cat((prefix_attention_mask, attention_mask), dim=1)\n", | |||||
| "\n", | |||||
| " outputs = self.back_bone(inputs_embeds=inputs_embeds, \n", | |||||
| " attention_mask=attention_mask, labels=labels)\n", | |||||
| " return outputs\n", | |||||
| "\n", | "\n", | ||||
| " def train_step(self, input_ids, attention_mask, labels):\n", | " def train_step(self, input_ids, attention_mask, labels):\n", | ||||
| " pred = self(input_ids).logits\n", | |||||
| " return {\"loss\": self.loss_fn(pred, labels)}\n", | |||||
| " return {\"loss\": self(input_ids, attention_mask, labels).loss}\n", | |||||
| "\n", | "\n", | ||||
| " def evaluate_step(self, input_ids, attention_mask, labels):\n", | " def evaluate_step(self, input_ids, attention_mask, labels):\n", | ||||
| " pred = self(input_ids).logits\n", | |||||
| " pred = self(input_ids, attention_mask, labels).logits\n", | |||||
| " pred = torch.max(pred, dim=-1)[1]\n", | " pred = torch.max(pred, dim=-1)[1]\n", | ||||
| " return {\"pred\": pred, \"target\": labels}" | " return {\"pred\": pred, \"target\": labels}" | ||||
| ] | ] | ||||
| }, | }, | ||||
| { | { | ||||
| "cell_type": "code", | "cell_type": "code", | ||||
| "execution_count": 5, | |||||
| "execution_count": 17, | |||||
| "metadata": {}, | "metadata": {}, | ||||
| "outputs": [ | "outputs": [ | ||||
| { | { | ||||
| "name": "stderr", | "name": "stderr", | ||||
| "output_type": "stream", | "output_type": "stream", | ||||
| "text": [ | "text": [ | ||||
| "Some weights of the model checkpoint at distilbert-base-uncased were not used when initializing DistilBertForSequenceClassification: ['vocab_layer_norm.bias', 'vocab_projector.weight', 'vocab_projector.bias', 'vocab_layer_norm.weight', 'vocab_transform.bias', 'vocab_transform.weight']\n", | |||||
| "Some weights of the model checkpoint at distilbert-base-uncased were not used when initializing DistilBertForSequenceClassification: ['vocab_projector.weight', 'vocab_transform.bias', 'vocab_projector.bias', 'vocab_layer_norm.bias', 'vocab_transform.weight', 'vocab_layer_norm.weight']\n", | |||||
| "- This IS expected if you are initializing DistilBertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n", | "- This IS expected if you are initializing DistilBertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n", | ||||
| "- This IS NOT expected if you are initializing DistilBertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n", | "- This IS NOT expected if you are initializing DistilBertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n", | ||||
| "Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['pre_classifier.bias', 'classifier.weight', 'pre_classifier.weight', 'classifier.bias']\n", | |||||
| "Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['classifier.weight', 'pre_classifier.bias', 'pre_classifier.weight', 'classifier.bias']\n", | |||||
| "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n" | "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n" | ||||
| ] | ] | ||||
| }, | |||||
| { | |||||
| "name": "stdout", | |||||
| "output_type": "stream", | |||||
| "text": [ | |||||
| "init prompt encoder...\n" | |||||
| ] | |||||
| } | } | ||||
| ], | ], | ||||
| "source": [ | "source": [ | ||||
| "num_labels = 3 if task.startswith(\"mnli\") else 1 if task == \"stsb\" else 2\n", | "num_labels = 3 if task.startswith(\"mnli\") else 1 if task == \"stsb\" else 2\n", | ||||
| "\n", | "\n", | ||||
| "model = ClassModel(num_labels=num_labels, model_checkpoint=model_checkpoint)\n", | |||||
| "model = ClassModel(num_labels=num_labels, model_checkpoint=model_checkpoint, pre_seq_len=16)\n", | |||||
| "\n", | |||||
| "# Generally, simple classification tasks prefer shorter prompts (less than 20)\n", | |||||
| "\n", | "\n", | ||||
| "optimizers = AdamW(params=model.parameters(), lr=5e-4)" | |||||
| "optimizers = AdamW(params=model.parameters(), lr=5e-3)" | |||||
| ] | ] | ||||
| }, | }, | ||||
| { | { | ||||
| "cell_type": "code", | "cell_type": "code", | ||||
| "execution_count": 6, | |||||
| "execution_count": 5, | |||||
| "metadata": { | "metadata": { | ||||
| "scrolled": false | "scrolled": false | ||||
| }, | }, | ||||
| @@ -209,13 +153,14 @@ | |||||
| "name": "stderr", | "name": "stderr", | ||||
| "output_type": "stream", | "output_type": "stream", | ||||
| "text": [ | "text": [ | ||||
| "Using the latest cached version of the module from /remote-home/xrliu/.cache/huggingface/modules/datasets_modules/datasets/glue/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad (last modified on Thu May 26 15:30:15 2022) since it couldn't be found locally at glue., or remotely on the Hugging Face Hub.\n", | |||||
| "Reusing dataset glue (/remote-home/xrliu/.cache/huggingface/datasets/glue/sst2/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad)\n" | "Reusing dataset glue (/remote-home/xrliu/.cache/huggingface/datasets/glue/sst2/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad)\n" | ||||
| ] | ] | ||||
| }, | }, | ||||
| { | { | ||||
| "data": { | "data": { | ||||
| "application/vnd.jupyter.widget-view+json": { | "application/vnd.jupyter.widget-view+json": { | ||||
| "model_id": "f82d2ccee863492582f94552654482f9", | |||||
| "model_id": "1b73650d43f245ac8a5501dc91c6fe8c", | |||||
| "version_major": 2, | "version_major": 2, | ||||
| "version_minor": 0 | "version_minor": 0 | ||||
| }, | }, | ||||
| @@ -230,46 +175,28 @@ | |||||
| "source": [ | "source": [ | ||||
| "from datasets import load_dataset, load_metric\n", | "from datasets import load_dataset, load_metric\n", | ||||
| "\n", | "\n", | ||||
| "dataset = load_dataset(\"glue\", \"mnli\" if task == \"mnli-mm\" else task)" | |||||
| "dataset = load_dataset(\"glue\", \"mnli\" if task == \"mnli-mm\" else task)\n", | |||||
| "\n", | |||||
| "tokenizer = AutoTokenizer.from_pretrained(model_checkpoint, use_fast=True)" | |||||
| ] | ] | ||||
| }, | }, | ||||
| { | { | ||||
| "cell_type": "code", | "cell_type": "code", | ||||
| "execution_count": 7, | |||||
| "execution_count": 6, | |||||
| "metadata": {}, | "metadata": {}, | ||||
| "outputs": [ | "outputs": [ | ||||
| { | { | ||||
| "data": { | |||||
| "application/vnd.jupyter.widget-view+json": { | |||||
| "model_id": "cf324902e7b94ea9be709b979b425c96", | |||||
| "version_major": 2, | |||||
| "version_minor": 0 | |||||
| }, | |||||
| "text/plain": [ | |||||
| " 0%| | 0/68 [00:00<?, ?ba/s]" | |||||
| ] | |||||
| }, | |||||
| "metadata": {}, | |||||
| "output_type": "display_data" | |||||
| }, | |||||
| { | |||||
| "data": { | |||||
| "application/vnd.jupyter.widget-view+json": { | |||||
| "model_id": "21eb6203ec6f4592b8cb8530a59eda49", | |||||
| "version_major": 2, | |||||
| "version_minor": 0 | |||||
| }, | |||||
| "text/plain": [ | |||||
| " 0%| | 0/1 [00:00<?, ?ba/s]" | |||||
| ] | |||||
| }, | |||||
| "metadata": {}, | |||||
| "output_type": "display_data" | |||||
| "name": "stderr", | |||||
| "output_type": "stream", | |||||
| "text": [ | |||||
| "Loading cached processed dataset at /remote-home/xrliu/.cache/huggingface/datasets/glue/sst2/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad/cache-294e481a713c5754.arrow\n", | |||||
| "Loading cached processed dataset at /remote-home/xrliu/.cache/huggingface/datasets/glue/sst2/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad/cache-ed9d9258aaf0fb54.arrow\n" | |||||
| ] | |||||
| }, | }, | ||||
| { | { | ||||
| "data": { | "data": { | ||||
| "application/vnd.jupyter.widget-view+json": { | "application/vnd.jupyter.widget-view+json": { | ||||
| "model_id": "05b83c4b1a9f44aea805788e1e52db78", | |||||
| "model_id": "0be84915c90f460896b8e67299e09df4", | |||||
| "version_major": 2, | "version_major": 2, | ||||
| "version_minor": 0 | "version_minor": 0 | ||||
| }, | }, | ||||
| @@ -283,14 +210,14 @@ | |||||
| ], | ], | ||||
| "source": [ | "source": [ | ||||
| "def preprocess_function(examples):\n", | "def preprocess_function(examples):\n", | ||||
| " return model.tokenizer(examples['sentence'], truncation=True)\n", | |||||
| " return tokenizer(examples['sentence'], truncation=True)\n", | |||||
| "\n", | "\n", | ||||
| "encoded_dataset = dataset.map(preprocess_function, batched=True)" | "encoded_dataset = dataset.map(preprocess_function, batched=True)" | ||||
| ] | ] | ||||
| }, | }, | ||||
| { | { | ||||
| "cell_type": "code", | "cell_type": "code", | ||||
| "execution_count": 8, | |||||
| "execution_count": 7, | |||||
| "metadata": {}, | "metadata": {}, | ||||
| "outputs": [], | "outputs": [], | ||||
| "source": [ | "source": [ | ||||
| @@ -309,7 +236,7 @@ | |||||
| }, | }, | ||||
| { | { | ||||
| "cell_type": "code", | "cell_type": "code", | ||||
| "execution_count": 9, | |||||
| "execution_count": 8, | |||||
| "metadata": {}, | "metadata": {}, | ||||
| "outputs": [], | "outputs": [], | ||||
| "source": [ | "source": [ | ||||
| @@ -335,7 +262,7 @@ | |||||
| }, | }, | ||||
| { | { | ||||
| "cell_type": "code", | "cell_type": "code", | ||||
| "execution_count": 10, | |||||
| "execution_count": 9, | |||||
| "metadata": {}, | "metadata": {}, | ||||
| "outputs": [], | "outputs": [], | ||||
| "source": [ | "source": [ | ||||
| @@ -349,7 +276,7 @@ | |||||
| }, | }, | ||||
| { | { | ||||
| "cell_type": "code", | "cell_type": "code", | ||||
| "execution_count": 11, | |||||
| "execution_count": 18, | |||||
| "metadata": {}, | "metadata": {}, | ||||
| "outputs": [], | "outputs": [], | ||||
| "source": [ | "source": [ | ||||
| @@ -367,7 +294,7 @@ | |||||
| }, | }, | ||||
| { | { | ||||
| "cell_type": "code", | "cell_type": "code", | ||||
| "execution_count": 12, | |||||
| "execution_count": 19, | |||||
| "metadata": {}, | "metadata": {}, | ||||
| "outputs": [ | "outputs": [ | ||||
| { | { | ||||
| @@ -410,7 +337,7 @@ | |||||
| }, | }, | ||||
| { | { | ||||
| "cell_type": "code", | "cell_type": "code", | ||||
| "execution_count": 13, | |||||
| "execution_count": 20, | |||||
| "metadata": {}, | "metadata": {}, | ||||
| "outputs": [ | "outputs": [ | ||||
| { | { | ||||
| @@ -436,10 +363,10 @@ | |||||
| { | { | ||||
| "data": { | "data": { | ||||
| "text/plain": [ | "text/plain": [ | ||||
| "{'acc#acc': 0.565367, 'total#acc': 872.0, 'correct#acc': 493.0}" | |||||
| "{'acc#acc': 0.644495, 'total#acc': 872.0, 'correct#acc': 562.0}" | |||||
| ] | ] | ||||
| }, | }, | ||||
| "execution_count": 13, | |||||
| "execution_count": 20, | |||||
| "metadata": {}, | "metadata": {}, | ||||
| "output_type": "execute_result" | "output_type": "execute_result" | ||||
| } | } | ||||