| @@ -1,2 +1,2 @@ | |||||
| from .file import File | |||||
| from .file import File, LocalStorage | |||||
| from .io import dump, dumps, load | from .io import dump, dumps, load | ||||
| @@ -240,7 +240,7 @@ class File(object): | |||||
| @staticmethod | @staticmethod | ||||
| def _get_storage(uri): | def _get_storage(uri): | ||||
| assert isinstance(uri, | assert isinstance(uri, | ||||
| str), f'uri should be str type, buf got {type(uri)}' | |||||
| str), f'uri should be str type, but got {type(uri)}' | |||||
| if '://' not in uri: | if '://' not in uri: | ||||
| # local path | # local path | ||||
| @@ -1,13 +1,12 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | # Copyright (c) Alibaba, Inc. and its affiliates. | ||||
| import os | |||||
| import os.path as osp | import os.path as osp | ||||
| from abc import ABC, abstractmethod | from abc import ABC, abstractmethod | ||||
| from typing import Dict, Optional, Union | |||||
| import numpy as np | |||||
| from typing import Callable, Dict, List, Optional, Union | |||||
| from modelscope.hub.snapshot_download import snapshot_download | from modelscope.hub.snapshot_download import snapshot_download | ||||
| from modelscope.models.builder import build_model | from modelscope.models.builder import build_model | ||||
| from modelscope.utils.checkpoint import save_pretrained | |||||
| from modelscope.utils.config import Config | from modelscope.utils.config import Config | ||||
| from modelscope.utils.constant import DEFAULT_MODEL_REVISION, ModelFile | from modelscope.utils.constant import DEFAULT_MODEL_REVISION, ModelFile | ||||
| from modelscope.utils.device import device_placement, verify_device | from modelscope.utils.device import device_placement, verify_device | ||||
| @@ -119,3 +118,28 @@ class Model(ABC): | |||||
| if hasattr(cfg, 'pipeline'): | if hasattr(cfg, 'pipeline'): | ||||
| model.pipeline = cfg.pipeline | model.pipeline = cfg.pipeline | ||||
| return model | return model | ||||
| def save_pretrained(self, | |||||
| target_folder: Union[str, os.PathLike], | |||||
| save_checkpoint_names: Union[str, List[str]] = None, | |||||
| save_function: Callable = None, | |||||
| config: Optional[dict] = None, | |||||
| **kwargs): | |||||
| """save the pretrained model, its configuration and other related files to a directory, so that it can be re-loaded | |||||
| Args: | |||||
| target_folder (Union[str, os.PathLike]): | |||||
| Directory to which to save. Will be created if it doesn't exist. | |||||
| save_checkpoint_names (Union[str, List[str]]): | |||||
| The checkpoint names to be saved in the target_folder | |||||
| save_function (Callable, optional): | |||||
| The function to use to save the state dictionary. | |||||
| config (Optional[dict], optional): | |||||
| The config for the configuration.json, might not be identical with model.config | |||||
| """ | |||||
| save_pretrained(self, target_folder, save_checkpoint_names, | |||||
| save_function, config, **kwargs) | |||||
| @@ -1,10 +1,12 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | # Copyright (c) Alibaba, Inc. and its affiliates. | ||||
| import os | import os | ||||
| import json | |||||
| from modelscope import __version__ | from modelscope import __version__ | ||||
| from modelscope.metainfo import Hooks | from modelscope.metainfo import Hooks | ||||
| from modelscope.utils.checkpoint import save_checkpoint | from modelscope.utils.checkpoint import save_checkpoint | ||||
| from modelscope.utils.constant import LogKeys | |||||
| from modelscope.utils.constant import LogKeys, ModelFile | |||||
| from modelscope.utils.logger import get_logger | from modelscope.utils.logger import get_logger | ||||
| from modelscope.utils.torch_utils import is_master | from modelscope.utils.torch_utils import is_master | ||||
| from .builder import HOOKS | from .builder import HOOKS | ||||
| @@ -73,6 +75,18 @@ class CheckpointHook(Hook): | |||||
| self.save_dir, f'{LogKeys.ITER}_{trainer.iter + 1}.pth') | self.save_dir, f'{LogKeys.ITER}_{trainer.iter + 1}.pth') | ||||
| save_checkpoint(trainer.model, cur_save_name, trainer.optimizer) | save_checkpoint(trainer.model, cur_save_name, trainer.optimizer) | ||||
| self._save_pretrained(trainer) | |||||
| def _save_pretrained(self, trainer): | |||||
| if self.is_last_epoch(trainer) and self.by_epoch: | |||||
| output_dir = os.path.join(self.save_dir, | |||||
| ModelFile.TRAIN_OUTPUT_DIR) | |||||
| trainer.model.save_pretrained( | |||||
| output_dir, | |||||
| ModelFile.TORCH_MODEL_BIN_FILE, | |||||
| save_function=save_checkpoint, | |||||
| config=trainer.cfg.to_dict()) | |||||
| def after_train_iter(self, trainer): | def after_train_iter(self, trainer): | ||||
| if self.by_epoch: | if self.by_epoch: | ||||
| @@ -166,3 +180,4 @@ class BestCkptSaverHook(CheckpointHook): | |||||
| ) | ) | ||||
| save_checkpoint(trainer.model, cur_save_name, trainer.optimizer) | save_checkpoint(trainer.model, cur_save_name, trainer.optimizer) | ||||
| self._best_ckpt_file = cur_save_name | self._best_ckpt_file = cur_save_name | ||||
| self._save_pretrained(trainer) | |||||
| @@ -1,15 +1,23 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | # Copyright (c) Alibaba, Inc. and its affiliates. | ||||
| import io | import io | ||||
| import os | |||||
| import time | import time | ||||
| from collections import OrderedDict | from collections import OrderedDict | ||||
| from typing import Optional | |||||
| from shutil import copytree, ignore_patterns, rmtree | |||||
| from typing import Callable, List, Optional, Union | |||||
| import json | |||||
| import numpy as np | |||||
| import torch | import torch | ||||
| from torch.optim import Optimizer | from torch.optim import Optimizer | ||||
| from modelscope import __version__ | from modelscope import __version__ | ||||
| from modelscope.fileio import File | |||||
| from modelscope.fileio import File, LocalStorage | |||||
| from modelscope.utils.config import JSONIteratorEncoder | |||||
| from modelscope.utils.constant import ConfigFields, ModelFile | |||||
| storage = LocalStorage() | |||||
| def weights_to_cpu(state_dict): | def weights_to_cpu(state_dict): | ||||
| @@ -72,3 +80,76 @@ def save_checkpoint(model: torch.nn.Module, | |||||
| with io.BytesIO() as f: | with io.BytesIO() as f: | ||||
| torch.save(checkpoint, f) | torch.save(checkpoint, f) | ||||
| File.write(f.getvalue(), filename) | File.write(f.getvalue(), filename) | ||||
| def save_pretrained(model, | |||||
| target_folder: Union[str, os.PathLike], | |||||
| save_checkpoint_name: str = None, | |||||
| save_function: Callable = None, | |||||
| config: Optional[dict] = None, | |||||
| **kwargs): | |||||
| """save the pretrained model, its configuration and other related files to a directory, so that it can be re-loaded | |||||
| Args: | |||||
| model (Model): Model whose params are to be saved. | |||||
| target_folder (Union[str, os.PathLike]): | |||||
| Directory to which to save. Will be created if it doesn't exist. | |||||
| save_checkpoint_name (str): | |||||
| The checkpoint name to be saved in the target_folder | |||||
| save_function (Callable, optional): | |||||
| The function to use to save the state dictionary. | |||||
| config (Optional[dict], optional): | |||||
| The config for the configuration.json, might not be identical with model.config | |||||
| """ | |||||
| if save_function is None or not isinstance(save_function, Callable): | |||||
| raise Exception('A valid save function must be passed in') | |||||
| if target_folder is None or os.path.isfile(target_folder): | |||||
| raise ValueError( | |||||
| f'Provided path ({target_folder}) should be a directory, not a file' | |||||
| ) | |||||
| if save_checkpoint_name is None: | |||||
| raise Exception( | |||||
| 'At least pass in one checkpoint name for saving method') | |||||
| if config is None: | |||||
| raise ValueError('Configuration is not valid') | |||||
| # Clean the folder from a previous save | |||||
| if os.path.exists(target_folder): | |||||
| rmtree(target_folder) | |||||
| # Single ckpt path, sharded ckpt logic will be added later | |||||
| output_ckpt_path = os.path.join(target_folder, save_checkpoint_name) | |||||
| # Save the files to be copied to the save directory, ignore the original ckpts and configuration | |||||
| origin_file_to_be_ignored = [save_checkpoint_name] | |||||
| ignore_file_set = set(origin_file_to_be_ignored) | |||||
| ignore_file_set.add(ModelFile.CONFIGURATION) | |||||
| ignore_file_set.add('.*') | |||||
| if hasattr(model, 'model_dir') and model.model_dir is not None: | |||||
| copytree( | |||||
| model.model_dir, | |||||
| target_folder, | |||||
| ignore=ignore_patterns(*ignore_file_set)) | |||||
| # Save the ckpt to the save directory | |||||
| try: | |||||
| save_function(model, output_ckpt_path) | |||||
| except Exception as e: | |||||
| raise Exception( | |||||
| f'During saving checkpoints, the error of "{type(e).__name__} ' | |||||
| f'with msg {e} throwed') | |||||
| # Dump the config to the configuration.json | |||||
| if ConfigFields.pipeline not in config: | |||||
| config[ConfigFields.pipeline] = {'type': config[ConfigFields.task]} | |||||
| cfg_str = json.dumps(config, cls=JSONIteratorEncoder) | |||||
| config_file = os.path.join(target_folder, ModelFile.CONFIGURATION) | |||||
| storage.write(cfg_str.encode(), config_file) | |||||
| @@ -12,6 +12,7 @@ from pathlib import Path | |||||
| from typing import Dict, Union | from typing import Dict, Union | ||||
| import addict | import addict | ||||
| import json | |||||
| from yapf.yapflib.yapf_api import FormatCode | from yapf.yapflib.yapf_api import FormatCode | ||||
| from modelscope.utils.constant import ConfigFields, ModelFile | from modelscope.utils.constant import ConfigFields, ModelFile | ||||
| @@ -627,3 +628,20 @@ def check_config(cfg: Union[str, ConfigDict]): | |||||
| check_attr(ConfigFields.model) | check_attr(ConfigFields.model) | ||||
| check_attr(ConfigFields.preprocessor) | check_attr(ConfigFields.preprocessor) | ||||
| check_attr(ConfigFields.evaluation) | check_attr(ConfigFields.evaluation) | ||||
| class JSONIteratorEncoder(json.JSONEncoder): | |||||
| """Implement this method in order that supporting arbitrary iterators, it returns | |||||
| a serializable object for ``obj``, or calls the base implementation | |||||
| (to raise a ``TypeError``). | |||||
| """ | |||||
| def default(self, obj): | |||||
| try: | |||||
| iterable = iter(obj) | |||||
| except TypeError: | |||||
| pass | |||||
| else: | |||||
| return list(iterable) | |||||
| return json.JSONEncoder.default(self, obj) | |||||
| @@ -211,6 +211,7 @@ class ModelFile(object): | |||||
| VOCAB_FILE = 'vocab.txt' | VOCAB_FILE = 'vocab.txt' | ||||
| ONNX_MODEL_FILE = 'model.onnx' | ONNX_MODEL_FILE = 'model.onnx' | ||||
| LABEL_MAPPING = 'label_mapping.json' | LABEL_MAPPING = 'label_mapping.json' | ||||
| TRAIN_OUTPUT_DIR = 'output' | |||||
| class ConfigFields(object): | class ConfigFields(object): | ||||
| @@ -10,7 +10,8 @@ from modelscope.hub.constants import Licenses, ModelVisibility | |||||
| from modelscope.hub.file_download import model_file_download | from modelscope.hub.file_download import model_file_download | ||||
| from modelscope.hub.snapshot_download import snapshot_download | from modelscope.hub.snapshot_download import snapshot_download | ||||
| from modelscope.utils.config import Config | from modelscope.utils.config import Config | ||||
| from modelscope.utils.constant import DEFAULT_MODEL_REVISION, ModelFile | |||||
| from modelscope.utils.constant import (DEFAULT_MODEL_REVISION, ConfigFields, | |||||
| ModelFile) | |||||
| from .logger import get_logger | from .logger import get_logger | ||||
| logger = get_logger(__name__) | logger = get_logger(__name__) | ||||
| @@ -119,8 +120,13 @@ def parse_label_mapping(model_dir): | |||||
| if label2id is None: | if label2id is None: | ||||
| config_path = os.path.join(model_dir, ModelFile.CONFIGURATION) | config_path = os.path.join(model_dir, ModelFile.CONFIGURATION) | ||||
| config = Config.from_file(config_path) | config = Config.from_file(config_path) | ||||
| if hasattr(config, 'model') and hasattr(config.model, 'label2id'): | |||||
| label2id = config.model.label2id | |||||
| if hasattr(config, ConfigFields.model) and hasattr( | |||||
| config[ConfigFields.model], 'label2id'): | |||||
| label2id = config[ConfigFields.model].label2id | |||||
| elif hasattr(config, ConfigFields.preprocessor) and hasattr( | |||||
| config[ConfigFields.preprocessor], 'label2id'): | |||||
| label2id = config[ConfigFields.preprocessor].label2id | |||||
| if label2id is None: | if label2id is None: | ||||
| config_path = os.path.join(model_dir, 'config.json') | config_path = os.path.join(model_dir, 'config.json') | ||||
| config = Config.from_file(config_path) | config = Config.from_file(config_path) | ||||
| @@ -11,6 +11,7 @@ import torch | |||||
| from torch import nn | from torch import nn | ||||
| from modelscope.metainfo import Trainers | from modelscope.metainfo import Trainers | ||||
| from modelscope.models.base import Model | |||||
| from modelscope.trainers import build_trainer | from modelscope.trainers import build_trainer | ||||
| from modelscope.utils.constant import LogKeys, ModelFile | from modelscope.utils.constant import LogKeys, ModelFile | ||||
| from modelscope.utils.test_utils import create_dummy_test_dataset | from modelscope.utils.test_utils import create_dummy_test_dataset | ||||
| @@ -19,7 +20,7 @@ dummy_dataset = create_dummy_test_dataset( | |||||
| np.random.random(size=(5, )), np.random.randint(0, 4, (1, )), 20) | np.random.random(size=(5, )), np.random.randint(0, 4, (1, )), 20) | ||||
| class DummyModel(nn.Module): | |||||
| class DummyModel(nn.Module, Model): | |||||
| def __init__(self): | def __init__(self): | ||||
| super().__init__() | super().__init__() | ||||
| @@ -11,11 +11,14 @@ from torch import nn | |||||
| from modelscope.metainfo import Trainers | from modelscope.metainfo import Trainers | ||||
| from modelscope.metrics.builder import METRICS, MetricKeys | from modelscope.metrics.builder import METRICS, MetricKeys | ||||
| from modelscope.models.base import Model | |||||
| from modelscope.trainers import build_trainer | from modelscope.trainers import build_trainer | ||||
| from modelscope.utils.constant import LogKeys, ModelFile | from modelscope.utils.constant import LogKeys, ModelFile | ||||
| from modelscope.utils.registry import default_group | from modelscope.utils.registry import default_group | ||||
| from modelscope.utils.test_utils import create_dummy_test_dataset | from modelscope.utils.test_utils import create_dummy_test_dataset | ||||
| SRC_DIR = os.path.dirname(__file__) | |||||
| def create_dummy_metric(): | def create_dummy_metric(): | ||||
| _global_iter = 0 | _global_iter = 0 | ||||
| @@ -39,12 +42,13 @@ dummy_dataset = create_dummy_test_dataset( | |||||
| np.random.random(size=(5, )), np.random.randint(0, 4, (1, )), 20) | np.random.random(size=(5, )), np.random.randint(0, 4, (1, )), 20) | ||||
| class DummyModel(nn.Module): | |||||
| class DummyModel(nn.Module, Model): | |||||
| def __init__(self): | def __init__(self): | ||||
| super().__init__() | super().__init__() | ||||
| self.linear = nn.Linear(5, 4) | self.linear = nn.Linear(5, 4) | ||||
| self.bn = nn.BatchNorm1d(4) | self.bn = nn.BatchNorm1d(4) | ||||
| self.model_dir = SRC_DIR | |||||
| def forward(self, feat, labels): | def forward(self, feat, labels): | ||||
| x = self.linear(feat) | x = self.linear(feat) | ||||
| @@ -123,6 +127,14 @@ class CheckpointHookTest(unittest.TestCase): | |||||
| self.assertIn(f'{LogKeys.EPOCH}_1.pth', results_files) | self.assertIn(f'{LogKeys.EPOCH}_1.pth', results_files) | ||||
| self.assertIn(f'{LogKeys.EPOCH}_2.pth', results_files) | self.assertIn(f'{LogKeys.EPOCH}_2.pth', results_files) | ||||
| output_files = os.listdir( | |||||
| os.path.join(self.tmp_dir, ModelFile.TRAIN_OUTPUT_DIR)) | |||||
| self.assertIn(ModelFile.CONFIGURATION, output_files) | |||||
| self.assertIn(ModelFile.TORCH_MODEL_BIN_FILE, output_files) | |||||
| copy_src_files = os.listdir(SRC_DIR) | |||||
| self.assertIn(copy_src_files[0], output_files) | |||||
| self.assertIn(copy_src_files[-1], output_files) | |||||
| class BestCkptSaverHookTest(unittest.TestCase): | class BestCkptSaverHookTest(unittest.TestCase): | ||||
| @@ -198,6 +210,14 @@ class BestCkptSaverHookTest(unittest.TestCase): | |||||
| self.assertIn(f'best_{LogKeys.EPOCH}1_{MetricKeys.ACCURACY}0.1.pth', | self.assertIn(f'best_{LogKeys.EPOCH}1_{MetricKeys.ACCURACY}0.1.pth', | ||||
| results_files) | results_files) | ||||
| output_files = os.listdir( | |||||
| os.path.join(self.tmp_dir, ModelFile.TRAIN_OUTPUT_DIR)) | |||||
| self.assertIn(ModelFile.CONFIGURATION, output_files) | |||||
| self.assertIn(ModelFile.TORCH_MODEL_BIN_FILE, output_files) | |||||
| copy_src_files = os.listdir(SRC_DIR) | |||||
| self.assertIn(copy_src_files[0], output_files) | |||||
| self.assertIn(copy_src_files[-1], output_files) | |||||
| if __name__ == '__main__': | if __name__ == '__main__': | ||||
| unittest.main() | unittest.main() | ||||
| @@ -11,6 +11,7 @@ from torch import nn | |||||
| from modelscope.metainfo import Trainers | from modelscope.metainfo import Trainers | ||||
| from modelscope.metrics.builder import METRICS, MetricKeys | from modelscope.metrics.builder import METRICS, MetricKeys | ||||
| from modelscope.models.base import Model | |||||
| from modelscope.trainers import build_trainer | from modelscope.trainers import build_trainer | ||||
| from modelscope.utils.constant import ModelFile | from modelscope.utils.constant import ModelFile | ||||
| from modelscope.utils.registry import default_group | from modelscope.utils.registry import default_group | ||||
| @@ -34,7 +35,7 @@ dummy_dataset = create_dummy_test_dataset( | |||||
| np.random.random(size=(5, )), np.random.randint(0, 4, (1, )), 20) | np.random.random(size=(5, )), np.random.randint(0, 4, (1, )), 20) | ||||
| class DummyModel(nn.Module): | |||||
| class DummyModel(nn.Module, Model): | |||||
| def __init__(self): | def __init__(self): | ||||
| super().__init__() | super().__init__() | ||||
| @@ -13,6 +13,7 @@ from torch.optim.lr_scheduler import MultiStepLR | |||||
| from modelscope.metainfo import Trainers | from modelscope.metainfo import Trainers | ||||
| from modelscope.metrics.builder import METRICS, MetricKeys | from modelscope.metrics.builder import METRICS, MetricKeys | ||||
| from modelscope.models.base import Model | |||||
| from modelscope.trainers import build_trainer | from modelscope.trainers import build_trainer | ||||
| from modelscope.utils.constant import LogKeys, ModelFile, TrainerStages | from modelscope.utils.constant import LogKeys, ModelFile, TrainerStages | ||||
| from modelscope.utils.registry import default_group | from modelscope.utils.registry import default_group | ||||
| @@ -40,7 +41,7 @@ def create_dummy_metric(): | |||||
| return {MetricKeys.ACCURACY: self._fake_acc_by_epoch[_global_iter]} | return {MetricKeys.ACCURACY: self._fake_acc_by_epoch[_global_iter]} | ||||
| class DummyModel(nn.Module): | |||||
| class DummyModel(nn.Module, Model): | |||||
| def __init__(self): | def __init__(self): | ||||
| super().__init__() | super().__init__() | ||||
| @@ -12,6 +12,7 @@ from torch.optim import SGD | |||||
| from torch.optim.lr_scheduler import MultiStepLR | from torch.optim.lr_scheduler import MultiStepLR | ||||
| from modelscope.metainfo import Trainers | from modelscope.metainfo import Trainers | ||||
| from modelscope.models.base import Model | |||||
| from modelscope.trainers import build_trainer | from modelscope.trainers import build_trainer | ||||
| from modelscope.utils.constant import ModelFile, TrainerStages | from modelscope.utils.constant import ModelFile, TrainerStages | ||||
| from modelscope.utils.test_utils import create_dummy_test_dataset | from modelscope.utils.test_utils import create_dummy_test_dataset | ||||
| @@ -20,7 +21,7 @@ dummy_dataset = create_dummy_test_dataset( | |||||
| np.random.random(size=(2, )), np.random.randint(0, 2, (1, )), 10) | np.random.random(size=(2, )), np.random.randint(0, 2, (1, )), 10) | ||||
| class DummyModel(nn.Module): | |||||
| class DummyModel(nn.Module, Model): | |||||
| def __init__(self): | def __init__(self): | ||||
| super().__init__() | super().__init__() | ||||
| @@ -12,6 +12,7 @@ from torch.optim import SGD | |||||
| from torch.optim.lr_scheduler import MultiStepLR | from torch.optim.lr_scheduler import MultiStepLR | ||||
| from modelscope.metainfo import Trainers | from modelscope.metainfo import Trainers | ||||
| from modelscope.models.base import Model | |||||
| from modelscope.trainers import build_trainer | from modelscope.trainers import build_trainer | ||||
| from modelscope.utils.constant import LogKeys, ModelFile, TrainerStages | from modelscope.utils.constant import LogKeys, ModelFile, TrainerStages | ||||
| from modelscope.utils.test_utils import create_dummy_test_dataset | from modelscope.utils.test_utils import create_dummy_test_dataset | ||||
| @@ -20,7 +21,7 @@ dummy_dataset = create_dummy_test_dataset( | |||||
| np.random.random(size=(5, )), np.random.randint(0, 4, (1, )), 10) | np.random.random(size=(5, )), np.random.randint(0, 4, (1, )), 10) | ||||
| class DummyModel(nn.Module): | |||||
| class DummyModel(nn.Module, Model): | |||||
| def __init__(self): | def __init__(self): | ||||
| super().__init__() | super().__init__() | ||||
| @@ -83,8 +84,8 @@ class IterTimerHookTest(unittest.TestCase): | |||||
| trainer.train_dataset, **trainer.cfg.train.get('dataloader', {})) | trainer.train_dataset, **trainer.cfg.train.get('dataloader', {})) | ||||
| trainer.register_optimizers_hook() | trainer.register_optimizers_hook() | ||||
| trainer.register_hook_from_cfg(trainer.cfg.train.hooks) | trainer.register_hook_from_cfg(trainer.cfg.train.hooks) | ||||
| trainer.data_loader = train_dataloader | |||||
| trainer.train_dataloader = train_dataloader | trainer.train_dataloader = train_dataloader | ||||
| trainer.data_loader = train_dataloader | |||||
| trainer.invoke_hook(TrainerStages.before_run) | trainer.invoke_hook(TrainerStages.before_run) | ||||
| for i in range(trainer._epoch, trainer._max_epochs): | for i in range(trainer._epoch, trainer._max_epochs): | ||||
| trainer.invoke_hook(TrainerStages.before_train_epoch) | trainer.invoke_hook(TrainerStages.before_train_epoch) | ||||
| @@ -4,11 +4,18 @@ import shutil | |||||
| import tempfile | import tempfile | ||||
| import unittest | import unittest | ||||
| from modelscope.metainfo import Trainers | |||||
| from modelscope.metainfo import Preprocessors, Trainers | |||||
| from modelscope.models import Model | |||||
| from modelscope.pipelines import pipeline | |||||
| from modelscope.trainers import build_trainer | from modelscope.trainers import build_trainer | ||||
| from modelscope.utils.constant import ModelFile, Tasks | |||||
| class TestFinetuneSequenceClassification(unittest.TestCase): | class TestFinetuneSequenceClassification(unittest.TestCase): | ||||
| epoch_num = 1 | |||||
| sentence1 = '今天气温比昨天高么?' | |||||
| sentence2 = '今天湿度比昨天高么?' | |||||
| def setUp(self): | def setUp(self): | ||||
| print(('Testing %s.%s' % (type(self).__name__, self._testMethodName))) | print(('Testing %s.%s' % (type(self).__name__, self._testMethodName))) | ||||
| @@ -40,15 +47,32 @@ class TestFinetuneSequenceClassification(unittest.TestCase): | |||||
| trainer.train() | trainer.train() | ||||
| results_files = os.listdir(self.tmp_dir) | results_files = os.listdir(self.tmp_dir) | ||||
| self.assertIn(f'{trainer.timestamp}.log.json', results_files) | self.assertIn(f'{trainer.timestamp}.log.json', results_files) | ||||
| for i in range(10): | |||||
| for i in range(self.epoch_num): | |||||
| self.assertIn(f'epoch_{i+1}.pth', results_files) | self.assertIn(f'epoch_{i+1}.pth', results_files) | ||||
| output_files = os.listdir( | |||||
| os.path.join(self.tmp_dir, ModelFile.TRAIN_OUTPUT_DIR)) | |||||
| self.assertIn(ModelFile.CONFIGURATION, output_files) | |||||
| self.assertIn(ModelFile.TORCH_MODEL_BIN_FILE, output_files) | |||||
| copy_src_files = os.listdir(trainer.model_dir) | |||||
| print(f'copy_src_files are {copy_src_files}') | |||||
| print(f'output_files are {output_files}') | |||||
| for item in copy_src_files: | |||||
| if not item.startswith('.'): | |||||
| self.assertIn(item, output_files) | |||||
| def pipeline_sentence_similarity(self, model_dir): | |||||
| model = Model.from_pretrained(model_dir) | |||||
| pipeline_ins = pipeline(task=Tasks.sentence_similarity, model=model) | |||||
| print(pipeline_ins(input=(self.sentence1, self.sentence2))) | |||||
| @unittest.skip | @unittest.skip | ||||
| def test_finetune_afqmc(self): | def test_finetune_afqmc(self): | ||||
| def cfg_modify_fn(cfg): | def cfg_modify_fn(cfg): | ||||
| cfg.task = 'sentence-similarity' | |||||
| cfg['preprocessor'] = {'type': 'sen-sim-tokenizer'} | |||||
| cfg.task = Tasks.sentence_similarity | |||||
| cfg['preprocessor'] = {'type': Preprocessors.sen_sim_tokenizer} | |||||
| cfg.train.optimizer.lr = 2e-5 | cfg.train.optimizer.lr = 2e-5 | ||||
| cfg['dataset'] = { | cfg['dataset'] = { | ||||
| 'train': { | 'train': { | ||||
| @@ -58,7 +82,7 @@ class TestFinetuneSequenceClassification(unittest.TestCase): | |||||
| 'label': 'label', | 'label': 'label', | ||||
| } | } | ||||
| } | } | ||||
| cfg.train.max_epochs = 10 | |||||
| cfg.train.max_epochs = self.epoch_num | |||||
| cfg.train.lr_scheduler = { | cfg.train.lr_scheduler = { | ||||
| 'type': 'LinearLR', | 'type': 'LinearLR', | ||||
| 'start_factor': 1.0, | 'start_factor': 1.0, | ||||
| @@ -95,6 +119,9 @@ class TestFinetuneSequenceClassification(unittest.TestCase): | |||||
| eval_dataset=dataset['validation'], | eval_dataset=dataset['validation'], | ||||
| cfg_modify_fn=cfg_modify_fn) | cfg_modify_fn=cfg_modify_fn) | ||||
| output_dir = os.path.join(self.tmp_dir, ModelFile.TRAIN_OUTPUT_DIR) | |||||
| self.pipeline_sentence_similarity(output_dir) | |||||
| @unittest.skip | @unittest.skip | ||||
| def test_finetune_tnews(self): | def test_finetune_tnews(self): | ||||
| @@ -14,6 +14,7 @@ from torch.utils.data import IterableDataset | |||||
| from modelscope.metainfo import Metrics, Trainers | from modelscope.metainfo import Metrics, Trainers | ||||
| from modelscope.metrics.builder import MetricKeys | from modelscope.metrics.builder import MetricKeys | ||||
| from modelscope.models.base import Model | |||||
| from modelscope.trainers import build_trainer | from modelscope.trainers import build_trainer | ||||
| from modelscope.utils.constant import LogKeys, ModeKeys, ModelFile | from modelscope.utils.constant import LogKeys, ModeKeys, ModelFile | ||||
| from modelscope.utils.test_utils import create_dummy_test_dataset, test_level | from modelscope.utils.test_utils import create_dummy_test_dataset, test_level | ||||
| @@ -35,7 +36,7 @@ dummy_dataset_big = create_dummy_test_dataset( | |||||
| np.random.random(size=(5, )), np.random.randint(0, 4, (1, )), 40) | np.random.random(size=(5, )), np.random.randint(0, 4, (1, )), 40) | ||||
| class DummyModel(nn.Module): | |||||
| class DummyModel(nn.Module, Model): | |||||
| def __init__(self): | def __init__(self): | ||||
| super().__init__() | super().__init__() | ||||
| @@ -15,6 +15,7 @@ from torch.utils.data import IterableDataset | |||||
| from modelscope.metainfo import Metrics, Trainers | from modelscope.metainfo import Metrics, Trainers | ||||
| from modelscope.metrics.builder import MetricKeys | from modelscope.metrics.builder import MetricKeys | ||||
| from modelscope.models.base import Model | |||||
| from modelscope.trainers import EpochBasedTrainer, build_trainer | from modelscope.trainers import EpochBasedTrainer, build_trainer | ||||
| from modelscope.utils.constant import LogKeys, ModeKeys, ModelFile | from modelscope.utils.constant import LogKeys, ModeKeys, ModelFile | ||||
| from modelscope.utils.test_utils import (DistributedTestCase, | from modelscope.utils.test_utils import (DistributedTestCase, | ||||
| @@ -37,7 +38,7 @@ dummy_dataset_big = create_dummy_test_dataset( | |||||
| np.random.random(size=(5, )), np.random.randint(0, 4, (1, )), 40) | np.random.random(size=(5, )), np.random.randint(0, 4, (1, )), 40) | ||||
| class DummyModel(nn.Module): | |||||
| class DummyModel(nn.Module, Model): | |||||
| def __init__(self): | def __init__(self): | ||||
| super().__init__() | super().__init__() | ||||
| @@ -6,16 +6,20 @@ import unittest | |||||
| from modelscope.hub.snapshot_download import snapshot_download | from modelscope.hub.snapshot_download import snapshot_download | ||||
| from modelscope.metainfo import Metrics | from modelscope.metainfo import Metrics | ||||
| from modelscope.models.base import Model | |||||
| from modelscope.models.nlp.sequence_classification import \ | from modelscope.models.nlp.sequence_classification import \ | ||||
| SbertForSequenceClassification | SbertForSequenceClassification | ||||
| from modelscope.msdatasets import MsDataset | from modelscope.msdatasets import MsDataset | ||||
| from modelscope.pipelines import pipeline | |||||
| from modelscope.trainers import build_trainer | from modelscope.trainers import build_trainer | ||||
| from modelscope.utils.constant import ModelFile | |||||
| from modelscope.utils.constant import ModelFile, Tasks | |||||
| from modelscope.utils.hub import read_config | from modelscope.utils.hub import read_config | ||||
| from modelscope.utils.test_utils import test_level | from modelscope.utils.test_utils import test_level | ||||
| class TestTrainerWithNlp(unittest.TestCase): | class TestTrainerWithNlp(unittest.TestCase): | ||||
| sentence1 = '今天气温比昨天高么?' | |||||
| sentence2 = '今天湿度比昨天高么?' | |||||
| def setUp(self): | def setUp(self): | ||||
| print(('Testing %s.%s' % (type(self).__name__, self._testMethodName))) | print(('Testing %s.%s' % (type(self).__name__, self._testMethodName))) | ||||
| @@ -30,7 +34,7 @@ class TestTrainerWithNlp(unittest.TestCase): | |||||
| shutil.rmtree(self.tmp_dir) | shutil.rmtree(self.tmp_dir) | ||||
| super().tearDown() | super().tearDown() | ||||
| @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') | |||||
| @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | |||||
| def test_trainer(self): | def test_trainer(self): | ||||
| model_id = 'damo/nlp_structbert_sentence-similarity_chinese-base' | model_id = 'damo/nlp_structbert_sentence-similarity_chinese-base' | ||||
| kwargs = dict( | kwargs = dict( | ||||
| @@ -47,6 +51,27 @@ class TestTrainerWithNlp(unittest.TestCase): | |||||
| for i in range(10): | for i in range(10): | ||||
| self.assertIn(f'epoch_{i+1}.pth', results_files) | self.assertIn(f'epoch_{i+1}.pth', results_files) | ||||
| output_files = os.listdir( | |||||
| os.path.join(self.tmp_dir, ModelFile.TRAIN_OUTPUT_DIR)) | |||||
| self.assertIn(ModelFile.CONFIGURATION, output_files) | |||||
| self.assertIn(ModelFile.TORCH_MODEL_BIN_FILE, output_files) | |||||
| copy_src_files = os.listdir(trainer.model_dir) | |||||
| print(f'copy_src_files are {copy_src_files}') | |||||
| print(f'output_files are {output_files}') | |||||
| for item in copy_src_files: | |||||
| if not item.startswith('.'): | |||||
| self.assertIn(item, output_files) | |||||
| def pipeline_sentence_similarity(model_dir): | |||||
| model = Model.from_pretrained(model_dir) | |||||
| pipeline_ins = pipeline( | |||||
| task=Tasks.sentence_similarity, model=model) | |||||
| print(pipeline_ins(input=(self.sentence1, self.sentence2))) | |||||
| output_dir = os.path.join(self.tmp_dir, ModelFile.TRAIN_OUTPUT_DIR) | |||||
| pipeline_sentence_similarity(output_dir) | |||||
| @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') | @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') | ||||
| def test_trainer_with_backbone_head(self): | def test_trainer_with_backbone_head(self): | ||||
| model_id = 'damo/nlp_structbert_sentiment-classification_chinese-base' | model_id = 'damo/nlp_structbert_sentiment-classification_chinese-base' | ||||
| @@ -11,6 +11,7 @@ from torch.utils.data import DataLoader | |||||
| from modelscope.metrics.builder import MetricKeys | from modelscope.metrics.builder import MetricKeys | ||||
| from modelscope.metrics.sequence_classification_metric import \ | from modelscope.metrics.sequence_classification_metric import \ | ||||
| SequenceClassificationMetric | SequenceClassificationMetric | ||||
| from modelscope.models.base import Model | |||||
| from modelscope.trainers.utils.inference import multi_gpu_test, single_gpu_test | from modelscope.trainers.utils.inference import multi_gpu_test, single_gpu_test | ||||
| from modelscope.utils.test_utils import (DistributedTestCase, | from modelscope.utils.test_utils import (DistributedTestCase, | ||||
| create_dummy_test_dataset, test_level) | create_dummy_test_dataset, test_level) | ||||
| @@ -20,7 +21,7 @@ dummy_dataset = create_dummy_test_dataset( | |||||
| torch.rand((5, )), torch.randint(0, 4, (1, )), 20) | torch.rand((5, )), torch.randint(0, 4, (1, )), 20) | ||||
| class DummyModel(nn.Module): | |||||
| class DummyModel(nn.Module, Model): | |||||
| def __init__(self): | def __init__(self): | ||||
| super().__init__() | super().__init__() | ||||