之前的finetune代码当dataset最后长度不足制定batch size时会出错,现已修正
Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/10524066
master
| @@ -18,14 +18,12 @@ logger = logging.get_logger(__name__) | |||||
| @MODELS.register_module(Tasks.text_ranking, module_name=Models.bert) | @MODELS.register_module(Tasks.text_ranking, module_name=Models.bert) | ||||
| class BertForTextRanking(BertForSequenceClassification): | class BertForTextRanking(BertForSequenceClassification): | ||||
| def __init__(self, config, **kwargs): | |||||
| def __init__(self, config, *args, **kwargs): | |||||
| super().__init__(config) | super().__init__(config) | ||||
| self.train_batch_size = kwargs.get('train_batch_size', 4) | |||||
| neg_sample = kwargs.get('neg_sample', 8) | |||||
| self.neg_sample = neg_sample | |||||
| setattr(self, self.base_model_prefix, | setattr(self, self.base_model_prefix, | ||||
| BertModel(self.config, add_pooling_layer=True)) | BertModel(self.config, add_pooling_layer=True)) | ||||
| self.register_buffer( | |||||
| 'target_label', | |||||
| torch.zeros(self.train_batch_size, dtype=torch.long)) | |||||
| def forward(self, | def forward(self, | ||||
| input_ids=None, | input_ids=None, | ||||
| @@ -55,9 +53,12 @@ class BertForTextRanking(BertForSequenceClassification): | |||||
| pooled_output = self.dropout(pooled_output) | pooled_output = self.dropout(pooled_output) | ||||
| logits = self.classifier(pooled_output) | logits = self.classifier(pooled_output) | ||||
| if self.base_model.training: | if self.base_model.training: | ||||
| scores = logits.view(self.train_batch_size, -1) | |||||
| scores = logits.view(-1, self.neg_sample + 1) | |||||
| batch_size = scores.size(0) | |||||
| loss_fct = torch.nn.CrossEntropyLoss() | loss_fct = torch.nn.CrossEntropyLoss() | ||||
| loss = loss_fct(scores, self.target_label) | |||||
| target_label = torch.zeros( | |||||
| batch_size, dtype=torch.long, device=scores.device) | |||||
| loss = loss_fct(scores, target_label) | |||||
| return AttentionTextClassificationModelOutput( | return AttentionTextClassificationModelOutput( | ||||
| loss=loss, | loss=loss, | ||||
| logits=logits, | logits=logits, | ||||
| @@ -78,9 +79,11 @@ class BertForTextRanking(BertForSequenceClassification): | |||||
| Returns: | Returns: | ||||
| The loaded model, which is initialized by transformers.PreTrainedModel.from_pretrained | The loaded model, which is initialized by transformers.PreTrainedModel.from_pretrained | ||||
| """ | """ | ||||
| num_labels = kwargs.get('num_labels', 1) | num_labels = kwargs.get('num_labels', 1) | ||||
| neg_sample = kwargs.get('neg_sample', 4) | |||||
| model_args = {} if num_labels is None else {'num_labels': num_labels} | model_args = {} if num_labels is None else {'num_labels': num_labels} | ||||
| if neg_sample is not None: | |||||
| model_args['neg_sample'] = neg_sample | |||||
| model_dir = kwargs.get('model_dir') | model_dir = kwargs.get('model_dir') | ||||
| model = super(Model, cls).from_pretrained( | model = super(Model, cls).from_pretrained( | ||||
| @@ -39,8 +39,7 @@ class TextRankingDataset(TorchTaskDataset): | |||||
| ['title', 'text']) | ['title', 'text']) | ||||
| self.qid_field = self.dataset_config.get('qid_field', 'query_id') | self.qid_field = self.dataset_config.get('qid_field', 'query_id') | ||||
| if mode == ModeKeys.TRAIN: | if mode == ModeKeys.TRAIN: | ||||
| train_config = kwargs.get('train', {}) | |||||
| self.neg_samples = train_config.get('neg_samples', 4) | |||||
| self.neg_samples = self.dataset_config.get('neg_sample', 4) | |||||
| super().__init__(datasets, mode, preprocessor, **kwargs) | super().__init__(datasets, mode, preprocessor, **kwargs) | ||||
| @@ -63,6 +63,7 @@ class TestFinetuneSequenceClassification(unittest.TestCase): | |||||
| def test_finetune_msmarco(self): | def test_finetune_msmarco(self): | ||||
| def cfg_modify_fn(cfg): | def cfg_modify_fn(cfg): | ||||
| neg_sample = 4 | |||||
| cfg.task = 'text-ranking' | cfg.task = 'text-ranking' | ||||
| cfg['preprocessor'] = {'type': 'text-ranking'} | cfg['preprocessor'] = {'type': 'text-ranking'} | ||||
| cfg.train.optimizer.lr = 2e-5 | cfg.train.optimizer.lr = 2e-5 | ||||
| @@ -73,7 +74,8 @@ class TestFinetuneSequenceClassification(unittest.TestCase): | |||||
| 'pos_sequence': 'positive_passages', | 'pos_sequence': 'positive_passages', | ||||
| 'neg_sequence': 'negative_passages', | 'neg_sequence': 'negative_passages', | ||||
| 'text_fileds': ['title', 'text'], | 'text_fileds': ['title', 'text'], | ||||
| 'qid_field': 'query_id' | |||||
| 'qid_field': 'query_id', | |||||
| 'neg_sample': neg_sample | |||||
| }, | }, | ||||
| 'val': { | 'val': { | ||||
| 'type': 'bert', | 'type': 'bert', | ||||
| @@ -84,7 +86,6 @@ class TestFinetuneSequenceClassification(unittest.TestCase): | |||||
| 'qid_field': 'query_id' | 'qid_field': 'query_id' | ||||
| }, | }, | ||||
| } | } | ||||
| cfg['train']['neg_samples'] = 4 | |||||
| cfg['evaluation']['dataloader']['batch_size_per_gpu'] = 30 | cfg['evaluation']['dataloader']['batch_size_per_gpu'] = 30 | ||||
| cfg.train.max_epochs = 1 | cfg.train.max_epochs = 1 | ||||
| cfg.train.train_batch_size = 4 | cfg.train.train_batch_size = 4 | ||||
| @@ -96,6 +97,7 @@ class TestFinetuneSequenceClassification(unittest.TestCase): | |||||
| 'by_epoch': False | 'by_epoch': False | ||||
| } | } | ||||
| } | } | ||||
| cfg.model['neg_sample'] = 4 | |||||
| cfg.train.hooks = [{ | cfg.train.hooks = [{ | ||||
| 'type': 'CheckpointHook', | 'type': 'CheckpointHook', | ||||
| 'interval': 1 | 'interval': 1 | ||||
| @@ -151,7 +153,6 @@ class TestFinetuneSequenceClassification(unittest.TestCase): | |||||
| 'qid_field': 'query_id' | 'qid_field': 'query_id' | ||||
| }, | }, | ||||
| } | } | ||||
| cfg['train']['neg_samples'] = 4 | |||||
| cfg['evaluation']['dataloader']['batch_size_per_gpu'] = 30 | cfg['evaluation']['dataloader']['batch_size_per_gpu'] = 30 | ||||
| cfg.train.max_epochs = 1 | cfg.train.max_epochs = 1 | ||||
| cfg.train.train_batch_size = 4 | cfg.train.train_batch_size = 4 | ||||
| @@ -180,9 +181,8 @@ class TestFinetuneSequenceClassification(unittest.TestCase): | |||||
| # load dataset | # load dataset | ||||
| ds = MsDataset.load('dureader-retrieval-ranking', 'zyznull') | ds = MsDataset.load('dureader-retrieval-ranking', 'zyznull') | ||||
| train_ds = ds['train'].to_hf_dataset() | |||||
| train_ds = ds['train'].to_hf_dataset().shard(1000, index=0) | |||||
| dev_ds = ds['dev'].to_hf_dataset() | dev_ds = ds['dev'].to_hf_dataset() | ||||
| model_id = 'damo/nlp_rom_passage-ranking_chinese-base' | model_id = 'damo/nlp_rom_passage-ranking_chinese-base' | ||||
| self.finetune( | self.finetune( | ||||
| model_id=model_id, | model_id=model_id, | ||||