| @@ -1,2 +1,2 @@ | |||
| from .file import File | |||
| from .file import File, LocalStorage | |||
| from .io import dump, dumps, load | |||
| @@ -240,7 +240,7 @@ class File(object): | |||
| @staticmethod | |||
| def _get_storage(uri): | |||
| assert isinstance(uri, | |||
| str), f'uri should be str type, buf got {type(uri)}' | |||
| str), f'uri should be str type, but got {type(uri)}' | |||
| if '://' not in uri: | |||
| # local path | |||
| @@ -1,13 +1,12 @@ | |||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||
| import os | |||
| import os.path as osp | |||
| from abc import ABC, abstractmethod | |||
| from typing import Dict, Optional, Union | |||
| import numpy as np | |||
| from typing import Callable, Dict, List, Optional, Union | |||
| from modelscope.hub.snapshot_download import snapshot_download | |||
| from modelscope.models.builder import build_model | |||
| from modelscope.utils.checkpoint import save_pretrained | |||
| from modelscope.utils.config import Config | |||
| from modelscope.utils.constant import DEFAULT_MODEL_REVISION, ModelFile | |||
| from modelscope.utils.device import device_placement, verify_device | |||
| @@ -119,3 +118,28 @@ class Model(ABC): | |||
| if hasattr(cfg, 'pipeline'): | |||
| model.pipeline = cfg.pipeline | |||
| return model | |||
| def save_pretrained(self, | |||
| target_folder: Union[str, os.PathLike], | |||
| save_checkpoint_names: Union[str, List[str]] = None, | |||
| save_function: Callable = None, | |||
| config: Optional[dict] = None, | |||
| **kwargs): | |||
| """save the pretrained model, its configuration and other related files to a directory, so that it can be re-loaded | |||
| Args: | |||
| target_folder (Union[str, os.PathLike]): | |||
| Directory to which to save. Will be created if it doesn't exist. | |||
| save_checkpoint_names (Union[str, List[str]]): | |||
| The checkpoint names to be saved in the target_folder | |||
| save_function (Callable, optional): | |||
| The function to use to save the state dictionary. | |||
| config (Optional[dict], optional): | |||
| The config for the configuration.json, might not be identical with model.config | |||
| """ | |||
| save_pretrained(self, target_folder, save_checkpoint_names, | |||
| save_function, config, **kwargs) | |||
| @@ -1,10 +1,12 @@ | |||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||
| import os | |||
| import json | |||
| from modelscope import __version__ | |||
| from modelscope.metainfo import Hooks | |||
| from modelscope.utils.checkpoint import save_checkpoint | |||
| from modelscope.utils.constant import LogKeys | |||
| from modelscope.utils.constant import LogKeys, ModelFile | |||
| from modelscope.utils.logger import get_logger | |||
| from modelscope.utils.torch_utils import is_master | |||
| from .builder import HOOKS | |||
| @@ -73,6 +75,18 @@ class CheckpointHook(Hook): | |||
| self.save_dir, f'{LogKeys.ITER}_{trainer.iter + 1}.pth') | |||
| save_checkpoint(trainer.model, cur_save_name, trainer.optimizer) | |||
| self._save_pretrained(trainer) | |||
| def _save_pretrained(self, trainer): | |||
| if self.is_last_epoch(trainer) and self.by_epoch: | |||
| output_dir = os.path.join(self.save_dir, | |||
| ModelFile.TRAIN_OUTPUT_DIR) | |||
| trainer.model.save_pretrained( | |||
| output_dir, | |||
| ModelFile.TORCH_MODEL_BIN_FILE, | |||
| save_function=save_checkpoint, | |||
| config=trainer.cfg.to_dict()) | |||
| def after_train_iter(self, trainer): | |||
| if self.by_epoch: | |||
| @@ -166,3 +180,4 @@ class BestCkptSaverHook(CheckpointHook): | |||
| ) | |||
| save_checkpoint(trainer.model, cur_save_name, trainer.optimizer) | |||
| self._best_ckpt_file = cur_save_name | |||
| self._save_pretrained(trainer) | |||
| @@ -1,15 +1,23 @@ | |||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||
| import io | |||
| import os | |||
| import time | |||
| from collections import OrderedDict | |||
| from typing import Optional | |||
| from shutil import copytree, ignore_patterns, rmtree | |||
| from typing import Callable, List, Optional, Union | |||
| import json | |||
| import numpy as np | |||
| import torch | |||
| from torch.optim import Optimizer | |||
| from modelscope import __version__ | |||
| from modelscope.fileio import File | |||
| from modelscope.fileio import File, LocalStorage | |||
| from modelscope.utils.config import JSONIteratorEncoder | |||
| from modelscope.utils.constant import ConfigFields, ModelFile | |||
| storage = LocalStorage() | |||
| def weights_to_cpu(state_dict): | |||
| @@ -72,3 +80,76 @@ def save_checkpoint(model: torch.nn.Module, | |||
| with io.BytesIO() as f: | |||
| torch.save(checkpoint, f) | |||
| File.write(f.getvalue(), filename) | |||
| def save_pretrained(model, | |||
| target_folder: Union[str, os.PathLike], | |||
| save_checkpoint_name: str = None, | |||
| save_function: Callable = None, | |||
| config: Optional[dict] = None, | |||
| **kwargs): | |||
| """save the pretrained model, its configuration and other related files to a directory, so that it can be re-loaded | |||
| Args: | |||
| model (Model): Model whose params are to be saved. | |||
| target_folder (Union[str, os.PathLike]): | |||
| Directory to which to save. Will be created if it doesn't exist. | |||
| save_checkpoint_name (str): | |||
| The checkpoint name to be saved in the target_folder | |||
| save_function (Callable, optional): | |||
| The function to use to save the state dictionary. | |||
| config (Optional[dict], optional): | |||
| The config for the configuration.json, might not be identical with model.config | |||
| """ | |||
| if save_function is None or not isinstance(save_function, Callable): | |||
| raise Exception('A valid save function must be passed in') | |||
| if target_folder is None or os.path.isfile(target_folder): | |||
| raise ValueError( | |||
| f'Provided path ({target_folder}) should be a directory, not a file' | |||
| ) | |||
| if save_checkpoint_name is None: | |||
| raise Exception( | |||
| 'At least pass in one checkpoint name for saving method') | |||
| if config is None: | |||
| raise ValueError('Configuration is not valid') | |||
| # Clean the folder from a previous save | |||
| if os.path.exists(target_folder): | |||
| rmtree(target_folder) | |||
| # Single ckpt path, sharded ckpt logic will be added later | |||
| output_ckpt_path = os.path.join(target_folder, save_checkpoint_name) | |||
| # Save the files to be copied to the save directory, ignore the original ckpts and configuration | |||
| origin_file_to_be_ignored = [save_checkpoint_name] | |||
| ignore_file_set = set(origin_file_to_be_ignored) | |||
| ignore_file_set.add(ModelFile.CONFIGURATION) | |||
| ignore_file_set.add('.*') | |||
| if hasattr(model, 'model_dir') and model.model_dir is not None: | |||
| copytree( | |||
| model.model_dir, | |||
| target_folder, | |||
| ignore=ignore_patterns(*ignore_file_set)) | |||
| # Save the ckpt to the save directory | |||
| try: | |||
| save_function(model, output_ckpt_path) | |||
| except Exception as e: | |||
| raise Exception( | |||
| f'During saving checkpoints, the error of "{type(e).__name__} ' | |||
| f'with msg {e} throwed') | |||
| # Dump the config to the configuration.json | |||
| if ConfigFields.pipeline not in config: | |||
| config[ConfigFields.pipeline] = {'type': config[ConfigFields.task]} | |||
| cfg_str = json.dumps(config, cls=JSONIteratorEncoder) | |||
| config_file = os.path.join(target_folder, ModelFile.CONFIGURATION) | |||
| storage.write(cfg_str.encode(), config_file) | |||
| @@ -12,6 +12,7 @@ from pathlib import Path | |||
| from typing import Dict, Union | |||
| import addict | |||
| import json | |||
| from yapf.yapflib.yapf_api import FormatCode | |||
| from modelscope.utils.constant import ConfigFields, ModelFile | |||
| @@ -627,3 +628,20 @@ def check_config(cfg: Union[str, ConfigDict]): | |||
| check_attr(ConfigFields.model) | |||
| check_attr(ConfigFields.preprocessor) | |||
| check_attr(ConfigFields.evaluation) | |||
| class JSONIteratorEncoder(json.JSONEncoder): | |||
| """Implement this method in order that supporting arbitrary iterators, it returns | |||
| a serializable object for ``obj``, or calls the base implementation | |||
| (to raise a ``TypeError``). | |||
| """ | |||
| def default(self, obj): | |||
| try: | |||
| iterable = iter(obj) | |||
| except TypeError: | |||
| pass | |||
| else: | |||
| return list(iterable) | |||
| return json.JSONEncoder.default(self, obj) | |||
| @@ -211,6 +211,7 @@ class ModelFile(object): | |||
| VOCAB_FILE = 'vocab.txt' | |||
| ONNX_MODEL_FILE = 'model.onnx' | |||
| LABEL_MAPPING = 'label_mapping.json' | |||
| TRAIN_OUTPUT_DIR = 'output' | |||
| class ConfigFields(object): | |||
| @@ -10,7 +10,8 @@ from modelscope.hub.constants import Licenses, ModelVisibility | |||
| from modelscope.hub.file_download import model_file_download | |||
| from modelscope.hub.snapshot_download import snapshot_download | |||
| from modelscope.utils.config import Config | |||
| from modelscope.utils.constant import DEFAULT_MODEL_REVISION, ModelFile | |||
| from modelscope.utils.constant import (DEFAULT_MODEL_REVISION, ConfigFields, | |||
| ModelFile) | |||
| from .logger import get_logger | |||
| logger = get_logger(__name__) | |||
| @@ -119,8 +120,13 @@ def parse_label_mapping(model_dir): | |||
| if label2id is None: | |||
| config_path = os.path.join(model_dir, ModelFile.CONFIGURATION) | |||
| config = Config.from_file(config_path) | |||
| if hasattr(config, 'model') and hasattr(config.model, 'label2id'): | |||
| label2id = config.model.label2id | |||
| if hasattr(config, ConfigFields.model) and hasattr( | |||
| config[ConfigFields.model], 'label2id'): | |||
| label2id = config[ConfigFields.model].label2id | |||
| elif hasattr(config, ConfigFields.preprocessor) and hasattr( | |||
| config[ConfigFields.preprocessor], 'label2id'): | |||
| label2id = config[ConfigFields.preprocessor].label2id | |||
| if label2id is None: | |||
| config_path = os.path.join(model_dir, 'config.json') | |||
| config = Config.from_file(config_path) | |||
| @@ -11,6 +11,7 @@ import torch | |||
| from torch import nn | |||
| from modelscope.metainfo import Trainers | |||
| from modelscope.models.base import Model | |||
| from modelscope.trainers import build_trainer | |||
| from modelscope.utils.constant import LogKeys, ModelFile | |||
| from modelscope.utils.test_utils import create_dummy_test_dataset | |||
| @@ -19,7 +20,7 @@ dummy_dataset = create_dummy_test_dataset( | |||
| np.random.random(size=(5, )), np.random.randint(0, 4, (1, )), 20) | |||
| class DummyModel(nn.Module): | |||
| class DummyModel(nn.Module, Model): | |||
| def __init__(self): | |||
| super().__init__() | |||
| @@ -11,11 +11,14 @@ from torch import nn | |||
| from modelscope.metainfo import Trainers | |||
| from modelscope.metrics.builder import METRICS, MetricKeys | |||
| from modelscope.models.base import Model | |||
| from modelscope.trainers import build_trainer | |||
| from modelscope.utils.constant import LogKeys, ModelFile | |||
| from modelscope.utils.registry import default_group | |||
| from modelscope.utils.test_utils import create_dummy_test_dataset | |||
| SRC_DIR = os.path.dirname(__file__) | |||
| def create_dummy_metric(): | |||
| _global_iter = 0 | |||
| @@ -39,12 +42,13 @@ dummy_dataset = create_dummy_test_dataset( | |||
| np.random.random(size=(5, )), np.random.randint(0, 4, (1, )), 20) | |||
| class DummyModel(nn.Module): | |||
| class DummyModel(nn.Module, Model): | |||
| def __init__(self): | |||
| super().__init__() | |||
| self.linear = nn.Linear(5, 4) | |||
| self.bn = nn.BatchNorm1d(4) | |||
| self.model_dir = SRC_DIR | |||
| def forward(self, feat, labels): | |||
| x = self.linear(feat) | |||
| @@ -123,6 +127,14 @@ class CheckpointHookTest(unittest.TestCase): | |||
| self.assertIn(f'{LogKeys.EPOCH}_1.pth', results_files) | |||
| self.assertIn(f'{LogKeys.EPOCH}_2.pth', results_files) | |||
| output_files = os.listdir( | |||
| os.path.join(self.tmp_dir, ModelFile.TRAIN_OUTPUT_DIR)) | |||
| self.assertIn(ModelFile.CONFIGURATION, output_files) | |||
| self.assertIn(ModelFile.TORCH_MODEL_BIN_FILE, output_files) | |||
| copy_src_files = os.listdir(SRC_DIR) | |||
| self.assertIn(copy_src_files[0], output_files) | |||
| self.assertIn(copy_src_files[-1], output_files) | |||
| class BestCkptSaverHookTest(unittest.TestCase): | |||
| @@ -198,6 +210,14 @@ class BestCkptSaverHookTest(unittest.TestCase): | |||
| self.assertIn(f'best_{LogKeys.EPOCH}1_{MetricKeys.ACCURACY}0.1.pth', | |||
| results_files) | |||
| output_files = os.listdir( | |||
| os.path.join(self.tmp_dir, ModelFile.TRAIN_OUTPUT_DIR)) | |||
| self.assertIn(ModelFile.CONFIGURATION, output_files) | |||
| self.assertIn(ModelFile.TORCH_MODEL_BIN_FILE, output_files) | |||
| copy_src_files = os.listdir(SRC_DIR) | |||
| self.assertIn(copy_src_files[0], output_files) | |||
| self.assertIn(copy_src_files[-1], output_files) | |||
| if __name__ == '__main__': | |||
| unittest.main() | |||
| @@ -11,6 +11,7 @@ from torch import nn | |||
| from modelscope.metainfo import Trainers | |||
| from modelscope.metrics.builder import METRICS, MetricKeys | |||
| from modelscope.models.base import Model | |||
| from modelscope.trainers import build_trainer | |||
| from modelscope.utils.constant import ModelFile | |||
| from modelscope.utils.registry import default_group | |||
| @@ -34,7 +35,7 @@ dummy_dataset = create_dummy_test_dataset( | |||
| np.random.random(size=(5, )), np.random.randint(0, 4, (1, )), 20) | |||
| class DummyModel(nn.Module): | |||
| class DummyModel(nn.Module, Model): | |||
| def __init__(self): | |||
| super().__init__() | |||
| @@ -13,6 +13,7 @@ from torch.optim.lr_scheduler import MultiStepLR | |||
| from modelscope.metainfo import Trainers | |||
| from modelscope.metrics.builder import METRICS, MetricKeys | |||
| from modelscope.models.base import Model | |||
| from modelscope.trainers import build_trainer | |||
| from modelscope.utils.constant import LogKeys, ModelFile, TrainerStages | |||
| from modelscope.utils.registry import default_group | |||
| @@ -40,7 +41,7 @@ def create_dummy_metric(): | |||
| return {MetricKeys.ACCURACY: self._fake_acc_by_epoch[_global_iter]} | |||
| class DummyModel(nn.Module): | |||
| class DummyModel(nn.Module, Model): | |||
| def __init__(self): | |||
| super().__init__() | |||
| @@ -12,6 +12,7 @@ from torch.optim import SGD | |||
| from torch.optim.lr_scheduler import MultiStepLR | |||
| from modelscope.metainfo import Trainers | |||
| from modelscope.models.base import Model | |||
| from modelscope.trainers import build_trainer | |||
| from modelscope.utils.constant import ModelFile, TrainerStages | |||
| from modelscope.utils.test_utils import create_dummy_test_dataset | |||
| @@ -20,7 +21,7 @@ dummy_dataset = create_dummy_test_dataset( | |||
| np.random.random(size=(2, )), np.random.randint(0, 2, (1, )), 10) | |||
| class DummyModel(nn.Module): | |||
| class DummyModel(nn.Module, Model): | |||
| def __init__(self): | |||
| super().__init__() | |||
| @@ -12,6 +12,7 @@ from torch.optim import SGD | |||
| from torch.optim.lr_scheduler import MultiStepLR | |||
| from modelscope.metainfo import Trainers | |||
| from modelscope.models.base import Model | |||
| from modelscope.trainers import build_trainer | |||
| from modelscope.utils.constant import LogKeys, ModelFile, TrainerStages | |||
| from modelscope.utils.test_utils import create_dummy_test_dataset | |||
| @@ -20,7 +21,7 @@ dummy_dataset = create_dummy_test_dataset( | |||
| np.random.random(size=(5, )), np.random.randint(0, 4, (1, )), 10) | |||
| class DummyModel(nn.Module): | |||
| class DummyModel(nn.Module, Model): | |||
| def __init__(self): | |||
| super().__init__() | |||
| @@ -83,8 +84,8 @@ class IterTimerHookTest(unittest.TestCase): | |||
| trainer.train_dataset, **trainer.cfg.train.get('dataloader', {})) | |||
| trainer.register_optimizers_hook() | |||
| trainer.register_hook_from_cfg(trainer.cfg.train.hooks) | |||
| trainer.data_loader = train_dataloader | |||
| trainer.train_dataloader = train_dataloader | |||
| trainer.data_loader = train_dataloader | |||
| trainer.invoke_hook(TrainerStages.before_run) | |||
| for i in range(trainer._epoch, trainer._max_epochs): | |||
| trainer.invoke_hook(TrainerStages.before_train_epoch) | |||
| @@ -4,11 +4,18 @@ import shutil | |||
| import tempfile | |||
| import unittest | |||
| from modelscope.metainfo import Trainers | |||
| from modelscope.metainfo import Preprocessors, Trainers | |||
| from modelscope.models import Model | |||
| from modelscope.pipelines import pipeline | |||
| from modelscope.trainers import build_trainer | |||
| from modelscope.utils.constant import ModelFile, Tasks | |||
| class TestFinetuneSequenceClassification(unittest.TestCase): | |||
| epoch_num = 1 | |||
| sentence1 = '今天气温比昨天高么?' | |||
| sentence2 = '今天湿度比昨天高么?' | |||
| def setUp(self): | |||
| print(('Testing %s.%s' % (type(self).__name__, self._testMethodName))) | |||
| @@ -40,15 +47,32 @@ class TestFinetuneSequenceClassification(unittest.TestCase): | |||
| trainer.train() | |||
| results_files = os.listdir(self.tmp_dir) | |||
| self.assertIn(f'{trainer.timestamp}.log.json', results_files) | |||
| for i in range(10): | |||
| for i in range(self.epoch_num): | |||
| self.assertIn(f'epoch_{i+1}.pth', results_files) | |||
| output_files = os.listdir( | |||
| os.path.join(self.tmp_dir, ModelFile.TRAIN_OUTPUT_DIR)) | |||
| self.assertIn(ModelFile.CONFIGURATION, output_files) | |||
| self.assertIn(ModelFile.TORCH_MODEL_BIN_FILE, output_files) | |||
| copy_src_files = os.listdir(trainer.model_dir) | |||
| print(f'copy_src_files are {copy_src_files}') | |||
| print(f'output_files are {output_files}') | |||
| for item in copy_src_files: | |||
| if not item.startswith('.'): | |||
| self.assertIn(item, output_files) | |||
| def pipeline_sentence_similarity(self, model_dir): | |||
| model = Model.from_pretrained(model_dir) | |||
| pipeline_ins = pipeline(task=Tasks.sentence_similarity, model=model) | |||
| print(pipeline_ins(input=(self.sentence1, self.sentence2))) | |||
| @unittest.skip | |||
| def test_finetune_afqmc(self): | |||
| def cfg_modify_fn(cfg): | |||
| cfg.task = 'sentence-similarity' | |||
| cfg['preprocessor'] = {'type': 'sen-sim-tokenizer'} | |||
| cfg.task = Tasks.sentence_similarity | |||
| cfg['preprocessor'] = {'type': Preprocessors.sen_sim_tokenizer} | |||
| cfg.train.optimizer.lr = 2e-5 | |||
| cfg['dataset'] = { | |||
| 'train': { | |||
| @@ -58,7 +82,7 @@ class TestFinetuneSequenceClassification(unittest.TestCase): | |||
| 'label': 'label', | |||
| } | |||
| } | |||
| cfg.train.max_epochs = 10 | |||
| cfg.train.max_epochs = self.epoch_num | |||
| cfg.train.lr_scheduler = { | |||
| 'type': 'LinearLR', | |||
| 'start_factor': 1.0, | |||
| @@ -95,6 +119,9 @@ class TestFinetuneSequenceClassification(unittest.TestCase): | |||
| eval_dataset=dataset['validation'], | |||
| cfg_modify_fn=cfg_modify_fn) | |||
| output_dir = os.path.join(self.tmp_dir, ModelFile.TRAIN_OUTPUT_DIR) | |||
| self.pipeline_sentence_similarity(output_dir) | |||
| @unittest.skip | |||
| def test_finetune_tnews(self): | |||
| @@ -14,6 +14,7 @@ from torch.utils.data import IterableDataset | |||
| from modelscope.metainfo import Metrics, Trainers | |||
| from modelscope.metrics.builder import MetricKeys | |||
| from modelscope.models.base import Model | |||
| from modelscope.trainers import build_trainer | |||
| from modelscope.utils.constant import LogKeys, ModeKeys, ModelFile | |||
| from modelscope.utils.test_utils import create_dummy_test_dataset, test_level | |||
| @@ -35,7 +36,7 @@ dummy_dataset_big = create_dummy_test_dataset( | |||
| np.random.random(size=(5, )), np.random.randint(0, 4, (1, )), 40) | |||
| class DummyModel(nn.Module): | |||
| class DummyModel(nn.Module, Model): | |||
| def __init__(self): | |||
| super().__init__() | |||
| @@ -15,6 +15,7 @@ from torch.utils.data import IterableDataset | |||
| from modelscope.metainfo import Metrics, Trainers | |||
| from modelscope.metrics.builder import MetricKeys | |||
| from modelscope.models.base import Model | |||
| from modelscope.trainers import EpochBasedTrainer, build_trainer | |||
| from modelscope.utils.constant import LogKeys, ModeKeys, ModelFile | |||
| from modelscope.utils.test_utils import (DistributedTestCase, | |||
| @@ -37,7 +38,7 @@ dummy_dataset_big = create_dummy_test_dataset( | |||
| np.random.random(size=(5, )), np.random.randint(0, 4, (1, )), 40) | |||
| class DummyModel(nn.Module): | |||
| class DummyModel(nn.Module, Model): | |||
| def __init__(self): | |||
| super().__init__() | |||
| @@ -6,16 +6,20 @@ import unittest | |||
| from modelscope.hub.snapshot_download import snapshot_download | |||
| from modelscope.metainfo import Metrics | |||
| from modelscope.models.base import Model | |||
| from modelscope.models.nlp.sequence_classification import \ | |||
| SbertForSequenceClassification | |||
| from modelscope.msdatasets import MsDataset | |||
| from modelscope.pipelines import pipeline | |||
| from modelscope.trainers import build_trainer | |||
| from modelscope.utils.constant import ModelFile | |||
| from modelscope.utils.constant import ModelFile, Tasks | |||
| from modelscope.utils.hub import read_config | |||
| from modelscope.utils.test_utils import test_level | |||
| class TestTrainerWithNlp(unittest.TestCase): | |||
| sentence1 = '今天气温比昨天高么?' | |||
| sentence2 = '今天湿度比昨天高么?' | |||
| def setUp(self): | |||
| print(('Testing %s.%s' % (type(self).__name__, self._testMethodName))) | |||
| @@ -30,7 +34,7 @@ class TestTrainerWithNlp(unittest.TestCase): | |||
| shutil.rmtree(self.tmp_dir) | |||
| super().tearDown() | |||
| @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') | |||
| @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | |||
| def test_trainer(self): | |||
| model_id = 'damo/nlp_structbert_sentence-similarity_chinese-base' | |||
| kwargs = dict( | |||
| @@ -47,6 +51,27 @@ class TestTrainerWithNlp(unittest.TestCase): | |||
| for i in range(10): | |||
| self.assertIn(f'epoch_{i+1}.pth', results_files) | |||
| output_files = os.listdir( | |||
| os.path.join(self.tmp_dir, ModelFile.TRAIN_OUTPUT_DIR)) | |||
| self.assertIn(ModelFile.CONFIGURATION, output_files) | |||
| self.assertIn(ModelFile.TORCH_MODEL_BIN_FILE, output_files) | |||
| copy_src_files = os.listdir(trainer.model_dir) | |||
| print(f'copy_src_files are {copy_src_files}') | |||
| print(f'output_files are {output_files}') | |||
| for item in copy_src_files: | |||
| if not item.startswith('.'): | |||
| self.assertIn(item, output_files) | |||
| def pipeline_sentence_similarity(model_dir): | |||
| model = Model.from_pretrained(model_dir) | |||
| pipeline_ins = pipeline( | |||
| task=Tasks.sentence_similarity, model=model) | |||
| print(pipeline_ins(input=(self.sentence1, self.sentence2))) | |||
| output_dir = os.path.join(self.tmp_dir, ModelFile.TRAIN_OUTPUT_DIR) | |||
| pipeline_sentence_similarity(output_dir) | |||
| @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') | |||
| def test_trainer_with_backbone_head(self): | |||
| model_id = 'damo/nlp_structbert_sentiment-classification_chinese-base' | |||
| @@ -11,6 +11,7 @@ from torch.utils.data import DataLoader | |||
| from modelscope.metrics.builder import MetricKeys | |||
| from modelscope.metrics.sequence_classification_metric import \ | |||
| SequenceClassificationMetric | |||
| from modelscope.models.base import Model | |||
| from modelscope.trainers.utils.inference import multi_gpu_test, single_gpu_test | |||
| from modelscope.utils.test_utils import (DistributedTestCase, | |||
| create_dummy_test_dataset, test_level) | |||
| @@ -20,7 +21,7 @@ dummy_dataset = create_dummy_test_dataset( | |||
| torch.rand((5, )), torch.randint(0, 4, (1, )), 20) | |||
| class DummyModel(nn.Module): | |||
| class DummyModel(nn.Module, Model): | |||
| def __init__(self): | |||
| super().__init__() | |||