Browse Source

[to #42322933]Add model.save_pretrained method and allow finetune results used by pipeline

master
zhangzhicheng.zzc 3 years ago
parent
commit
b94bb74f66
19 changed files with 254 additions and 29 deletions
  1. +1
    -1
      modelscope/fileio/__init__.py
  2. +1
    -1
      modelscope/fileio/file.py
  3. +28
    -4
      modelscope/models/base/base_model.py
  4. +16
    -1
      modelscope/trainers/hooks/checkpoint_hook.py
  5. +83
    -2
      modelscope/utils/checkpoint.py
  6. +18
    -0
      modelscope/utils/config.py
  7. +1
    -0
      modelscope/utils/constant.py
  8. +9
    -3
      modelscope/utils/hub.py
  9. +2
    -1
      tests/trainers/hooks/logger/test_tensorboard_hook.py
  10. +21
    -1
      tests/trainers/hooks/test_checkpoint_hook.py
  11. +2
    -1
      tests/trainers/hooks/test_evaluation_hook.py
  12. +2
    -1
      tests/trainers/hooks/test_lr_scheduler_hook.py
  13. +2
    -1
      tests/trainers/hooks/test_optimizer_hook.py
  14. +3
    -2
      tests/trainers/hooks/test_timer_hook.py
  15. +32
    -5
      tests/trainers/test_finetune_sequence_classification.py
  16. +2
    -1
      tests/trainers/test_trainer.py
  17. +2
    -1
      tests/trainers/test_trainer_gpu.py
  18. +27
    -2
      tests/trainers/test_trainer_with_nlp.py
  19. +2
    -1
      tests/trainers/utils/test_inference.py

+ 1
- 1
modelscope/fileio/__init__.py View File

@@ -1,2 +1,2 @@
from .file import File
from .file import File, LocalStorage
from .io import dump, dumps, load

+ 1
- 1
modelscope/fileio/file.py View File

@@ -240,7 +240,7 @@ class File(object):
@staticmethod
def _get_storage(uri):
assert isinstance(uri,
str), f'uri should be str type, buf got {type(uri)}'
str), f'uri should be str type, but got {type(uri)}'

if '://' not in uri:
# local path


+ 28
- 4
modelscope/models/base/base_model.py View File

@@ -1,13 +1,12 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import os
import os.path as osp
from abc import ABC, abstractmethod
from typing import Dict, Optional, Union

import numpy as np
from typing import Callable, Dict, List, Optional, Union

from modelscope.hub.snapshot_download import snapshot_download
from modelscope.models.builder import build_model
from modelscope.utils.checkpoint import save_pretrained
from modelscope.utils.config import Config
from modelscope.utils.constant import DEFAULT_MODEL_REVISION, ModelFile
from modelscope.utils.device import device_placement, verify_device
@@ -119,3 +118,28 @@ class Model(ABC):
if hasattr(cfg, 'pipeline'):
model.pipeline = cfg.pipeline
return model

def save_pretrained(self,
target_folder: Union[str, os.PathLike],
save_checkpoint_names: Union[str, List[str]] = None,
save_function: Callable = None,
config: Optional[dict] = None,
**kwargs):
"""save the pretrained model, its configuration and other related files to a directory, so that it can be re-loaded

Args:
target_folder (Union[str, os.PathLike]):
Directory to which to save. Will be created if it doesn't exist.

save_checkpoint_names (Union[str, List[str]]):
The checkpoint names to be saved in the target_folder

save_function (Callable, optional):
The function to use to save the state dictionary.

config (Optional[dict], optional):
The config for the configuration.json, might not be identical with model.config

"""
save_pretrained(self, target_folder, save_checkpoint_names,
save_function, config, **kwargs)

+ 16
- 1
modelscope/trainers/hooks/checkpoint_hook.py View File

@@ -1,10 +1,12 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import os

import json

from modelscope import __version__
from modelscope.metainfo import Hooks
from modelscope.utils.checkpoint import save_checkpoint
from modelscope.utils.constant import LogKeys
from modelscope.utils.constant import LogKeys, ModelFile
from modelscope.utils.logger import get_logger
from modelscope.utils.torch_utils import is_master
from .builder import HOOKS
@@ -73,6 +75,18 @@ class CheckpointHook(Hook):
self.save_dir, f'{LogKeys.ITER}_{trainer.iter + 1}.pth')

