Browse Source

Support onnx split op multiple outputs

tags/v1.2.0-rc1
liangtianshu 5 years ago
parent
commit
b7d831ccae
7 changed files with 163 additions and 62 deletions
  1. +14
    -0
      mindinsight/mindconverter/graph_based_converter/common/utils.py
  2. +77
    -2
      mindinsight/mindconverter/graph_based_converter/generator/generator.py
  3. +0
    -9
      mindinsight/mindconverter/graph_based_converter/mapper/base.py
  4. +28
    -33
      mindinsight/mindconverter/graph_based_converter/mapper/impl/ops/reshape_mapper.py
  5. +32
    -3
      mindinsight/mindconverter/graph_based_converter/mapper/impl/ops/split_mapper.py
  6. +9
    -7
      mindinsight/mindconverter/graph_based_converter/third_party_graph/onnx_utils.py
  7. +3
    -8
      mindinsight/mindconverter/graph_based_converter/third_party_graph/pytorch_graph_parser.py

+ 14
- 0
mindinsight/mindconverter/graph_based_converter/common/utils.py View File

@@ -219,3 +219,17 @@ def reset_init_or_construct(template, variable_slot, new_data, scope):
template[variable_slot][scope].clear()
template[variable_slot][scope] += new_data
return template

def replace_string_in_list(str_list: list, original_str: str, target_str: str):
"""
Replace a string in a list by provided string.

Args:
str_list (list): A list contains the string to be replaced.
original_str (str): The string to be replaced.
target_str (str): The replacement of string.

Returns,
list, the original list with replaced string.
"""
return [s.replace(original_str, target_str) for s in str_list]

+ 77
- 2
mindinsight/mindconverter/graph_based_converter/generator/generator.py View File

@@ -27,9 +27,10 @@ from mindinsight.mindconverter.graph_based_converter.common.global_context impor
from mindinsight.mindconverter.graph_based_converter.common.outputs import BaseOutput, ModuleOutputManager
from mindinsight.mindconverter.graph_based_converter.common.yapf_config import mindspore_yapf_config
from mindinsight.mindconverter.graph_based_converter.common.name_mgr import GlobalVarNameMgr
from mindinsight.mindconverter.graph_based_converter.constant import NEW_LINE, SECOND_LEVEL_INDENT, FIRST_LEVEL_INDENT, \
get_imported_module
from mindinsight.mindconverter.graph_based_converter.constant import NEW_LINE, SECOND_LEVEL_INDENT, \
FIRST_LEVEL_INDENT, get_imported_module
from mindinsight.mindconverter.graph_based_converter.report_generator import ReportGenerator
from mindinsight.mindconverter.graph_based_converter.common.utils import replace_string_in_list


class CodeStruct:
@@ -373,6 +374,10 @@ class Generator:
self.module_structs.get('[]').allocate_construct_header_x()
self.module_structs.get('[]').collect_returns()

for nd_struct in self.node_structs.values():
if nd_struct.fragment.metadata.get("operation") == "Split":
self._split_op_procs(nd_struct)

def _update_all_modules_args_translator(self):
"""Update all modules' args translators."""
done_submodule = set()
@@ -633,3 +638,73 @@ class Generator:
if user in root_module.external_successor_nodes_names:
return True
return False

def _split_op_procs(self, split_struct: NodeStruct):
"""
Support for Split operation multiple outputs.

Args:
split_struct (NodeStruct): The NodeStruct of the Split op.
"""
for successor in split_struct.successor_nodes_structs:
# 1. target user is internal
if split_struct.check_target_node_internal(successor.identifier):
idx = self._get_correct_input_idx_from_split(split_struct, successor)
if idx is None:
raise ValueError("The Split OP should not has empty output.")
correct_input = split_struct.fragment.fragment.get_outputs_by_idx(0, idx)
to_be_replaced = None
for inp in successor.matched_inputs:
if "split" in inp:
to_be_replaced = inp
break
if to_be_replaced is not None:
successor.matched_inputs = replace_string_in_list(successor.matched_inputs,
to_be_replaced,
correct_input)
# 2. target user is external
else:
public_parent = self._get_public_parent_module(split_struct, successor)
to_be_modified_md = self._get_submodule_has_out_user_under_public_parent(public_parent, successor)
idx = self._get_correct_input_idx_from_split(split_struct, successor)
if idx is None:
raise ValueError("The Split OP should not has empty output.")
if to_be_modified_md is None:
raise ValueError("Unable to locate the submodule to be modified for Split output matching.")
correct_input = split_struct.fragment.fragment.get_outputs_by_idx(0, idx)
to_be_replaced = None
for inp in to_be_modified_md.matched_inputs:
if "split" in inp:
to_be_replaced = inp
break
if to_be_replaced is not None:
to_be_modified_md.matched_inputs = replace_string_in_list(to_be_modified_md.matched_inputs,
to_be_replaced,
correct_input)

