| @@ -489,9 +489,9 @@ class HierarchicalTree(Tree): | |||
| """ | |||
| return s.split(SEPARATOR_IN_SCOPE)[-1].lower().split(SEPARATOR_BTW_NAME_AND_ID)[0] | |||
| def _get_previous_opt_var(self, cur_nd, pre_nd): | |||
| def _find_all_previous_opt_var_(self, cur_nd, pre_nd): | |||
| """ | |||
| Get needed input variable names. | |||
| Find all input varian names. | |||
| Args: | |||
| cur_nd (Node): Current node. | |||
| @@ -501,28 +501,41 @@ class HierarchicalTree(Tree): | |||
| str, needed var names. | |||
| """ | |||
| ipt_lst = [] | |||
| if cur_nd.data.node_type == NodeType.OPERATION.value: | |||
| for e in cur_nd.data.precursor_nodes: | |||
| p_nd = self.get_node(e) | |||
| if e not in pre_nd.successors(self.tree_identifier): | |||
| while True: | |||
| if p_nd.identifier in pre_nd.successors(self.tree_identifier): | |||
| ipt_lst.append(p_nd.data.opt_var_name) | |||
| break | |||
| pre_nd_name = p_nd.predecessor(self.tree_identifier) | |||
| if not pre_nd_name: | |||
| ipt_lst.append("x") | |||
| break | |||
| p_nd = self.get_node(pre_nd_name) | |||
| continue | |||
| for e in cur_nd.data.precursor_nodes: | |||
| p_nd = self.get_node(e) | |||
| if e not in pre_nd.successors(self.tree_identifier): | |||
| while True: | |||
| if p_nd.identifier in pre_nd.successors(self.tree_identifier): | |||
| ipt_lst.append(p_nd.data.opt_var_name) | |||
| break | |||
| pre_nd_name = p_nd.predecessor(self.tree_identifier) | |||
| if not pre_nd_name: | |||
| ipt_lst.append("x") | |||
| break | |||
| p_nd = self.get_node(pre_nd_name) | |||
| continue | |||
| ipt_lst.append(p_nd.data.opt_var_name) | |||
| else: | |||
| idx = pre_nd.successors(self.tree_identifier).index(cur_nd.identifier) - 1 | |||
| p_nd = self.get_node(pre_nd.successors(self.tree_identifier)[idx]) | |||
| ipt_lst.append(p_nd.data.opt_var_name) | |||
| return ipt_lst | |||
| def _get_previous_opt_var(self, cur_nd, pre_nd): | |||
| """ | |||
| Get needed input variable names. | |||
| return ", ".join(ipt_lst) | |||
| Args: | |||
| cur_nd (Node): Current node. | |||
| pre_nd (Node): Precursor node. | |||
| Returns: | |||
| str, needed var names. | |||
| """ | |||
| if cur_nd.data.node_type != NodeType.OPERATION.value: | |||
| while True: | |||
| p_nd = cur_nd.successors(self.tree_identifier) | |||
| if not p_nd: | |||
| break | |||
| cur_nd = self.get_node(p_nd[0]) | |||
| return ", ".join(self._find_all_previous_opt_var_(cur_nd, pre_nd)) | |||
| def hash_key(self, node): | |||
| """ | |||
| @@ -41,12 +41,12 @@ class Mapper(metaclass=abc.ABCMeta): | |||
| @staticmethod | |||
| @abc.abstractmethod | |||
| def _operation_name_in_ms(): | |||
| def _operation_name_in_ms(*args, **kwargs): | |||
| """Corresponding operation name in mindspore.""" | |||
| @staticmethod | |||
| @abc.abstractmethod | |||
| def _convert_params(params): | |||
| def _convert_params(params, weights): | |||
| """Convert third party operation's param into MindSpore operation.""" | |||
| @staticmethod | |||
| @@ -95,8 +95,8 @@ class ONNXToMindSporeMapper(Mapper, abc.ABC): | |||
| return None, dict() | |||
| try: | |||
| converter_name = op_name_converter() | |||
| converted_params = params_converter(params) | |||
| converter_name = op_name_converter(params=params, weights=weights, op_name=op_name) | |||
| converted_params = params_converter(params, weights) | |||
| converted_weights = weights_converter(weights) if weights else dict() | |||
| converted_params.update(converted_weights) | |||
| except (AttributeError,) as _: | |||
| @@ -106,11 +106,11 @@ class ONNXToMindSporeMapper(Mapper, abc.ABC): | |||
| return converter_name, converted_params | |||
| @staticmethod | |||
| def _operation_name_in_ms(): | |||
| def _operation_name_in_ms(*args, **kwargs): | |||
| raise NotImplementedError | |||
| @staticmethod | |||
| def _convert_params(params): | |||
| def _convert_params(params, weights): | |||
| raise NotImplementedError | |||
| @staticmethod | |||
| @@ -20,13 +20,14 @@ class BatchNormMapper(ONNXToMindSporeMapper): | |||
| """BatchNorm mapper.""" | |||
| @staticmethod | |||
| def _operation_name_in_ms(): | |||
| return "nn.BatchNorm2d" | |||
| def _operation_name_in_ms(*args, **kwargs): | |||
| dim = len(kwargs['params']['output_shape']) - 2 | |||
| return f"nn.BatchNorm{dim}d" | |||
| @staticmethod | |||
| def _convert_params(params): | |||
| def _convert_params(params, weights): | |||
| return { | |||
| 'num_features': params['input_shape'][1], | |||
| 'num_features': params['output_shape'][1], | |||
| 'eps': params['epsilon'], | |||
| 'momentum': params['momentum'] | |||
| } | |||
| @@ -1,41 +0,0 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd.All Rights Reserved. | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================== | |||
| """Mapper module.""" | |||
| from ...base import ONNXToMindSporeMapper | |||
| class Conv2dMapper(ONNXToMindSporeMapper): | |||
| """Conv2d mapper.""" | |||
| @staticmethod | |||
| def _operation_name_in_ms(): | |||
| return "nn.Conv2d" | |||
| @staticmethod | |||
| def _convert_params(params): | |||
| return { | |||
| 'in_channels': params['input_shape'][1], | |||
| 'out_channels': params['output_shape'][1], | |||
| 'kernel_size': params['kernel_shape'], | |||
| 'stride': params['strides'][0], | |||
| 'pad': params['pads'][0], | |||
| 'dilation': params['dilations'][0], | |||
| 'group': params['group']} | |||
| @staticmethod | |||
| def _convert_trained_weights(weights): | |||
| if weights: | |||
| pass | |||
| return dict() | |||
| @@ -0,0 +1,70 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd.All Rights Reserved. | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================== | |||
| """Mapper module.""" | |||
| import numpy as np | |||
| from ...base import ONNXToMindSporeMapper | |||
| class ConvMapper(ONNXToMindSporeMapper): | |||
| """Conv2d mapper.""" | |||
| @staticmethod | |||
| def _operation_name_in_ms(*args, **kwargs): | |||
| weight = kwargs['weights']['weight'].numpy() | |||
| dim = weight.ndim - 2 | |||
| return f"nn.Conv{dim}d" | |||
| @staticmethod | |||
| def _convert_params(params, weights): | |||
| weight = weights['weight'].numpy() | |||
| weight = np.transpose(weight, list(range(2, weight.ndim)) + [1, 0]) | |||
| if isinstance(params['dilations'], list): | |||
| dilation = tuple(params['dilations']) | |||
| else: | |||
| dilation = params['dilations'] | |||
| if isinstance(params['strides'], list): | |||
| stride = tuple(params['strides']) | |||
| else: | |||
| stride = params['strides'] | |||
| kernel_shape = list(weight.shape) | |||
| in_channels = kernel_shape[-2] | |||
| out_channels = kernel_shape[-1] | |||
| kernel_size = kernel_shape[:-2] | |||
| if len(kernel_size) == 1: | |||
| kernel_size = kernel_size[0] | |||
| else: | |||
| kernel_size = tuple(kernel_size) | |||
| pad_mode, padding = ConvMapper._convert_padding(params) | |||
| return { | |||
| 'in_channels': in_channels, | |||
| 'out_channels': out_channels, | |||
| 'kernel_size': kernel_size, | |||
| 'stride': stride, | |||
| 'padding': padding, | |||
| 'pad_mode': pad_mode, | |||
| 'dilation': dilation, | |||
| 'group': params['group']} | |||
| @staticmethod | |||
| def _convert_trained_weights(weights): | |||
| if weights: | |||
| pass | |||
| return dict() | |||
| @staticmethod | |||
| def _convert_padding(params): | |||
| if sum(params['pads']) == 0: | |||
| return '\"valid\"', 0 | |||
| return '\"pad\"', tuple(params['pads']) | |||
| @@ -20,14 +20,18 @@ class DenseMapper(ONNXToMindSporeMapper): | |||
| """Dense mapper.""" | |||
| @staticmethod | |||
| def _operation_name_in_ms(): | |||
| def _operation_name_in_ms(*args, **kwargs): | |||
| return "nn.Dense" | |||
| @staticmethod | |||
| def _convert_params(params): | |||
| def _convert_params(params, weights): | |||
| has_bias = bool('bias' in weights) | |||
| weight = weights['weight'].numpy().transpose() | |||
| in_channels, out_channels = weight.shape | |||
| return { | |||
| 'in_channels': params['input_shape'][1], | |||
| 'out_channels': params['output_shape'][1] | |||
| 'in_channels': in_channels, | |||
| 'out_channels': out_channels, | |||
| 'has_bias': has_bias | |||
| } | |||
| @staticmethod | |||
| @@ -20,11 +20,11 @@ class FlattenMapper(ONNXToMindSporeMapper): | |||
| """Flatten mapper.""" | |||
| @staticmethod | |||
| def _operation_name_in_ms(): | |||
| def _operation_name_in_ms(*args, **kwargs): | |||
| return "nn.Flatten" | |||
| @staticmethod | |||
| def _convert_params(params): | |||
| def _convert_params(params, weights): | |||
| if params: | |||
| pass | |||
| return dict() | |||
| @@ -16,18 +16,29 @@ | |||
| from ...base import ONNXToMindSporeMapper | |||
| class GlobalAvgPoolMapper(ONNXToMindSporeMapper): | |||
| class GlobalPoolMapper(ONNXToMindSporeMapper): | |||
| """AvgPool mapper.""" | |||
| @staticmethod | |||
| def _operation_name_in_ms(): | |||
| return "nn.AvgPool2d" | |||
| def _operation_name_in_ms(*args, **kwargs): | |||
| if kwargs['op_name'] == 'onnx::GlobalAveragePool': | |||
| op_name = 'nn.AvgPool{}d' | |||
| else: | |||
| op_name = 'nn.MaxPool{}d' | |||
| dim = 1 if len(kwargs['params']['input_shape']) == 3\ | |||
| else 2 | |||
| return op_name.format(dim) | |||
| @staticmethod | |||
| def _convert_params(params): | |||
| kernel_size_height = params['input_shape'][2] // params['output_shape'][2] | |||
| kernel_size_width = params['input_shape'][3] // params['output_shape'][3] | |||
| kernel_size = [kernel_size_height, kernel_size_width] | |||
| def _convert_params(params, weights): | |||
| dim = 1 if len(params['input_shape']) == 3\ | |||
| else 2 | |||
| if dim == 1: | |||
| kernel_size = params['input_shape'][-1] // params['output_shape'][-1] | |||
| else: | |||
| kernel_size_height = params['input_shape'][-2] // params['output_shape'][-2] | |||
| kernel_size_width = params['input_shape'][-1] // params['output_shape'][-1] | |||
| kernel_size = (kernel_size_height, kernel_size_width) | |||
| return { | |||
| 'kernel_size': kernel_size | |||
| } | |||
| @@ -16,18 +16,28 @@ | |||
| from ...base import ONNXToMindSporeMapper | |||
| class MaxPoolMapper(ONNXToMindSporeMapper): | |||
| class PoolMapper(ONNXToMindSporeMapper): | |||
| """MaxPool mapper.""" | |||
| @staticmethod | |||
| def _operation_name_in_ms(): | |||
| return "nn.MaxPool2d" | |||
| def _operation_name_in_ms(*args, **kwargs): | |||
| if kwargs['op_name'] == 'onnx::AveragePool': | |||
| op_name = 'nn.AvgPool{}d' | |||
| else: | |||
| op_name = 'nn.MaxPool{}d' | |||
| dim = len(kwargs['params']['strides']) | |||
| return op_name.format(dim) | |||
| @staticmethod | |||
| def _convert_params(params): | |||
| def _convert_params(params, weights): | |||
| if sum(params['pads']) == 0: | |||
| pad_mode = '\"valid\"' | |||
| else: | |||
| pad_mode = '\"same\"' | |||
| return { | |||
| 'kernel_size': params['kernel_shape'], | |||
| 'stride': params['strides'] | |||
| 'kernel_size': tuple(params['kernel_shape']), | |||
| 'stride': tuple(params['strides']), | |||
| 'pad_mode': pad_mode | |||
| } | |||
| @staticmethod | |||
| @@ -20,11 +20,11 @@ class ReLUMapper(ONNXToMindSporeMapper): | |||
| """ReLU mapper.""" | |||
| @staticmethod | |||
| def _operation_name_in_ms(): | |||
| def _operation_name_in_ms(*args, **kwargs): | |||
| return "nn.ReLU" | |||
| @staticmethod | |||
| def _convert_params(params): | |||
| def _convert_params(params, weights): | |||
| if params: | |||
| pass | |||
| return dict() | |||
| @@ -20,11 +20,11 @@ class AddMapper(ONNXToMindSporeMapper): | |||
| """Add mapper.""" | |||
| @staticmethod | |||
| def _operation_name_in_ms(): | |||
| def _operation_name_in_ms(*args, **kwargs): | |||
| return "P.TensorAdd" | |||
| @staticmethod | |||
| def _convert_params(params): | |||
| def _convert_params(params, weights): | |||
| if params: | |||
| pass | |||
| return dict() | |||
| @@ -1,10 +1,10 @@ | |||
| { | |||
| "onnx::Conv": "mindinsight.mindconverter.graph_based_converter.mapper.impl.nn.conv2d_mapper.Conv2dMapper", | |||
| "onnx::Conv": "mindinsight.mindconverter.graph_based_converter.mapper.impl.nn.conv_mapper.ConvMapper", | |||
| "onnx::Gemm": "mindinsight.mindconverter.graph_based_converter.mapper.impl.nn.dense_mapper.DenseMapper", | |||
| "onnx::BatchNormalization": "mindinsight.mindconverter.graph_based_converter.mapper.impl.nn.batch_norm_mapper.BatchNormMapper", | |||
| "onnx::Relu": "mindinsight.mindconverter.graph_based_converter.mapper.impl.nn.relu_mapper.ReLUMapper", | |||
| "onnx::MaxPool": "mindinsight.mindconverter.graph_based_converter.mapper.impl.nn.max_pool_mapper.MaxPoolMapper", | |||
| "onnx::GlobalAveragePool": "mindinsight.mindconverter.graph_based_converter.mapper.impl.nn.avg_pool_mapper.GlobalAvgPoolMapper", | |||
| "onnx::MaxPool": "mindinsight.mindconverter.graph_based_converter.mapper.impl.nn.pool_mapper.PoolMapper", | |||
| "onnx::GlobalAveragePool": "mindinsight.mindconverter.graph_based_converter.mapper.impl.nn.global_pool_mapper.GlobalPoolMapper", | |||
| "onnx::Flatten": "mindinsight.mindconverter.graph_based_converter.mapper.impl.nn.flatten_mapper.FlattenMapper", | |||
| "onnx::Add": "mindinsight.mindconverter.graph_based_converter.mapper.impl.ops.add_mapper.AddMapper" | |||
| } | |||
| @@ -288,6 +288,8 @@ class GraphNode(abc.ABC): | |||
| self._ipt_shape = None | |||
| # Output shape of current op. | |||
| self._opt_shape = None | |||
| # Weight of current op. | |||
| self._weight = None | |||
| @property | |||
| def opt_var_name(self): | |||
| @@ -536,7 +538,8 @@ class GraphNode(abc.ABC): | |||
| "output_shape": self.output_shape}) | |||
| op_name_in_mindspore, ms_params = mapper.convert(op_name=self.op_name, | |||
| params=params) | |||
| params=params, | |||
| weights=self._weight) | |||
| if op_name_in_mindspore: | |||
| self._op_in_ms = op_name_in_mindspore | |||
| self._params_in_ms = ms_params | |||
| @@ -13,7 +13,6 @@ | |||
| # limitations under the License. | |||
| # ============================================================================== | |||
| """Define PyTorch graph.""" | |||
| import platform | |||
| import warnings | |||
| import re | |||
| from typing import Dict, NoReturn | |||
| @@ -112,9 +111,7 @@ class PyTorchGraph(Graph): | |||
| self._check_input_shape(input_shape) | |||
| def _extract_shape(shape): | |||
| if platform.system() == "Darwin": | |||
| return [int(x.split(":")[0]) for x in shape.split(',')] | |||
| return [int(x.replace("!", "")) for x in shape.split(',')] | |||
| return [int(x.split(":")[0].replace("!", "")) for x in shape.split(',')] | |||
| feed_forward_ipt_shape = (1, *input_shape) | |||
| batched_sample = create_autograd_variable(torch.rand(*feed_forward_ipt_shape)) | |||
| @@ -134,9 +131,16 @@ class PyTorchGraph(Graph): | |||
| output_shape_str_list = re.findall(r'[^()!]+', str(node)) | |||
| output_shape_str = output_shape_str_list[1] | |||
| output_shape = _extract_shape(output_shape_str) | |||
| weight_scope = '.'.join( | |||
| re.findall(r'\[([\w\d.]+)\]', node.scopeName()) | |||
| ) | |||
| node_weight = {} | |||
| for scope, weight in self._params_dict.items(): | |||
| split_scope = scope.split('.') | |||
| if '.'.join(split_scope[:-1]) == weight_scope: | |||
| node_weight[split_scope[-1]] = weight | |||
| self._shape_dict[node_name] = output_shape | |||
| self._nodes_collection[node_name] = PyTorchGraphNode(node) | |||
| self._nodes_collection[node_name] = PyTorchGraphNode(node, node_weight) | |||
| self._nodes_record[node_name] = node_name | |||
| for node_input in list(node.inputs()): | |||
| @@ -32,7 +32,7 @@ class PyTorchGraphNode(GraphNode): | |||
| _type_frozen = False | |||
| _module_name_frozen = False | |||
| def __init__(self, node=None): | |||
| def __init__(self, node=None, weight=None): | |||
| super(PyTorchGraphNode, self).__init__(node=node) | |||
| self._op_params = self._get_raw_params(node) | |||
| self._op_name = node.kind() if node else None | |||
| @@ -40,6 +40,7 @@ class PyTorchGraphNode(GraphNode): | |||
| self._opt_var_name = None | |||
| self._variable_name = self._extract_var_name(self._scope_name) | |||
| self._module_name = None | |||
| self._weight = weight | |||
| def clear_args_of_declaration(self): | |||
| """ | |||