save_checkpoint(trainer.model, cur_save_name, trainer.optimizer)
self._save_pretrained(trainer)

def _save_pretrained(self, trainer):
if self.is_last_epoch(trainer) and self.by_epoch:
output_dir = os.path.join(self.save_dir,
ModelFile.TRAIN_OUTPUT_DIR)

trainer.model.save_pretrained(
output_dir,
ModelFile.TORCH_MODEL_BIN_FILE,
save_function=save_checkpoint,
config=trainer.cfg.to_dict())

def after_train_iter(self, trainer):
if self.by_epoch:
@@ -166,3 +180,4 @@ class BestCkptSaverHook(CheckpointHook):
)
save_checkpoint(trainer.model, cur_save_name, trainer.optimizer)
self._best_ckpt_file = cur_save_name
self._save_pretrained(trainer)

+ 83
- 2
modelscope/utils/checkpoint.py View File

@@ -1,15 +1,23 @@
# Copyright (c) Alibaba, Inc. and its affiliates.

import io
import os
import time
from collections import OrderedDict
from typing import Optional
from shutil import copytree, ignore_patterns, rmtree
from typing import Callable, List, Optional, Union

import json
import numpy as np
import torch
from torch.optim import Optimizer

from modelscope import __version__
from modelscope.fileio import File
from modelscope.fileio import File, LocalStorage
from modelscope.utils.config import JSONIteratorEncoder
from modelscope.utils.constant import ConfigFields, ModelFile

storage = LocalStorage()