def _get_correct_input_idx_from_split(self, split_struct: NodeStruct, split_out_user: NodeStruct):
"""Return the index of the split output the user used."""
split_struct_out_edges = split_struct.fragment.metadata.get("outputs")
for idx, out in enumerate(split_struct_out_edges):
if out in split_out_user.fragment.metadata.get("inputs"):
return idx
return None

def _get_public_parent_module(self, node_a: NodeStruct, node_b: NodeStruct):
"""Return the public parent module of both Node A and Node B."""
find = False
b_onnx_name = node_b.onnx_name
tmp = node_a
while not find:
parent_struct = tmp.parent_module_struct
if b_onnx_name in parent_struct.onnx_names:
find = True
tmp = parent_struct
return tmp

def _get_submodule_has_out_user_under_public_parent(self, public_module: ModuleStruct, node_out_user: NodeStruct):
"""Return the ModuleStruct which under the public module and contains the NodeStruct which provided."""
for module_struct in public_module.module_structs:
if node_out_user.onnx_name in module_struct.onnx_names:
return module_struct
return None

+ 0
- 9
mindinsight/mindconverter/graph_based_converter/mapper/base.py View File

@@ -58,11 +58,6 @@ class Mapper(metaclass=abc.ABCMeta):
def _convert_trained_weights(**kwargs):
"""Convert third party operation's weights into MindSpore operation."""

@staticmethod
@abc.abstractmethod
def _convert_settings(**kwargs):
"""Convert third party operation's params into MindSpore OP operator."""

@classmethod
@abc.abstractmethod
def convert(cls, op_name: str, params: Dict, weights: Dict = None):
@@ -148,10 +143,6 @@ class ONNXToMindSporeMapper(Mapper, abc.ABC):
def _convert_trained_weights(**kwargs):
raise NotImplementedError

@staticmethod
def _convert_settings(**kwargs):
raise NotImplementedError

@staticmethod
def _generate_snippet_template(**kwargs):
op = kwargs.get("operation")


+ 28
- 33
mindinsight/mindconverter/graph_based_converter/mapper/impl/ops/reshape_mapper.py View File

@@ -13,10 +13,8 @@
# limitations under the License.
# ==============================================================================
"""Mapper module."""
from mindinsight.mindconverter.graph_based_converter.common.utils import reset_init_or_construct
from mindinsight.mindconverter.graph_based_converter.constant import ExchangeMessageKeywords, TemplateKeywords
from mindinsight.mindconverter.graph_based_converter.mapper.base import ONNXToMindSporeMapper
from mindinsight.mindconverter.graph_based_converter.mapper.gen_setting import Setting


class ReshapeMapper(ONNXToMindSporeMapper):
@@ -34,40 +32,37 @@ class ReshapeMapper(ONNXToMindSporeMapper):
def _convert_trained_weights(**kwargs):
return dict()

@staticmethod
def _convert_settings(**kwargs):
if kwargs.get("weights", None):
return ReshapeMapper._convert_settings_tf(**kwargs)
return ReshapeMapper._convert_settings_pytorch(**kwargs)

@staticmethod
def _convert_settings_pytorch(**kwargs):
params = kwargs.get("params")
shape = params.get("output_shape")
return Setting(op_extra_input={"input_shape": tuple(shape)})

