| @@ -219,3 +219,17 @@ def reset_init_or_construct(template, variable_slot, new_data, scope): | |||
| template[variable_slot][scope].clear() | |||
| template[variable_slot][scope] += new_data | |||
| return template | |||
| def replace_string_in_list(str_list: list, original_str: str, target_str: str): | |||
| """ | |||
| Replace a string in a list by provided string. | |||
| Args: | |||
| str_list (list): A list contains the string to be replaced. | |||
| original_str (str): The string to be replaced. | |||
| target_str (str): The replacement of string. | |||
| Returns, | |||
| list, the original list with replaced string. | |||
| """ | |||
| return [s.replace(original_str, target_str) for s in str_list] | |||
| @@ -27,9 +27,10 @@ from mindinsight.mindconverter.graph_based_converter.common.global_context impor | |||
| from mindinsight.mindconverter.graph_based_converter.common.outputs import BaseOutput, ModuleOutputManager | |||
| from mindinsight.mindconverter.graph_based_converter.common.yapf_config import mindspore_yapf_config | |||
| from mindinsight.mindconverter.graph_based_converter.common.name_mgr import GlobalVarNameMgr | |||
| from mindinsight.mindconverter.graph_based_converter.constant import NEW_LINE, SECOND_LEVEL_INDENT, FIRST_LEVEL_INDENT, \ | |||
| get_imported_module | |||
| from mindinsight.mindconverter.graph_based_converter.constant import NEW_LINE, SECOND_LEVEL_INDENT, \ | |||
| FIRST_LEVEL_INDENT, get_imported_module | |||
| from mindinsight.mindconverter.graph_based_converter.report_generator import ReportGenerator | |||
| from mindinsight.mindconverter.graph_based_converter.common.utils import replace_string_in_list | |||
| class CodeStruct: | |||
| @@ -373,6 +374,10 @@ class Generator: | |||
| self.module_structs.get('[]').allocate_construct_header_x() | |||
| self.module_structs.get('[]').collect_returns() | |||
| for nd_struct in self.node_structs.values(): | |||
| if nd_struct.fragment.metadata.get("operation") == "Split": | |||
| self._split_op_procs(nd_struct) | |||
| def _update_all_modules_args_translator(self): | |||
| """Update all modules' args translators.""" | |||
| done_submodule = set() | |||
| @@ -633,3 +638,73 @@ class Generator: | |||
| if user in root_module.external_successor_nodes_names: | |||
| return True | |||
| return False | |||
| def _split_op_procs(self, split_struct: NodeStruct): | |||
| """ | |||
| Support for Split operation multiple outputs. | |||
| Args: | |||
| split_struct (NodeStruct): The NodeStruct of the Split op. | |||
| """ | |||
| for successor in split_struct.successor_nodes_structs: | |||
| # 1. target user is internal | |||
| if split_struct.check_target_node_internal(successor.identifier): | |||
| idx = self._get_correct_input_idx_from_split(split_struct, successor) | |||
| if idx is None: | |||
| raise ValueError("The Split OP should not has empty output.") | |||
| correct_input = split_struct.fragment.fragment.get_outputs_by_idx(0, idx) | |||
| to_be_replaced = None | |||
| for inp in successor.matched_inputs: | |||
| if "split" in inp: | |||
| to_be_replaced = inp | |||
| break | |||
| if to_be_replaced is not None: | |||
| successor.matched_inputs = replace_string_in_list(successor.matched_inputs, | |||
| to_be_replaced, | |||
| correct_input) | |||
| # 2. target user is external | |||
| else: | |||
| public_parent = self._get_public_parent_module(split_struct, successor) | |||
| to_be_modified_md = self._get_submodule_has_out_user_under_public_parent(public_parent, successor) | |||
| idx = self._get_correct_input_idx_from_split(split_struct, successor) | |||
| if idx is None: | |||
| raise ValueError("The Split OP should not has empty output.") | |||
| if to_be_modified_md is None: | |||
| raise ValueError("Unable to locate the submodule to be modified for Split output matching.") | |||
| correct_input = split_struct.fragment.fragment.get_outputs_by_idx(0, idx) | |||
| to_be_replaced = None | |||
| for inp in to_be_modified_md.matched_inputs: | |||
| if "split" in inp: | |||
| to_be_replaced = inp | |||
| break | |||
| if to_be_replaced is not None: | |||
| to_be_modified_md.matched_inputs = replace_string_in_list(to_be_modified_md.matched_inputs, | |||
| to_be_replaced, | |||
| correct_input) | |||
| def _get_correct_input_idx_from_split(self, split_struct: NodeStruct, split_out_user: NodeStruct): | |||
| """Return the index of the split output the user used.""" | |||
| split_struct_out_edges = split_struct.fragment.metadata.get("outputs") | |||
| for idx, out in enumerate(split_struct_out_edges): | |||
| if out in split_out_user.fragment.metadata.get("inputs"): | |||
| return idx | |||
| return None | |||
| def _get_public_parent_module(self, node_a: NodeStruct, node_b: NodeStruct): | |||
| """Return the public parent module of both Node A and Node B.""" | |||
| find = False | |||
| b_onnx_name = node_b.onnx_name | |||
| tmp = node_a | |||
| while not find: | |||
| parent_struct = tmp.parent_module_struct | |||
| if b_onnx_name in parent_struct.onnx_names: | |||
| find = True | |||
| tmp = parent_struct | |||
| return tmp | |||
| def _get_submodule_has_out_user_under_public_parent(self, public_module: ModuleStruct, node_out_user: NodeStruct): | |||
| """Return the ModuleStruct which under the public module and contains the NodeStruct which provided.""" | |||
| for module_struct in public_module.module_structs: | |||
| if node_out_user.onnx_name in module_struct.onnx_names: | |||
| return module_struct | |||
| return None | |||
| @@ -58,11 +58,6 @@ class Mapper(metaclass=abc.ABCMeta): | |||
| def _convert_trained_weights(**kwargs): | |||
| """Convert third party operation's weights into MindSpore operation.""" | |||
| @staticmethod | |||
| @abc.abstractmethod | |||
| def _convert_settings(**kwargs): | |||
| """Convert third party operation's params into MindSpore OP operator.""" | |||
| @classmethod | |||
| @abc.abstractmethod | |||
| def convert(cls, op_name: str, params: Dict, weights: Dict = None): | |||
| @@ -148,10 +143,6 @@ class ONNXToMindSporeMapper(Mapper, abc.ABC): | |||
| def _convert_trained_weights(**kwargs): | |||
| raise NotImplementedError | |||
| @staticmethod | |||
| def _convert_settings(**kwargs): | |||
| raise NotImplementedError | |||
| @staticmethod | |||
| def _generate_snippet_template(**kwargs): | |||
| op = kwargs.get("operation") | |||
| @@ -13,10 +13,8 @@ | |||
| # limitations under the License. | |||
| # ============================================================================== | |||
| """Mapper module.""" | |||
| from mindinsight.mindconverter.graph_based_converter.common.utils import reset_init_or_construct | |||
| from mindinsight.mindconverter.graph_based_converter.constant import ExchangeMessageKeywords, TemplateKeywords | |||
| from mindinsight.mindconverter.graph_based_converter.mapper.base import ONNXToMindSporeMapper | |||
| from mindinsight.mindconverter.graph_based_converter.mapper.gen_setting import Setting | |||
| class ReshapeMapper(ONNXToMindSporeMapper): | |||
| @@ -34,40 +32,37 @@ class ReshapeMapper(ONNXToMindSporeMapper): | |||
| def _convert_trained_weights(**kwargs): | |||
| return dict() | |||
| @staticmethod | |||
| def _convert_settings(**kwargs): | |||
| if kwargs.get("weights", None): | |||
| return ReshapeMapper._convert_settings_tf(**kwargs) | |||
| return ReshapeMapper._convert_settings_pytorch(**kwargs) | |||
| @staticmethod | |||
| def _convert_settings_pytorch(**kwargs): | |||
| params = kwargs.get("params") | |||
| shape = params.get("output_shape") | |||
| return Setting(op_extra_input={"input_shape": tuple(shape)}) | |||
| @staticmethod | |||
| def _convert_settings_tf(**kwargs): | |||
| weights = kwargs.get("weights") | |||
| if len(weights) > 1: | |||
| raise ValueError("For reshape, `weights` length should equal to 1.") | |||
| shape = [-1] | |||
| shape += list(weights.values())[0][1:].tolist() | |||
| return Setting(op_extra_input={"shape": tuple(shape)}) | |||
| @staticmethod | |||
| def _generate_snippet_template(**kwargs): | |||
| template, exchange_msg, outputs_list, outputs_mapping = ONNXToMindSporeMapper._generate_snippet_template( | |||
| **kwargs) | |||
| weights = kwargs.get("weights") | |||
| if len(weights) > 1: | |||
| raise ValueError("For reshape, `weights` length should equal to 1.") | |||
| shape = [-1] | |||
| shape += list(weights.values())[0][1:].tolist() | |||
| output_shape = kwargs.get("raw_params").get("output_shape") | |||
| variable_slot = "var_0" | |||
| construct_template = f"opt_{{{variable_slot}}} = self.{{{variable_slot}}}" \ | |||
| f"({{{ExchangeMessageKeywords.VariableScope.value.INPUTS.value}}}, {tuple(shape)})" | |||
| template = reset_init_or_construct(template, variable_slot, [construct_template], | |||
| TemplateKeywords.CONSTRUCT.value) | |||
| op = kwargs.get("operation") | |||
| init_template = f"self.{{{variable_slot}}} = {op}()" | |||
| target_shape = f"self.{{{variable_slot}}}_shape = tuple({{shape}})" | |||
| args = {"shape": output_shape} | |||
| construct_template = f"opt_{{{variable_slot}}} = self.{{{variable_slot}}}" \ | |||
| f"({{{ExchangeMessageKeywords.VariableScope.value.INPUTS.value}}}, " \ | |||
| f"self.{{{variable_slot}}}_shape)" | |||
| template = { | |||
| variable_slot: { | |||
| TemplateKeywords.INIT.value: [init_template, target_shape], | |||
| TemplateKeywords.CONSTRUCT.value: [construct_template] | |||
| } | |||
| } | |||
| exchange_msg = { | |||
| variable_slot: { | |||
| ExchangeMessageKeywords.VariableScope.value.OPERATION.value: op, | |||
| ExchangeMessageKeywords.VariableScope.value.VARIABLE_NAME.value: None, | |||
| ExchangeMessageKeywords.VariableScope.value.OUTPUT_TYPE.value: | |||
| ExchangeMessageKeywords.VariableScope.value.TSR_TYPE.value, | |||
| ExchangeMessageKeywords.VariableScope.value.INPUTS.value: [], | |||
| ExchangeMessageKeywords.VariableScope.value.ARGS.value: args, | |||
| ExchangeMessageKeywords.VariableScope.value.WEIGHTS.value: weights, | |||
| ExchangeMessageKeywords.VariableScope.value.TRAINABLE_PARAMS.value: {} | |||
| } | |||
| } | |||
| outputs_list = [f"opt_{{{variable_slot}}}"] | |||
| outputs_mapping = ((0, 0),) | |||
| return template, exchange_msg, outputs_list, outputs_mapping | |||
| @@ -13,8 +13,8 @@ | |||
| # limitations under the License. | |||
| # ============================================================================== | |||
| """Mapper module.""" | |||
| from mindinsight.mindconverter.graph_based_converter.constant import ExchangeMessageKeywords, TemplateKeywords | |||
| from mindinsight.mindconverter.graph_based_converter.mapper.base import ONNXToMindSporeMapper | |||
| from mindinsight.mindconverter.graph_based_converter.mapper.gen_setting import Setting | |||
| class SplitMapper(ONNXToMindSporeMapper): | |||
| @@ -37,5 +37,34 @@ class SplitMapper(ONNXToMindSporeMapper): | |||
| return dict() | |||
| @staticmethod | |||
| def _convert_settings(**kwargs): | |||
| return Setting() | |||
| def _generate_snippet_template(**kwargs): | |||
| op = kwargs.get("operation") | |||
| args = kwargs.get("converted_params", dict()) | |||
| weights = kwargs.get("weights") | |||
| if not op: | |||
| raise ValueError("Can not get MindSpore operation name.") | |||
| variable_slot = "var_0" | |||
| init_template = f"self.{{{variable_slot}}} = {op}({', '.join(['%s={%s}' % (p, p) for p in args])})" | |||
| construct_template = f"opt_{{{variable_slot}}} = self.{{{variable_slot}}}" \ | |||
| f"({{{ExchangeMessageKeywords.VariableScope.value.INPUTS.value}}})" | |||
| template = { | |||
| variable_slot: { | |||
| TemplateKeywords.INIT.value: [init_template], | |||
| TemplateKeywords.CONSTRUCT.value: [construct_template] | |||
| } | |||
| } | |||
| exchange_msg = { | |||
| variable_slot: { | |||
| ExchangeMessageKeywords.VariableScope.value.OPERATION.value: op, | |||
| ExchangeMessageKeywords.VariableScope.value.VARIABLE_NAME.value: None, | |||
| ExchangeMessageKeywords.VariableScope.value.OUTPUT_TYPE.value: | |||
| ExchangeMessageKeywords.VariableScope.value.ARR_TYPE.value, | |||
| ExchangeMessageKeywords.VariableScope.value.INPUTS.value: [], | |||
| ExchangeMessageKeywords.VariableScope.value.ARGS.value: args, | |||
| ExchangeMessageKeywords.VariableScope.value.WEIGHTS.value: weights, | |||
| ExchangeMessageKeywords.VariableScope.value.TRAINABLE_PARAMS.value: {} | |||
| } | |||
| } | |||
| outputs_list = [f"opt_{{{variable_slot}}}"] | |||
| outputs_mapping = ((0, 0),) | |||
| return template, exchange_msg, outputs_list, outputs_mapping | |||
| @@ -24,10 +24,11 @@ import numpy as np | |||
| from mindinsight.mindconverter.common.log import logger as log | |||
| from ..common.utils import fetch_output_from_onnx_model | |||
| from ..common.global_context import GlobalContext | |||
| from .optimizer import OnnxSimplify | |||
| from ..constant import ONNX_TYPE_INT, ONNX_TYPE_INTS, ONNX_TYPE_STRING, \ | |||
| ONNX_TYPE_FLOATS, ONNX_TYPE_FLOAT, SCALAR_WITHOUT_SHAPE, DYNAMIC_SHAPE, UNKNOWN_DIM_VAL | |||
| from ...common.exceptions import GraphInitError, ModelNotSupportError, ModelLoadingError | |||
| from ...common.exceptions import GraphInitError, ModelLoadingError | |||
| def convert_tf_graph_to_onnx(model_path, model_inputs, model_outputs, opset=12): | |||
| @@ -258,9 +259,11 @@ class OnnxDataLoader: | |||
| def __init__(self, onnx_model, graph_input_shape: Union[tuple, list], | |||
| input_nodes: list, output_nodes: list, infer_shape=True): | |||
| self.model = onnx_model | |||
| self.graph = onnx_model.graph | |||
| self.nodes = onnx_model.graph.node | |||
| onnx_sim = OnnxSimplify() | |||
| onnx_model_sim = onnx_sim.run_onnx_simplify(onnx_model, graph_input_shape) | |||
| self.model = onnx_model_sim | |||
| self.graph = onnx_model_sim.graph | |||
| self.nodes = onnx_model_sim.graph.node | |||
| self.graph_input_shape = graph_input_shape | |||
| self.input_nodes = input_nodes if isinstance(input_nodes, list) else [input_nodes] | |||
| self.output_nodes = output_nodes if isinstance(output_nodes, list) else [output_nodes] | |||
| @@ -388,9 +391,8 @@ class OnnxDataLoader: | |||
| n = OnnxNode(node) | |||
| self._nodes_dict[n.name] = n | |||
| nodes_topo_idx.append((idx, n.name)) | |||
| if len(node.output) > 1: | |||
| raise ModelNotSupportError(msg=f"{node.name} has multi-outputs which is not supported now.") | |||
| self.output_name_to_node_name[node.output[0]] = node.name | |||
| for out in node.output: | |||
| self.output_name_to_node_name[out] = node.name | |||
| for ipt_nd in node.input: | |||
| if ipt_nd not in self.output_name_to_node_name: | |||
| @@ -17,10 +17,8 @@ import os | |||
| from importlib import import_module | |||
| from mindinsight.mindconverter.common.log import logger as log | |||
| from .base import GraphParser | |||
| from .optimizer import OnnxSimplify | |||
| from ...common.exceptions import ModelNotSupportError | |||
| from mindinsight.mindconverter.graph_based_converter.third_party_graph.base import GraphParser | |||
| from mindinsight.mindconverter.common.exceptions import ModelNotSupportError | |||
| class PyTorchGraphParser(GraphParser): | |||
| """Define pytorch graph parser.""" | |||
| @@ -106,7 +104,4 @@ class PyTorchGraphParser(GraphParser): | |||
| onnx = import_module('onnx') | |||
| onnx_model = onnx.load_model_from_string(proto) | |||
| onnx_simplify = OnnxSimplify() | |||
| onnx_model_sim = onnx_simplify.run_onnx_simplify(onnx_model, sample_shape) | |||
| return onnx_model_sim | |||
| return onnx_model | |||