def weights_to_cpu(state_dict):
@@ -72,3 +80,76 @@ def save_checkpoint(model: torch.nn.Module,
with io.BytesIO() as f:
torch.save(checkpoint, f)
File.write(f.getvalue(), filename)


def save_pretrained(model,
target_folder: Union[str, os.PathLike],
save_checkpoint_name: str = None,
save_function: Callable = None,
config: Optional[dict] = None,
**kwargs):
"""save the pretrained model, its configuration and other related files to a directory, so that it can be re-loaded

Args:
model (Model): Model whose params are to be saved.

target_folder (Union[str, os.PathLike]):
Directory to which to save. Will be created if it doesn't exist.

save_checkpoint_name (str):
The checkpoint name to be saved in the target_folder

save_function (Callable, optional):
The function to use to save the state dictionary.

config (Optional[dict], optional):
The config for the configuration.json, might not be identical with model.config
"""

if save_function is None or not isinstance(save_function, Callable):
raise Exception('A valid save function must be passed in')

if target_folder is None or os.path.isfile(target_folder):
raise ValueError(
f'Provided path ({target_folder}) should be a directory, not a file'
)

if save_checkpoint_name is None:
raise Exception(
'At least pass in one checkpoint name for saving method')

if config is None:
raise ValueError('Configuration is not valid')

# Clean the folder from a previous save
if os.path.exists(target_folder):
rmtree(target_folder)

# Single ckpt path, sharded ckpt logic will be added later
output_ckpt_path = os.path.join(target_folder, save_checkpoint_name)

# Save the files to be copied to the save directory, ignore the original ckpts and configuration
origin_file_to_be_ignored = [save_checkpoint_name]
ignore_file_set = set(origin_file_to_be_ignored)
ignore_file_set.add(ModelFile.CONFIGURATION)
ignore_file_set.add('.*')
if hasattr(model, 'model_dir') and model.model_dir is not None:
copytree(
model.model_dir,
target_folder,
ignore=ignore_patterns(*ignore_file_set))

# Save the ckpt to the save directory
try:
save_function(model, output_ckpt_path)
except Exception as e:
raise Exception(
f'During saving checkpoints, the error of "{type(e).__name__} '
f'with msg {e} throwed')

# Dump the config to the configuration.json
if ConfigFields.pipeline not in config:
config[ConfigFields.pipeline] = {'type': config[ConfigFields.task]}
cfg_str = json.dumps(config, cls=JSONIteratorEncoder)
config_file = os.path.join(target_folder, ModelFile.CONFIGURATION)
storage.write(cfg_str.encode(), config_file)

+ 18
- 0
modelscope/utils/config.py View File

@@ -12,6 +12,7 @@ from pathlib import Path
from typing import Dict, Union

import addict
import json
from yapf.yapflib.yapf_api import FormatCode

from modelscope.utils.constant import ConfigFields, ModelFile
@@ -627,3 +628,20 @@ def check_config(cfg: Union[str, ConfigDict]):
check_attr(ConfigFields.model)
check_attr(ConfigFields.preprocessor)
check_attr(ConfigFields.evaluation)


class JSONIteratorEncoder(json.JSONEncoder):
"""Implement this method in order that supporting arbitrary iterators, it returns
a serializable object for ``obj``, or calls the base implementation
(to raise a ``TypeError``).

"""

def default(self, obj):
try:
iterable = iter(obj)
except TypeError:
pass
else:
return list(iterable)
return json.JSONEncoder.default(self, obj)

+ 1
- 0
modelscope/utils/constant.py View File

@@ -211,6 +211,7 @@ class ModelFile(object):
VOCAB_FILE = 'vocab.txt'
ONNX_MODEL_FILE = 'model.onnx'
LABEL_MAPPING = 'label_mapping.json'
TRAIN_OUTPUT_DIR = 'output'


class ConfigFields(object):


+ 9
- 3
modelscope/utils/hub.py View File

@@ -10,7 +10,8 @@ from modelscope.hub.constants import Licenses, ModelVisibility
from modelscope.hub.file_download import model_file_download
from modelscope.hub.snapshot_download import snapshot_download
from modelscope.utils.config import Config
from modelscope.utils.constant import DEFAULT_MODEL_REVISION, ModelFile
from modelscope.utils.constant import (DEFAULT_MODEL_REVISION, ConfigFields,
ModelFile)
from .logger import get_logger

logger = get_logger(__name__)
@@ -119,8 +120,13 @@ def parse_label_mapping(model_dir):
if label2id is None:
config_path = os.path.join(model_dir, ModelFile.CONFIGURATION)
config = Config.from_file(config_path)
if hasattr(config, 'model') and hasattr(config.model, 'label2id'):
label2id = config.model.label2id
if hasattr(config, ConfigFields.model) and hasattr(
config[ConfigFields.model], 'label2id'):
label2id = config[ConfigFields.model].label2id
elif hasattr(config, ConfigFields.preprocessor) and hasattr(
config[ConfigFields.preprocessor], 'label2id'):
label2id = config[ConfigFields.preprocessor].label2id

if label2id is None:
config_path = os.path.join(model_dir, 'config.json')
config = Config.from_file(config_path)


+ 2
- 1
tests/trainers/hooks/logger/test_tensorboard_hook.py View File

@@ -11,6 +11,7 @@ import torch
from torch import nn

from modelscope.metainfo import Trainers
from modelscope.models.base import Model
from modelscope.trainers import build_trainer
from modelscope.utils.constant import LogKeys, ModelFile
from modelscope.utils.test_utils import create_dummy_test_dataset
@@ -19,7 +20,7 @@ dummy_dataset = create_dummy_test_dataset(
np.random.random(size=(5, )), np.random.randint(0, 4, (1, )), 20)


class DummyModel(nn.Module):
class DummyModel(nn.Module, Model):

def __init__(self):
super().__init__()


+ 21
- 1
tests/trainers/hooks/test_checkpoint_hook.py View File

@@ -11,11 +11,14 @@ from torch import nn

from modelscope.metainfo import Trainers
from modelscope.metrics.builder import METRICS, MetricKeys
from modelscope.models.base import Model
from modelscope.trainers import build_trainer
from modelscope.utils.constant import LogKeys, ModelFile
from modelscope.utils.registry import default_group
from modelscope.utils.test_utils import create_dummy_test_dataset

SRC_DIR = os.path.dirname(__file__)


def create_dummy_metric():
_global_iter = 0
@@ -39,12 +42,13 @@ dummy_dataset = create_dummy_test_dataset(
np.random.random(size=(5, )), np.random.randint(0, 4, (1, )), 20)


class DummyModel(nn.Module):
class DummyModel(nn.Module, Model):

def __init__(self):
super().__init__()
self.linear = nn.Linear(5, 4)
self.bn = nn.BatchNorm1d(4)
self.model_dir = SRC_DIR

def forward(self, feat, labels):
x = self.linear(feat)
@@ -123,6 +127,14 @@ class CheckpointHookTest(unittest.TestCase):
self.assertIn(f'{LogKeys.EPOCH}_1.pth', results_files)
self.assertIn(f'{LogKeys.EPOCH}_2.pth', results_files)

output_files = os.listdir(
os.path.join(self.tmp_dir, ModelFile.TRAIN_OUTPUT_DIR))
self.assertIn(ModelFile.CONFIGURATION, output_files)
self.assertIn(ModelFile.TORCH_MODEL_BIN_FILE, output_files)
copy_src_files = os.listdir(SRC_DIR)
self.assertIn(copy_src_files[0], output_files)
self.assertIn(copy_src_files[-1], output_files)


class BestCkptSaverHookTest(unittest.TestCase):

@@ -198,6 +210,14 @@ class BestCkptSaverHookTest(unittest.TestCase):
self.assertIn(f'best_{LogKeys.EPOCH}1_{MetricKeys.ACCURACY}0.1.pth',
results_files)

output_files = os.listdir(
os.path.join(self.tmp_dir, ModelFile.TRAIN_OUTPUT_DIR))
self.assertIn(ModelFile.CONFIGURATION, output_files)
self.assertIn(ModelFile.TORCH_MODEL_BIN_FILE, output_files)
copy_src_files = os.listdir(SRC_DIR)
self.assertIn(copy_src_files[0], output_files)
self.assertIn(copy_src_files[-1], output_files)


if __name__ == '__main__':
unittest.main()

+ 2
- 1
tests/trainers/hooks/test_evaluation_hook.py View File

@@ -11,6 +11,7 @@ from torch import nn

from modelscope.metainfo import Trainers
from modelscope.metrics.builder import METRICS, MetricKeys
from modelscope.models.base import Model
from modelscope.trainers import build_trainer
from modelscope.utils.constant import ModelFile
from modelscope.utils.registry import default_group
@@ -34,7 +35,7 @@ dummy_dataset = create_dummy_test_dataset(
np.random.random(size=(5, )), np.random.randint(0, 4, (1, )), 20)


class DummyModel(nn.Module):
class DummyModel(nn.Module, Model):

def __init__(self):
super().__init__()


+ 2
- 1
tests/trainers/hooks/test_lr_scheduler_hook.py View File

@@ -13,6 +13,7 @@ from torch.optim.lr_scheduler import MultiStepLR

from modelscope.metainfo import Trainers
from modelscope.metrics.builder import METRICS, MetricKeys
from modelscope.models.base import Model
from modelscope.trainers import build_trainer
from modelscope.utils.constant import LogKeys, ModelFile, TrainerStages
from modelscope.utils.registry import default_group
@@ -40,7 +41,7 @@ def create_dummy_metric():
return {MetricKeys.ACCURACY: self._fake_acc_by_epoch[_global_iter]}


class DummyModel(nn.Module):
class DummyModel(nn.Module, Model):

def __init__(self):
super().__init__()


+ 2
- 1
tests/trainers/hooks/test_optimizer_hook.py View File

@@ -12,6 +12,7 @@ from torch.optim import SGD
from torch.optim.lr_scheduler import MultiStepLR

from modelscope.metainfo import Trainers
from modelscope.models.base import Model
from modelscope.trainers import build_trainer
from modelscope.utils.constant import ModelFile, TrainerStages
from modelscope.utils.test_utils import create_dummy_test_dataset
@@ -20,7 +21,7 @@ dummy_dataset = create_dummy_test_dataset(
np.random.random(size=(2, )), np.random.randint(0, 2, (1, )), 10)


class DummyModel(nn.Module):
class DummyModel(nn.Module, Model):

def __init__(self):
super().__init__()


+ 3
- 2
tests/trainers/hooks/test_timer_hook.py View File

@@ -12,6 +12,7 @@ from torch.optim import SGD
from torch.optim.lr_scheduler import MultiStepLR

from modelscope.metainfo import Trainers
from modelscope.models.base import Model
from modelscope.trainers import build_trainer
from modelscope.utils.constant import LogKeys, ModelFile, TrainerStages
from modelscope.utils.test_utils import create_dummy_test_dataset
@@ -20,7 +21,7 @@ dummy_dataset = create_dummy_test_dataset(
np.random.random(size=(5, )), np.random.randint(0, 4, (1, )), 10)


class DummyModel(nn.Module):
class DummyModel(nn.Module, Model):

def __init__(self):
super().__init__()
@@ -83,8 +84,8 @@ class IterTimerHookTest(unittest.TestCase):
trainer.train_dataset, **trainer.cfg.train.get('dataloader', {}))
trainer.register_optimizers_hook()
trainer.register_hook_from_cfg(trainer.cfg.train.hooks)
trainer.data_loader = train_dataloader
trainer.train_dataloader = train_dataloader
trainer.data_loader = train_dataloader
trainer.invoke_hook(TrainerStages.before_run)
for i in range(trainer._epoch, trainer._max_epochs):
trainer.invoke_hook(TrainerStages.before_train_epoch)