@staticmethod
def _convert_settings_tf(**kwargs):
weights = kwargs.get("weights")
if len(weights) > 1:
raise ValueError("For reshape, `weights` length should equal to 1.")
shape = [-1]
shape += list(weights.values())[0][1:].tolist()
return Setting(op_extra_input={"shape": tuple(shape)})

@staticmethod
def _generate_snippet_template(**kwargs):
template, exchange_msg, outputs_list, outputs_mapping = ONNXToMindSporeMapper._generate_snippet_template(
**kwargs)
weights = kwargs.get("weights")
if len(weights) > 1:
raise ValueError("For reshape, `weights` length should equal to 1.")
shape = [-1]
shape += list(weights.values())[0][1:].tolist()
output_shape = kwargs.get("raw_params").get("output_shape")
variable_slot = "var_0"
construct_template = f"opt_{{{variable_slot}}} = self.{{{variable_slot}}}" \
f"({{{ExchangeMessageKeywords.VariableScope.value.INPUTS.value}}}, {tuple(shape)})"
template = reset_init_or_construct(template, variable_slot, [construct_template],
TemplateKeywords.CONSTRUCT.value)
op = kwargs.get("operation")
init_template = f"self.{{{variable_slot}}} = {op}()"
target_shape = f"self.{{{variable_slot}}}_shape = tuple({{shape}})"
args = {"shape": output_shape}

construct_template = f"opt_{{{variable_slot}}} = self.{{{variable_slot}}}" \
f"({{{ExchangeMessageKeywords.VariableScope.value.INPUTS.value}}}, " \
f"self.{{{variable_slot}}}_shape)"
template = {
variable_slot: {
TemplateKeywords.INIT.value: [init_template, target_shape],
TemplateKeywords.CONSTRUCT.value: [construct_template]
}
}
exchange_msg = {
variable_slot: {
ExchangeMessageKeywords.VariableScope.value.OPERATION.value: op,
ExchangeMessageKeywords.VariableScope.value.VARIABLE_NAME.value: None,
ExchangeMessageKeywords.VariableScope.value.OUTPUT_TYPE.value:
ExchangeMessageKeywords.VariableScope.value.TSR_TYPE.value,
ExchangeMessageKeywords.VariableScope.value.INPUTS.value: [],
ExchangeMessageKeywords.VariableScope.value.ARGS.value: args,
ExchangeMessageKeywords.VariableScope.value.WEIGHTS.value: weights,
ExchangeMessageKeywords.VariableScope.value.TRAINABLE_PARAMS.value: {}
}
}
outputs_list = [f"opt_{{{variable_slot}}}"]
outputs_mapping = ((0, 0),)
return template, exchange_msg, outputs_list, outputs_mapping

+ 32
- 3
mindinsight/mindconverter/graph_based_converter/mapper/impl/ops/split_mapper.py View File

@@ -13,8 +13,8 @@
# limitations under the License.
# ==============================================================================
"""Mapper module."""
from mindinsight.mindconverter.graph_based_converter.constant import ExchangeMessageKeywords, TemplateKeywords
from mindinsight.mindconverter.graph_based_converter.mapper.base import ONNXToMindSporeMapper
from mindinsight.mindconverter.graph_based_converter.mapper.gen_setting import Setting


class SplitMapper(ONNXToMindSporeMapper):
@@ -37,5 +37,34 @@ class SplitMapper(ONNXToMindSporeMapper):
return dict()

@staticmethod
def _convert_settings(**kwargs):
return Setting()
def _generate_snippet_template(**kwargs):
op = kwargs.get("operation")
args = kwargs.get("converted_params", dict())
weights = kwargs.get("weights")
if not op:
raise ValueError("Can not get MindSpore operation name.")
variable_slot = "var_0"
init_template = f"self.{{{variable_slot}}} = {op}({', '.join(['%s={%s}' % (p, p) for p in args])})"
construct_template = f"opt_{{{variable_slot}}} = self.{{{variable_slot}}}" \
f"({{{ExchangeMessageKeywords.VariableScope.value.INPUTS.value}}})"
template = {
variable_slot: {
TemplateKeywords.INIT.value: [init_template],
TemplateKeywords.CONSTRUCT.value: [construct_template]
}
}
exchange_msg = {
variable_slot: {
ExchangeMessageKeywords.VariableScope.value.OPERATION.value: op,
ExchangeMessageKeywords.VariableScope.value.VARIABLE_NAME.value: None,
ExchangeMessageKeywords.VariableScope.value.OUTPUT_TYPE.value:
ExchangeMessageKeywords.VariableScope.value.ARR_TYPE.value,
ExchangeMessageKeywords.VariableScope.value.INPUTS.value: [],
ExchangeMessageKeywords.VariableScope.value.ARGS.value: args,
ExchangeMessageKeywords.VariableScope.value.WEIGHTS.value: weights,
ExchangeMessageKeywords.VariableScope.value.TRAINABLE_PARAMS.value: {}
}
}
outputs_list = [f"opt_{{{variable_slot}}}"]
outputs_mapping = ((0, 0),)
return template, exchange_msg, outputs_list, outputs_mapping

