1. Add exporter module
2. Move collate_fn out of the base pipeline class for reusing.
3. Add dummy inputs method in nlp tokenization preprocessor base class
4. Support Mapping in tensor numpify and detaching.
Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/10037704
master
| @@ -0,0 +1,4 @@ | |||
| from .base import Exporter | |||
| from .builder import build_exporter | |||
| from .nlp import SbertForSequenceClassificationExporter | |||
| from .torch_model_exporter import TorchModelExporter | |||
| @@ -0,0 +1,53 @@ | |||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||
| import os | |||
| from abc import ABC, abstractmethod | |||
| from modelscope.models import Model | |||
| from modelscope.utils.config import Config, ConfigDict | |||
| from modelscope.utils.constant import ModelFile | |||
| from .builder import build_exporter | |||
| class Exporter(ABC): | |||
| """Exporter base class to output model to onnx, torch_script, graphdef, etc. | |||
| """ | |||
| def __init__(self): | |||
| self.model = None | |||
| @classmethod | |||
| def from_model(cls, model: Model, **kwargs): | |||
| """Build the Exporter instance. | |||
| @param model: A model instance. it will be used to output the generated file, | |||
| and the configuration.json in its model_dir field will be used to create the exporter instance. | |||
| @param kwargs: Extra kwargs used to create the Exporter instance. | |||
| @return: The Exporter instance | |||
| """ | |||
| cfg = Config.from_file( | |||
| os.path.join(model.model_dir, ModelFile.CONFIGURATION)) | |||
| task_name = cfg.task | |||
| model_cfg = cfg.model | |||
| if hasattr(model_cfg, 'model_type') and not hasattr(model_cfg, 'type'): | |||
| model_cfg.type = model_cfg.model_type | |||
| export_cfg = ConfigDict({'type': model_cfg.type}) | |||
| if hasattr(cfg, 'export'): | |||
| export_cfg.update(cfg.export) | |||
| exporter = build_exporter(export_cfg, task_name, kwargs) | |||
| exporter.model = model | |||
| return exporter | |||
| @abstractmethod | |||
| def export_onnx(self, outputs: str, opset=11, **kwargs): | |||
| """Export the model as onnx format files. | |||
| In some cases, several files may be generated, | |||
| So please return a dict which contains the generated name with the file path. | |||
| @param opset: The version of the ONNX operator set to use. | |||
| @param outputs: The output dir. | |||
| @param kwargs: In this default implementation, | |||
| kwargs will be carried to generate_dummy_inputs as extra arguments (like input shape). | |||
| @return: A dict contains the model name with the model file path. | |||
| """ | |||
| pass | |||
| @@ -0,0 +1,21 @@ | |||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||
| from modelscope.utils.config import ConfigDict | |||
| from modelscope.utils.registry import Registry, build_from_cfg | |||
| EXPORTERS = Registry('exporters') | |||
| def build_exporter(cfg: ConfigDict, | |||
| task_name: str = None, | |||
| default_args: dict = None): | |||
| """ build exporter by the given model config dict | |||
| Args: | |||
| cfg (:obj:`ConfigDict`): config dict for exporter object. | |||
| task_name (str, optional): task name, refer to | |||
| :obj:`Tasks` for more details | |||
| default_args (dict, optional): Default initialization arguments. | |||
| """ | |||
| return build_from_cfg( | |||
| cfg, EXPORTERS, group_key=task_name, default_args=default_args) | |||
| @@ -0,0 +1,2 @@ | |||
| from .sbert_for_sequence_classification_exporter import \ | |||
| SbertForSequenceClassificationExporter | |||
| @@ -0,0 +1,81 @@ | |||
| import os | |||
| from collections import OrderedDict | |||
| from typing import Any, Dict, Mapping, Tuple | |||
| from torch.utils.data.dataloader import default_collate | |||
| from modelscope.exporters.builder import EXPORTERS | |||
| from modelscope.exporters.torch_model_exporter import TorchModelExporter | |||
| from modelscope.metainfo import Models | |||
| from modelscope.preprocessors import Preprocessor, build_preprocessor | |||
| from modelscope.utils.config import Config | |||
| from modelscope.utils.constant import ModeKeys, Tasks | |||
| @EXPORTERS.register_module( | |||
| Tasks.sentence_similarity, module_name=Models.structbert) | |||
| @EXPORTERS.register_module( | |||
| Tasks.sentiment_classification, module_name=Models.structbert) | |||
| @EXPORTERS.register_module(Tasks.nli, module_name=Models.structbert) | |||
| @EXPORTERS.register_module( | |||
| Tasks.zero_shot_classification, module_name=Models.structbert) | |||
| class SbertForSequenceClassificationExporter(TorchModelExporter): | |||
| def generate_dummy_inputs(self, | |||
| shape: Tuple = None, | |||
| **kwargs) -> Dict[str, Any]: | |||
| """Generate dummy inputs for model exportation to onnx or other formats by tracing. | |||
| @param shape: A tuple of input shape which should have at most two dimensions. | |||
| shape = (1, ) batch_size=1, sequence_length will be taken from the preprocessor. | |||
| shape = (8, 128) batch_size=1, sequence_length=128, which will cover the config of the preprocessor. | |||
| @return: Dummy inputs. | |||
| """ | |||
| cfg = Config.from_file( | |||
| os.path.join(self.model.model_dir, 'configuration.json')) | |||
| field_name = Tasks.find_field_by_task(cfg.task) | |||
| if 'type' not in cfg.preprocessor and 'val' in cfg.preprocessor: | |||
| cfg = cfg.preprocessor.val | |||
| else: | |||
| cfg = cfg.preprocessor | |||
| batch_size = 1 | |||
| sequence_length = {} | |||
| if shape is not None: | |||
| if len(shape) == 1: | |||
| batch_size = shape[0] | |||
| elif len(shape) == 2: | |||
| batch_size, max_length = shape | |||
| sequence_length = {'sequence_length': max_length} | |||
| cfg.update({ | |||
| 'model_dir': self.model.model_dir, | |||
| 'mode': ModeKeys.TRAIN, | |||
| **sequence_length | |||
| }) | |||
| preprocessor: Preprocessor = build_preprocessor(cfg, field_name) | |||
| if preprocessor.pair: | |||
| first_sequence = preprocessor.tokenizer.unk_token | |||
| second_sequence = preprocessor.tokenizer.unk_token | |||
| else: | |||
| first_sequence = preprocessor.tokenizer.unk_token | |||
| second_sequence = None | |||
| batched = [] | |||
| for _ in range(batch_size): | |||
| batched.append(preprocessor((first_sequence, second_sequence))) | |||
| return default_collate(batched) | |||
| @property | |||
| def inputs(self) -> Mapping[str, Mapping[int, str]]: | |||
| dynamic_axis = {0: 'batch', 1: 'sequence'} | |||
| return OrderedDict([ | |||
| ('input_ids', dynamic_axis), | |||
| ('attention_mask', dynamic_axis), | |||
| ('token_type_ids', dynamic_axis), | |||
| ]) | |||
| @property | |||
| def outputs(self) -> Mapping[str, Mapping[int, str]]: | |||
| return OrderedDict({'logits': {0: 'batch'}}) | |||
| @@ -0,0 +1,247 @@ | |||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||
| import os | |||
| from contextlib import contextmanager | |||
| from itertools import chain | |||
| from typing import Any, Dict, Mapping | |||
| import torch | |||
| from torch import nn | |||
| from torch.onnx import export as onnx_export | |||
| from torch.onnx.utils import _decide_input_format | |||
| from modelscope.models import TorchModel | |||
| from modelscope.pipelines.base import collate_fn | |||
| from modelscope.utils.constant import ModelFile | |||
| from modelscope.utils.logger import get_logger | |||
| from modelscope.utils.regress_test_utils import compare_arguments_nested | |||
| from modelscope.utils.tensor_utils import torch_nested_numpify | |||
| from .base import Exporter | |||
| logger = get_logger(__name__) | |||
| class TorchModelExporter(Exporter): | |||
| """The torch base class of exporter. | |||
| This class provides the default implementations for exporting onnx and torch script. | |||
| Each specific model may implement its own exporter by overriding the export_onnx/export_torch_script, | |||
| and to provide implementations for generate_dummy_inputs/inputs/outputs methods. | |||
| """ | |||
| def export_onnx(self, outputs: str, opset=11, **kwargs): | |||
| """Export the model as onnx format files. | |||
| In some cases, several files may be generated, | |||
| So please return a dict which contains the generated name with the file path. | |||
| @param opset: The version of the ONNX operator set to use. | |||
| @param outputs: The output dir. | |||
| @param kwargs: In this default implementation, | |||
| you can pass the arguments needed by _torch_export_onnx, other unrecognized args | |||
| will be carried to generate_dummy_inputs as extra arguments (such as input shape). | |||
| @return: A dict containing the model key - model file path pairs. | |||
| """ | |||
| model = self.model | |||
| if not isinstance(model, nn.Module) and hasattr(model, 'model'): | |||
| model = model.model | |||
| onnx_file = os.path.join(outputs, ModelFile.ONNX_MODEL_FILE) | |||
| self._torch_export_onnx(model, onnx_file, opset=opset, **kwargs) | |||
| return {'model': onnx_file} | |||
| def export_torch_script(self, outputs: str, **kwargs): | |||
| """Export the model as torch script files. | |||
| In some cases, several files may be generated, | |||
| So please return a dict which contains the generated name with the file path. | |||
| @param outputs: The output dir. | |||
| @param kwargs: In this default implementation, | |||
| you can pass the arguments needed by _torch_export_torch_script, other unrecognized args | |||
| will be carried to generate_dummy_inputs as extra arguments (like input shape). | |||
| @return: A dict contains the model name with the model file path. | |||
| """ | |||
| model = self.model | |||
| if not isinstance(model, nn.Module) and hasattr(model, 'model'): | |||
| model = model.model | |||
| ts_file = os.path.join(outputs, ModelFile.TS_MODEL_FILE) | |||
| # generate ts by tracing | |||
| self._torch_export_torch_script(model, ts_file, **kwargs) | |||
| return {'model': ts_file} | |||
| def generate_dummy_inputs(self, **kwargs) -> Dict[str, Any]: | |||
| """Generate dummy inputs for model exportation to onnx or other formats by tracing. | |||
| @return: Dummy inputs. | |||
| """ | |||
| return None | |||
| @property | |||
| def inputs(self) -> Mapping[str, Mapping[int, str]]: | |||
| """Return an ordered dict contains the model's input arguments name with their dynamic axis. | |||
| About the information of dynamic axis please check the dynamic_axes argument of torch.onnx.export function | |||
| """ | |||
| return None | |||
| @property | |||
| def outputs(self) -> Mapping[str, Mapping[int, str]]: | |||
| """Return an ordered dict contains the model's output arguments name with their dynamic axis. | |||
| About the information of dynamic axis please check the dynamic_axes argument of torch.onnx.export function | |||
| """ | |||
| return None | |||
| def _torch_export_onnx(self, | |||
| model: nn.Module, | |||
| output: str, | |||
| opset: int = 11, | |||
| device: str = 'cpu', | |||
| validation: bool = True, | |||
| rtol: float = None, | |||
| atol: float = None, | |||
| **kwargs): | |||
| """Export the model to an onnx format file. | |||
| @param model: A torch.nn.Module instance to export. | |||
| @param output: The output file. | |||
| @param opset: The version of the ONNX operator set to use. | |||
| @param device: The device used to forward. | |||
| @param validation: Whether validate the export file. | |||
| @param rtol: The rtol used to regress the outputs. | |||
| @param atol: The atol used to regress the outputs. | |||
| """ | |||
| dummy_inputs = self.generate_dummy_inputs(**kwargs) | |||
| inputs = self.inputs | |||
| outputs = self.outputs | |||
| if dummy_inputs is None or inputs is None or outputs is None: | |||
| raise NotImplementedError( | |||
| 'Model property dummy_inputs,inputs,outputs must be set.') | |||
| with torch.no_grad(): | |||
| model.eval() | |||
| device = torch.device(device) | |||
| model.to(device) | |||
| dummy_inputs = collate_fn(dummy_inputs, device) | |||
| if isinstance(dummy_inputs, Mapping): | |||
| dummy_inputs = dict(dummy_inputs) | |||
| onnx_outputs = list(self.outputs.keys()) | |||
| with replace_call(): | |||
| onnx_export( | |||
| model, | |||
| (dummy_inputs, ), | |||
| f=output, | |||
| input_names=list(inputs.keys()), | |||
| output_names=onnx_outputs, | |||
| dynamic_axes={ | |||
| name: axes | |||
| for name, axes in chain(inputs.items(), | |||
| outputs.items()) | |||
| }, | |||
| do_constant_folding=True, | |||
| opset_version=opset, | |||
| ) | |||
| if validation: | |||
| try: | |||
| import onnx | |||
| import onnxruntime as ort | |||
| except ImportError: | |||
| logger.warn( | |||
| 'Cannot validate the exported onnx file, because ' | |||
| 'the installation of onnx or onnxruntime cannot be found') | |||
| return | |||
| onnx_model = onnx.load(output) | |||
| onnx.checker.check_model(onnx_model) | |||
| ort_session = ort.InferenceSession(output) | |||
| with torch.no_grad(): | |||
| model.eval() | |||
| outputs_origin = model.forward( | |||
| *_decide_input_format(model, dummy_inputs)) | |||
| if isinstance(outputs_origin, Mapping): | |||
| outputs_origin = torch_nested_numpify( | |||
| list(outputs_origin.values())) | |||
| outputs = ort_session.run( | |||
| onnx_outputs, | |||
| torch_nested_numpify(dummy_inputs), | |||
| ) | |||
| tols = {} | |||
| if rtol is not None: | |||
| tols['rtol'] = rtol | |||
| if atol is not None: | |||
| tols['atol'] = atol | |||
| if not compare_arguments_nested('Onnx model output match failed', | |||
| outputs, outputs_origin, **tols): | |||
| raise RuntimeError( | |||
| 'export onnx failed because of validation error.') | |||
| def _torch_export_torch_script(self, | |||
| model: nn.Module, | |||
| output: str, | |||
| device: str = 'cpu', | |||
| validation: bool = True, | |||
| rtol: float = None, | |||
| atol: float = None, | |||
| **kwargs): | |||
| """Export the model to a torch script file. | |||
| @param model: A torch.nn.Module instance to export. | |||
| @param output: The output file. | |||
| @param device: The device used to forward. | |||
| @param validation: Whether validate the export file. | |||
| @param rtol: The rtol used to regress the outputs. | |||
| @param atol: The atol used to regress the outputs. | |||
| """ | |||
| model.eval() | |||
| dummy_inputs = self.generate_dummy_inputs(**kwargs) | |||
| if dummy_inputs is None: | |||
| raise NotImplementedError( | |||
| 'Model property dummy_inputs must be set.') | |||
| dummy_inputs = collate_fn(dummy_inputs, device) | |||
| if isinstance(dummy_inputs, Mapping): | |||
| dummy_inputs = tuple(dummy_inputs.values()) | |||
| with torch.no_grad(): | |||
| model.eval() | |||
| with replace_call(): | |||
| traced_model = torch.jit.trace( | |||
| model, dummy_inputs, strict=False) | |||
| torch.jit.save(traced_model, output) | |||
| if validation: | |||
| ts_model = torch.jit.load(output) | |||
| with torch.no_grad(): | |||
| model.eval() | |||
| ts_model.eval() | |||
| outputs = ts_model.forward(*dummy_inputs) | |||
| outputs = torch_nested_numpify(outputs) | |||
| outputs_origin = model.forward(*dummy_inputs) | |||
| outputs_origin = torch_nested_numpify(outputs_origin) | |||
| tols = {} | |||
| if rtol is not None: | |||
| tols['rtol'] = rtol | |||
| if atol is not None: | |||
| tols['atol'] = atol | |||
| if not compare_arguments_nested( | |||
| 'Torch script model output match failed', outputs, | |||
| outputs_origin, **tols): | |||
| raise RuntimeError( | |||
| 'export torch script failed because of validation error.') | |||
| @contextmanager | |||
| def replace_call(): | |||
| """This function is used to recover the original call method. | |||
| The Model class of modelscope overrides the call method. When exporting to onnx or torchscript, torch will | |||
| prepare the parameters as the prototype of forward method, and trace the call method, this causes | |||
| problems. Here we recover the call method to the default implementation of torch.nn.Module, and change it | |||
| back after the tracing was done. | |||
| """ | |||
| TorchModel.call_origin, TorchModel.__call__ = TorchModel.__call__, TorchModel._call_impl | |||
| yield | |||
| TorchModel.__call__ = TorchModel.call_origin | |||
| del TorchModel.call_origin | |||
| @@ -28,7 +28,7 @@ if is_torch_available(): | |||
| import torch | |||
| if is_tf_available(): | |||
| import tensorflow as tf | |||
| pass | |||
| Tensor = Union['torch.Tensor', 'tf.Tensor'] | |||
| Input = Union[str, tuple, MsDataset, 'Image.Image', 'numpy.ndarray'] | |||
| @@ -204,44 +204,7 @@ class Pipeline(ABC): | |||
| yield self._process_single(ele, *args, **kwargs) | |||
| def _collate_fn(self, data): | |||
| """Prepare the input just before the forward function. | |||
| This method will move the tensors to the right device. | |||
| Usually this method does not need to be overridden. | |||
| Args: | |||
| data: The data out of the dataloader. | |||
| Returns: The processed data. | |||
| """ | |||
| from torch.utils.data.dataloader import default_collate | |||
| from modelscope.preprocessors import InputFeatures | |||
| if isinstance(data, dict) or isinstance(data, Mapping): | |||
| return type(data)( | |||
| {k: self._collate_fn(v) | |||
| for k, v in data.items()}) | |||
| elif isinstance(data, (tuple, list)): | |||
| if isinstance(data[0], (int, float)): | |||
| return default_collate(data).to(self.device) | |||
| else: | |||
| return type(data)(self._collate_fn(v) for v in data) | |||
| elif isinstance(data, np.ndarray): | |||
| if data.dtype.type is np.str_: | |||
| return data | |||
| else: | |||
| return self._collate_fn(torch.from_numpy(data)) | |||
| elif isinstance(data, torch.Tensor): | |||
| return data.to(self.device) | |||
| elif isinstance(data, (bytes, str, int, float, bool, type(None))): | |||
| return data | |||
| elif isinstance(data, InputFeatures): | |||
| return data | |||
| else: | |||
| import mmcv | |||
| if isinstance(data, mmcv.parallel.data_container.DataContainer): | |||
| return data | |||
| else: | |||
| raise ValueError(f'Unsupported data type {type(data)}') | |||
| return collate_fn(data, self.device) | |||
| def _process_single(self, input: Input, *args, **kwargs) -> Dict[str, Any]: | |||
| preprocess_params = kwargs.get('preprocess_params', {}) | |||
| @@ -410,3 +373,43 @@ class DistributedPipeline(Pipeline): | |||
| @return: The forward results. | |||
| """ | |||
| pass | |||
| def collate_fn(data, device): | |||
| """Prepare the input just before the forward function. | |||
| This method will move the tensors to the right device. | |||
| Usually this method does not need to be overridden. | |||
| Args: | |||
| data: The data out of the dataloader. | |||
| device: The device to move data to. | |||
| Returns: The processed data. | |||
| """ | |||
| from torch.utils.data.dataloader import default_collate | |||
| from modelscope.preprocessors import InputFeatures | |||
| if isinstance(data, dict) or isinstance(data, Mapping): | |||
| return type(data)({k: collate_fn(v, device) for k, v in data.items()}) | |||
| elif isinstance(data, (tuple, list)): | |||
| if isinstance(data[0], (int, float)): | |||
| return default_collate(data).to(device) | |||
| else: | |||
| return type(data)(collate_fn(v, device) for v in data) | |||
| elif isinstance(data, np.ndarray): | |||
| if data.dtype.type is np.str_: | |||
| return data | |||
| else: | |||
| return collate_fn(torch.from_numpy(data), device) | |||
| elif isinstance(data, torch.Tensor): | |||
| return data.to(device) | |||
| elif isinstance(data, (bytes, str, int, float, bool, type(None))): | |||
| return data | |||
| elif isinstance(data, InputFeatures): | |||
| return data | |||
| else: | |||
| import mmcv | |||
| if isinstance(data, mmcv.parallel.data_container.DataContainer): | |||
| return data | |||
| else: | |||
| raise ValueError(f'Unsupported data type {type(data)}') | |||
| @@ -246,6 +246,7 @@ class ModelFile(object): | |||
| ONNX_MODEL_FILE = 'model.onnx' | |||
| LABEL_MAPPING = 'label_mapping.json' | |||
| TRAIN_OUTPUT_DIR = 'output' | |||
| TS_MODEL_FILE = 'model.ts' | |||
| class ConfigFields(object): | |||
| @@ -352,10 +352,10 @@ def numpify_tensor_nested(tensors, reduction=None, clip_value=10000): | |||
| return type(tensors)( | |||
| numpify_tensor_nested(t, reduction, clip_value) for t in tensors) | |||
| if isinstance(tensors, Mapping): | |||
| return type(tensors)({ | |||
| return { | |||
| k: numpify_tensor_nested(t, reduction, clip_value) | |||
| for k, t in tensors.items() | |||
| }) | |||
| } | |||
| if isinstance(tensors, torch.Tensor): | |||
| t: np.ndarray = tensors.cpu().numpy() | |||
| if clip_value is not None: | |||
| @@ -375,9 +375,7 @@ def detach_tensor_nested(tensors): | |||
| if isinstance(tensors, (list, tuple)): | |||
| return type(tensors)(detach_tensor_nested(t) for t in tensors) | |||
| if isinstance(tensors, Mapping): | |||
| return type(tensors)( | |||
| {k: detach_tensor_nested(t) | |||
| for k, t in tensors.items()}) | |||
| return {k: detach_tensor_nested(t) for k, t in tensors.items()} | |||
| if isinstance(tensors, torch.Tensor): | |||
| return tensors.detach() | |||
| return tensors | |||
| @@ -496,7 +494,11 @@ def intercept_module(module: nn.Module, | |||
| intercept_module(module, io_json, full_name, restore) | |||
| def compare_arguments_nested(print_content, arg1, arg2): | |||
| def compare_arguments_nested(print_content, | |||
| arg1, | |||
| arg2, | |||
| rtol=1.e-3, | |||
| atol=1.e-8): | |||
| type1 = type(arg1) | |||
| type2 = type(arg2) | |||
| if type1.__name__ != type2.__name__: | |||
| @@ -515,7 +517,7 @@ def compare_arguments_nested(print_content, arg1, arg2): | |||
| return False | |||
| return True | |||
| elif isinstance(arg1, (float, np.floating)): | |||
| if not np.isclose(arg1, arg2, rtol=1.e-3, atol=1.e-8, equal_nan=True): | |||
| if not np.isclose(arg1, arg2, rtol=rtol, atol=atol, equal_nan=True): | |||
| if print_content is not None: | |||
| print(f'{print_content}, arg1:{arg1}, arg2:{arg2}') | |||
| return False | |||
| @@ -562,7 +564,7 @@ def compare_arguments_nested(print_content, arg1, arg2): | |||
| arg2 = np.where(np.equal(arg2, None), np.NaN, | |||
| arg2).astype(dtype=np.float) | |||
| if not all( | |||
| np.isclose(arg1, arg2, rtol=1.e-3, atol=1.e-8, | |||
| np.isclose(arg1, arg2, rtol=rtol, atol=atol, | |||
| equal_nan=True).flatten()): | |||
| if print_content is not None: | |||
| print(f'{print_content}') | |||
| @@ -1,12 +1,24 @@ | |||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||
| # Part of the implementation is borrowed from huggingface/transformers. | |||
| from collections import Mapping | |||
| def torch_nested_numpify(tensors): | |||
| """ Numpify nested torch tensors. | |||
| NOTE: If the type of input tensors is dict-like(Mapping, dict, OrderedDict, etc.), the return type will be dict. | |||
| @param tensors: Nested torch tensors. | |||
| @return: The numpify tensors. | |||
| """ | |||
| import torch | |||
| "Numpify `tensors` (even if it's a nested list/tuple of tensors)." | |||
| if isinstance(tensors, (list, tuple)): | |||
| return type(tensors)(torch_nested_numpify(t) for t in tensors) | |||
| if isinstance(tensors, Mapping): | |||
| # return dict | |||
| return {k: torch_nested_numpify(t) for k, t in tensors.items()} | |||
| if isinstance(tensors, torch.Tensor): | |||
| t = tensors.cpu() | |||
| return t.numpy() | |||
| @@ -14,10 +26,20 @@ def torch_nested_numpify(tensors): | |||
| def torch_nested_detach(tensors): | |||
| """ Detach nested torch tensors. | |||
| NOTE: If the type of input tensors is dict-like(Mapping, dict, OrderedDict, etc.), the return type will be dict. | |||
| @param tensors: Nested torch tensors. | |||
| @return: The detached tensors. | |||
| """ | |||
| import torch | |||
| "Detach `tensors` (even if it's a nested list/tuple of tensors)." | |||
| if isinstance(tensors, (list, tuple)): | |||
| return type(tensors)(torch_nested_detach(t) for t in tensors) | |||
| if isinstance(tensors, Mapping): | |||
| return {k: torch_nested_detach(t) for k, t in tensors.items()} | |||
| if isinstance(tensors, torch.Tensor): | |||
| return tensors.detach() | |||
| return tensors | |||
| @@ -0,0 +1,37 @@ | |||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||
| import os | |||
| import shutil | |||
| import tempfile | |||
| import unittest | |||
| from modelscope.exporters import Exporter, TorchModelExporter | |||
| from modelscope.models.base import Model | |||
| from modelscope.utils.test_utils import test_level | |||
| class TestExportSbertSequenceClassification(unittest.TestCase): | |||
| def setUp(self): | |||
| print(('Testing %s.%s' % (type(self).__name__, self._testMethodName))) | |||
| self.tmp_dir = tempfile.TemporaryDirectory().name | |||
| if not os.path.exists(self.tmp_dir): | |||
| os.makedirs(self.tmp_dir) | |||
| self.model_id = 'damo/nlp_structbert_sentence-similarity_chinese-base' | |||
| def tearDown(self): | |||
| shutil.rmtree(self.tmp_dir) | |||
| super().tearDown() | |||
| @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') | |||
| def test_export_sbert_sequence_classification(self): | |||
| model = Model.from_pretrained(self.model_id) | |||
| print( | |||
| Exporter.from_model(model).export_onnx( | |||
| shape=(2, 256), outputs=self.tmp_dir)) | |||
| print( | |||
| TorchModelExporter.from_model(model).export_torch_script( | |||
| shape=(2, 256), outputs=self.tmp_dir)) | |||
| if __name__ == '__main__': | |||
| unittest.main() | |||