+ 32
- 5
tests/trainers/test_finetune_sequence_classification.py View File

@@ -4,11 +4,18 @@ import shutil
import tempfile
import unittest

from modelscope.metainfo import Trainers
from modelscope.metainfo import Preprocessors, Trainers
from modelscope.models import Model
from modelscope.pipelines import pipeline
from modelscope.trainers import build_trainer
from modelscope.utils.constant import ModelFile, Tasks


class TestFinetuneSequenceClassification(unittest.TestCase):
epoch_num = 1

sentence1 = '今天气温比昨天高么?'
sentence2 = '今天湿度比昨天高么?'

def setUp(self):
print(('Testing %s.%s' % (type(self).__name__, self._testMethodName)))
@@ -40,15 +47,32 @@ class TestFinetuneSequenceClassification(unittest.TestCase):
trainer.train()
results_files = os.listdir(self.tmp_dir)
self.assertIn(f'{trainer.timestamp}.log.json', results_files)
for i in range(10):
for i in range(self.epoch_num):
self.assertIn(f'epoch_{i+1}.pth', results_files)

output_files = os.listdir(
os.path.join(self.tmp_dir, ModelFile.TRAIN_OUTPUT_DIR))
self.assertIn(ModelFile.CONFIGURATION, output_files)
self.assertIn(ModelFile.TORCH_MODEL_BIN_FILE, output_files)
copy_src_files = os.listdir(trainer.model_dir)

