Browse Source

[to #42322933] Refactor text generation model outputs and fix some bugs

1. 将 single_gpu_test 与 multi_gpu_test 中的 model.forward 部分分离为 EpochBasedTrainer 中的 evaluation_step,为部分 evaluation 阶段不调用 forward 的模型提供更好的灵活性
2. 重构代码将文本生成模型 Model 层的输入输出统一为 Tensor,Tensor 到 str 的 decode 过程移动到 pipeline 中完成
3. pipeline 后处理添加对中文和中文标点与英文混杂时空格的处理,使 decode 后中英文混杂输出正确
4. 添加 TextGenerationTrainer 修复了部分模型 evaluation 过程 forward 输出单个 token 计算 metrics 的问题
5. 修复了 rouge 无法接收空字符串的问题
        Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/10473768
master
hemu.zp yingda.chen 3 years ago
parent
commit
69104c0f8a
15 changed files with 166 additions and 93 deletions
  1. +1
    -0
      modelscope/metainfo.py
  2. +14
    -3
      modelscope/metrics/text_generation_metric.py
  3. +2
    -1
      modelscope/models/nlp/__init__.py
  4. +19
    -0
      modelscope/models/nlp/bloom/__init__.py
  5. +2
    -2
      modelscope/models/nlp/bloom/backbone.py
  6. +2
    -6
      modelscope/models/nlp/gpt3/text_generation.py
  7. +2
    -2
      modelscope/models/nlp/palm_v2/backbone.py
  8. +3
    -38
      modelscope/models/nlp/palm_v2/text_generation.py
  9. +38
    -7
      modelscope/pipelines/nlp/text_generation_pipeline.py
  10. +3
    -1
      modelscope/trainers/nlp/__init__.py
  11. +36
    -0
      modelscope/trainers/nlp/text_generation_trainer.py
  12. +24
    -2
      modelscope/trainers/trainer.py
  13. +8
    -28
      modelscope/trainers/utils/inference.py
  14. +2
    -2
      tests/trainers/test_finetune_text_generation.py
  15. +10
    -1
      tests/trainers/utils/test_inference.py

+ 1
- 0
modelscope/metainfo.py View File

@@ -313,6 +313,7 @@ class Trainers(object):
nlp_base_trainer = 'nlp-base-trainer' nlp_base_trainer = 'nlp-base-trainer'
nlp_veco_trainer = 'nlp-veco-trainer' nlp_veco_trainer = 'nlp-veco-trainer'
nlp_text_ranking_trainer = 'nlp-text-ranking-trainer' nlp_text_ranking_trainer = 'nlp-text-ranking-trainer'
text_generation_trainer = 'text-generation-trainer'


# audio trainers # audio trainers
speech_frcrn_ans_cirm_16k = 'speech_frcrn_ans_cirm_16k' speech_frcrn_ans_cirm_16k = 'speech_frcrn_ans_cirm_16k'


+ 14
- 3
modelscope/metrics/text_generation_metric.py View File

@@ -36,20 +36,31 @@ class TextGenerationMetric(Metric):
for char in string for char in string
]).split()) ]).split())


def add(self, outputs: Dict[str, List[str]], inputs: Dict = None):
ground_truths = outputs['tgts']
def add(self, outputs: Dict[str, List[str]], inputs: Dict[str, List[str]]):
ground_truths = inputs['tgts']
eval_results = outputs['preds'] eval_results = outputs['preds']
for truth in ground_truths: for truth in ground_truths:
self.tgts.append(self.rebuild_str(truth)) self.tgts.append(self.rebuild_str(truth))
for result in eval_results: for result in eval_results:
self.preds.append(self.rebuild_str(result)) self.preds.append(self.rebuild_str(result))


def _check(self, pred: str, tgt: str) -> bool:

def remove_useless(string: str) -> str:
return string.replace(' ', '').replace('.', '')

return remove_useless(pred) and remove_useless(tgt)

