From be17cc4eaecdc4cd277fefb7e45ccb31bcc49198 Mon Sep 17 00:00:00 2001 From: moran Date: Mon, 14 Sep 2020 20:47:28 +0800 Subject: [PATCH] Add one exception catch & Optimize code --- mindinsight/mindconverter/cli.py | 30 +++++++++++++++++- .../third_party_graph/pytorch_graph.py | 31 ------------------- 2 files changed, 29 insertions(+), 32 deletions(-) diff --git a/mindinsight/mindconverter/cli.py b/mindinsight/mindconverter/cli.py index 2529e23d..722ae687 100644 --- a/mindinsight/mindconverter/cli.py +++ b/mindinsight/mindconverter/cli.py @@ -109,6 +109,8 @@ class ProjectPathAction(argparse.Action): option_string (str): Optional string for specific argument name. Default: None. """ outfile_dir = FileDirAction.check_path(parser, values, option_string) + if not os.path.exists(outfile_dir): + parser.error(f'{option_string} {outfile_dir} not exists') if not os.path.isdir(outfile_dir): parser.error(f'{option_string} [{outfile_dir}] should be a directory.') @@ -138,6 +140,32 @@ class InFileAction(argparse.Action): setattr(namespace, self.dest, outfile_dir) +class ModelFileAction(argparse.Action): + """Model File action class definition.""" + + def __call__(self, parser, namespace, values, option_string=None): + """ + Inherited __call__ method from argparse.Action. + + Args: + parser (ArgumentParser): Passed-in argument parser. + namespace (Namespace): Namespace object to hold arguments. + values (object): Argument values with type depending on argument definition. + option_string (str): Optional string for specific argument name. Default: None. + """ + outfile_dir = FileDirAction.check_path(parser, values, option_string) + if not os.path.exists(outfile_dir): + parser.error(f'{option_string} {outfile_dir} not exists') + + if not os.path.isfile(outfile_dir): + parser.error(f'{option_string} {outfile_dir} is not a file') + + if not outfile_dir.endswith('.pth'): + parser.error(f"{option_string} {outfile_dir} should be a Pytorch model, ending with '.pth'.") + + setattr(namespace, self.dest, outfile_dir) + + class LogFileAction(argparse.Action): """Log file action class definition.""" @@ -208,7 +236,7 @@ def cli_entry(): parser.add_argument( '--model_file', type=str, - action=InFileAction, + action=ModelFileAction, required=False, help=""" Pytorch .pth model file path ot use graph diff --git a/mindinsight/mindconverter/graph_based_converter/third_party_graph/pytorch_graph.py b/mindinsight/mindconverter/graph_based_converter/third_party_graph/pytorch_graph.py index a7a2e70a..12a5bd72 100644 --- a/mindinsight/mindconverter/graph_based_converter/third_party_graph/pytorch_graph.py +++ b/mindinsight/mindconverter/graph_based_converter/third_party_graph/pytorch_graph.py @@ -204,37 +204,6 @@ class PyTorchGraph(Graph): """ raise NotImplementedError() - def to_hierarchical_tree(self): - """ - Generate hierarchical tree based on graph. - """ - from ..hierarchical_tree import HierarchicalTree - - tree = HierarchicalTree() - node_input = None - for _, node_name in enumerate(self.nodes_in_topological_order): - node_inst = self.get_node(node_name) - node_output = self._shape_dict.get(node_name) - if node_inst.in_degree == 0: - # If in-degree equals to zero, then it's a input node. - continue - - # If the node is on the top, then fetch its input - # from input table. - if not node_input: - node_input = self._input_shape.get(node_name) - - if not node_input: - error = ValueError(f"This model is not supported now. " - f"Cannot find {node_name}'s input shape.") - log.error(str(error)) - log.exception(error) - raise error - - tree.insert(node_inst, node_name, node_input, node_output) - node_input = node_output - return tree - def build_connection(self, src, tgt) -> NoReturn: """ Build connection between source node and target node.