print(f'copy_src_files are {copy_src_files}')
print(f'output_files are {output_files}')
for item in copy_src_files:
if not item.startswith('.'):
self.assertIn(item, output_files)

def pipeline_sentence_similarity(self, model_dir):
model = Model.from_pretrained(model_dir)
pipeline_ins = pipeline(task=Tasks.sentence_similarity, model=model)
print(pipeline_ins(input=(self.sentence1, self.sentence2)))

@unittest.skip
def test_finetune_afqmc(self):

def cfg_modify_fn(cfg):
cfg.task = 'sentence-similarity'
cfg['preprocessor'] = {'type': 'sen-sim-tokenizer'}
cfg.task = Tasks.sentence_similarity
cfg['preprocessor'] = {'type': Preprocessors.sen_sim_tokenizer}
cfg.train.optimizer.lr = 2e-5
cfg['dataset'] = {
'train': {
@@ -58,7 +82,7 @@ class TestFinetuneSequenceClassification(unittest.TestCase):
'label': 'label',
}
}
cfg.train.max_epochs = 10
cfg.train.max_epochs = self.epoch_num
cfg.train.lr_scheduler = {
'type': 'LinearLR',
'start_factor': 1.0,
@@ -95,6 +119,9 @@ class TestFinetuneSequenceClassification(unittest.TestCase):
eval_dataset=dataset['validation'],
cfg_modify_fn=cfg_modify_fn)

output_dir = os.path.join(self.tmp_dir, ModelFile.TRAIN_OUTPUT_DIR)
self.pipeline_sentence_similarity(output_dir)

@unittest.skip
def test_finetune_tnews(self):



+ 2
- 1
tests/trainers/test_trainer.py View File

@@ -14,6 +14,7 @@ from torch.utils.data import IterableDataset

from modelscope.metainfo import Metrics, Trainers
from modelscope.metrics.builder import MetricKeys
from modelscope.models.base import Model
from modelscope.trainers import build_trainer
from modelscope.utils.constant import LogKeys, ModeKeys, ModelFile
from modelscope.utils.test_utils import create_dummy_test_dataset, test_level
@@ -35,7 +36,7 @@ dummy_dataset_big = create_dummy_test_dataset(
np.random.random(size=(5, )), np.random.randint(0, 4, (1, )), 40)


