Browse Source

rewrite code-like comments to avoid CI syntax warnings

Add error info. for input and output node name check
tags/v1.1.0
liangtianshu 5 years ago
parent
commit
5724ef7662
2 changed files with 14 additions and 4 deletions
  1. +1
    -1
      mindinsight/mindconverter/graph_based_converter/mapper/impl/nn/conv_mapper.py
  2. +13
    -3
      mindinsight/mindconverter/graph_based_converter/third_party_graph/onnx_utils.py

+ 1
- 1
mindinsight/mindconverter/graph_based_converter/mapper/impl/nn/conv_mapper.py View File

@@ -100,7 +100,7 @@ class ConvMapper(ONNXToMindSporeMapper):
stride = params.get('strides')

kernel_size = params.get('kernel_shape')
# Onnx inchannel = ms inchannel / group
# Onnx in_channel equals ms inchannel divided by group
in_channels = weight.shape[1] * params.get('group', 1)
out_channels = weight.shape[0]
if len(kernel_size) == 1:


+ 13
- 3
mindinsight/mindconverter/graph_based_converter/third_party_graph/onnx_utils.py View File

@@ -52,6 +52,17 @@ def convert_tf_graph_to_onnx(model_path, model_inputs, model_outputs, opset=None

target = ",".join(constants.DEFAULT_TARGET)
shape_override = None
if not 'input' in model_inputs:
error_msg = "The given input node is not an eligible input node."
error = ValueError(error_msg)
log.error(str(error))
raise error

if 'input' in model_outputs:
error_msg = "The given output node is an input node."
error = ValueError(error_msg)
log.error(str(error))
raise error

if model_inputs:
model_inputs, shape_override = utils.split_nodename_and_shape(
@@ -254,8 +265,7 @@ class OnnxDataLoader:

self.nodes_dict = OrderedDict() # {node_name: OnnxNode} NO INPUT NODE
self.tensors_dict = {} # {tensor_name: OnnxTensor}
# {node_name : (type, dim)} NO INPUT & OUTPUT NODE
self.value_info_dict = {}
self.value_info_dict = {} # Not contains input and output nodes

self.tensor_name_set = set() # [str]
self.node_name_set = set() # [str]
@@ -395,7 +405,7 @@ class OnnxDataLoader:
for node_name, node in self.nodes_dict.items():
# for each input of a node
for input_name in node.input_name_list:
# input_name = input_node:0, remove :0 here
# remove :0 in the name to ensure consistency in hierarical tree.
input_name = input_name.split(':')[0]
if input_name in self.node_name_set:
# input is a node


Loading…
Cancel
Save