From c2b1ff8389887c056b866347654a56c05e99ab81 Mon Sep 17 00:00:00 2001 From: "yuze.zyz" Date: Wed, 21 Sep 2022 14:25:06 +0800 Subject: [PATCH] [to #42322933] Add exporter module for onnx,ts and other formats. 1. Add exporter module 2. Move collate_fn out of the base pipeline class for reusing. 3. Add dummy inputs method in nlp tokenization preprocessor base class 4. Support Mapping in tensor numpify and detaching. Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/10037704 --- modelscope/exporters/__init__.py | 4 + modelscope/exporters/base.py | 53 ++++ modelscope/exporters/builder.py | 21 ++ modelscope/exporters/nlp/__init__.py | 2 + ...rt_for_sequence_classification_exporter.py | 81 ++++++ modelscope/exporters/torch_model_exporter.py | 247 ++++++++++++++++++ modelscope/pipelines/base.py | 81 +++--- modelscope/utils/constant.py | 1 + modelscope/utils/regress_test_utils.py | 18 +- modelscope/utils/tensor_utils.py | 22 ++ tests/export/__init__.py | 0 ...st_export_sbert_sequence_classification.py | 37 +++ 12 files changed, 520 insertions(+), 47 deletions(-) create mode 100644 modelscope/exporters/__init__.py create mode 100644 modelscope/exporters/base.py create mode 100644 modelscope/exporters/builder.py create mode 100644 modelscope/exporters/nlp/__init__.py create mode 100644 modelscope/exporters/nlp/sbert_for_sequence_classification_exporter.py create mode 100644 modelscope/exporters/torch_model_exporter.py create mode 100644 tests/export/__init__.py create mode 100644 tests/export/test_export_sbert_sequence_classification.py diff --git a/modelscope/exporters/__init__.py b/modelscope/exporters/__init__.py new file mode 100644 index 00000000..a597114f --- /dev/null +++ b/modelscope/exporters/__init__.py @@ -0,0 +1,4 @@ +from .base import Exporter +from .builder import build_exporter +from .nlp import SbertForSequenceClassificationExporter +from .torch_model_exporter import TorchModelExporter diff --git a/modelscope/exporters/base.py b/modelscope/exporters/base.py new file mode 100644 index 00000000..f19d2bbb --- /dev/null +++ b/modelscope/exporters/base.py @@ -0,0 +1,53 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import os +from abc import ABC, abstractmethod + +from modelscope.models import Model +from modelscope.utils.config import Config, ConfigDict +from modelscope.utils.constant import ModelFile +from .builder import build_exporter + + +class Exporter(ABC): + """Exporter base class to output model to onnx, torch_script, graphdef, etc. + """ + + def __init__(self): + self.model = None + + @classmethod + def from_model(cls, model: Model, **kwargs): + """Build the Exporter instance. + + @param model: A model instance. it will be used to output the generated file, + and the configuration.json in its model_dir field will be used to create the exporter instance. + @param kwargs: Extra kwargs used to create the Exporter instance. + @return: The Exporter instance + """ + cfg = Config.from_file( + os.path.join(model.model_dir, ModelFile.CONFIGURATION)) + task_name = cfg.task + model_cfg = cfg.model + if hasattr(model_cfg, 'model_type') and not hasattr(model_cfg, 'type'): + model_cfg.type = model_cfg.model_type + export_cfg = ConfigDict({'type': model_cfg.type}) + if hasattr(cfg, 'export'): + export_cfg.update(cfg.export) + exporter = build_exporter(export_cfg, task_name, kwargs) + exporter.model = model + return exporter + + @abstractmethod + def export_onnx(self, outputs: str, opset=11, **kwargs): + """Export the model as onnx format files. + + In some cases, several files may be generated, + So please return a dict which contains the generated name with the file path. + + @param opset: The version of the ONNX operator set to use. + @param outputs: The output dir. + @param kwargs: In this default implementation, + kwargs will be carried to generate_dummy_inputs as extra arguments (like input shape). + @return: A dict contains the model name with the model file path. + """ + pass diff --git a/modelscope/exporters/builder.py b/modelscope/exporters/builder.py new file mode 100644 index 00000000..90699c12 --- /dev/null +++ b/modelscope/exporters/builder.py @@ -0,0 +1,21 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +from modelscope.utils.config import ConfigDict +from modelscope.utils.registry import Registry, build_from_cfg + +EXPORTERS = Registry('exporters') + + +def build_exporter(cfg: ConfigDict, + task_name: str = None, + default_args: dict = None): + """ build exporter by the given model config dict + + Args: + cfg (:obj:`ConfigDict`): config dict for exporter object. + task_name (str, optional): task name, refer to + :obj:`Tasks` for more details + default_args (dict, optional): Default initialization arguments. + """ + return build_from_cfg( + cfg, EXPORTERS, group_key=task_name, default_args=default_args) diff --git a/modelscope/exporters/nlp/__init__.py b/modelscope/exporters/nlp/__init__.py new file mode 100644 index 00000000..fdfd2711 --- /dev/null +++ b/modelscope/exporters/nlp/__init__.py @@ -0,0 +1,2 @@ +from .sbert_for_sequence_classification_exporter import \ + SbertForSequenceClassificationExporter diff --git a/modelscope/exporters/nlp/sbert_for_sequence_classification_exporter.py b/modelscope/exporters/nlp/sbert_for_sequence_classification_exporter.py new file mode 100644 index 00000000..dc1e2b92 --- /dev/null +++ b/modelscope/exporters/nlp/sbert_for_sequence_classification_exporter.py @@ -0,0 +1,81 @@ +import os +from collections import OrderedDict +from typing import Any, Dict, Mapping, Tuple + +from torch.utils.data.dataloader import default_collate + +from modelscope.exporters.builder import EXPORTERS +from modelscope.exporters.torch_model_exporter import TorchModelExporter +from modelscope.metainfo import Models +from modelscope.preprocessors import Preprocessor, build_preprocessor +from modelscope.utils.config import Config +from modelscope.utils.constant import ModeKeys, Tasks + + +@EXPORTERS.register_module( + Tasks.sentence_similarity, module_name=Models.structbert) +@EXPORTERS.register_module( + Tasks.sentiment_classification, module_name=Models.structbert) +@EXPORTERS.register_module(Tasks.nli, module_name=Models.structbert) +@EXPORTERS.register_module( + Tasks.zero_shot_classification, module_name=Models.structbert) +class SbertForSequenceClassificationExporter(TorchModelExporter): + + def generate_dummy_inputs(self, + shape: Tuple = None, + **kwargs) -> Dict[str, Any]: + """Generate dummy inputs for model exportation to onnx or other formats by tracing. + + @param shape: A tuple of input shape which should have at most two dimensions. + shape = (1, ) batch_size=1, sequence_length will be taken from the preprocessor. + shape = (8, 128) batch_size=1, sequence_length=128, which will cover the config of the preprocessor. + @return: Dummy inputs. + """ + + cfg = Config.from_file( + os.path.join(self.model.model_dir, 'configuration.json')) + field_name = Tasks.find_field_by_task(cfg.task) + if 'type' not in cfg.preprocessor and 'val' in cfg.preprocessor: + cfg = cfg.preprocessor.val + else: + cfg = cfg.preprocessor + + batch_size = 1 + sequence_length = {} + if shape is not None: + if len(shape) == 1: + batch_size = shape[0] + elif len(shape) == 2: + batch_size, max_length = shape + sequence_length = {'sequence_length': max_length} + + cfg.update({ + 'model_dir': self.model.model_dir, + 'mode': ModeKeys.TRAIN, + **sequence_length + }) + preprocessor: Preprocessor = build_preprocessor(cfg, field_name) + if preprocessor.pair: + first_sequence = preprocessor.tokenizer.unk_token + second_sequence = preprocessor.tokenizer.unk_token + else: + first_sequence = preprocessor.tokenizer.unk_token + second_sequence = None + + batched = [] + for _ in range(batch_size): + batched.append(preprocessor((first_sequence, second_sequence))) + return default_collate(batched) + + @property + def inputs(self) -> Mapping[str, Mapping[int, str]]: + dynamic_axis = {0: 'batch', 1: 'sequence'} + return OrderedDict([ + ('input_ids', dynamic_axis), + ('attention_mask', dynamic_axis), + ('token_type_ids', dynamic_axis), + ]) + + @property + def outputs(self) -> Mapping[str, Mapping[int, str]]: + return OrderedDict({'logits': {0: 'batch'}}) diff --git a/modelscope/exporters/torch_model_exporter.py b/modelscope/exporters/torch_model_exporter.py new file mode 100644 index 00000000..98a23fe5 --- /dev/null +++ b/modelscope/exporters/torch_model_exporter.py @@ -0,0 +1,247 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import os +from contextlib import contextmanager +from itertools import chain +from typing import Any, Dict, Mapping + +import torch +from torch import nn +from torch.onnx import export as onnx_export +from torch.onnx.utils import _decide_input_format + +from modelscope.models import TorchModel +from modelscope.pipelines.base import collate_fn +from modelscope.utils.constant import ModelFile +from modelscope.utils.logger import get_logger +from modelscope.utils.regress_test_utils import compare_arguments_nested +from modelscope.utils.tensor_utils import torch_nested_numpify +from .base import Exporter + +logger = get_logger(__name__) + + +class TorchModelExporter(Exporter): + """The torch base class of exporter. + + This class provides the default implementations for exporting onnx and torch script. + Each specific model may implement its own exporter by overriding the export_onnx/export_torch_script, + and to provide implementations for generate_dummy_inputs/inputs/outputs methods. + """ + + def export_onnx(self, outputs: str, opset=11, **kwargs): + """Export the model as onnx format files. + + In some cases, several files may be generated, + So please return a dict which contains the generated name with the file path. + + @param opset: The version of the ONNX operator set to use. + @param outputs: The output dir. + @param kwargs: In this default implementation, + you can pass the arguments needed by _torch_export_onnx, other unrecognized args + will be carried to generate_dummy_inputs as extra arguments (such as input shape). + @return: A dict containing the model key - model file path pairs. + """ + model = self.model + if not isinstance(model, nn.Module) and hasattr(model, 'model'): + model = model.model + onnx_file = os.path.join(outputs, ModelFile.ONNX_MODEL_FILE) + self._torch_export_onnx(model, onnx_file, opset=opset, **kwargs) + return {'model': onnx_file} + + def export_torch_script(self, outputs: str, **kwargs): + """Export the model as torch script files. + + In some cases, several files may be generated, + So please return a dict which contains the generated name with the file path. + + @param outputs: The output dir. + @param kwargs: In this default implementation, + you can pass the arguments needed by _torch_export_torch_script, other unrecognized args + will be carried to generate_dummy_inputs as extra arguments (like input shape). + @return: A dict contains the model name with the model file path. + """ + model = self.model + if not isinstance(model, nn.Module) and hasattr(model, 'model'): + model = model.model + ts_file = os.path.join(outputs, ModelFile.TS_MODEL_FILE) + # generate ts by tracing + self._torch_export_torch_script(model, ts_file, **kwargs) + return {'model': ts_file} + + def generate_dummy_inputs(self, **kwargs) -> Dict[str, Any]: + """Generate dummy inputs for model exportation to onnx or other formats by tracing. + @return: Dummy inputs. + """ + return None + + @property + def inputs(self) -> Mapping[str, Mapping[int, str]]: + """Return an ordered dict contains the model's input arguments name with their dynamic axis. + + About the information of dynamic axis please check the dynamic_axes argument of torch.onnx.export function + """ + return None + + @property + def outputs(self) -> Mapping[str, Mapping[int, str]]: + """Return an ordered dict contains the model's output arguments name with their dynamic axis. + + About the information of dynamic axis please check the dynamic_axes argument of torch.onnx.export function + """ + return None + + def _torch_export_onnx(self, + model: nn.Module, + output: str, + opset: int = 11, + device: str = 'cpu', + validation: bool = True, + rtol: float = None, + atol: float = None, + **kwargs): + """Export the model to an onnx format file. + + @param model: A torch.nn.Module instance to export. + @param output: The output file. + @param opset: The version of the ONNX operator set to use. + @param device: The device used to forward. + @param validation: Whether validate the export file. + @param rtol: The rtol used to regress the outputs. + @param atol: The atol used to regress the outputs. + """ + + dummy_inputs = self.generate_dummy_inputs(**kwargs) + inputs = self.inputs + outputs = self.outputs + if dummy_inputs is None or inputs is None or outputs is None: + raise NotImplementedError( + 'Model property dummy_inputs,inputs,outputs must be set.') + + with torch.no_grad(): + model.eval() + device = torch.device(device) + model.to(device) + dummy_inputs = collate_fn(dummy_inputs, device) + + if isinstance(dummy_inputs, Mapping): + dummy_inputs = dict(dummy_inputs) + onnx_outputs = list(self.outputs.keys()) + + with replace_call(): + onnx_export( + model, + (dummy_inputs, ), + f=output, + input_names=list(inputs.keys()), + output_names=onnx_outputs, + dynamic_axes={ + name: axes + for name, axes in chain(inputs.items(), + outputs.items()) + }, + do_constant_folding=True, + opset_version=opset, + ) + + if validation: + try: + import onnx + import onnxruntime as ort + except ImportError: + logger.warn( + 'Cannot validate the exported onnx file, because ' + 'the installation of onnx or onnxruntime cannot be found') + return + onnx_model = onnx.load(output) + onnx.checker.check_model(onnx_model) + ort_session = ort.InferenceSession(output) + with torch.no_grad(): + model.eval() + outputs_origin = model.forward( + *_decide_input_format(model, dummy_inputs)) + if isinstance(outputs_origin, Mapping): + outputs_origin = torch_nested_numpify( + list(outputs_origin.values())) + outputs = ort_session.run( + onnx_outputs, + torch_nested_numpify(dummy_inputs), + ) + + tols = {} + if rtol is not None: + tols['rtol'] = rtol + if atol is not None: + tols['atol'] = atol + if not compare_arguments_nested('Onnx model output match failed', + outputs, outputs_origin, **tols): + raise RuntimeError( + 'export onnx failed because of validation error.') + + def _torch_export_torch_script(self, + model: nn.Module, + output: str, + device: str = 'cpu', + validation: bool = True, + rtol: float = None, + atol: float = None, + **kwargs): + """Export the model to a torch script file. + + @param model: A torch.nn.Module instance to export. + @param output: The output file. + @param device: The device used to forward. + @param validation: Whether validate the export file. + @param rtol: The rtol used to regress the outputs. + @param atol: The atol used to regress the outputs. + """ + + model.eval() + dummy_inputs = self.generate_dummy_inputs(**kwargs) + if dummy_inputs is None: + raise NotImplementedError( + 'Model property dummy_inputs must be set.') + dummy_inputs = collate_fn(dummy_inputs, device) + if isinstance(dummy_inputs, Mapping): + dummy_inputs = tuple(dummy_inputs.values()) + with torch.no_grad(): + model.eval() + with replace_call(): + traced_model = torch.jit.trace( + model, dummy_inputs, strict=False) + torch.jit.save(traced_model, output) + + if validation: + ts_model = torch.jit.load(output) + with torch.no_grad(): + model.eval() + ts_model.eval() + outputs = ts_model.forward(*dummy_inputs) + outputs = torch_nested_numpify(outputs) + outputs_origin = model.forward(*dummy_inputs) + outputs_origin = torch_nested_numpify(outputs_origin) + tols = {} + if rtol is not None: + tols['rtol'] = rtol + if atol is not None: + tols['atol'] = atol + if not compare_arguments_nested( + 'Torch script model output match failed', outputs, + outputs_origin, **tols): + raise RuntimeError( + 'export torch script failed because of validation error.') + + +@contextmanager +def replace_call(): + """This function is used to recover the original call method. + + The Model class of modelscope overrides the call method. When exporting to onnx or torchscript, torch will + prepare the parameters as the prototype of forward method, and trace the call method, this causes + problems. Here we recover the call method to the default implementation of torch.nn.Module, and change it + back after the tracing was done. + """ + + TorchModel.call_origin, TorchModel.__call__ = TorchModel.__call__, TorchModel._call_impl + yield + TorchModel.__call__ = TorchModel.call_origin + del TorchModel.call_origin diff --git a/modelscope/pipelines/base.py b/modelscope/pipelines/base.py index 5369220f..c5db2b57 100644 --- a/modelscope/pipelines/base.py +++ b/modelscope/pipelines/base.py @@ -28,7 +28,7 @@ if is_torch_available(): import torch if is_tf_available(): - import tensorflow as tf + pass Tensor = Union['torch.Tensor', 'tf.Tensor'] Input = Union[str, tuple, MsDataset, 'Image.Image', 'numpy.ndarray'] @@ -204,44 +204,7 @@ class Pipeline(ABC): yield self._process_single(ele, *args, **kwargs) def _collate_fn(self, data): - """Prepare the input just before the forward function. - This method will move the tensors to the right device. - Usually this method does not need to be overridden. - - Args: - data: The data out of the dataloader. - - Returns: The processed data. - - """ - from torch.utils.data.dataloader import default_collate - from modelscope.preprocessors import InputFeatures - if isinstance(data, dict) or isinstance(data, Mapping): - return type(data)( - {k: self._collate_fn(v) - for k, v in data.items()}) - elif isinstance(data, (tuple, list)): - if isinstance(data[0], (int, float)): - return default_collate(data).to(self.device) - else: - return type(data)(self._collate_fn(v) for v in data) - elif isinstance(data, np.ndarray): - if data.dtype.type is np.str_: - return data - else: - return self._collate_fn(torch.from_numpy(data)) - elif isinstance(data, torch.Tensor): - return data.to(self.device) - elif isinstance(data, (bytes, str, int, float, bool, type(None))): - return data - elif isinstance(data, InputFeatures): - return data - else: - import mmcv - if isinstance(data, mmcv.parallel.data_container.DataContainer): - return data - else: - raise ValueError(f'Unsupported data type {type(data)}') + return collate_fn(data, self.device) def _process_single(self, input: Input, *args, **kwargs) -> Dict[str, Any]: preprocess_params = kwargs.get('preprocess_params', {}) @@ -410,3 +373,43 @@ class DistributedPipeline(Pipeline): @return: The forward results. """ pass + + +def collate_fn(data, device): + """Prepare the input just before the forward function. + This method will move the tensors to the right device. + Usually this method does not need to be overridden. + + Args: + data: The data out of the dataloader. + device: The device to move data to. + + Returns: The processed data. + + """ + from torch.utils.data.dataloader import default_collate + from modelscope.preprocessors import InputFeatures + if isinstance(data, dict) or isinstance(data, Mapping): + return type(data)({k: collate_fn(v, device) for k, v in data.items()}) + elif isinstance(data, (tuple, list)): + if isinstance(data[0], (int, float)): + return default_collate(data).to(device) + else: + return type(data)(collate_fn(v, device) for v in data) + elif isinstance(data, np.ndarray): + if data.dtype.type is np.str_: + return data + else: + return collate_fn(torch.from_numpy(data), device) + elif isinstance(data, torch.Tensor): + return data.to(device) + elif isinstance(data, (bytes, str, int, float, bool, type(None))): + return data + elif isinstance(data, InputFeatures): + return data + else: + import mmcv + if isinstance(data, mmcv.parallel.data_container.DataContainer): + return data + else: + raise ValueError(f'Unsupported data type {type(data)}') diff --git a/modelscope/utils/constant.py b/modelscope/utils/constant.py index 57d38da7..d6b0da40 100644 --- a/modelscope/utils/constant.py +++ b/modelscope/utils/constant.py @@ -246,6 +246,7 @@ class ModelFile(object): ONNX_MODEL_FILE = 'model.onnx' LABEL_MAPPING = 'label_mapping.json' TRAIN_OUTPUT_DIR = 'output' + TS_MODEL_FILE = 'model.ts' class ConfigFields(object): diff --git a/modelscope/utils/regress_test_utils.py b/modelscope/utils/regress_test_utils.py index 8b6c24a7..47bbadfe 100644 --- a/modelscope/utils/regress_test_utils.py +++ b/modelscope/utils/regress_test_utils.py @@ -352,10 +352,10 @@ def numpify_tensor_nested(tensors, reduction=None, clip_value=10000): return type(tensors)( numpify_tensor_nested(t, reduction, clip_value) for t in tensors) if isinstance(tensors, Mapping): - return type(tensors)({ + return { k: numpify_tensor_nested(t, reduction, clip_value) for k, t in tensors.items() - }) + } if isinstance(tensors, torch.Tensor): t: np.ndarray = tensors.cpu().numpy() if clip_value is not None: @@ -375,9 +375,7 @@ def detach_tensor_nested(tensors): if isinstance(tensors, (list, tuple)): return type(tensors)(detach_tensor_nested(t) for t in tensors) if isinstance(tensors, Mapping): - return type(tensors)( - {k: detach_tensor_nested(t) - for k, t in tensors.items()}) + return {k: detach_tensor_nested(t) for k, t in tensors.items()} if isinstance(tensors, torch.Tensor): return tensors.detach() return tensors @@ -496,7 +494,11 @@ def intercept_module(module: nn.Module, intercept_module(module, io_json, full_name, restore) -def compare_arguments_nested(print_content, arg1, arg2): +def compare_arguments_nested(print_content, + arg1, + arg2, + rtol=1.e-3, + atol=1.e-8): type1 = type(arg1) type2 = type(arg2) if type1.__name__ != type2.__name__: @@ -515,7 +517,7 @@ def compare_arguments_nested(print_content, arg1, arg2): return False return True elif isinstance(arg1, (float, np.floating)): - if not np.isclose(arg1, arg2, rtol=1.e-3, atol=1.e-8, equal_nan=True): + if not np.isclose(arg1, arg2, rtol=rtol, atol=atol, equal_nan=True): if print_content is not None: print(f'{print_content}, arg1:{arg1}, arg2:{arg2}') return False @@ -562,7 +564,7 @@ def compare_arguments_nested(print_content, arg1, arg2): arg2 = np.where(np.equal(arg2, None), np.NaN, arg2).astype(dtype=np.float) if not all( - np.isclose(arg1, arg2, rtol=1.e-3, atol=1.e-8, + np.isclose(arg1, arg2, rtol=rtol, atol=atol, equal_nan=True).flatten()): if print_content is not None: print(f'{print_content}') diff --git a/modelscope/utils/tensor_utils.py b/modelscope/utils/tensor_utils.py index b438e476..b68a639c 100644 --- a/modelscope/utils/tensor_utils.py +++ b/modelscope/utils/tensor_utils.py @@ -1,12 +1,24 @@ # Copyright (c) Alibaba, Inc. and its affiliates. # Part of the implementation is borrowed from huggingface/transformers. +from collections import Mapping def torch_nested_numpify(tensors): + """ Numpify nested torch tensors. + + NOTE: If the type of input tensors is dict-like(Mapping, dict, OrderedDict, etc.), the return type will be dict. + + @param tensors: Nested torch tensors. + @return: The numpify tensors. + """ + import torch "Numpify `tensors` (even if it's a nested list/tuple of tensors)." if isinstance(tensors, (list, tuple)): return type(tensors)(torch_nested_numpify(t) for t in tensors) + if isinstance(tensors, Mapping): + # return dict + return {k: torch_nested_numpify(t) for k, t in tensors.items()} if isinstance(tensors, torch.Tensor): t = tensors.cpu() return t.numpy() @@ -14,10 +26,20 @@ def torch_nested_numpify(tensors): def torch_nested_detach(tensors): + """ Detach nested torch tensors. + + NOTE: If the type of input tensors is dict-like(Mapping, dict, OrderedDict, etc.), the return type will be dict. + + @param tensors: Nested torch tensors. + @return: The detached tensors. + """ + import torch "Detach `tensors` (even if it's a nested list/tuple of tensors)." if isinstance(tensors, (list, tuple)): return type(tensors)(torch_nested_detach(t) for t in tensors) + if isinstance(tensors, Mapping): + return {k: torch_nested_detach(t) for k, t in tensors.items()} if isinstance(tensors, torch.Tensor): return tensors.detach() return tensors diff --git a/tests/export/__init__.py b/tests/export/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/export/test_export_sbert_sequence_classification.py b/tests/export/test_export_sbert_sequence_classification.py new file mode 100644 index 00000000..535b3f5d --- /dev/null +++ b/tests/export/test_export_sbert_sequence_classification.py @@ -0,0 +1,37 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import os +import shutil +import tempfile +import unittest + +from modelscope.exporters import Exporter, TorchModelExporter +from modelscope.models.base import Model +from modelscope.utils.test_utils import test_level + + +class TestExportSbertSequenceClassification(unittest.TestCase): + + def setUp(self): + print(('Testing %s.%s' % (type(self).__name__, self._testMethodName))) + self.tmp_dir = tempfile.TemporaryDirectory().name + if not os.path.exists(self.tmp_dir): + os.makedirs(self.tmp_dir) + self.model_id = 'damo/nlp_structbert_sentence-similarity_chinese-base' + + def tearDown(self): + shutil.rmtree(self.tmp_dir) + super().tearDown() + + @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') + def test_export_sbert_sequence_classification(self): + model = Model.from_pretrained(self.model_id) + print( + Exporter.from_model(model).export_onnx( + shape=(2, 256), outputs=self.tmp_dir)) + print( + TorchModelExporter.from_model(model).export_torch_script( + shape=(2, 256), outputs=self.tmp_dir)) + + +if __name__ == '__main__': + unittest.main()