Browse Source

Add avgpool mapper.

tags/v1.0.0
liuchongming74 liuchongming74 5 years ago
parent
commit
3e9a289206
5 changed files with 61 additions and 30 deletions
  1. +1
    -1
      mindinsight/mindconverter/graph_based_converter/mapper/base.py
  2. +11
    -9
      mindinsight/mindconverter/graph_based_converter/mapper/impl/nn/pool_mapper.py
  3. +1
    -0
      mindinsight/mindconverter/graph_based_converter/mapper/onnx_to_ms.json
  4. +33
    -18
      mindinsight/mindconverter/graph_based_converter/report_generator.py
  5. +15
    -2
      mindinsight/mindconverter/graph_based_converter/third_party_graph/pytorch_graph.py

+ 1
- 1
mindinsight/mindconverter/graph_based_converter/mapper/base.py View File

@@ -99,7 +99,7 @@ class ONNXToMindSporeMapper(Mapper, abc.ABC):
converted_params = params_converter(params, weights)
converted_weights = weights_converter(weights) if weights else dict()
converted_params.update(converted_weights)
except (AttributeError,) as _:
except (AttributeError, KeyError, ValueError, TypeError) as _:
print(f"Converting {op_name} failed.")
return None, dict()



+ 11
- 9
mindinsight/mindconverter/graph_based_converter/mapper/impl/nn/pool_mapper.py View File

@@ -30,15 +30,17 @@ class PoolMapper(ONNXToMindSporeMapper):

@staticmethod
def _convert_params(params, weights):
if sum(params['pads']) == 0:
pad_mode = '\"valid\"'
else:
pad_mode = '\"same\"'
return {
'kernel_size': tuple(params['kernel_shape']),
'stride': tuple(params['strides']),
'pad_mode': pad_mode
}
transformed_params = dict()
transformed_params["kernel_size"] = tuple(params['kernel_shape'])
transformed_params["stride"] = tuple(params['strides'])
if "pads" in params:
if sum(params['pads']) == 0:
pad_mode = '\"valid\"'
else:
pad_mode = '\"same\"'
transformed_params["pad_mode"] = pad_mode

return transformed_params

@staticmethod
def _convert_trained_weights(weights):


+ 1
- 0
mindinsight/mindconverter/graph_based_converter/mapper/onnx_to_ms.json View File

@@ -4,6 +4,7 @@
"onnx::BatchNormalization": "mindinsight.mindconverter.graph_based_converter.mapper.impl.nn.batch_norm_mapper.BatchNormMapper",
"onnx::Relu": "mindinsight.mindconverter.graph_based_converter.mapper.impl.nn.relu_mapper.ReLUMapper",
"onnx::MaxPool": "mindinsight.mindconverter.graph_based_converter.mapper.impl.nn.pool_mapper.PoolMapper",
"onnx::AveragePool": "mindinsight.mindconverter.graph_based_converter.mapper.impl.nn.pool_mapper.PoolMapper",
"onnx::GlobalAveragePool": "mindinsight.mindconverter.graph_based_converter.mapper.impl.nn.global_pool_mapper.GlobalPoolMapper",
"onnx::Flatten": "mindinsight.mindconverter.graph_based_converter.mapper.impl.nn.flatten_mapper.FlattenMapper",
"onnx::Add": "mindinsight.mindconverter.graph_based_converter.mapper.impl.ops.add_mapper.AddMapper"

+ 33
- 18
mindinsight/mindconverter/graph_based_converter/report_generator.py View File

@@ -19,7 +19,8 @@ from mindinsight.mindconverter.graph_based_converter.constant import NEW_LINE, S


class ReportGenerator(metaclass=abc.ABCMeta):
"""Generate report."""
"""Generate report of scripts transformation."""

def __init__(self):
self._title = self._gen_title()
self._extra = self._gen_extra()
@@ -27,27 +28,40 @@ class ReportGenerator(metaclass=abc.ABCMeta):

@staticmethod
def _gen_title():
"""Generate title."""
"""
Generate title of scripts transformation.

Returns:
str, title of scripts transformation report.
"""
title_info = ''
return title_info

@staticmethod
def _gen_extra():
"""Generate extra information."""
"""
Generate extra information.

Returns:
str, body of scripts transformation report.
"""
extra_info = {'start': '[Start Convert]',
'end': '[Convert Over]'}
return extra_info