def evaluate(self): def evaluate(self):
assert self.preds, 'preds in TextGenerationMetric must not be empty!'
tmp = [(pred, tgt) for pred, tgt in zip(self.preds, self.tgts)
if self._check(pred, tgt)]
preds, tgts = zip(*tmp)


def mean(iter: Iterable) -> float: def mean(iter: Iterable) -> float:
return sum(iter) / len(self.preds) return sum(iter) / len(self.preds)


rouge_scores = self.rouge.get_scores(hyps=self.preds, refs=self.tgts)
rouge_scores = self.rouge.get_scores(hyps=preds, refs=tgts)
rouge_1 = mean(map(lambda score: score['rouge-1']['f'], rouge_scores)) rouge_1 = mean(map(lambda score: score['rouge-1']['f'], rouge_scores))
rouge_l = mean(map(lambda score: score['rouge-l']['f'], rouge_scores)) rouge_l = mean(map(lambda score: score['rouge-l']['f'], rouge_scores))
pred_split = tuple(pred.split(' ') for pred in self.preds) pred_split = tuple(pred.split(' ') for pred in self.preds)


+ 2
- 1
modelscope/models/nlp/__init__.py View File

@@ -49,7 +49,7 @@ if TYPE_CHECKING:
VecoForSequenceClassification, VecoForSequenceClassification,
VecoForTokenClassification, VecoModel, VecoTokenizer, VecoForTokenClassification, VecoModel, VecoTokenizer,
VecoTokenizerFast) VecoTokenizerFast)
from .bloom import BloomModel
else: else:
_import_structure = { _import_structure = {
'backbones': ['SbertModel'], 'backbones': ['SbertModel'],
@@ -107,6 +107,7 @@ else:
'sentence_embedding': ['SentenceEmbedding'], 'sentence_embedding': ['SentenceEmbedding'],
'T5': ['T5ForConditionalGeneration'], 'T5': ['T5ForConditionalGeneration'],
'gpt_neo': ['GPTNeoModel'], 'gpt_neo': ['GPTNeoModel'],
'bloom': ['BloomModel'],
} }


import sys import sys


+ 19
- 0
modelscope/models/nlp/bloom/__init__.py View File

@@ -0,0 +1,19 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
from typing import TYPE_CHECKING

from modelscope.utils.import_utils import LazyImportModule

if TYPE_CHECKING:
from .backbone import BloomModel
else:
_import_structure = {
'backbone': ['BloomModel'],
}
import sys
sys.modules[__name__] = LazyImportModule(
__name__,
globals()['__file__'],
_import_structure,
module_spec=__spec__,
extra_objects={},
)

+ 2
- 2
modelscope/models/nlp/bloom/backbone.py View File

@@ -4,10 +4,10 @@ from transformers import BloomModel as BloomModelTransform


from modelscope.metainfo import Models from modelscope.metainfo import Models
from modelscope.models.builder import BACKBONES from modelscope.models.builder import BACKBONES
from modelscope.utils.constant import Fields
from modelscope.utils.constant import Tasks




@BACKBONES.register_module(group_key=Fields.nlp, module_name=Models.bloom)
@BACKBONES.register_module(group_key=Tasks.backbone, module_name=Models.bloom)
class BloomModel(BloomModelTransform): class BloomModel(BloomModelTransform):


def __init__(self, **kwargs): def __init__(self, **kwargs):


+ 2
- 6
modelscope/models/nlp/gpt3/text_generation.py View File

@@ -42,7 +42,7 @@ class GPT3ForTextGeneration(TorchModel):
""" """
return self.model(**input) return self.model(**input)


