1. 将 single_gpu_test 与 multi_gpu_test 中的 model.forward 部分分离为 EpochBasedTrainer 中的 evaluation_step,为部分 evaluation 阶段不调用 forward 的模型提供更好的灵活性
2. 重构代码将文本生成模型 Model 层的输入输出统一为 Tensor,Tensor 到 str 的 decode 过程移动到 pipeline 中完成
3. pipeline 后处理添加对中文和中文标点与英文混杂时空格的处理,使 decode 后中英文混杂输出正确
4. 添加 TextGenerationTrainer 修复了部分模型 evaluation 过程 forward 输出单个 token 计算 metrics 的问题
5. 修复了 rouge 无法接收空字符串的问题
Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/10473768
master
| @@ -313,6 +313,7 @@ class Trainers(object): | |||||
| nlp_base_trainer = 'nlp-base-trainer' | nlp_base_trainer = 'nlp-base-trainer' | ||||
| nlp_veco_trainer = 'nlp-veco-trainer' | nlp_veco_trainer = 'nlp-veco-trainer' | ||||
| nlp_text_ranking_trainer = 'nlp-text-ranking-trainer' | nlp_text_ranking_trainer = 'nlp-text-ranking-trainer' | ||||
| text_generation_trainer = 'text-generation-trainer' | |||||
| # audio trainers | # audio trainers | ||||
| speech_frcrn_ans_cirm_16k = 'speech_frcrn_ans_cirm_16k' | speech_frcrn_ans_cirm_16k = 'speech_frcrn_ans_cirm_16k' | ||||
| @@ -36,20 +36,31 @@ class TextGenerationMetric(Metric): | |||||
| for char in string | for char in string | ||||
| ]).split()) | ]).split()) | ||||
| def add(self, outputs: Dict[str, List[str]], inputs: Dict = None): | |||||
| ground_truths = outputs['tgts'] | |||||
| def add(self, outputs: Dict[str, List[str]], inputs: Dict[str, List[str]]): | |||||
| ground_truths = inputs['tgts'] | |||||
| eval_results = outputs['preds'] | eval_results = outputs['preds'] | ||||
| for truth in ground_truths: | for truth in ground_truths: | ||||
| self.tgts.append(self.rebuild_str(truth)) | self.tgts.append(self.rebuild_str(truth)) | ||||
| for result in eval_results: | for result in eval_results: | ||||
| self.preds.append(self.rebuild_str(result)) | self.preds.append(self.rebuild_str(result)) | ||||
| def _check(self, pred: str, tgt: str) -> bool: | |||||
| def remove_useless(string: str) -> str: | |||||
| return string.replace(' ', '').replace('.', '') | |||||
| return remove_useless(pred) and remove_useless(tgt) | |||||
| def evaluate(self): | def evaluate(self): | ||||
| assert self.preds, 'preds in TextGenerationMetric must not be empty!' | |||||
| tmp = [(pred, tgt) for pred, tgt in zip(self.preds, self.tgts) | |||||
| if self._check(pred, tgt)] | |||||
| preds, tgts = zip(*tmp) | |||||
| def mean(iter: Iterable) -> float: | def mean(iter: Iterable) -> float: | ||||
| return sum(iter) / len(self.preds) | return sum(iter) / len(self.preds) | ||||
| rouge_scores = self.rouge.get_scores(hyps=self.preds, refs=self.tgts) | |||||
| rouge_scores = self.rouge.get_scores(hyps=preds, refs=tgts) | |||||
| rouge_1 = mean(map(lambda score: score['rouge-1']['f'], rouge_scores)) | rouge_1 = mean(map(lambda score: score['rouge-1']['f'], rouge_scores)) | ||||
| rouge_l = mean(map(lambda score: score['rouge-l']['f'], rouge_scores)) | rouge_l = mean(map(lambda score: score['rouge-l']['f'], rouge_scores)) | ||||
| pred_split = tuple(pred.split(' ') for pred in self.preds) | pred_split = tuple(pred.split(' ') for pred in self.preds) | ||||
| @@ -49,7 +49,7 @@ if TYPE_CHECKING: | |||||
| VecoForSequenceClassification, | VecoForSequenceClassification, | ||||
| VecoForTokenClassification, VecoModel, VecoTokenizer, | VecoForTokenClassification, VecoModel, VecoTokenizer, | ||||
| VecoTokenizerFast) | VecoTokenizerFast) | ||||
| from .bloom import BloomModel | |||||
| else: | else: | ||||
| _import_structure = { | _import_structure = { | ||||
| 'backbones': ['SbertModel'], | 'backbones': ['SbertModel'], | ||||
| @@ -107,6 +107,7 @@ else: | |||||
| 'sentence_embedding': ['SentenceEmbedding'], | 'sentence_embedding': ['SentenceEmbedding'], | ||||
| 'T5': ['T5ForConditionalGeneration'], | 'T5': ['T5ForConditionalGeneration'], | ||||
| 'gpt_neo': ['GPTNeoModel'], | 'gpt_neo': ['GPTNeoModel'], | ||||
| 'bloom': ['BloomModel'], | |||||
| } | } | ||||
| import sys | import sys | ||||
| @@ -0,0 +1,19 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||||
| from typing import TYPE_CHECKING | |||||
| from modelscope.utils.import_utils import LazyImportModule | |||||
| if TYPE_CHECKING: | |||||
| from .backbone import BloomModel | |||||
| else: | |||||
| _import_structure = { | |||||
| 'backbone': ['BloomModel'], | |||||
| } | |||||
| import sys | |||||
| sys.modules[__name__] = LazyImportModule( | |||||
| __name__, | |||||
| globals()['__file__'], | |||||
| _import_structure, | |||||
| module_spec=__spec__, | |||||
| extra_objects={}, | |||||
| ) | |||||
| @@ -4,10 +4,10 @@ from transformers import BloomModel as BloomModelTransform | |||||
| from modelscope.metainfo import Models | from modelscope.metainfo import Models | ||||
| from modelscope.models.builder import BACKBONES | from modelscope.models.builder import BACKBONES | ||||
| from modelscope.utils.constant import Fields | |||||
| from modelscope.utils.constant import Tasks | |||||
| @BACKBONES.register_module(group_key=Fields.nlp, module_name=Models.bloom) | |||||
| @BACKBONES.register_module(group_key=Tasks.backbone, module_name=Models.bloom) | |||||
| class BloomModel(BloomModelTransform): | class BloomModel(BloomModelTransform): | ||||
| def __init__(self, **kwargs): | def __init__(self, **kwargs): | ||||
| @@ -42,7 +42,7 @@ class GPT3ForTextGeneration(TorchModel): | |||||
| """ | """ | ||||
| return self.model(**input) | return self.model(**input) | ||||
| def generate(self, input: Dict[str, Tensor]) -> Dict[str, str]: | |||||
| def generate(self, input: Dict[str, Tensor]) -> Dict[str, Tensor]: | |||||
| assert 'input_ids' in input, "generate function must accept 'input_ids' key" | assert 'input_ids' in input, "generate function must accept 'input_ids' key" | ||||
| input_ids = input['input_ids'] | input_ids = input['input_ids'] | ||||
| if 'attention_mask' in input: | if 'attention_mask' in input: | ||||
| @@ -59,8 +59,4 @@ class GPT3ForTextGeneration(TorchModel): | |||||
| gen_params['top_k'] = input.pop('top_k', 10) | gen_params['top_k'] = input.pop('top_k', 10) | ||||
| gen_params['top_p'] = input.pop('top_p', None) | gen_params['top_p'] = input.pop('top_p', None) | ||||
| sample_output = self.model.generate(**gen_params) | sample_output = self.model.generate(**gen_params) | ||||
| return { | |||||
| OutputKeys.TEXT: | |||||
| self.tokenizer.decode(sample_output[0], | |||||
| skip_special_tokens=True).replace(' ', '') | |||||
| } | |||||
| return {'sequences': sample_output[0]} | |||||
| @@ -1314,8 +1314,8 @@ class Translator(object): | |||||
| return results | return results | ||||
| def __call__(self, input_ids: torch.Tensor, | |||||
| attention_mask: torch.Tensor) -> Dict[str, torch.Tensor]: | |||||
| def __call__(self, input_ids: torch.Tensor, attention_mask: torch.Tensor, | |||||
| **kwargs) -> Dict[str, torch.Tensor]: | |||||
| batch = self.Batch( | batch = self.Batch( | ||||
| batch_size=input_ids.size()[0], | batch_size=input_ids.size()[0], | ||||
| src=input_ids, | src=input_ids, | ||||
| @@ -29,22 +29,6 @@ class PalmForTextGeneration(TorchModel): | |||||
| self.tokenizer = self.model.tokenizer | self.tokenizer = self.model.tokenizer | ||||
| self.generator = Translator(self.model) | self.generator = Translator(self.model) | ||||
| def _evaluate_postprocess(self, ids_list: List[List[int]]) -> List[str]: | |||||
| replace_tokens_bert = (('[unused0]', ''), ('[PAD]', ''), ('[unused1]', | |||||
| ''), | |||||
| (r' +', ' '), ('[SEP]', ''), ('[unused2]', ''), | |||||
| ('[CLS]', ''), ('[UNK]', ''), (' ', '')) | |||||
| replace_tokens_roberta = ((r' +', ' '), ('<mask>', '. '), | |||||
| ('<pad>', ''), ('<s>', ''), ('</s>', ''), | |||||
| ('<unk>', ' '), ('<q>', '. ')) | |||||
| replace_tokens = replace_tokens_roberta \ | |||||
| if self.model.config.encoder == 'roberta' else replace_tokens_bert | |||||
| strings = [self.tokenizer.decode(pred_ids) for pred_ids in ids_list] | |||||
| for _old, _new in replace_tokens: | |||||
| strings = [s.replace(_old, _new) for s in strings] | |||||
| return strings | |||||
| def forward(self, input: Dict[str, Tensor]) -> Dict[str, Tensor]: | def forward(self, input: Dict[str, Tensor]) -> Dict[str, Tensor]: | ||||
| """return the result by the model | """return the result by the model | ||||
| @@ -57,29 +41,10 @@ class PalmForTextGeneration(TorchModel): | |||||
| { | { | ||||
| 'loss': Tensor([12.34]), # loss for backward | 'loss': Tensor([12.34]), # loss for backward | ||||
| } | } | ||||
| or | |||||
| { | |||||
| 'preds': List["hello word"...] # the predicted strings | |||||
| 'tgts': List["hello world"...] # target strings | |||||
| } | |||||
| """ | """ | ||||
| if self.training: | |||||
| return self.model(**input) | |||||
| else: | |||||
| outputs = self.generator(input['input_ids'], | |||||
| input['attention_mask']) | |||||
| preds = outputs['predictions'] | |||||
| pred_ids_list = [ | |||||
| pred_batch[0].cpu().numpy().tolist() for pred_batch in preds | |||||
| ] | |||||
| tgt_ids_list = input['labels'].cpu().numpy().tolist() | |||||
| return { | |||||
| 'preds': self._evaluate_postprocess(pred_ids_list), | |||||
| 'tgts': self._evaluate_postprocess(tgt_ids_list) | |||||
| } | |||||
| return self.model(**input) | |||||
| def generate(self, input: Dict[str, Tensor]) -> Dict[str, str]: | |||||
| def generate(self, input: Dict[str, Tensor]) -> Dict[str, Tensor]: | |||||
| outputs = self.generator(**input) | outputs = self.generator(**input) | ||||
| preds = outputs['predictions'] | preds = outputs['predictions'] | ||||
| pred_ids_list = [preds[0][0].cpu().numpy().tolist()] | |||||
| return {OutputKeys.TEXT: self._evaluate_postprocess(pred_ids_list)[0]} | |||||
| return {'sequences': [pred[0] for pred in preds]} | |||||
| @@ -53,7 +53,7 @@ class TextGenerationPipeline(Pipeline): | |||||
| model = model if isinstance(model, | model = model if isinstance(model, | ||||
| Model) else Model.from_pretrained(model) | Model) else Model.from_pretrained(model) | ||||
| cfg = read_config(model.model_dir) | cfg = read_config(model.model_dir) | ||||
| self.postprocessor = cfg.pop('postprocessor', None) | |||||
| self.postprocessor = cfg.pop('postprocessor', 'decode') | |||||
| if preprocessor is None: | if preprocessor is None: | ||||
| preprocessor_cfg = cfg.preprocessor | preprocessor_cfg = cfg.preprocessor | ||||
| preprocessor_cfg.update({ | preprocessor_cfg.update({ | ||||
| @@ -78,8 +78,37 @@ class TextGenerationPipeline(Pipeline): | |||||
| with torch.no_grad(): | with torch.no_grad(): | ||||
| return self.model.generate(inputs, **forward_params) | return self.model.generate(inputs, **forward_params) | ||||
| def sentence_piece(self, inputs) -> Dict[str, Tensor]: | |||||
| return self.preprocessor.tokenizer.decode(inputs.tolist()[0]) | |||||
| def _is_chinese_char(self, word: str): | |||||
| chinese_punctuations = (',', '。', ';', ':' '!', '?', '《', '》') | |||||
| return len(word) == 1 \ | |||||
| and ('\u4e00' <= word <= '\u9fa5' or word in chinese_punctuations) | |||||
| def _remove_space_between_chinese_chars(self, decoded: str): | |||||
| old_word_list = decoded.split(' ') | |||||
| new_word_list = [] | |||||
| start = -1 | |||||
| for i, word in enumerate(old_word_list): | |||||
| if self._is_chinese_char(word): | |||||
| if start == -1: | |||||
| start = i | |||||
| else: | |||||
| if start != -1: | |||||
| new_word_list.append(''.join(old_word_list[start:i])) | |||||
| start = -1 | |||||
| new_word_list.append(word) | |||||
| if start != -1: | |||||
| new_word_list.append(''.join(old_word_list[start:])) | |||||
| return ' '.join(new_word_list) | |||||
| def decode(self, inputs) -> str: | |||||
| tokenizer = self.preprocessor.tokenizer | |||||
| return tokenizer.decode(inputs.tolist(), skip_special_tokens=True) | |||||
| def roberta(self, inputs) -> str: | |||||
| tokenizer = self.preprocessor.tokenizer | |||||
| decoded = tokenizer.decode(inputs.tolist()) | |||||
| return decoded.replace('<q>', '. ').replace('<mask>', | |||||
| '. ').replace('</s>', '') | |||||
| def postprocess(self, inputs: Dict[str, Tensor], | def postprocess(self, inputs: Dict[str, Tensor], | ||||
| **postprocess_params) -> Dict[str, str]: | **postprocess_params) -> Dict[str, str]: | ||||
| @@ -91,7 +120,9 @@ class TextGenerationPipeline(Pipeline): | |||||
| Returns: | Returns: | ||||
| Dict[str, str]: the prediction results | Dict[str, str]: the prediction results | ||||
| """ | """ | ||||
| return inputs if self.postprocessor is None else { | |||||
| OutputKeys.TEXT: | |||||
| getattr(self, self.postprocessor.replace('-', '_'))(inputs) | |||||
| } | |||||
| inputs = inputs['sequences'] | |||||
| if isinstance(inputs, list): | |||||
| inputs = inputs[0] | |||||
| decoded = getattr(self, self.postprocessor)(inputs) | |||||
| text = self._remove_space_between_chinese_chars(decoded) | |||||
| return {OutputKeys.TEXT: text} | |||||
| @@ -7,11 +7,13 @@ if TYPE_CHECKING: | |||||
| from .sequence_classification_trainer import SequenceClassificationTrainer | from .sequence_classification_trainer import SequenceClassificationTrainer | ||||
| from .csanmt_translation_trainer import CsanmtTranslationTrainer | from .csanmt_translation_trainer import CsanmtTranslationTrainer | ||||
| from .text_ranking_trainer import TextRankingTrainer | from .text_ranking_trainer import TextRankingTrainer | ||||
| from .text_generation_trainer import TextGenerationTrainer | |||||
| else: | else: | ||||
| _import_structure = { | _import_structure = { | ||||
| 'sequence_classification_trainer': ['SequenceClassificationTrainer'], | 'sequence_classification_trainer': ['SequenceClassificationTrainer'], | ||||
| 'csanmt_translation_trainer': ['CsanmtTranslationTrainer'], | 'csanmt_translation_trainer': ['CsanmtTranslationTrainer'], | ||||
| 'text_ranking_trainer': ['TextRankingTrainer'] | |||||
| 'text_ranking_trainer': ['TextRankingTrainer'], | |||||
| 'text_generation_trainer': ['TextGenerationTrainer'], | |||||
| } | } | ||||
| import sys | import sys | ||||
| @@ -0,0 +1,36 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||||
| from collections.abc import Mapping | |||||
| import torch | |||||
| from modelscope.metainfo import Trainers | |||||
| from modelscope.trainers import NlpEpochBasedTrainer | |||||
| from modelscope.trainers.builder import TRAINERS | |||||
| from modelscope.utils.file_utils import func_receive_dict_inputs | |||||
| @TRAINERS.register_module(module_name=Trainers.text_generation_trainer) | |||||
| class TextGenerationTrainer(NlpEpochBasedTrainer): | |||||
| def _decode(self, tokens): | |||||
| tokenizer = self.eval_preprocessor.tokenizer | |||||
| return tokenizer.decode(tokens.tolist(), skip_special_tokens=True) | |||||
| def evaluation_step(self, data): | |||||
| model = self.model | |||||
| model.eval() | |||||
| with torch.no_grad(): | |||||
| if isinstance( | |||||
| data, | |||||
| Mapping) and not func_receive_dict_inputs(model.generate): | |||||
| result = model.generate(**data) | |||||
| else: | |||||
| result = model.generate(data) | |||||
| result['preds'] = [self._decode(seq) for seq in result['sequences']] | |||||
| data['tgts'] = [self._decode(seq) for seq in data['labels']] | |||||
| assert len(result['preds']) == len(data['tgts']) | |||||
| return result | |||||
| @@ -855,6 +855,28 @@ class EpochBasedTrainer(BaseTrainer): | |||||
| self.invoke_hook(TrainerStages.after_run) | self.invoke_hook(TrainerStages.after_run) | ||||
| def evaluation_step(self, data): | |||||
| """Perform a training step on a batch of inputs. | |||||
| Subclass and override to inject custom behavior. | |||||
| """ | |||||
| model = self.model | |||||
| model.eval() | |||||
| if is_parallel(model): | |||||
| receive_dict_inputs = func_receive_dict_inputs( | |||||
| model.module.forward) | |||||
| else: | |||||
| receive_dict_inputs = func_receive_dict_inputs(model.forward) | |||||
| with torch.no_grad(): | |||||
| if isinstance(data, Mapping) and not receive_dict_inputs: | |||||
| result = model.forward(**data) | |||||
| else: | |||||
| result = model.forward(data) | |||||
| return result | |||||
| def evaluation_loop(self, data_loader, metric_classes): | def evaluation_loop(self, data_loader, metric_classes): | ||||
| """ Evaluation loop used by `EpochBasedTrainer.evaluate()`. | """ Evaluation loop used by `EpochBasedTrainer.evaluate()`. | ||||
| @@ -862,7 +884,7 @@ class EpochBasedTrainer(BaseTrainer): | |||||
| if self._dist: | if self._dist: | ||||
| from modelscope.trainers.utils.inference import multi_gpu_test | from modelscope.trainers.utils.inference import multi_gpu_test | ||||
| metric_values = multi_gpu_test( | metric_values = multi_gpu_test( | ||||
| self.model, | |||||
| self, | |||||
| data_loader, | data_loader, | ||||
| device=self.device, | device=self.device, | ||||
| tmpdir=None, | tmpdir=None, | ||||
| @@ -872,7 +894,7 @@ class EpochBasedTrainer(BaseTrainer): | |||||
| else: | else: | ||||
| from modelscope.trainers.utils.inference import single_gpu_test | from modelscope.trainers.utils.inference import single_gpu_test | ||||
| metric_values = single_gpu_test( | metric_values = single_gpu_test( | ||||
| self.model, | |||||
| self, | |||||
| data_loader, | data_loader, | ||||
| device=self.device, | device=self.device, | ||||
| metric_classes=metric_classes, | metric_classes=metric_classes, | ||||
| @@ -4,29 +4,25 @@ import logging | |||||
| import os | import os | ||||
| import pickle | import pickle | ||||
| import shutil | import shutil | ||||
| import time | |||||
| from collections.abc import Mapping | |||||
| import torch | import torch | ||||
| from torch import distributed as dist | from torch import distributed as dist | ||||
| from tqdm import tqdm | from tqdm import tqdm | ||||
| from modelscope.trainers.parallel.utils import is_parallel | |||||
| from modelscope.utils.data_utils import to_device | from modelscope.utils.data_utils import to_device | ||||
| from modelscope.utils.file_utils import func_receive_dict_inputs | |||||
| from modelscope.utils.torch_utils import (broadcast, get_dist_info, is_master, | from modelscope.utils.torch_utils import (broadcast, get_dist_info, is_master, | ||||
| make_tmp_dir) | make_tmp_dir) | ||||
| def single_gpu_test(model, | |||||
| def single_gpu_test(trainer, | |||||
| data_loader, | data_loader, | ||||
| device, | device, | ||||
| metric_classes=None, | metric_classes=None, | ||||
| data_loader_iters=None): | data_loader_iters=None): | ||||
| """Test model with a single gpu. | |||||
| """Test model in EpochBasedTrainer with a single gpu. | |||||
| Args: | Args: | ||||
| model (nn.Module): Model to be tested. | |||||
| trainer (modelscope.trainers.EpochBasedTrainer): Trainer to be tested. | |||||
| data_loader (nn.Dataloader): Pytorch data loader. | data_loader (nn.Dataloader): Pytorch data loader. | ||||
| device (str | torch.device): The target device for the data. | device (str | torch.device): The target device for the data. | ||||
| metric_classes (List): List of Metric class that uses to collect metrics | metric_classes (List): List of Metric class that uses to collect metrics | ||||
| @@ -35,7 +31,6 @@ def single_gpu_test(model, | |||||
| Returns: | Returns: | ||||
| list: The prediction results. | list: The prediction results. | ||||
| """ | """ | ||||
| model.eval() | |||||
| dataset = data_loader.dataset | dataset = data_loader.dataset | ||||
| progress_with_iters = False | progress_with_iters = False | ||||
| if data_loader_iters is None: | if data_loader_iters is None: | ||||
| @@ -55,12 +50,7 @@ def single_gpu_test(model, | |||||
| with tqdm(total=data_len, desc=desc) as pbar: | with tqdm(total=data_len, desc=desc) as pbar: | ||||
| for i, data in enumerate(data_loader): | for i, data in enumerate(data_loader): | ||||
| data = to_device(data, device) | data = to_device(data, device) | ||||
| with torch.no_grad(): | |||||
| if isinstance(data, Mapping) and not func_receive_dict_inputs( | |||||
| model.forward): | |||||
| result = model.forward(**data) | |||||
| else: | |||||
| result = model.forward(data) | |||||
| result = trainer.evaluation_step(data) | |||||
| if metric_classes is not None: | if metric_classes is not None: | ||||
| for metric_cls in metric_classes: | for metric_cls in metric_classes: | ||||
| metric_cls.add(result, data) | metric_cls.add(result, data) | ||||
| @@ -88,14 +78,14 @@ def single_gpu_test(model, | |||||
| return metric_values | return metric_values | ||||
| def multi_gpu_test(model, | |||||
| def multi_gpu_test(trainer, | |||||
| data_loader, | data_loader, | ||||
| device, | device, | ||||
| tmpdir=None, | tmpdir=None, | ||||
| gpu_collect=False, | gpu_collect=False, | ||||
| metric_classes=None, | metric_classes=None, | ||||
| data_loader_iters_per_gpu=None): | data_loader_iters_per_gpu=None): | ||||
| """Test model with multiple gpus. | |||||
| """Test model in EpochBasedTrainer with multiple gpus. | |||||
| This method tests model with multiple gpus and collects the results | This method tests model with multiple gpus and collects the results | ||||
| under two different modes: gpu and cpu modes. By setting | under two different modes: gpu and cpu modes. By setting | ||||
| @@ -104,7 +94,7 @@ def multi_gpu_test(model, | |||||
| different gpus to ``tmpdir`` and collects them by the rank 0 worker. | different gpus to ``tmpdir`` and collects them by the rank 0 worker. | ||||
| Args: | Args: | ||||
| model (nn.Module): Model to be tested. | |||||
| trainer (modelscope.trainers.EpochBasedTrainer): Trainer to be tested. | |||||
| data_loader (nn.Dataloader): Pytorch data loader. | data_loader (nn.Dataloader): Pytorch data loader. | ||||
| device: (str | torch.device): The target device for the data. | device: (str | torch.device): The target device for the data. | ||||
| tmpdir (str): Path of directory to save the temporary results from | tmpdir (str): Path of directory to save the temporary results from | ||||
| @@ -115,7 +105,6 @@ def multi_gpu_test(model, | |||||
| Returns: | Returns: | ||||
| list: The prediction results. | list: The prediction results. | ||||
| """ | """ | ||||
| model.eval() | |||||
| results = [] | results = [] | ||||
| data_list = [] | data_list = [] | ||||
| dataset = data_loader.dataset | dataset = data_loader.dataset | ||||
| @@ -138,21 +127,12 @@ def multi_gpu_test(model, | |||||
| data_len = data_loader_iters_per_gpu * world_size | data_len = data_loader_iters_per_gpu * world_size | ||||
| desc = 'Total test iterations with multi gpus' | desc = 'Total test iterations with multi gpus' | ||||
| if is_parallel(model): | |||||
| receive_dict_inputs = func_receive_dict_inputs(model.module.forward) | |||||
| else: | |||||
| receive_dict_inputs = func_receive_dict_inputs(model.forward) | |||||
| count = 0 | count = 0 | ||||
| with tqdm(total=data_len, desc=desc) as pbar: | with tqdm(total=data_len, desc=desc) as pbar: | ||||
| for i, data in enumerate(data_loader): | for i, data in enumerate(data_loader): | ||||
| data = to_device(data, device) | data = to_device(data, device) | ||||
| data_list.append(data) | data_list.append(data) | ||||
| with torch.no_grad(): | |||||
| if isinstance(data, Mapping) and not receive_dict_inputs: | |||||
| result = model.forward(**data) | |||||
| else: | |||||
| result = model.forward(data) | |||||
| result = trainer.evaluation_step(data) | |||||
| results.append(result) | results.append(result) | ||||
| if isinstance(data, dict): | if isinstance(data, dict): | ||||
| @@ -59,7 +59,7 @@ class TestFinetuneTextGeneration(unittest.TestCase): | |||||
| work_dir=self.tmp_dir) | work_dir=self.tmp_dir) | ||||
| trainer = build_trainer( | trainer = build_trainer( | ||||
| name=Trainers.nlp_base_trainer, default_args=kwargs) | |||||
| name=Trainers.text_generation_trainer, default_args=kwargs) | |||||
| trainer.train() | trainer.train() | ||||
| results_files = os.listdir(self.tmp_dir) | results_files = os.listdir(self.tmp_dir) | ||||
| self.assertIn(f'{trainer.timestamp}.log.json', results_files) | self.assertIn(f'{trainer.timestamp}.log.json', results_files) | ||||
| @@ -98,7 +98,7 @@ class TestFinetuneTextGeneration(unittest.TestCase): | |||||
| work_dir=self.tmp_dir) | work_dir=self.tmp_dir) | ||||
| trainer = build_trainer( | trainer = build_trainer( | ||||
| name=Trainers.nlp_base_trainer, default_args=kwargs) | |||||
| name=Trainers.text_generation_trainer, default_args=kwargs) | |||||
| trainer.train() | trainer.train() | ||||
| results_files = os.listdir(self.tmp_dir) | results_files = os.listdir(self.tmp_dir) | ||||
| self.assertIn(f'{trainer.timestamp}.log.json', results_files) | self.assertIn(f'{trainer.timestamp}.log.json', results_files) | ||||
| @@ -12,6 +12,7 @@ from modelscope.metrics.builder import MetricKeys | |||||
| from modelscope.metrics.sequence_classification_metric import \ | from modelscope.metrics.sequence_classification_metric import \ | ||||
| SequenceClassificationMetric | SequenceClassificationMetric | ||||
| from modelscope.models.base import Model | from modelscope.models.base import Model | ||||
| from modelscope.trainers import EpochBasedTrainer | |||||
| from modelscope.trainers.utils.inference import multi_gpu_test, single_gpu_test | from modelscope.trainers.utils.inference import multi_gpu_test, single_gpu_test | ||||
| from modelscope.utils.test_utils import (DistributedTestCase, | from modelscope.utils.test_utils import (DistributedTestCase, | ||||
| create_dummy_test_dataset, test_level) | create_dummy_test_dataset, test_level) | ||||
| @@ -36,6 +37,12 @@ class DummyModel(nn.Module, Model): | |||||
| return dict(logits=x, loss=loss) | return dict(logits=x, loss=loss) | ||||
| class DummyTrainer(EpochBasedTrainer): | |||||
| def __init__(self, model): | |||||
| self.model = model | |||||
| def test_func(dist=False): | def test_func(dist=False): | ||||
| dummy_model = DummyModel() | dummy_model = DummyModel() | ||||
| dataset = dummy_dataset.to_torch_dataset() | dataset = dummy_dataset.to_torch_dataset() | ||||
| @@ -62,8 +69,10 @@ def test_func(dist=False): | |||||
| else: | else: | ||||
| test_func = single_gpu_test | test_func = single_gpu_test | ||||
| dummy_trainer = DummyTrainer(dummy_model) | |||||
| metric_results = test_func( | metric_results = test_func( | ||||
| dummy_model, | |||||
| dummy_trainer, | |||||
| dummy_loader, | dummy_loader, | ||||
| device=device, | device=device, | ||||
| metric_classes=[metric_class]) | metric_classes=[metric_class]) | ||||