之前的finetune代码当dataset最后长度不足制定batch size时会出错,现已修正
Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/10524066
master
| @@ -18,14 +18,12 @@ logger = logging.get_logger(__name__) | |||
| @MODELS.register_module(Tasks.text_ranking, module_name=Models.bert) | |||
| class BertForTextRanking(BertForSequenceClassification): | |||
| def __init__(self, config, **kwargs): | |||
| def __init__(self, config, *args, **kwargs): | |||
| super().__init__(config) | |||
| self.train_batch_size = kwargs.get('train_batch_size', 4) | |||
| neg_sample = kwargs.get('neg_sample', 8) | |||
| self.neg_sample = neg_sample | |||
| setattr(self, self.base_model_prefix, | |||
| BertModel(self.config, add_pooling_layer=True)) | |||
| self.register_buffer( | |||
| 'target_label', | |||
| torch.zeros(self.train_batch_size, dtype=torch.long)) | |||
| def forward(self, | |||
| input_ids=None, | |||
| @@ -55,9 +53,12 @@ class BertForTextRanking(BertForSequenceClassification): | |||
| pooled_output = self.dropout(pooled_output) | |||
| logits = self.classifier(pooled_output) | |||
| if self.base_model.training: | |||
| scores = logits.view(self.train_batch_size, -1) | |||
| scores = logits.view(-1, self.neg_sample + 1) | |||
| batch_size = scores.size(0) | |||
| loss_fct = torch.nn.CrossEntropyLoss() | |||
| loss = loss_fct(scores, self.target_label) | |||
| target_label = torch.zeros( | |||
| batch_size, dtype=torch.long, device=scores.device) | |||
| loss = loss_fct(scores, target_label) | |||
| return AttentionTextClassificationModelOutput( | |||
| loss=loss, | |||
| logits=logits, | |||
| @@ -78,9 +79,11 @@ class BertForTextRanking(BertForSequenceClassification): | |||
| Returns: | |||
| The loaded model, which is initialized by transformers.PreTrainedModel.from_pretrained | |||
| """ | |||
| num_labels = kwargs.get('num_labels', 1) | |||
| neg_sample = kwargs.get('neg_sample', 4) | |||
| model_args = {} if num_labels is None else {'num_labels': num_labels} | |||
| if neg_sample is not None: | |||
| model_args['neg_sample'] = neg_sample | |||
| model_dir = kwargs.get('model_dir') | |||
| model = super(Model, cls).from_pretrained( | |||
| @@ -39,8 +39,7 @@ class TextRankingDataset(TorchTaskDataset): | |||
| ['title', 'text']) | |||
| self.qid_field = self.dataset_config.get('qid_field', 'query_id') | |||
| if mode == ModeKeys.TRAIN: | |||
| train_config = kwargs.get('train', {}) | |||
| self.neg_samples = train_config.get('neg_samples', 4) | |||
| self.neg_samples = self.dataset_config.get('neg_sample', 4) | |||
| super().__init__(datasets, mode, preprocessor, **kwargs) | |||
| @@ -63,6 +63,7 @@ class TestFinetuneSequenceClassification(unittest.TestCase): | |||
| def test_finetune_msmarco(self): | |||
| def cfg_modify_fn(cfg): | |||
| neg_sample = 4 | |||
| cfg.task = 'text-ranking' | |||
| cfg['preprocessor'] = {'type': 'text-ranking'} | |||
| cfg.train.optimizer.lr = 2e-5 | |||
| @@ -73,7 +74,8 @@ class TestFinetuneSequenceClassification(unittest.TestCase): | |||
| 'pos_sequence': 'positive_passages', | |||
| 'neg_sequence': 'negative_passages', | |||
| 'text_fileds': ['title', 'text'], | |||
| 'qid_field': 'query_id' | |||
| 'qid_field': 'query_id', | |||
| 'neg_sample': neg_sample | |||
| }, | |||
| 'val': { | |||
| 'type': 'bert', | |||
| @@ -84,7 +86,6 @@ class TestFinetuneSequenceClassification(unittest.TestCase): | |||
| 'qid_field': 'query_id' | |||
| }, | |||
| } | |||
| cfg['train']['neg_samples'] = 4 | |||
| cfg['evaluation']['dataloader']['batch_size_per_gpu'] = 30 | |||
| cfg.train.max_epochs = 1 | |||
| cfg.train.train_batch_size = 4 | |||
| @@ -96,6 +97,7 @@ class TestFinetuneSequenceClassification(unittest.TestCase): | |||
| 'by_epoch': False | |||
| } | |||
| } | |||
| cfg.model['neg_sample'] = 4 | |||
| cfg.train.hooks = [{ | |||
| 'type': 'CheckpointHook', | |||
| 'interval': 1 | |||
| @@ -151,7 +153,6 @@ class TestFinetuneSequenceClassification(unittest.TestCase): | |||
| 'qid_field': 'query_id' | |||
| }, | |||
| } | |||
| cfg['train']['neg_samples'] = 4 | |||
| cfg['evaluation']['dataloader']['batch_size_per_gpu'] = 30 | |||
| cfg.train.max_epochs = 1 | |||
| cfg.train.train_batch_size = 4 | |||
| @@ -180,9 +181,8 @@ class TestFinetuneSequenceClassification(unittest.TestCase): | |||
| # load dataset | |||
| ds = MsDataset.load('dureader-retrieval-ranking', 'zyznull') | |||
| train_ds = ds['train'].to_hf_dataset() | |||
| train_ds = ds['train'].to_hf_dataset().shard(1000, index=0) | |||
| dev_ds = ds['dev'].to_hf_dataset() | |||
| model_id = 'damo/nlp_rom_passage-ranking_chinese-base' | |||
| self.finetune( | |||
| model_id=model_id, | |||