diff --git a/modelscope/fileio/__init__.py b/modelscope/fileio/__init__.py index 5fd10f85..b526d593 100644 --- a/modelscope/fileio/__init__.py +++ b/modelscope/fileio/__init__.py @@ -1,2 +1,2 @@ -from .file import File +from .file import File, LocalStorage from .io import dump, dumps, load diff --git a/modelscope/fileio/file.py b/modelscope/fileio/file.py index 343cad9a..3fff80c8 100644 --- a/modelscope/fileio/file.py +++ b/modelscope/fileio/file.py @@ -240,7 +240,7 @@ class File(object): @staticmethod def _get_storage(uri): assert isinstance(uri, - str), f'uri should be str type, buf got {type(uri)}' + str), f'uri should be str type, but got {type(uri)}' if '://' not in uri: # local path diff --git a/modelscope/models/base/base_model.py b/modelscope/models/base/base_model.py index 279dbba2..872c42e8 100644 --- a/modelscope/models/base/base_model.py +++ b/modelscope/models/base/base_model.py @@ -1,13 +1,12 @@ # Copyright (c) Alibaba, Inc. and its affiliates. - +import os import os.path as osp from abc import ABC, abstractmethod -from typing import Dict, Optional, Union - -import numpy as np +from typing import Callable, Dict, List, Optional, Union from modelscope.hub.snapshot_download import snapshot_download from modelscope.models.builder import build_model +from modelscope.utils.checkpoint import save_pretrained from modelscope.utils.config import Config from modelscope.utils.constant import DEFAULT_MODEL_REVISION, ModelFile from modelscope.utils.device import device_placement, verify_device @@ -119,3 +118,28 @@ class Model(ABC): if hasattr(cfg, 'pipeline'): model.pipeline = cfg.pipeline return model + + def save_pretrained(self, + target_folder: Union[str, os.PathLike], + save_checkpoint_names: Union[str, List[str]] = None, + save_function: Callable = None, + config: Optional[dict] = None, + **kwargs): + """save the pretrained model, its configuration and other related files to a directory, so that it can be re-loaded + + Args: + target_folder (Union[str, os.PathLike]): + Directory to which to save. Will be created if it doesn't exist. + + save_checkpoint_names (Union[str, List[str]]): + The checkpoint names to be saved in the target_folder + + save_function (Callable, optional): + The function to use to save the state dictionary. + + config (Optional[dict], optional): + The config for the configuration.json, might not be identical with model.config + + """ + save_pretrained(self, target_folder, save_checkpoint_names, + save_function, config, **kwargs) diff --git a/modelscope/trainers/hooks/checkpoint_hook.py b/modelscope/trainers/hooks/checkpoint_hook.py index fc0281a1..623d4654 100644 --- a/modelscope/trainers/hooks/checkpoint_hook.py +++ b/modelscope/trainers/hooks/checkpoint_hook.py @@ -1,10 +1,12 @@ # Copyright (c) Alibaba, Inc. and its affiliates. import os +import json + from modelscope import __version__ from modelscope.metainfo import Hooks from modelscope.utils.checkpoint import save_checkpoint -from modelscope.utils.constant import LogKeys +from modelscope.utils.constant import LogKeys, ModelFile from modelscope.utils.logger import get_logger from modelscope.utils.torch_utils import is_master from .builder import HOOKS @@ -73,6 +75,18 @@ class CheckpointHook(Hook): self.save_dir, f'{LogKeys.ITER}_{trainer.iter + 1}.pth') save_checkpoint(trainer.model, cur_save_name, trainer.optimizer) + self._save_pretrained(trainer) + + def _save_pretrained(self, trainer): + if self.is_last_epoch(trainer) and self.by_epoch: + output_dir = os.path.join(self.save_dir, + ModelFile.TRAIN_OUTPUT_DIR) + + trainer.model.save_pretrained( + output_dir, + ModelFile.TORCH_MODEL_BIN_FILE, + save_function=save_checkpoint, + config=trainer.cfg.to_dict()) def after_train_iter(self, trainer): if self.by_epoch: @@ -166,3 +180,4 @@ class BestCkptSaverHook(CheckpointHook): ) save_checkpoint(trainer.model, cur_save_name, trainer.optimizer) self._best_ckpt_file = cur_save_name + self._save_pretrained(trainer) diff --git a/modelscope/utils/checkpoint.py b/modelscope/utils/checkpoint.py index 76fb2a19..8b9d027a 100644 --- a/modelscope/utils/checkpoint.py +++ b/modelscope/utils/checkpoint.py @@ -1,15 +1,23 @@ # Copyright (c) Alibaba, Inc. and its affiliates. import io +import os import time from collections import OrderedDict -from typing import Optional +from shutil import copytree, ignore_patterns, rmtree +from typing import Callable, List, Optional, Union +import json +import numpy as np import torch from torch.optim import Optimizer from modelscope import __version__ -from modelscope.fileio import File +from modelscope.fileio import File, LocalStorage +from modelscope.utils.config import JSONIteratorEncoder +from modelscope.utils.constant import ConfigFields, ModelFile + +storage = LocalStorage() def weights_to_cpu(state_dict): @@ -72,3 +80,76 @@ def save_checkpoint(model: torch.nn.Module, with io.BytesIO() as f: torch.save(checkpoint, f) File.write(f.getvalue(), filename) + + +def save_pretrained(model, + target_folder: Union[str, os.PathLike], + save_checkpoint_name: str = None, + save_function: Callable = None, + config: Optional[dict] = None, + **kwargs): + """save the pretrained model, its configuration and other related files to a directory, so that it can be re-loaded + + Args: + model (Model): Model whose params are to be saved. + + target_folder (Union[str, os.PathLike]): + Directory to which to save. Will be created if it doesn't exist. + + save_checkpoint_name (str): + The checkpoint name to be saved in the target_folder + + save_function (Callable, optional): + The function to use to save the state dictionary. + + config (Optional[dict], optional): + The config for the configuration.json, might not be identical with model.config + """ + + if save_function is None or not isinstance(save_function, Callable): + raise Exception('A valid save function must be passed in') + + if target_folder is None or os.path.isfile(target_folder): + raise ValueError( + f'Provided path ({target_folder}) should be a directory, not a file' + ) + + if save_checkpoint_name is None: + raise Exception( + 'At least pass in one checkpoint name for saving method') + + if config is None: + raise ValueError('Configuration is not valid') + + # Clean the folder from a previous save + if os.path.exists(target_folder): + rmtree(target_folder) + + # Single ckpt path, sharded ckpt logic will be added later + output_ckpt_path = os.path.join(target_folder, save_checkpoint_name) + + # Save the files to be copied to the save directory, ignore the original ckpts and configuration + origin_file_to_be_ignored = [save_checkpoint_name] + ignore_file_set = set(origin_file_to_be_ignored) + ignore_file_set.add(ModelFile.CONFIGURATION) + ignore_file_set.add('.*') + if hasattr(model, 'model_dir') and model.model_dir is not None: + copytree( + model.model_dir, + target_folder, + ignore=ignore_patterns(*ignore_file_set)) + + # Save the ckpt to the save directory + try: + save_function(model, output_ckpt_path) + except Exception as e: + raise Exception( + f'During saving checkpoints, the error of "{type(e).__name__} ' + f'with msg {e} throwed') + + # Dump the config to the configuration.json + if ConfigFields.pipeline not in config: + config[ConfigFields.pipeline] = {'type': config[ConfigFields.task]} + cfg_str = json.dumps(config, cls=JSONIteratorEncoder) + config_file = os.path.join(target_folder, ModelFile.CONFIGURATION) + storage.write(cfg_str.encode(), config_file) diff --git a/modelscope/utils/config.py b/modelscope/utils/config.py index a28ac1ab..42985db6 100644 --- a/modelscope/utils/config.py +++ b/modelscope/utils/config.py @@ -12,6 +12,7 @@ from pathlib import Path from typing import Dict, Union import addict +import json from yapf.yapflib.yapf_api import FormatCode from modelscope.utils.constant import ConfigFields, ModelFile @@ -627,3 +628,20 @@ def check_config(cfg: Union[str, ConfigDict]): check_attr(ConfigFields.model) check_attr(ConfigFields.preprocessor) check_attr(ConfigFields.evaluation) + + +class JSONIteratorEncoder(json.JSONEncoder): + """Implement this method in order that supporting arbitrary iterators, it returns + a serializable object for ``obj``, or calls the base implementation + (to raise a ``TypeError``). + + """ + + def default(self, obj): + try: + iterable = iter(obj) + except TypeError: + pass + else: + return list(iterable) + return json.JSONEncoder.default(self, obj) diff --git a/modelscope/utils/constant.py b/modelscope/utils/constant.py index d914767b..81712983 100644 --- a/modelscope/utils/constant.py +++ b/modelscope/utils/constant.py @@ -211,6 +211,7 @@ class ModelFile(object): VOCAB_FILE = 'vocab.txt' ONNX_MODEL_FILE = 'model.onnx' LABEL_MAPPING = 'label_mapping.json' + TRAIN_OUTPUT_DIR = 'output' class ConfigFields(object): diff --git a/modelscope/utils/hub.py b/modelscope/utils/hub.py index 6d685b87..f79097fe 100644 --- a/modelscope/utils/hub.py +++ b/modelscope/utils/hub.py @@ -10,7 +10,8 @@ from modelscope.hub.constants import Licenses, ModelVisibility from modelscope.hub.file_download import model_file_download from modelscope.hub.snapshot_download import snapshot_download from modelscope.utils.config import Config -from modelscope.utils.constant import DEFAULT_MODEL_REVISION, ModelFile +from modelscope.utils.constant import (DEFAULT_MODEL_REVISION, ConfigFields, + ModelFile) from .logger import get_logger logger = get_logger(__name__) @@ -119,8 +120,13 @@ def parse_label_mapping(model_dir): if label2id is None: config_path = os.path.join(model_dir, ModelFile.CONFIGURATION) config = Config.from_file(config_path) - if hasattr(config, 'model') and hasattr(config.model, 'label2id'): - label2id = config.model.label2id + if hasattr(config, ConfigFields.model) and hasattr( + config[ConfigFields.model], 'label2id'): + label2id = config[ConfigFields.model].label2id + elif hasattr(config, ConfigFields.preprocessor) and hasattr( + config[ConfigFields.preprocessor], 'label2id'): + label2id = config[ConfigFields.preprocessor].label2id + if label2id is None: config_path = os.path.join(model_dir, 'config.json') config = Config.from_file(config_path) diff --git a/tests/trainers/hooks/logger/test_tensorboard_hook.py b/tests/trainers/hooks/logger/test_tensorboard_hook.py index 54c31056..67b1aa63 100644 --- a/tests/trainers/hooks/logger/test_tensorboard_hook.py +++ b/tests/trainers/hooks/logger/test_tensorboard_hook.py @@ -11,6 +11,7 @@ import torch from torch import nn from modelscope.metainfo import Trainers +from modelscope.models.base import Model from modelscope.trainers import build_trainer from modelscope.utils.constant import LogKeys, ModelFile from modelscope.utils.test_utils import create_dummy_test_dataset @@ -19,7 +20,7 @@ dummy_dataset = create_dummy_test_dataset( np.random.random(size=(5, )), np.random.randint(0, 4, (1, )), 20) -class DummyModel(nn.Module): +class DummyModel(nn.Module, Model): def __init__(self): super().__init__() diff --git a/tests/trainers/hooks/test_checkpoint_hook.py b/tests/trainers/hooks/test_checkpoint_hook.py index 1c81d057..c694ece6 100644 --- a/tests/trainers/hooks/test_checkpoint_hook.py +++ b/tests/trainers/hooks/test_checkpoint_hook.py @@ -11,11 +11,14 @@ from torch import nn from modelscope.metainfo import Trainers from modelscope.metrics.builder import METRICS, MetricKeys +from modelscope.models.base import Model from modelscope.trainers import build_trainer from modelscope.utils.constant import LogKeys, ModelFile from modelscope.utils.registry import default_group from modelscope.utils.test_utils import create_dummy_test_dataset +SRC_DIR = os.path.dirname(__file__) + def create_dummy_metric(): _global_iter = 0 @@ -39,12 +42,13 @@ dummy_dataset = create_dummy_test_dataset( np.random.random(size=(5, )), np.random.randint(0, 4, (1, )), 20) -class DummyModel(nn.Module): +class DummyModel(nn.Module, Model): def __init__(self): super().__init__() self.linear = nn.Linear(5, 4) self.bn = nn.BatchNorm1d(4) + self.model_dir = SRC_DIR def forward(self, feat, labels): x = self.linear(feat) @@ -123,6 +127,14 @@ class CheckpointHookTest(unittest.TestCase): self.assertIn(f'{LogKeys.EPOCH}_1.pth', results_files) self.assertIn(f'{LogKeys.EPOCH}_2.pth', results_files) + output_files = os.listdir( + os.path.join(self.tmp_dir, ModelFile.TRAIN_OUTPUT_DIR)) + self.assertIn(ModelFile.CONFIGURATION, output_files) + self.assertIn(ModelFile.TORCH_MODEL_BIN_FILE, output_files) + copy_src_files = os.listdir(SRC_DIR) + self.assertIn(copy_src_files[0], output_files) + self.assertIn(copy_src_files[-1], output_files) + class BestCkptSaverHookTest(unittest.TestCase): @@ -198,6 +210,14 @@ class BestCkptSaverHookTest(unittest.TestCase): self.assertIn(f'best_{LogKeys.EPOCH}1_{MetricKeys.ACCURACY}0.1.pth', results_files) + output_files = os.listdir( + os.path.join(self.tmp_dir, ModelFile.TRAIN_OUTPUT_DIR)) + self.assertIn(ModelFile.CONFIGURATION, output_files) + self.assertIn(ModelFile.TORCH_MODEL_BIN_FILE, output_files) + copy_src_files = os.listdir(SRC_DIR) + self.assertIn(copy_src_files[0], output_files) + self.assertIn(copy_src_files[-1], output_files) + if __name__ == '__main__': unittest.main() diff --git a/tests/trainers/hooks/test_evaluation_hook.py b/tests/trainers/hooks/test_evaluation_hook.py index 1338bb2c..2c71e790 100644 --- a/tests/trainers/hooks/test_evaluation_hook.py +++ b/tests/trainers/hooks/test_evaluation_hook.py @@ -11,6 +11,7 @@ from torch import nn from modelscope.metainfo import Trainers from modelscope.metrics.builder import METRICS, MetricKeys +from modelscope.models.base import Model from modelscope.trainers import build_trainer from modelscope.utils.constant import ModelFile from modelscope.utils.registry import default_group @@ -34,7 +35,7 @@ dummy_dataset = create_dummy_test_dataset( np.random.random(size=(5, )), np.random.randint(0, 4, (1, )), 20) -class DummyModel(nn.Module): +class DummyModel(nn.Module, Model): def __init__(self): super().__init__() diff --git a/tests/trainers/hooks/test_lr_scheduler_hook.py b/tests/trainers/hooks/test_lr_scheduler_hook.py index 86d53ecc..7a1ff220 100644 --- a/tests/trainers/hooks/test_lr_scheduler_hook.py +++ b/tests/trainers/hooks/test_lr_scheduler_hook.py @@ -13,6 +13,7 @@ from torch.optim.lr_scheduler import MultiStepLR from modelscope.metainfo import Trainers from modelscope.metrics.builder import METRICS, MetricKeys +from modelscope.models.base import Model from modelscope.trainers import build_trainer from modelscope.utils.constant import LogKeys, ModelFile, TrainerStages from modelscope.utils.registry import default_group @@ -40,7 +41,7 @@ def create_dummy_metric(): return {MetricKeys.ACCURACY: self._fake_acc_by_epoch[_global_iter]} -class DummyModel(nn.Module): +class DummyModel(nn.Module, Model): def __init__(self): super().__init__() diff --git a/tests/trainers/hooks/test_optimizer_hook.py b/tests/trainers/hooks/test_optimizer_hook.py index 25457c1c..84c783b5 100644 --- a/tests/trainers/hooks/test_optimizer_hook.py +++ b/tests/trainers/hooks/test_optimizer_hook.py @@ -12,6 +12,7 @@ from torch.optim import SGD from torch.optim.lr_scheduler import MultiStepLR from modelscope.metainfo import Trainers +from modelscope.models.base import Model from modelscope.trainers import build_trainer from modelscope.utils.constant import ModelFile, TrainerStages from modelscope.utils.test_utils import create_dummy_test_dataset @@ -20,7 +21,7 @@ dummy_dataset = create_dummy_test_dataset( np.random.random(size=(2, )), np.random.randint(0, 2, (1, )), 10) -class DummyModel(nn.Module): +class DummyModel(nn.Module, Model): def __init__(self): super().__init__() diff --git a/tests/trainers/hooks/test_timer_hook.py b/tests/trainers/hooks/test_timer_hook.py index 614f7688..9fb79c77 100644 --- a/tests/trainers/hooks/test_timer_hook.py +++ b/tests/trainers/hooks/test_timer_hook.py @@ -12,6 +12,7 @@ from torch.optim import SGD from torch.optim.lr_scheduler import MultiStepLR from modelscope.metainfo import Trainers +from modelscope.models.base import Model from modelscope.trainers import build_trainer from modelscope.utils.constant import LogKeys, ModelFile, TrainerStages from modelscope.utils.test_utils import create_dummy_test_dataset @@ -20,7 +21,7 @@ dummy_dataset = create_dummy_test_dataset( np.random.random(size=(5, )), np.random.randint(0, 4, (1, )), 10) -class DummyModel(nn.Module): +class DummyModel(nn.Module, Model): def __init__(self): super().__init__() @@ -83,8 +84,8 @@ class IterTimerHookTest(unittest.TestCase): trainer.train_dataset, **trainer.cfg.train.get('dataloader', {})) trainer.register_optimizers_hook() trainer.register_hook_from_cfg(trainer.cfg.train.hooks) - trainer.data_loader = train_dataloader trainer.train_dataloader = train_dataloader + trainer.data_loader = train_dataloader trainer.invoke_hook(TrainerStages.before_run) for i in range(trainer._epoch, trainer._max_epochs): trainer.invoke_hook(TrainerStages.before_train_epoch) diff --git a/tests/trainers/test_finetune_sequence_classification.py b/tests/trainers/test_finetune_sequence_classification.py index 12c7da77..847e47ef 100644 --- a/tests/trainers/test_finetune_sequence_classification.py +++ b/tests/trainers/test_finetune_sequence_classification.py @@ -4,11 +4,18 @@ import shutil import tempfile import unittest -from modelscope.metainfo import Trainers +from modelscope.metainfo import Preprocessors, Trainers +from modelscope.models import Model +from modelscope.pipelines import pipeline from modelscope.trainers import build_trainer +from modelscope.utils.constant import ModelFile, Tasks class TestFinetuneSequenceClassification(unittest.TestCase): + epoch_num = 1 + + sentence1 = '今天气温比昨天高么?' + sentence2 = '今天湿度比昨天高么?' def setUp(self): print(('Testing %s.%s' % (type(self).__name__, self._testMethodName))) @@ -40,15 +47,32 @@ class TestFinetuneSequenceClassification(unittest.TestCase): trainer.train() results_files = os.listdir(self.tmp_dir) self.assertIn(f'{trainer.timestamp}.log.json', results_files) - for i in range(10): + for i in range(self.epoch_num): self.assertIn(f'epoch_{i+1}.pth', results_files) + output_files = os.listdir( + os.path.join(self.tmp_dir, ModelFile.TRAIN_OUTPUT_DIR)) + self.assertIn(ModelFile.CONFIGURATION, output_files) + self.assertIn(ModelFile.TORCH_MODEL_BIN_FILE, output_files) + copy_src_files = os.listdir(trainer.model_dir) + + print(f'copy_src_files are {copy_src_files}') + print(f'output_files are {output_files}') + for item in copy_src_files: + if not item.startswith('.'): + self.assertIn(item, output_files) + + def pipeline_sentence_similarity(self, model_dir): + model = Model.from_pretrained(model_dir) + pipeline_ins = pipeline(task=Tasks.sentence_similarity, model=model) + print(pipeline_ins(input=(self.sentence1, self.sentence2))) + @unittest.skip def test_finetune_afqmc(self): def cfg_modify_fn(cfg): - cfg.task = 'sentence-similarity' - cfg['preprocessor'] = {'type': 'sen-sim-tokenizer'} + cfg.task = Tasks.sentence_similarity + cfg['preprocessor'] = {'type': Preprocessors.sen_sim_tokenizer} cfg.train.optimizer.lr = 2e-5 cfg['dataset'] = { 'train': { @@ -58,7 +82,7 @@ class TestFinetuneSequenceClassification(unittest.TestCase): 'label': 'label', } } - cfg.train.max_epochs = 10 + cfg.train.max_epochs = self.epoch_num cfg.train.lr_scheduler = { 'type': 'LinearLR', 'start_factor': 1.0, @@ -95,6 +119,9 @@ class TestFinetuneSequenceClassification(unittest.TestCase): eval_dataset=dataset['validation'], cfg_modify_fn=cfg_modify_fn) + output_dir = os.path.join(self.tmp_dir, ModelFile.TRAIN_OUTPUT_DIR) + self.pipeline_sentence_similarity(output_dir) + @unittest.skip def test_finetune_tnews(self): diff --git a/tests/trainers/test_trainer.py b/tests/trainers/test_trainer.py index 0259f804..be29844d 100644 --- a/tests/trainers/test_trainer.py +++ b/tests/trainers/test_trainer.py @@ -14,6 +14,7 @@ from torch.utils.data import IterableDataset from modelscope.metainfo import Metrics, Trainers from modelscope.metrics.builder import MetricKeys +from modelscope.models.base import Model from modelscope.trainers import build_trainer from modelscope.utils.constant import LogKeys, ModeKeys, ModelFile from modelscope.utils.test_utils import create_dummy_test_dataset, test_level @@ -35,7 +36,7 @@ dummy_dataset_big = create_dummy_test_dataset( np.random.random(size=(5, )), np.random.randint(0, 4, (1, )), 40) -class DummyModel(nn.Module): +class DummyModel(nn.Module, Model): def __init__(self): super().__init__() diff --git a/tests/trainers/test_trainer_gpu.py b/tests/trainers/test_trainer_gpu.py index 9781816d..3777772d 100644 --- a/tests/trainers/test_trainer_gpu.py +++ b/tests/trainers/test_trainer_gpu.py @@ -15,6 +15,7 @@ from torch.utils.data import IterableDataset from modelscope.metainfo import Metrics, Trainers from modelscope.metrics.builder import MetricKeys +from modelscope.models.base import Model from modelscope.trainers import EpochBasedTrainer, build_trainer from modelscope.utils.constant import LogKeys, ModeKeys, ModelFile from modelscope.utils.test_utils import (DistributedTestCase, @@ -37,7 +38,7 @@ dummy_dataset_big = create_dummy_test_dataset( np.random.random(size=(5, )), np.random.randint(0, 4, (1, )), 40) -class DummyModel(nn.Module): +class DummyModel(nn.Module, Model): def __init__(self): super().__init__() diff --git a/tests/trainers/test_trainer_with_nlp.py b/tests/trainers/test_trainer_with_nlp.py index 213b6b4f..2cf1c152 100644 --- a/tests/trainers/test_trainer_with_nlp.py +++ b/tests/trainers/test_trainer_with_nlp.py @@ -6,16 +6,20 @@ import unittest from modelscope.hub.snapshot_download import snapshot_download from modelscope.metainfo import Metrics +from modelscope.models.base import Model from modelscope.models.nlp.sequence_classification import \ SbertForSequenceClassification from modelscope.msdatasets import MsDataset +from modelscope.pipelines import pipeline from modelscope.trainers import build_trainer -from modelscope.utils.constant import ModelFile +from modelscope.utils.constant import ModelFile, Tasks from modelscope.utils.hub import read_config from modelscope.utils.test_utils import test_level class TestTrainerWithNlp(unittest.TestCase): + sentence1 = '今天气温比昨天高么?' + sentence2 = '今天湿度比昨天高么?' def setUp(self): print(('Testing %s.%s' % (type(self).__name__, self._testMethodName))) @@ -30,7 +34,7 @@ class TestTrainerWithNlp(unittest.TestCase): shutil.rmtree(self.tmp_dir) super().tearDown() - @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') def test_trainer(self): model_id = 'damo/nlp_structbert_sentence-similarity_chinese-base' kwargs = dict( @@ -47,6 +51,27 @@ class TestTrainerWithNlp(unittest.TestCase): for i in range(10): self.assertIn(f'epoch_{i+1}.pth', results_files) + output_files = os.listdir( + os.path.join(self.tmp_dir, ModelFile.TRAIN_OUTPUT_DIR)) + self.assertIn(ModelFile.CONFIGURATION, output_files) + self.assertIn(ModelFile.TORCH_MODEL_BIN_FILE, output_files) + copy_src_files = os.listdir(trainer.model_dir) + + print(f'copy_src_files are {copy_src_files}') + print(f'output_files are {output_files}') + for item in copy_src_files: + if not item.startswith('.'): + self.assertIn(item, output_files) + + def pipeline_sentence_similarity(model_dir): + model = Model.from_pretrained(model_dir) + pipeline_ins = pipeline( + task=Tasks.sentence_similarity, model=model) + print(pipeline_ins(input=(self.sentence1, self.sentence2))) + + output_dir = os.path.join(self.tmp_dir, ModelFile.TRAIN_OUTPUT_DIR) + pipeline_sentence_similarity(output_dir) + @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') def test_trainer_with_backbone_head(self): model_id = 'damo/nlp_structbert_sentiment-classification_chinese-base' diff --git a/tests/trainers/utils/test_inference.py b/tests/trainers/utils/test_inference.py index 87e5320e..23561734 100644 --- a/tests/trainers/utils/test_inference.py +++ b/tests/trainers/utils/test_inference.py @@ -11,6 +11,7 @@ from torch.utils.data import DataLoader from modelscope.metrics.builder import MetricKeys from modelscope.metrics.sequence_classification_metric import \ SequenceClassificationMetric +from modelscope.models.base import Model from modelscope.trainers.utils.inference import multi_gpu_test, single_gpu_test from modelscope.utils.test_utils import (DistributedTestCase, create_dummy_test_dataset, test_level) @@ -20,7 +21,7 @@ dummy_dataset = create_dummy_test_dataset( torch.rand((5, )), torch.randint(0, 4, (1, )), 20) -class DummyModel(nn.Module): +class DummyModel(nn.Module, Model): def __init__(self): super().__init__()