+ 9
- 7
mindinsight/mindconverter/graph_based_converter/third_party_graph/onnx_utils.py View File

@@ -24,10 +24,11 @@ import numpy as np
from mindinsight.mindconverter.common.log import logger as log
from ..common.utils import fetch_output_from_onnx_model
from ..common.global_context import GlobalContext
from .optimizer import OnnxSimplify

from ..constant import ONNX_TYPE_INT, ONNX_TYPE_INTS, ONNX_TYPE_STRING, \
ONNX_TYPE_FLOATS, ONNX_TYPE_FLOAT, SCALAR_WITHOUT_SHAPE, DYNAMIC_SHAPE, UNKNOWN_DIM_VAL
from ...common.exceptions import GraphInitError, ModelNotSupportError, ModelLoadingError
from ...common.exceptions import GraphInitError, ModelLoadingError


def convert_tf_graph_to_onnx(model_path, model_inputs, model_outputs, opset=12):
@@ -258,9 +259,11 @@ class OnnxDataLoader:

def __init__(self, onnx_model, graph_input_shape: Union[tuple, list],
input_nodes: list, output_nodes: list, infer_shape=True):
self.model = onnx_model
self.graph = onnx_model.graph
self.nodes = onnx_model.graph.node
onnx_sim = OnnxSimplify()
onnx_model_sim = onnx_sim.run_onnx_simplify(onnx_model, graph_input_shape)
self.model = onnx_model_sim
self.graph = onnx_model_sim.graph
self.nodes = onnx_model_sim.graph.node
self.graph_input_shape = graph_input_shape
self.input_nodes = input_nodes if isinstance(input_nodes, list) else [input_nodes]
self.output_nodes = output_nodes if isinstance(output_nodes, list) else [output_nodes]
@@ -388,9 +391,8 @@ class OnnxDataLoader:
n = OnnxNode(node)
self._nodes_dict[n.name] = n
nodes_topo_idx.append((idx, n.name))
if len(node.output) > 1:
raise ModelNotSupportError(msg=f"{node.name} has multi-outputs which is not supported now.")
self.output_name_to_node_name[node.output[0]] = node.name
for out in node.output:
self.output_name_to_node_name[out] = node.name

for ipt_nd in node.input:
if ipt_nd not in self.output_name_to_node_name:


+ 3
- 8
mindinsight/mindconverter/graph_based_converter/third_party_graph/pytorch_graph_parser.py View File

@@ -17,10 +17,8 @@ import os
from importlib import import_module

from mindinsight.mindconverter.common.log import logger as log
from .base import GraphParser
from .optimizer import OnnxSimplify
from ...common.exceptions import ModelNotSupportError

from mindinsight.mindconverter.graph_based_converter.third_party_graph.base import GraphParser
from mindinsight.mindconverter.common.exceptions import ModelNotSupportError

class PyTorchGraphParser(GraphParser):
"""Define pytorch graph parser."""
@@ -106,7 +104,4 @@ class PyTorchGraphParser(GraphParser):
onnx = import_module('onnx')
onnx_model = onnx.load_model_from_string(proto)

onnx_simplify = OnnxSimplify()
onnx_model_sim = onnx_simplify.run_onnx_simplify(onnx_model, sample_shape)

return onnx_model_sim
return onnx_model

Loading…
Cancel
Save