class DummyModel(nn.Module):
class DummyModel(nn.Module, Model):

def __init__(self):
super().__init__()


+ 2
- 1
tests/trainers/test_trainer_gpu.py View File

@@ -15,6 +15,7 @@ from torch.utils.data import IterableDataset

from modelscope.metainfo import Metrics, Trainers
from modelscope.metrics.builder import MetricKeys
from modelscope.models.base import Model
from modelscope.trainers import EpochBasedTrainer, build_trainer
from modelscope.utils.constant import LogKeys, ModeKeys, ModelFile
from modelscope.utils.test_utils import (DistributedTestCase,
@@ -37,7 +38,7 @@ dummy_dataset_big = create_dummy_test_dataset(
np.random.random(size=(5, )), np.random.randint(0, 4, (1, )), 40)


class DummyModel(nn.Module):
class DummyModel(nn.Module, Model):

def __init__(self):
super().__init__()


+ 27
- 2
tests/trainers/test_trainer_with_nlp.py View File

@@ -6,16 +6,20 @@ import unittest

from modelscope.hub.snapshot_download import snapshot_download
from modelscope.metainfo import Metrics
from modelscope.models.base import Model
from modelscope.models.nlp.sequence_classification import \
SbertForSequenceClassification
from modelscope.msdatasets import MsDataset
from modelscope.pipelines import pipeline
from modelscope.trainers import build_trainer
from modelscope.utils.constant import ModelFile
from modelscope.utils.constant import ModelFile, Tasks
from modelscope.utils.hub import read_config
from modelscope.utils.test_utils import test_level


class TestTrainerWithNlp(unittest.TestCase):
sentence1 = '今天气温比昨天高么?'
sentence2 = '今天湿度比昨天高么?'

def setUp(self):
print(('Testing %s.%s' % (type(self).__name__, self._testMethodName)))
@@ -30,7 +34,7 @@ class TestTrainerWithNlp(unittest.TestCase):
shutil.rmtree(self.tmp_dir)
super().tearDown()

@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
def test_trainer(self):
model_id = 'damo/nlp_structbert_sentence-similarity_chinese-base'
kwargs = dict(
@@ -47,6 +51,27 @@ class TestTrainerWithNlp(unittest.TestCase):
for i in range(10):
self.assertIn(f'epoch_{i+1}.pth', results_files)

output_files = os.listdir(
os.path.join(self.tmp_dir, ModelFile.TRAIN_OUTPUT_DIR))
self.assertIn(ModelFile.CONFIGURATION, output_files)
self.assertIn(ModelFile.TORCH_MODEL_BIN_FILE, output_files)
copy_src_files = os.listdir(trainer.model_dir)

print(f'copy_src_files are {copy_src_files}')
print(f'output_files are {output_files}')
for item in copy_src_files:
if not item.startswith('.'):
self.assertIn(item, output_files)

def pipeline_sentence_similarity(model_dir):
model = Model.from_pretrained(model_dir)
pipeline_ins = pipeline(
task=Tasks.sentence_similarity, model=model)
print(pipeline_ins(input=(self.sentence1, self.sentence2)))

output_dir = os.path.join(self.tmp_dir, ModelFile.TRAIN_OUTPUT_DIR)
pipeline_sentence_similarity(output_dir)

@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
def test_trainer_with_backbone_head(self):
model_id = 'damo/nlp_structbert_sentiment-classification_chinese-base'


+ 2
- 1
tests/trainers/utils/test_inference.py View File

@@ -11,6 +11,7 @@ from torch.utils.data import DataLoader
from modelscope.metrics.builder import MetricKeys
from modelscope.metrics.sequence_classification_metric import \
SequenceClassificationMetric
from modelscope.models.base import Model
from modelscope.trainers.utils.inference import multi_gpu_test, single_gpu_test
from modelscope.utils.test_utils import (DistributedTestCase,
create_dummy_test_dataset, test_level)
@@ -20,7 +21,7 @@ dummy_dataset = create_dummy_test_dataset(
torch.rand((5, )), torch.randint(0, 4, (1, )), 20)


class DummyModel(nn.Module):
class DummyModel(nn.Module, Model):

def __init__(self):
super().__init__()


Loading…
Cancel
Save