| @@ -2,13 +2,14 @@ | |||
| bert.py is modified from huggingface/pytorch-pretrained-BERT, which is licensed under the Apache License 2.0. | |||
| """ | |||
| import os | |||
| import torch | |||
| from torch import nn | |||
| from .base_model import BaseModel | |||
| from ..core.const import Const | |||
| from ..modules.encoder import BertModel | |||
| from ..modules.encoder.bert import BertConfig | |||
| from ..modules.encoder.bert import BertConfig, CONFIG_FILE | |||
| class BertForSequenceClassification(BaseModel): | |||
| @@ -54,6 +55,7 @@ class BertForSequenceClassification(BaseModel): | |||
| self.num_labels = num_labels | |||
| if bert_dir is not None: | |||
| self.bert = BertModel.from_pretrained(bert_dir) | |||
| config = BertConfig(os.path.join(bert_dir, CONFIG_FILE)) | |||
| else: | |||
| if config is None: | |||
| config = BertConfig(30522) | |||
| @@ -67,20 +69,20 @@ class BertForSequenceClassification(BaseModel): | |||
| model = cls(num_labels=num_labels, config=config, bert_dir=pretrained_model_dir) | |||
| return model | |||
| def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=None): | |||
| _, pooled_output = self.bert(input_ids, token_type_ids, attention_mask, output_all_encoded_layers=False) | |||
| def forward(self, words, seq_len=None, target=None): | |||
| _, pooled_output = self.bert(words, attention_mask=seq_len, output_all_encoded_layers=False) | |||
| pooled_output = self.dropout(pooled_output) | |||
| logits = self.classifier(pooled_output) | |||
| if labels is not None: | |||
| if target is not None: | |||
| loss_fct = nn.CrossEntropyLoss() | |||
| loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) | |||
| loss = loss_fct(logits, target) | |||
| return {Const.OUTPUT: logits, Const.LOSS: loss} | |||
| else: | |||
| return {Const.OUTPUT: logits} | |||
| def predict(self, input_ids, token_type_ids=None, attention_mask=None): | |||
| logits = self.forward(input_ids, token_type_ids, attention_mask) | |||
| def predict(self, words, seq_len=None): | |||
| logits = self.forward(words, seq_len=seq_len)[Const.OUTPUT] | |||
| return {Const.OUTPUT: torch.argmax(logits, dim=-1)} | |||
| @@ -140,7 +142,8 @@ class BertForMultipleChoice(BaseModel): | |||
| model = cls(num_choices=num_choices, config=config, bert_dir=pretrained_model_dir) | |||
| return model | |||
| def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=None): | |||
| def forward(self, words, seq_len1=None, seq_len2=None, target=None): | |||
| input_ids, token_type_ids, attention_mask = words, seq_len1, seq_len2 | |||
| flat_input_ids = input_ids.view(-1, input_ids.size(-1)) | |||
| flat_token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) | |||
| flat_attention_mask = attention_mask.view(-1, attention_mask.size(-1)) | |||
| @@ -149,15 +152,15 @@ class BertForMultipleChoice(BaseModel): | |||
| logits = self.classifier(pooled_output) | |||
| reshaped_logits = logits.view(-1, self.num_choices) | |||
| if labels is not None: | |||
| if target is not None: | |||
| loss_fct = nn.CrossEntropyLoss() | |||
| loss = loss_fct(reshaped_logits, labels) | |||
| loss = loss_fct(reshaped_logits, target) | |||
| return {Const.OUTPUT: reshaped_logits, Const.LOSS: loss} | |||
| else: | |||
| return {Const.OUTPUT: reshaped_logits} | |||
| def predict(self, input_ids, token_type_ids=None, attention_mask=None): | |||
| logits = self.forward(input_ids, token_type_ids, attention_mask)[Const.OUTPUT] | |||
| def predict(self, words, seq_len1=None, seq_len2=None,): | |||
| logits = self.forward(words, seq_len1=seq_len1, seq_len2=seq_len2)[Const.OUTPUT] | |||
| return {Const.OUTPUT: torch.argmax(logits, dim=-1)} | |||
| @@ -219,27 +222,27 @@ class BertForTokenClassification(BaseModel): | |||
| model = cls(num_labels=num_labels, config=config, bert_dir=pretrained_model_dir) | |||
| return model | |||
| def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=None): | |||
| sequence_output, _ = self.bert(input_ids, token_type_ids, attention_mask, output_all_encoded_layers=False) | |||
| def forward(self, words, seq_len1=None, seq_len2=None, target=None): | |||
| sequence_output, _ = self.bert(words, seq_len1, seq_len2, output_all_encoded_layers=False) | |||
| sequence_output = self.dropout(sequence_output) | |||
| logits = self.classifier(sequence_output) | |||
| if labels is not None: | |||
| if target is not None: | |||
| loss_fct = nn.CrossEntropyLoss() | |||
| # Only keep active parts of the loss | |||
| if attention_mask is not None: | |||
| active_loss = attention_mask.view(-1) == 1 | |||
| if seq_len2 is not None: | |||
| active_loss = seq_len2.view(-1) == 1 | |||
| active_logits = logits.view(-1, self.num_labels)[active_loss] | |||
| active_labels = labels.view(-1)[active_loss] | |||
| active_labels = target.view(-1)[active_loss] | |||
| loss = loss_fct(active_logits, active_labels) | |||
| else: | |||
| loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) | |||
| loss = loss_fct(logits.view(-1, self.num_labels), target.view(-1)) | |||
| return {Const.OUTPUT: logits, Const.LOSS: loss} | |||
| else: | |||
| return {Const.OUTPUT: logits} | |||
| def predict(self, input_ids, token_type_ids=None, attention_mask=None): | |||
| logits = self.forward(input_ids, token_type_ids, attention_mask)[Const.OUTPUT] | |||
| def predict(self, words, seq_len1=None, seq_len2=None): | |||
| logits = self.forward(words, seq_len1, seq_len2)[Const.OUTPUT] | |||
| return {Const.OUTPUT: torch.argmax(logits, dim=-1)} | |||
| @@ -304,34 +307,34 @@ class BertForQuestionAnswering(BaseModel): | |||
| model = cls(config=config, bert_dir=pretrained_model_dir) | |||
| return model | |||
| def forward(self, input_ids, token_type_ids=None, attention_mask=None, start_positions=None, end_positions=None): | |||
| sequence_output, _ = self.bert(input_ids, token_type_ids, attention_mask, output_all_encoded_layers=False) | |||
| def forward(self, words, seq_len1=None, seq_len2=None, target1=None, target2=None): | |||
| sequence_output, _ = self.bert(words, seq_len1, seq_len2, output_all_encoded_layers=False) | |||
| logits = self.qa_outputs(sequence_output) | |||
| start_logits, end_logits = logits.split(1, dim=-1) | |||
| start_logits = start_logits.squeeze(-1) | |||
| end_logits = end_logits.squeeze(-1) | |||
| if start_positions is not None and end_positions is not None: | |||
| if target1 is not None and target2 is not None: | |||
| # If we are on multi-GPU, split add a dimension | |||
| if len(start_positions.size()) > 1: | |||
| start_positions = start_positions.squeeze(-1) | |||
| if len(end_positions.size()) > 1: | |||
| end_positions = end_positions.squeeze(-1) | |||
| if len(target1.size()) > 1: | |||
| target1 = target1.squeeze(-1) | |||
| if len(target2.size()) > 1: | |||
| target2 = target2.squeeze(-1) | |||
| # sometimes the start/end positions are outside our model inputs, we ignore these terms | |||
| ignored_index = start_logits.size(1) | |||
| start_positions.clamp_(0, ignored_index) | |||
| end_positions.clamp_(0, ignored_index) | |||
| target1.clamp_(0, ignored_index) | |||
| target2.clamp_(0, ignored_index) | |||
| loss_fct = nn.CrossEntropyLoss(ignore_index=ignored_index) | |||
| start_loss = loss_fct(start_logits, start_positions) | |||
| end_loss = loss_fct(end_logits, end_positions) | |||
| start_loss = loss_fct(start_logits, target1) | |||
| end_loss = loss_fct(end_logits, target2) | |||
| total_loss = (start_loss + end_loss) / 2 | |||
| return {Const.OUTPUTS(0): start_logits, Const.OUTPUTS(1): end_logits, Const.LOSS: total_loss} | |||
| else: | |||
| return {Const.OUTPUTS(0): start_logits, Const.OUTPUTS(1): end_logits} | |||
| def predict(self, input_ids, token_type_ids=None, attention_mask=None, **kwargs): | |||
| logits = self.forward(input_ids, token_type_ids, attention_mask) | |||
| def predict(self, words, seq_len1=None, seq_len2=None): | |||
| logits = self.forward(words, seq_len1, seq_len2) | |||
| start_logits = logits[Const.OUTPUTS(0)] | |||
| end_logits = logits[Const.OUTPUTS(1)] | |||
| return {Const.OUTPUTS(0): torch.argmax(start_logits, dim=-1), | |||