From 803b9689f1450eade368b3990234b43a5ab6bb9b Mon Sep 17 00:00:00 2001 From: liangtianshu Date: Mon, 8 Mar 2021 21:07:28 +0800 Subject: [PATCH] add missing mapper and minor logic optimization --- .../generator/node_struct.py | 5 +- .../generator/shared_weights.py | 5 +- .../mapper/impl/nn/one_hot_mapper.py | 50 +++++++++++++++++++ .../mapper/impl/ops/neg_mapper.py | 32 ++++++++++++ .../mapper/impl/ops/reciprocal_mapper.py | 32 ++++++++++++ .../mapper/impl/ops/reduce_mean_mapper.py | 35 ++++++++++--- .../mapper/impl/ops/rsqrt_mapper.py | 32 ++++++++++++ .../mapper/onnx_to_ms.json | 4 ++ 8 files changed, 185 insertions(+), 10 deletions(-) create mode 100644 mindinsight/mindconverter/graph_based_converter/mapper/impl/nn/one_hot_mapper.py create mode 100644 mindinsight/mindconverter/graph_based_converter/mapper/impl/ops/neg_mapper.py create mode 100644 mindinsight/mindconverter/graph_based_converter/mapper/impl/ops/reciprocal_mapper.py create mode 100644 mindinsight/mindconverter/graph_based_converter/mapper/impl/ops/rsqrt_mapper.py diff --git a/mindinsight/mindconverter/graph_based_converter/generator/node_struct.py b/mindinsight/mindconverter/graph_based_converter/generator/node_struct.py index 6ec3b3bc..db57a4ca 100644 --- a/mindinsight/mindconverter/graph_based_converter/generator/node_struct.py +++ b/mindinsight/mindconverter/graph_based_converter/generator/node_struct.py @@ -349,9 +349,8 @@ class NodeStruct: self.fragment.default_var["parameters"][trainable_param_postfix] = declare_statement continue # not a shared weight, skip the rest - if onnx_name in self._global_context.repeated_weights_declaration.keys(): - continue # already declared, skip - self._global_context.repeated_weights_declaration[onnx_name] = declare_statement + if onnx_name not in self._global_context.repeated_weights_declaration.keys(): + self._global_context.repeated_weights_declaration[onnx_name] = declare_statement # set template to mapper parameter rewritten. shared_w_var_in_parent = self._get_shared_weight_var_names_from_parent(onnx_name=onnx_name) diff --git a/mindinsight/mindconverter/graph_based_converter/generator/shared_weights.py b/mindinsight/mindconverter/graph_based_converter/generator/shared_weights.py index 2913124d..b887f1af 100644 --- a/mindinsight/mindconverter/graph_based_converter/generator/shared_weights.py +++ b/mindinsight/mindconverter/graph_based_converter/generator/shared_weights.py @@ -13,11 +13,11 @@ # limitations under the License. # ============================================================================== """Module rocessing for shared weights.""" +from mindinsight.mindconverter.graph_based_converter.constant import ExchangeMessageKeywords from mindinsight.mindconverter.graph_based_converter.common.global_context import GlobalContext from mindinsight.mindconverter.graph_based_converter.generator.node_struct import NodeStruct from mindinsight.mindconverter.graph_based_converter.generator.module_struct import ModuleStruct - class SharedWeightHelper: """Helper function to process shared weights.""" @@ -61,6 +61,9 @@ class SharedWeightHelper: share_weight_name (str): The onnx name of the shared weights. pub_module_identifier (list): The identifier of the public module the shared weight in. """ + if not node.fragment.default_var.get(ExchangeMessageKeywords.VariableScope.value.PARAMETERS_DECLARED.value): + # No weight shared operator, skip + return parent_module = node.parent_module_struct exit_flag = False while True: diff --git a/mindinsight/mindconverter/graph_based_converter/mapper/impl/nn/one_hot_mapper.py b/mindinsight/mindconverter/graph_based_converter/mapper/impl/nn/one_hot_mapper.py new file mode 100644 index 00000000..fb3fa365 --- /dev/null +++ b/mindinsight/mindconverter/graph_based_converter/mapper/impl/nn/one_hot_mapper.py @@ -0,0 +1,50 @@ +# Copyright 2021 Huawei Technologies Co., Ltd.All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Mapper module.""" +import numpy as np + +from mindinsight.mindconverter.graph_based_converter.mapper.base import ONNXToMindSporeMapper + + +class OneHotMapper(ONNXToMindSporeMapper): + """OneHot mapper.""" + + @staticmethod + def _operation_name_in_ms(*args, **kwargs): + return "nn.OneHot" + + @staticmethod + def _convert_params(**kwargs): + params = kwargs.get('params') + converted_params = {} + if params.get('axis'): + converted_params['axis'] = params.get('axis') + if kwargs.get('weights'): + weights = kwargs.get('weights') + depth = weights[0] + val = weights[1] + if depth and isinstance(depth.value, np.ndarray): + ms_depth = depth.value[0] + converted_params['depth'] = ms_depth + if val and isinstance(val.value, np.ndarray): + ms_off_val = val.value[0] + ms_on_val = val.value[1] + converted_params['off_value'] = ms_off_val + converted_params['on_value'] = ms_on_val + return converted_params + + @staticmethod + def _convert_trained_weights(**kwargs): + return dict() diff --git a/mindinsight/mindconverter/graph_based_converter/mapper/impl/ops/neg_mapper.py b/mindinsight/mindconverter/graph_based_converter/mapper/impl/ops/neg_mapper.py new file mode 100644 index 00000000..baa4253f --- /dev/null +++ b/mindinsight/mindconverter/graph_based_converter/mapper/impl/ops/neg_mapper.py @@ -0,0 +1,32 @@ +# Copyright 2021 Huawei Technologies Co., Ltd.All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Mapper module.""" +from mindinsight.mindconverter.graph_based_converter.mapper.base import ONNXToMindSporeMapper + + +class NegMapper(ONNXToMindSporeMapper): + """Neg mapper.""" + + @staticmethod + def _operation_name_in_ms(*args, **kwargs): + return "P.Neg" + + @staticmethod + def _convert_params(**kwargs): + return dict() + + @staticmethod + def _convert_trained_weights(**kwargs): + return dict() diff --git a/mindinsight/mindconverter/graph_based_converter/mapper/impl/ops/reciprocal_mapper.py b/mindinsight/mindconverter/graph_based_converter/mapper/impl/ops/reciprocal_mapper.py new file mode 100644 index 00000000..2e7111e1 --- /dev/null +++ b/mindinsight/mindconverter/graph_based_converter/mapper/impl/ops/reciprocal_mapper.py @@ -0,0 +1,32 @@ +# Copyright 2021 Huawei Technologies Co., Ltd.All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Mapper module.""" +from mindinsight.mindconverter.graph_based_converter.mapper.base import ONNXToMindSporeMapper + + +class ReciprocalMapper(ONNXToMindSporeMapper): + """Reciprocal mapper.""" + + @staticmethod + def _operation_name_in_ms(*args, **kwargs): + return "P.Reciprocal" + + @staticmethod + def _convert_params(**kwargs): + return dict() + + @staticmethod + def _convert_trained_weights(**kwargs): + return dict() diff --git a/mindinsight/mindconverter/graph_based_converter/mapper/impl/ops/reduce_mean_mapper.py b/mindinsight/mindconverter/graph_based_converter/mapper/impl/ops/reduce_mean_mapper.py index ea736417..8dd26fab 100644 --- a/mindinsight/mindconverter/graph_based_converter/mapper/impl/ops/reduce_mean_mapper.py +++ b/mindinsight/mindconverter/graph_based_converter/mapper/impl/ops/reduce_mean_mapper.py @@ -13,7 +13,6 @@ # limitations under the License. # ============================================================================== """Mapper module.""" -from mindinsight.mindconverter.graph_based_converter.common.utils import reset_init_or_construct from mindinsight.mindconverter.graph_based_converter.constant import ExchangeMessageKeywords, TemplateKeywords from mindinsight.mindconverter.graph_based_converter.mapper.base import ONNXToMindSporeMapper @@ -37,17 +36,41 @@ class ReduceMeanMapper(ONNXToMindSporeMapper): @staticmethod def _generate_snippet_template(**kwargs): - template, exchange_msg, outputs_list, outputs_mapping = ONNXToMindSporeMapper._generate_snippet_template( - **kwargs) + op = kwargs.get("operation") + args = kwargs.get("converted_params") raw_params = kwargs.get("raw_params") if raw_params.get('axes'): axis = raw_params['axes'][0] if len(raw_params['axes']) == 1 else tuple(raw_params['axes']) else: axis = tuple() variable_slot = "var_0" + init_template = f"self.{{{variable_slot}}} = {op}({', '.join(['%s={%s}' % (p, p) for p in args])})" + + args["axis"] = axis + init_tensor = f"self.{{{variable_slot}}}_axis = {{axis}}" + construct_template = f"opt_{{{variable_slot}}} = self.{{{variable_slot}}}" \ - f"({{{ExchangeMessageKeywords.VariableScope.value.INPUTS.value}}}, {axis})" - template = reset_init_or_construct(template, variable_slot, [construct_template], - TemplateKeywords.CONSTRUCT.value) + f"({{{ExchangeMessageKeywords.VariableScope.value.INPUTS.value}}}, " \ + f"self.{{{variable_slot}}}_axis)" + template = { + variable_slot: { + TemplateKeywords.INIT.value: [init_template, init_tensor], + TemplateKeywords.CONSTRUCT.value: [construct_template] + } + } + exchange_msg = { + variable_slot: { + ExchangeMessageKeywords.VariableScope.value.OPERATION.value: op, + ExchangeMessageKeywords.VariableScope.value.VARIABLE_NAME.value: None, + ExchangeMessageKeywords.VariableScope.value.OUTPUT_TYPE.value: + ExchangeMessageKeywords.VariableScope.value.TSR_TYPE.value, + ExchangeMessageKeywords.VariableScope.value.INPUTS.value: [], + ExchangeMessageKeywords.VariableScope.value.ARGS.value: args, + ExchangeMessageKeywords.VariableScope.value.WEIGHTS.value: [], + ExchangeMessageKeywords.VariableScope.value.TRAINABLE_PARAMS.value: {} + } + } + outputs_list = [f"opt_{{{variable_slot}}}"] + outputs_mapping = ((0, 0),) return template, exchange_msg, outputs_list, outputs_mapping diff --git a/mindinsight/mindconverter/graph_based_converter/mapper/impl/ops/rsqrt_mapper.py b/mindinsight/mindconverter/graph_based_converter/mapper/impl/ops/rsqrt_mapper.py new file mode 100644 index 00000000..f5005279 --- /dev/null +++ b/mindinsight/mindconverter/graph_based_converter/mapper/impl/ops/rsqrt_mapper.py @@ -0,0 +1,32 @@ +# Copyright 2021 Huawei Technologies Co., Ltd.All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Mapper module.""" +from mindinsight.mindconverter.graph_based_converter.mapper.base import ONNXToMindSporeMapper + + +class RsqrtMapper(ONNXToMindSporeMapper): + """Rsqart mapper.""" + + @staticmethod + def _operation_name_in_ms(*args, **kwargs): + return "P.Rsqrt" + + @staticmethod + def _convert_params(**kwargs): + return dict() + + @staticmethod + def _convert_trained_weights(**kwargs): + return dict() diff --git a/mindinsight/mindconverter/graph_based_converter/mapper/onnx_to_ms.json b/mindinsight/mindconverter/graph_based_converter/mapper/onnx_to_ms.json index bcec726b..589499ef 100644 --- a/mindinsight/mindconverter/graph_based_converter/mapper/onnx_to_ms.json +++ b/mindinsight/mindconverter/graph_based_converter/mapper/onnx_to_ms.json @@ -15,6 +15,10 @@ "onnx::Transpose": "mindinsight.mindconverter.graph_based_converter.mapper.impl.ops.transpose_mapper.TransposeMapper", "onnx::MatMul": "mindinsight.mindconverter.graph_based_converter.mapper.impl.nn.mat_mul_mapper.MatMulMapper", "onnx::Softmax": "mindinsight.mindconverter.graph_based_converter.mapper.impl.nn.softmax_mapper.SoftmaxMapper", + "onnx::OneHot": "mindinsight.mindconverter.graph_based_converter.mapper.impl.nn.one_hot_mapper.OneHotMapper", + "onnx::Neg": "mindinsight.mindconverter.graph_based_converter.mapper.impl.ops.neg_mapper.NegMapper", + "onnx::Reciprocal": "mindinsight.mindconverter.graph_based_converter.mapper.impl.ops.reciprocal_mapper.ReciprocalMapper", + "onnx::Rsqrt": "mindinsight.mindconverter.graph_based_converter.mapper.impl.ops.rsqrt_mapper.RsqrtMapper", "onnx::Reshape": "mindinsight.mindconverter.graph_based_converter.mapper.impl.ops.reshape_mapper.ReshapeMapper", "onnx::Slice": "mindinsight.mindconverter.graph_based_converter.mapper.impl.ops.slice_mapper.SliceMapper", "onnx::Mul": "mindinsight.mindconverter.graph_based_converter.mapper.impl.ops.mul_mapper.MulMapper",