def generate(self, input: Dict[str, Tensor]) -> Dict[str, str]:
def generate(self, input: Dict[str, Tensor]) -> Dict[str, Tensor]:
assert 'input_ids' in input, "generate function must accept 'input_ids' key" assert 'input_ids' in input, "generate function must accept 'input_ids' key"
input_ids = input['input_ids'] input_ids = input['input_ids']
if 'attention_mask' in input: if 'attention_mask' in input:
@@ -59,8 +59,4 @@ class GPT3ForTextGeneration(TorchModel):
gen_params['top_k'] = input.pop('top_k', 10) gen_params['top_k'] = input.pop('top_k', 10)
gen_params['top_p'] = input.pop('top_p', None) gen_params['top_p'] = input.pop('top_p', None)
sample_output = self.model.generate(**gen_params) sample_output = self.model.generate(**gen_params)
return {
OutputKeys.TEXT:
self.tokenizer.decode(sample_output[0],
skip_special_tokens=True).replace(' ', '')
}
return {'sequences': sample_output[0]}

+ 2
- 2
modelscope/models/nlp/palm_v2/backbone.py View File

@@ -1314,8 +1314,8 @@ class Translator(object):


return results return results


def __call__(self, input_ids: torch.Tensor,
attention_mask: torch.Tensor) -> Dict[str, torch.Tensor]:
def __call__(self, input_ids: torch.Tensor, attention_mask: torch.Tensor,
**kwargs) -> Dict[str, torch.Tensor]:
batch = self.Batch( batch = self.Batch(
batch_size=input_ids.size()[0], batch_size=input_ids.size()[0],
src=input_ids, src=input_ids,


+ 3
- 38
modelscope/models/nlp/palm_v2/text_generation.py View File

@@ -29,22 +29,6 @@ class PalmForTextGeneration(TorchModel):
self.tokenizer = self.model.tokenizer self.tokenizer = self.model.tokenizer
self.generator = Translator(self.model) self.generator = Translator(self.model)


def _evaluate_postprocess(self, ids_list: List[List[int]]) -> List[str]:
replace_tokens_bert = (('[unused0]', ''), ('[PAD]', ''), ('[unused1]',
''),
(r' +', ' '), ('[SEP]', ''), ('[unused2]', ''),
('[CLS]', ''), ('[UNK]', ''), (' ', ''))
replace_tokens_roberta = ((r' +', ' '), ('<mask>', '. '),
('<pad>', ''), ('<s>', ''), ('</s>', ''),
('<unk>', ' '), ('<q>', '. '))

replace_tokens = replace_tokens_roberta \
if self.model.config.encoder == 'roberta' else replace_tokens_bert
strings = [self.tokenizer.decode(pred_ids) for pred_ids in ids_list]
for _old, _new in replace_tokens:
strings = [s.replace(_old, _new) for s in strings]
return strings

def forward(self, input: Dict[str, Tensor]) -> Dict[str, Tensor]: def forward(self, input: Dict[str, Tensor]) -> Dict[str, Tensor]:
"""return the result by the model """return the result by the model


@@ -57,29 +41,10 @@ class PalmForTextGeneration(TorchModel):
{ {
'loss': Tensor([12.34]), # loss for backward 'loss': Tensor([12.34]), # loss for backward
} }
or
{
'preds': List["hello word"...] # the predicted strings
'tgts': List["hello world"...] # target strings
}
""" """
if self.training:
return self.model(**input)
else:
outputs = self.generator(input['input_ids'],
input['attention_mask'])
preds = outputs['predictions']
pred_ids_list = [
pred_batch[0].cpu().numpy().tolist() for pred_batch in preds
]
tgt_ids_list = input['labels'].cpu().numpy().tolist()
return {
'preds': self._evaluate_postprocess(pred_ids_list),
'tgts': self._evaluate_postprocess(tgt_ids_list)
}
return self.model(**input)


def generate(self, input: Dict[str, Tensor]) -> Dict[str, str]:
def generate(self, input: Dict[str, Tensor]) -> Dict[str, Tensor]:
outputs = self.generator(**input) outputs = self.generator(**input)
preds = outputs['predictions'] preds = outputs['predictions']
pred_ids_list = [preds[0][0].cpu().numpy().tolist()]
return {OutputKeys.TEXT: self._evaluate_postprocess(pred_ids_list)[0]}
return {'sequences': [pred[0] for pred in preds]}

+ 38
- 7
modelscope/pipelines/nlp/text_generation_pipeline.py View File

@@ -53,7 +53,7 @@ class TextGenerationPipeline(Pipeline):
model = model if isinstance(model, model = model if isinstance(model,
Model) else Model.from_pretrained(model) Model) else Model.from_pretrained(model)
cfg = read_config(model.model_dir) cfg = read_config(model.model_dir)
self.postprocessor = cfg.pop('postprocessor', None)
self.postprocessor = cfg.pop('postprocessor', 'decode')
if preprocessor is None: if preprocessor is None:
preprocessor_cfg = cfg.preprocessor preprocessor_cfg = cfg.preprocessor
preprocessor_cfg.update({ preprocessor_cfg.update({
@@ -78,8 +78,37 @@ class TextGenerationPipeline(Pipeline):
with torch.no_grad(): with torch.no_grad():
return self.model.generate(inputs, **forward_params) return self.model.generate(inputs, **forward_params)


def sentence_piece(self, inputs) -> Dict[str, Tensor]:
return self.preprocessor.tokenizer.decode(inputs.tolist()[0])
def _is_chinese_char(self, word: str):
chinese_punctuations = (',', '。', ';', ':' '!', '?', '《', '》')
return len(word) == 1 \
and ('\u4e00' <= word <= '\u9fa5' or word in chinese_punctuations)

def _remove_space_between_chinese_chars(self, decoded: str):
old_word_list = decoded.split(' ')
new_word_list = []
start = -1
for i, word in enumerate(old_word_list):
if self._is_chinese_char(word):
if start == -1:
start = i
else:
if start != -1:
new_word_list.append(''.join(old_word_list[start:i]))
start = -1
new_word_list.append(word)
if start != -1:
new_word_list.append(''.join(old_word_list[start:]))
return ' '.join(new_word_list)

def decode(self, inputs) -> str:
tokenizer = self.preprocessor.tokenizer
return tokenizer.decode(inputs.tolist(), skip_special_tokens=True)

def roberta(self, inputs) -> str:
tokenizer = self.preprocessor.tokenizer
decoded = tokenizer.decode(inputs.tolist())
return decoded.replace('<q>', '. ').replace('<mask>',
'. ').replace('</s>', '')


def postprocess(self, inputs: Dict[str, Tensor], def postprocess(self, inputs: Dict[str, Tensor],
**postprocess_params) -> Dict[str, str]: **postprocess_params) -> Dict[str, str]:
@@ -91,7 +120,9 @@ class TextGenerationPipeline(Pipeline):
Returns: Returns:
Dict[str, str]: the prediction results Dict[str, str]: the prediction results
""" """
return inputs if self.postprocessor is None else {
OutputKeys.TEXT:
getattr(self, self.postprocessor.replace('-', '_'))(inputs)
}
inputs = inputs['sequences']
if isinstance(inputs, list):
inputs = inputs[0]
decoded = getattr(self, self.postprocessor)(inputs)
text = self._remove_space_between_chinese_chars(decoded)
return {OutputKeys.TEXT: text}

+ 3
- 1
modelscope/trainers/nlp/__init__.py View File

@@ -7,11 +7,13 @@ if TYPE_CHECKING:
from .sequence_classification_trainer import SequenceClassificationTrainer from .sequence_classification_trainer import SequenceClassificationTrainer
from .csanmt_translation_trainer import CsanmtTranslationTrainer from .csanmt_translation_trainer import CsanmtTranslationTrainer
from .text_ranking_trainer import TextRankingTrainer from .text_ranking_trainer import TextRankingTrainer
from .text_generation_trainer import TextGenerationTrainer
else: else:
_import_structure = { _import_structure = {
'sequence_classification_trainer': ['SequenceClassificationTrainer'], 'sequence_classification_trainer': ['SequenceClassificationTrainer'],
'csanmt_translation_trainer': ['CsanmtTranslationTrainer'], 'csanmt_translation_trainer': ['CsanmtTranslationTrainer'],
'text_ranking_trainer': ['TextRankingTrainer']
'text_ranking_trainer': ['TextRankingTrainer'],
'text_generation_trainer': ['TextGenerationTrainer'],
} }


import sys import sys


+ 36
- 0
modelscope/trainers/nlp/text_generation_trainer.py View File

@@ -0,0 +1,36 @@
# Copyright (c) Alibaba, Inc. and its affiliates.

from collections.abc import Mapping

import torch

from modelscope.metainfo import Trainers
from modelscope.trainers import NlpEpochBasedTrainer
from modelscope.trainers.builder import TRAINERS
from modelscope.utils.file_utils import func_receive_dict_inputs


@TRAINERS.register_module(module_name=Trainers.text_generation_trainer)
class TextGenerationTrainer(NlpEpochBasedTrainer):

def _decode(self, tokens):
tokenizer = self.eval_preprocessor.tokenizer
return tokenizer.decode(tokens.tolist(), skip_special_tokens=True)

def evaluation_step(self, data):
model = self.model
model.eval()

with torch.no_grad():
if isinstance(
data,
Mapping) and not func_receive_dict_inputs(model.generate):
result = model.generate(**data)
else:
result = model.generate(data)

result['preds'] = [self._decode(seq) for seq in result['sequences']]
data['tgts'] = [self._decode(seq) for seq in data['labels']]
assert len(result['preds']) == len(data['tgts'])

return result

+ 24
- 2
modelscope/trainers/trainer.py View File

@@ -855,6 +855,28 @@ class EpochBasedTrainer(BaseTrainer):


self.invoke_hook(TrainerStages.after_run) self.invoke_hook(TrainerStages.after_run)


def evaluation_step(self, data):
"""Perform a training step on a batch of inputs.

Subclass and override to inject custom behavior.

"""
model = self.model
model.eval()

if is_parallel(model):
receive_dict_inputs = func_receive_dict_inputs(
model.module.forward)
else:
receive_dict_inputs = func_receive_dict_inputs(model.forward)

with torch.no_grad():
if isinstance(data, Mapping) and not receive_dict_inputs:
result = model.forward(**data)
else:
result = model.forward(data)
return result

def evaluation_loop(self, data_loader, metric_classes): def evaluation_loop(self, data_loader, metric_classes):
""" Evaluation loop used by `EpochBasedTrainer.evaluate()`. """ Evaluation loop used by `EpochBasedTrainer.evaluate()`.


@@ -862,7 +884,7 @@ class EpochBasedTrainer(BaseTrainer):
if self._dist: if self._dist:
from modelscope.trainers.utils.inference import multi_gpu_test from modelscope.trainers.utils.inference import multi_gpu_test
metric_values = multi_gpu_test( metric_values = multi_gpu_test(
self.model,
self,
data_loader, data_loader,
device=self.device, device=self.device,
tmpdir=None, tmpdir=None,
@@ -872,7 +894,7 @@ class EpochBasedTrainer(BaseTrainer):
else: else:
from modelscope.trainers.utils.inference import single_gpu_test from modelscope.trainers.utils.inference import single_gpu_test
metric_values = single_gpu_test( metric_values = single_gpu_test(
self.model,
self,
data_loader, data_loader,
device=self.device, device=self.device,
metric_classes=metric_classes, metric_classes=metric_classes,


+ 8
- 28
modelscope/trainers/utils/inference.py View File

@@ -4,29 +4,25 @@ import logging
import os import os
import pickle import pickle
import shutil import shutil
import time
from collections.abc import Mapping


import torch import torch
from torch import distributed as dist from torch import distributed as dist
from tqdm import tqdm from tqdm import tqdm


from modelscope.trainers.parallel.utils import is_parallel
from modelscope.utils.data_utils import to_device from modelscope.utils.data_utils import to_device
from modelscope.utils.file_utils import func_receive_dict_inputs
from modelscope.utils.torch_utils import (broadcast, get_dist_info, is_master, from modelscope.utils.torch_utils import (broadcast, get_dist_info, is_master,
make_tmp_dir) make_tmp_dir)




def single_gpu_test(model,
def single_gpu_test(trainer,
data_loader, data_loader,
device, device,
metric_classes=None, metric_classes=None,
data_loader_iters=None): data_loader_iters=None):
"""Test model with a single gpu.
"""Test model in EpochBasedTrainer with a single gpu.


Args: Args:
model (nn.Module): Model to be tested.
trainer (modelscope.trainers.EpochBasedTrainer): Trainer to be tested.
data_loader (nn.Dataloader): Pytorch data loader. data_loader (nn.Dataloader): Pytorch data loader.
device (str | torch.device): The target device for the data. device (str | torch.device): The target device for the data.
metric_classes (List): List of Metric class that uses to collect metrics metric_classes (List): List of Metric class that uses to collect metrics
@@ -35,7 +31,6 @@ def single_gpu_test(model,
Returns: Returns:
list: The prediction results. list: The prediction results.
""" """
model.eval()
dataset = data_loader.dataset dataset = data_loader.dataset
progress_with_iters = False progress_with_iters = False
if data_loader_iters is None: if data_loader_iters is None:
@@ -55,12 +50,7 @@ def single_gpu_test(model,
with tqdm(total=data_len, desc=desc) as pbar: with tqdm(total=data_len, desc=desc) as pbar:
for i, data in enumerate(data_loader): for i, data in enumerate(data_loader):
data = to_device(data, device) data = to_device(data, device)
with torch.no_grad():
if isinstance(data, Mapping) and not func_receive_dict_inputs(
model.forward):
result = model.forward(**data)
else:
result = model.forward(data)
result = trainer.evaluation_step(data)
if metric_classes is not None: if metric_classes is not None:
for metric_cls in metric_classes: for metric_cls in metric_classes:
metric_cls.add(result, data) metric_cls.add(result, data)
@@ -88,14 +78,14 @@ def single_gpu_test(model,
return metric_values return metric_values




def multi_gpu_test(model,
def multi_gpu_test(trainer,
data_loader, data_loader,
device, device,
tmpdir=None, tmpdir=None,
gpu_collect=False, gpu_collect=False,
metric_classes=None, metric_classes=None,
data_loader_iters_per_gpu=None): data_loader_iters_per_gpu=None):
"""Test model with multiple gpus.
"""Test model in EpochBasedTrainer with multiple gpus.


This method tests model with multiple gpus and collects the results This method tests model with multiple gpus and collects the results
under two different modes: gpu and cpu modes. By setting under two different modes: gpu and cpu modes. By setting
@@ -104,7 +94,7 @@ def multi_gpu_test(model,
different gpus to ``tmpdir`` and collects them by the rank 0 worker. different gpus to ``tmpdir`` and collects them by the rank 0 worker.


Args: Args:
model (nn.Module): Model to be tested.
trainer (modelscope.trainers.EpochBasedTrainer): Trainer to be tested.
data_loader (nn.Dataloader): Pytorch data loader. data_loader (nn.Dataloader): Pytorch data loader.
device: (str | torch.device): The target device for the data. device: (str | torch.device): The target device for the data.
tmpdir (str): Path of directory to save the temporary results from tmpdir (str): Path of directory to save the temporary results from
@@ -115,7 +105,6 @@ def multi_gpu_test(model,
Returns: Returns:
list: The prediction results. list: The prediction results.
""" """
model.eval()
results = [] results = []
data_list = [] data_list = []
dataset = data_loader.dataset dataset = data_loader.dataset
@@ -138,21 +127,12 @@ def multi_gpu_test(model,
data_len = data_loader_iters_per_gpu * world_size data_len = data_loader_iters_per_gpu * world_size
desc = 'Total test iterations with multi gpus' desc = 'Total test iterations with multi gpus'


if is_parallel(model):
receive_dict_inputs = func_receive_dict_inputs(model.module.forward)
else:
receive_dict_inputs = func_receive_dict_inputs(model.forward)

count = 0 count = 0
with tqdm(total=data_len, desc=desc) as pbar: with tqdm(total=data_len, desc=desc) as pbar:
for i, data in enumerate(data_loader): for i, data in enumerate(data_loader):
data = to_device(data, device) data = to_device(data, device)
data_list.append(data) data_list.append(data)
with torch.no_grad():
if isinstance(data, Mapping) and not receive_dict_inputs:
result = model.forward(**data)
else:
result = model.forward(data)
result = trainer.evaluation_step(data)
results.append(result) results.append(result)


if isinstance(data, dict): if isinstance(data, dict):


+ 2
- 2
tests/trainers/test_finetune_text_generation.py View File

@@ -59,7 +59,7 @@ class TestFinetuneTextGeneration(unittest.TestCase):
work_dir=self.tmp_dir) work_dir=self.tmp_dir)


trainer = build_trainer( trainer = build_trainer(
name=Trainers.nlp_base_trainer, default_args=kwargs)
name=Trainers.text_generation_trainer, default_args=kwargs)
trainer.train() trainer.train()
results_files = os.listdir(self.tmp_dir) results_files = os.listdir(self.tmp_dir)
self.assertIn(f'{trainer.timestamp}.log.json', results_files) self.assertIn(f'{trainer.timestamp}.log.json', results_files)
@@ -98,7 +98,7 @@ class TestFinetuneTextGeneration(unittest.TestCase):
work_dir=self.tmp_dir) work_dir=self.tmp_dir)


trainer = build_trainer( trainer = build_trainer(
name=Trainers.nlp_base_trainer, default_args=kwargs)
name=Trainers.text_generation_trainer, default_args=kwargs)
trainer.train() trainer.train()
results_files = os.listdir(self.tmp_dir) results_files = os.listdir(self.tmp_dir)
self.assertIn(f'{trainer.timestamp}.log.json', results_files) self.assertIn(f'{trainer.timestamp}.log.json', results_files)


+ 10
- 1
tests/trainers/utils/test_inference.py View File

@@ -12,6 +12,7 @@ from modelscope.metrics.builder import MetricKeys
from modelscope.metrics.sequence_classification_metric import \ from modelscope.metrics.sequence_classification_metric import \
SequenceClassificationMetric SequenceClassificationMetric
from modelscope.models.base import Model from modelscope.models.base import Model
from modelscope.trainers import EpochBasedTrainer
from modelscope.trainers.utils.inference import multi_gpu_test, single_gpu_test from modelscope.trainers.utils.inference import multi_gpu_test, single_gpu_test
from modelscope.utils.test_utils import (DistributedTestCase, from modelscope.utils.test_utils import (DistributedTestCase,
create_dummy_test_dataset, test_level) create_dummy_test_dataset, test_level)
@@ -36,6 +37,12 @@ class DummyModel(nn.Module, Model):
return dict(logits=x, loss=loss) return dict(logits=x, loss=loss)




class DummyTrainer(EpochBasedTrainer):

def __init__(self, model):
self.model = model


def test_func(dist=False): def test_func(dist=False):
dummy_model = DummyModel() dummy_model = DummyModel()
dataset = dummy_dataset.to_torch_dataset() dataset = dummy_dataset.to_torch_dataset()
@@ -62,8 +69,10 @@ def test_func(dist=False):
else: else:
test_func = single_gpu_test test_func = single_gpu_test


dummy_trainer = DummyTrainer(dummy_model)

metric_results = test_func( metric_results = test_func(
dummy_model,
dummy_trainer,
dummy_loader, dummy_loader,
device=device, device=device,
metric_classes=[metric_class]) metric_classes=[metric_class])


Loading…
Cancel
Save