|
|
|
@@ -52,6 +52,17 @@ def convert_tf_graph_to_onnx(model_path, model_inputs, model_outputs, opset=None |
|
|
|
|
|
|
|
target = ",".join(constants.DEFAULT_TARGET) |
|
|
|
shape_override = None |
|
|
|
if not 'input' in model_inputs: |
|
|
|
error_msg = "The given input node is not an eligible input node." |
|
|
|
error = ValueError(error_msg) |
|
|
|
log.error(str(error)) |
|
|
|
raise error |
|
|
|
|
|
|
|
if 'input' in model_outputs: |
|
|
|
error_msg = "The given output node is an input node." |
|
|
|
error = ValueError(error_msg) |
|
|
|
log.error(str(error)) |
|
|
|
raise error |
|
|
|
|
|
|
|
if model_inputs: |
|
|
|
model_inputs, shape_override = utils.split_nodename_and_shape( |
|
|
|
@@ -254,8 +265,7 @@ class OnnxDataLoader: |
|
|
|
|
|
|
|
self.nodes_dict = OrderedDict() # {node_name: OnnxNode} NO INPUT NODE |
|
|
|
self.tensors_dict = {} # {tensor_name: OnnxTensor} |
|
|
|
# {node_name : (type, dim)} NO INPUT & OUTPUT NODE |
|
|
|
self.value_info_dict = {} |
|
|
|
self.value_info_dict = {} # Not contains input and output nodes |
|
|
|
|
|
|
|
self.tensor_name_set = set() # [str] |
|
|
|
self.node_name_set = set() # [str] |
|
|
|
@@ -395,7 +405,7 @@ class OnnxDataLoader: |
|
|
|
for node_name, node in self.nodes_dict.items(): |
|
|
|
# for each input of a node |
|
|
|
for input_name in node.input_name_list: |
|
|
|
# input_name = input_node:0, remove :0 here |
|
|
|
# remove :0 in the name to ensure consistency in hierarical tree. |
|
|
|
input_name = input_name.split(':')[0] |
|
|
|
if input_name in self.node_name_set: |
|
|
|
# input is a node |
|
|
|
|