Browse Source

Add one exception catch & Optimize code

tags/v1.0.0
moran 5 years ago
parent
commit
be17cc4eae
2 changed files with 29 additions and 32 deletions
  1. +29
    -1
      mindinsight/mindconverter/cli.py
  2. +0
    -31
      mindinsight/mindconverter/graph_based_converter/third_party_graph/pytorch_graph.py

+ 29
- 1
mindinsight/mindconverter/cli.py View File

@@ -109,6 +109,8 @@ class ProjectPathAction(argparse.Action):
option_string (str): Optional string for specific argument name. Default: None.
"""
outfile_dir = FileDirAction.check_path(parser, values, option_string)
if not os.path.exists(outfile_dir):
parser.error(f'{option_string} {outfile_dir} not exists')
if not os.path.isdir(outfile_dir):
parser.error(f'{option_string} [{outfile_dir}] should be a directory.')

@@ -138,6 +140,32 @@ class InFileAction(argparse.Action):
setattr(namespace, self.dest, outfile_dir)


class ModelFileAction(argparse.Action):
"""Model File action class definition."""

def __call__(self, parser, namespace, values, option_string=None):
"""
Inherited __call__ method from argparse.Action.

Args:
parser (ArgumentParser): Passed-in argument parser.
namespace (Namespace): Namespace object to hold arguments.
values (object): Argument values with type depending on argument definition.
option_string (str): Optional string for specific argument name. Default: None.
"""
outfile_dir = FileDirAction.check_path(parser, values, option_string)
if not os.path.exists(outfile_dir):
parser.error(f'{option_string} {outfile_dir} not exists')

if not os.path.isfile(outfile_dir):
parser.error(f'{option_string} {outfile_dir} is not a file')

if not outfile_dir.endswith('.pth'):
parser.error(f"{option_string} {outfile_dir} should be a Pytorch model, ending with '.pth'.")

setattr(namespace, self.dest, outfile_dir)


class LogFileAction(argparse.Action):
"""Log file action class definition."""

@@ -208,7 +236,7 @@ def cli_entry():
parser.add_argument(
'--model_file',
type=str,
action=InFileAction,
action=ModelFileAction,
required=False,
help="""
Pytorch .pth model file path ot use graph


+ 0
- 31
mindinsight/mindconverter/graph_based_converter/third_party_graph/pytorch_graph.py View File

@@ -204,37 +204,6 @@ class PyTorchGraph(Graph):
"""
raise NotImplementedError()

def to_hierarchical_tree(self):
"""
Generate hierarchical tree based on graph.
"""
from ..hierarchical_tree import HierarchicalTree

tree = HierarchicalTree()
node_input = None
for _, node_name in enumerate(self.nodes_in_topological_order):
node_inst = self.get_node(node_name)
node_output = self._shape_dict.get(node_name)
if node_inst.in_degree == 0:
# If in-degree equals to zero, then it's a input node.
continue

# If the node is on the top, then fetch its input
# from input table.
if not node_input:
node_input = self._input_shape.get(node_name)

if not node_input:
error = ValueError(f"This model is not supported now. "
f"Cannot find {node_name}'s input shape.")
log.error(str(error))
log.exception(error)
raise error

tree.insert(node_inst, node_name, node_input, node_output)
node_input = node_output
return tree

def build_connection(self, src, tgt) -> NoReturn:
"""
Build connection between source node and target node.


Loading…
Cancel
Save