diff --git a/modelscope/models/nlp/bert/text_ranking.py b/modelscope/models/nlp/bert/text_ranking.py index 79a63045..d6bbf277 100644 --- a/modelscope/models/nlp/bert/text_ranking.py +++ b/modelscope/models/nlp/bert/text_ranking.py @@ -18,14 +18,12 @@ logger = logging.get_logger(__name__) @MODELS.register_module(Tasks.text_ranking, module_name=Models.bert) class BertForTextRanking(BertForSequenceClassification): - def __init__(self, config, **kwargs): + def __init__(self, config, *args, **kwargs): super().__init__(config) - self.train_batch_size = kwargs.get('train_batch_size', 4) + neg_sample = kwargs.get('neg_sample', 8) + self.neg_sample = neg_sample setattr(self, self.base_model_prefix, BertModel(self.config, add_pooling_layer=True)) - self.register_buffer( - 'target_label', - torch.zeros(self.train_batch_size, dtype=torch.long)) def forward(self, input_ids=None, @@ -55,9 +53,12 @@ class BertForTextRanking(BertForSequenceClassification): pooled_output = self.dropout(pooled_output) logits = self.classifier(pooled_output) if self.base_model.training: - scores = logits.view(self.train_batch_size, -1) + scores = logits.view(-1, self.neg_sample + 1) + batch_size = scores.size(0) loss_fct = torch.nn.CrossEntropyLoss() - loss = loss_fct(scores, self.target_label) + target_label = torch.zeros( + batch_size, dtype=torch.long, device=scores.device) + loss = loss_fct(scores, target_label) return AttentionTextClassificationModelOutput( loss=loss, logits=logits, @@ -78,9 +79,11 @@ class BertForTextRanking(BertForSequenceClassification): Returns: The loaded model, which is initialized by transformers.PreTrainedModel.from_pretrained """ - num_labels = kwargs.get('num_labels', 1) + neg_sample = kwargs.get('neg_sample', 4) model_args = {} if num_labels is None else {'num_labels': num_labels} + if neg_sample is not None: + model_args['neg_sample'] = neg_sample model_dir = kwargs.get('model_dir') model = super(Model, cls).from_pretrained( diff --git a/modelscope/msdatasets/task_datasets/text_ranking_dataset.py b/modelscope/msdatasets/task_datasets/text_ranking_dataset.py index dd44f7c2..54276843 100644 --- a/modelscope/msdatasets/task_datasets/text_ranking_dataset.py +++ b/modelscope/msdatasets/task_datasets/text_ranking_dataset.py @@ -39,8 +39,7 @@ class TextRankingDataset(TorchTaskDataset): ['title', 'text']) self.qid_field = self.dataset_config.get('qid_field', 'query_id') if mode == ModeKeys.TRAIN: - train_config = kwargs.get('train', {}) - self.neg_samples = train_config.get('neg_samples', 4) + self.neg_samples = self.dataset_config.get('neg_sample', 4) super().__init__(datasets, mode, preprocessor, **kwargs) diff --git a/tests/trainers/test_finetune_text_ranking.py b/tests/trainers/test_finetune_text_ranking.py index 3561cb46..6e97310d 100644 --- a/tests/trainers/test_finetune_text_ranking.py +++ b/tests/trainers/test_finetune_text_ranking.py @@ -63,6 +63,7 @@ class TestFinetuneSequenceClassification(unittest.TestCase): def test_finetune_msmarco(self): def cfg_modify_fn(cfg): + neg_sample = 4 cfg.task = 'text-ranking' cfg['preprocessor'] = {'type': 'text-ranking'} cfg.train.optimizer.lr = 2e-5 @@ -73,7 +74,8 @@ class TestFinetuneSequenceClassification(unittest.TestCase): 'pos_sequence': 'positive_passages', 'neg_sequence': 'negative_passages', 'text_fileds': ['title', 'text'], - 'qid_field': 'query_id' + 'qid_field': 'query_id', + 'neg_sample': neg_sample }, 'val': { 'type': 'bert', @@ -84,7 +86,6 @@ class TestFinetuneSequenceClassification(unittest.TestCase): 'qid_field': 'query_id' }, } - cfg['train']['neg_samples'] = 4 cfg['evaluation']['dataloader']['batch_size_per_gpu'] = 30 cfg.train.max_epochs = 1 cfg.train.train_batch_size = 4 @@ -96,6 +97,7 @@ class TestFinetuneSequenceClassification(unittest.TestCase): 'by_epoch': False } } + cfg.model['neg_sample'] = 4 cfg.train.hooks = [{ 'type': 'CheckpointHook', 'interval': 1 @@ -151,7 +153,6 @@ class TestFinetuneSequenceClassification(unittest.TestCase): 'qid_field': 'query_id' }, } - cfg['train']['neg_samples'] = 4 cfg['evaluation']['dataloader']['batch_size_per_gpu'] = 30 cfg.train.max_epochs = 1 cfg.train.train_batch_size = 4 @@ -180,9 +181,8 @@ class TestFinetuneSequenceClassification(unittest.TestCase): # load dataset ds = MsDataset.load('dureader-retrieval-ranking', 'zyznull') - train_ds = ds['train'].to_hf_dataset() + train_ds = ds['train'].to_hf_dataset().shard(1000, index=0) dev_ds = ds['dev'].to_hf_dataset() - model_id = 'damo/nlp_rom_passage-ranking_chinese-base' self.finetune( model_id=model_id,