@property
def title(self):
"""Title property."""
return self._title

@property
def extra(self):
"""Extra property."""
return self._extra

@staticmethod
def _gen_converted_operator_content(converted_location: list, converted_operator_name):
def _gen_converted_operator_content(converted_location: list,
converted_operator_name):
"""
Generate converted operator content.

@@ -61,11 +75,13 @@ class ReportGenerator(metaclass=abc.ABCMeta):
unconverted_operator_name = ''
content = \
f"line {':'.join(converted_location)}: " \
f"[Convert]'{unconverted_operator_name}' is converted to {converted_operator_name}."
f"[Convert]'{unconverted_operator_name}' is converted to " \
f"{converted_operator_name}."
return content

@staticmethod
def _gen_unconverted_operator_content(unconverted_location: list, unconverted_operator_name):
def _gen_unconverted_operator_content(unconverted_location: list,
unconverted_operator_name):
"""
Generate unconverted operator content.

@@ -76,9 +92,8 @@ class ReportGenerator(metaclass=abc.ABCMeta):
Returns:
String, report content of unconverted operator.
"""
content = \
f"line {':'.join(unconverted_location)}: " \
f"[UnConvert] '{unconverted_operator_name}' didn't convert."
content = f"line {':'.join(unconverted_location)}: " \
f"[UnConvert] '{unconverted_operator_name}' didn't convert."
return content

def gen_report(self, code: str):
@@ -100,18 +115,18 @@ class ReportGenerator(metaclass=abc.ABCMeta):
code_line = code_lines[num_line]
if 'onnx.' in code_line:
num_unconverted_line += 1
unconverted_operator = \
SEPARATOR_IN_ONNX_OP.join(('onnx', re.findall(r".*onnx.(.*)[(]", code_line)[0]))
info_unconverted_line = \
self._gen_unconverted_operator_content(
[f"{num_line + 1}", f"{code_line.index('onnx.') + 1}"],
unconverted_operator
)
self._content = f"{NEW_LINE}".join((self._content, info_unconverted_line))
unconverted_operator = SEPARATOR_IN_ONNX_OP.join(
('onnx', re.findall(r".*onnx.(.*)[(]", code_line)[0]))
info_unconverted_line = self._gen_unconverted_operator_content(
[f"{num_line + 1}", f"{code_line.index('onnx.') + 1}"],
unconverted_operator
)
self._content = f"{NEW_LINE}".join((self._content,
info_unconverted_line))
self._content = f"{NEW_LINE}".join((self._content, self._extra['end']))

converted_rate = (num_all_lines - num_unconverted_line) / num_all_lines
info_converted_rate = f"Converted Rate: {converted_rate * 100:.2f}%."
info_converted_rate = f"Converted Rate: {converted_rate * 100:.2f}%.{NEW_LINE}"
self._content = f"{NEW_LINE}".join((self._content, info_converted_rate))

return self._content

+ 15
- 2
mindinsight/mindconverter/graph_based_converter/third_party_graph/pytorch_graph.py View File

@@ -111,6 +111,18 @@ class PyTorchGraph(Graph):
self._check_input_shape(input_shape)

def _extract_shape(shape):
"""
Extract shape from string-type shape.

Args:
shape (str): Shape value in string-type.

Returns:
list, shape.
"""
pattern = re.compile(r"\d+:\d*")
if not pattern.findall(shape):
return []
return [int(x.split(":")[0].replace("!", "")) for x in shape.split(',')]

feed_forward_ipt_shape = (1, *input_shape)
@@ -132,7 +144,7 @@ class PyTorchGraph(Graph):
output_shape_str = output_shape_str_list[1]
output_shape = _extract_shape(output_shape_str)
weight_scope = '.'.join(
re.findall(r'\[([\w\d.]+)\]', node.scopeName())
re.findall(r'\[([\w\d.]+)]', node.scopeName())
)
node_weight = {}
for scope, weight in self._params_dict.items():
@@ -195,7 +207,8 @@ class PyTorchGraph(Graph):
node_input = self._input_shape.get(node_name)

if not node_input:
raise ValueError(f"Cannot find {node_name}'s input shape.")
raise ValueError(f"This model is not supported now. "
f"Cannot find {node_name}'s input shape.")

tree.insert(node_inst, node_name, node_input, node_output)
node_input = node_output


Loading…
Cancel
Save