Browse Source

Fix global context reference problem.

tags/v1.2.0-rc1
liuchongming 4 years ago
parent
commit
e089c974a3
8 changed files with 95 additions and 54 deletions
  1. +12
    -2
      mindinsight/mindconverter/graph_based_converter/common/global_context.py
  2. +37
    -0
      mindinsight/mindconverter/graph_based_converter/common/yapf_config.py
  3. +0
    -5
      mindinsight/mindconverter/graph_based_converter/constant.py
  4. +5
    -1
      mindinsight/mindconverter/graph_based_converter/framework.py
  5. +17
    -16
      mindinsight/mindconverter/graph_based_converter/generator/generator.py
  6. +7
    -13
      mindinsight/mindconverter/graph_based_converter/generator/module_struct.py
  7. +16
    -16
      mindinsight/mindconverter/graph_based_converter/generator/node_struct.py
  8. +1
    -1
      mindinsight/mindconverter/graph_based_converter/sub_graph_searcher/pattern.py

+ 12
- 2
mindinsight/mindconverter/graph_based_converter/common/global_context.py View File

@@ -26,6 +26,11 @@ class Singleton(type):
cls._instances[cls] = super(Singleton, cls).__call__(*args, **kwargs) cls._instances[cls] = super(Singleton, cls).__call__(*args, **kwargs)
return cls._instances[cls] return cls._instances[cls]


@classmethod
def release(mcs):
"""Clear singleton object."""
mcs._instances.clear()



class GlobalContext(metaclass=Singleton): class GlobalContext(metaclass=Singleton):
""" """
@@ -110,7 +115,7 @@ class GlobalContext(metaclass=Singleton):
if isinstance(arg, OrderedDict): if isinstance(arg, OrderedDict):
self._onnx_nodes_collection = arg # arg must be nodes_dict in OnnxDataLoader self._onnx_nodes_collection = arg # arg must be nodes_dict in OnnxDataLoader
else: else:
raise TypeError("GlobalContext received an unsupport variable to assign to onnx_nodes_collection.")
raise TypeError("GlobalContext received an unsupported variable to assign to onnx_nodes_collection.")


@property @property
def onnx_nodes_topo_index(self) -> dict: def onnx_nodes_topo_index(self) -> dict:
@@ -149,7 +154,7 @@ class GlobalContext(metaclass=Singleton):
if isinstance(arg, dict): if isinstance(arg, dict):
self._onnx_tensors_collection = arg # arg must be tensors_dict in OnnxDataLoader self._onnx_tensors_collection = arg # arg must be tensors_dict in OnnxDataLoader
else: else:
raise TypeError("GlobalContext received an unsupport variable to assign to onnx_tensors_collection.")
raise TypeError("GlobalContext received an unsupported variable to assign to onnx_tensors_collection.")


@property @property
def latest_node_struct_count(self): def latest_node_struct_count(self):
@@ -237,3 +242,8 @@ class GlobalContext(metaclass=Singleton):
self.module_structs[pattern_id] = [module_struct] self.module_structs[pattern_id] = [module_struct]
else: else:
self.module_structs[pattern_id].append(module_struct) self.module_structs[pattern_id].append(module_struct)

@classmethod
def release(cls):
"""Clear singleton object."""
Singleton.release()

+ 37
- 0
mindinsight/mindconverter/graph_based_converter/common/yapf_config.py View File

@@ -0,0 +1,37 @@
# Copyright 2021 Huawei Technologies Co., Ltd.All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Define code format configuration."""
from yapf.yapflib.style import CreatePEP8Style


def mindspore_yapf_config():
"""Create the PEP8 formatting style."""
style = CreatePEP8Style()
style['ALLOW_SPLIT_BEFORE_DEFAULT_OR_NAMED_ASSIGNS'] = False
style['ALLOW_MULTILINE_LAMBDAS'] = True
style['ALLOW_SPLIT_BEFORE_DICT_VALUE'] = False
style['COLUMN_LIMIT'] = 120
style['COALESCE_BRACKETS'] = True
style['FORCE_MULTILINE_DICT'] = True
style['DISABLE_ENDING_COMMA_HEURISTIC'] = True
style['INDENT_DICTIONARY_VALUE'] = True
style['JOIN_MULTIPLE_LINES'] = False
style['SPACES_BEFORE_COMMENT'] = 2
style['SPLIT_PENALTY_AFTER_OPENING_BRACKET'] = 30
style['SPLIT_PENALTY_BEFORE_IF_EXPR'] = 30
style['SPLIT_PENALTY_FOR_ADDED_LINE_SPLIT'] = 30
style['SPLIT_BEFORE_LOGICAL_OPERATOR'] = False
style['SPLIT_BEFORE_BITWISE_OPERATOR'] = False
return style

+ 0
- 5
mindinsight/mindconverter/graph_based_converter/constant.py View File

@@ -104,11 +104,6 @@ NO_CONVERTED_OPERATORS = [
] ]




@unique
class CodeFormatConfig(Enum):
PEP8 = "pep8"


@unique @unique
class NodeType(Enum): class NodeType(Enum):
MODULE = "module" MODULE = "module"


+ 5
- 1
mindinsight/mindconverter/graph_based_converter/framework.py View File

@@ -20,6 +20,7 @@ from importlib import import_module
from importlib.util import find_spec from importlib.util import find_spec


import mindinsight import mindinsight
from mindinsight.mindconverter.graph_based_converter.common.global_context import GlobalContext
from mindinsight.mindconverter.graph_based_converter.common.utils import lib_version_satisfied, \ from mindinsight.mindconverter.graph_based_converter.common.utils import lib_version_satisfied, \
save_code_file_and_report, get_framework_type save_code_file_and_report, get_framework_type
from mindinsight.mindconverter.graph_based_converter.constant import FrameworkType, \ from mindinsight.mindconverter.graph_based_converter.constant import FrameworkType, \
@@ -199,13 +200,14 @@ def graph_based_converter_pytorch_to_ms(graph_path: str, sample_shape: tuple,
output_folder (str): Output folder. output_folder (str): Output folder.
report_folder (str): Report output folder path. report_folder (str): Report output folder path.
""" """

graph_obj = GraphFactory.init(graph_path, sample_shape=sample_shape, graph_obj = GraphFactory.init(graph_path, sample_shape=sample_shape,
input_nodes=input_nodes, output_nodes=output_nodes) input_nodes=input_nodes, output_nodes=output_nodes)
generator_inst = batch_add_nodes(graph_obj, ONNXToMindSporeMapper) generator_inst = batch_add_nodes(graph_obj, ONNXToMindSporeMapper)
model_name = _extract_model_name(graph_path) model_name = _extract_model_name(graph_path)
code_fragments = generator_inst.generate() code_fragments = generator_inst.generate()
save_code_file_and_report(model_name, code_fragments, output_folder, report_folder) save_code_file_and_report(model_name, code_fragments, output_folder, report_folder)
# Release global context.
GlobalContext.release()




@tf_installation_validation @tf_installation_validation
@@ -238,6 +240,8 @@ def graph_based_converter_tf_to_ms(graph_path: str, sample_shape: tuple,
model_name = _extract_model_name(graph_path) model_name = _extract_model_name(graph_path)
code_fragments = generator_inst.generate() code_fragments = generator_inst.generate()
save_code_file_and_report(model_name, code_fragments, output_folder, report_folder) save_code_file_and_report(model_name, code_fragments, output_folder, report_folder)
# Release global context.
GlobalContext.release()




@BaseConverterError.uniform_catcher() @BaseConverterError.uniform_catcher()


+ 17
- 16
mindinsight/mindconverter/graph_based_converter/generator/generator.py View File

@@ -18,16 +18,18 @@ from collections import OrderedDict


from yapf.yapflib.yapf_api import FormatCode from yapf.yapflib.yapf_api import FormatCode


from .scope_utils import Scope
from .node_struct import NodeStruct
from .module_struct import ModuleStruct
from .args_translator import ArgsTranslationHelper
from ..common.global_context import GlobalContext
from ..common.outputs import BaseOutput, ModuleOutputManager
from ...common.exceptions import GeneratorError
from ..common.name_mgr import GlobalVarNameMgr
from ..constant import NEW_LINE, SECOND_LEVEL_INDENT, FIRST_LEVEL_INDENT, CodeFormatConfig, get_imported_module
from ..report_generator import ReportGenerator
from mindinsight.mindconverter.common.exceptions import GeneratorError
from mindinsight.mindconverter.graph_based_converter.generator.scope_utils import Scope
from mindinsight.mindconverter.graph_based_converter.generator.node_struct import NodeStruct
from mindinsight.mindconverter.graph_based_converter.generator.module_struct import ModuleStruct
from mindinsight.mindconverter.graph_based_converter.generator.args_translator import ArgsTranslationHelper
from mindinsight.mindconverter.graph_based_converter.common.global_context import GlobalContext
from mindinsight.mindconverter.graph_based_converter.common.outputs import BaseOutput, ModuleOutputManager
from mindinsight.mindconverter.graph_based_converter.common.yapf_config import mindspore_yapf_config
from mindinsight.mindconverter.graph_based_converter.common.name_mgr import GlobalVarNameMgr
from mindinsight.mindconverter.graph_based_converter.constant import NEW_LINE, SECOND_LEVEL_INDENT, FIRST_LEVEL_INDENT, \
get_imported_module
from mindinsight.mindconverter.graph_based_converter.report_generator import ReportGenerator




class CodeStruct: class CodeStruct:
@@ -35,7 +37,6 @@ class CodeStruct:
Define the Code template for each module generated in the final output. Define the Code template for each module generated in the final output.
Each module has only one CodeStruct to its pattern. Each module has only one CodeStruct to its pattern.
""" """
GLOBAL_CONTEXT = GlobalContext()
NOT_IN_SCOPE_OPT = dict() NOT_IN_SCOPE_OPT = dict()


def __init__(self, struct, repeated_submodules=None): def __init__(self, struct, repeated_submodules=None):
@@ -104,7 +105,7 @@ class CodeStruct:


elif isinstance(struct, ModuleStruct): elif isinstance(struct, ModuleStruct):
# check if this instance generated CodeStruct # check if this instance generated CodeStruct
if self.GLOBAL_CONTEXT.code_structs.get(struct.pattern_id) is None:
if GlobalContext().code_structs.get(struct.pattern_id) is None:
CodeStruct(struct, repeated_submodules) CodeStruct(struct, repeated_submodules)


code_line_init = struct.code_line_in_init() code_line_init = struct.code_line_in_init()
@@ -138,10 +139,10 @@ class CodeStruct:
returns = list(set(returns)) returns = list(set(returns))
else: else:
returns = [code_line_construct[0]] if isinstance(code_line_construct, tuple) \ returns = [code_line_construct[0]] if isinstance(code_line_construct, tuple) \
else [code_line_construct[-1].replace(' ', '').split('=')[0]]
else [code_line_construct[-1].replace(' ', '').split('=')[0]]
self.new_line = f"{SECOND_LEVEL_INDENT}return {', '.join(returns)}" self.new_line = f"{SECOND_LEVEL_INDENT}return {', '.join(returns)}"
self.new_line = f"{NEW_LINE * 2}" self.new_line = f"{NEW_LINE * 2}"
self.GLOBAL_CONTEXT.code_structs[md_struct.pattern_id] = self
GlobalContext().code_structs[md_struct.pattern_id] = self




class Generator: class Generator:
@@ -482,7 +483,7 @@ class Generator:
outputs.append(line) outputs.append(line)


formatted_code, _ = FormatCode("\n".join(outputs), formatted_code, _ = FormatCode("\n".join(outputs),
style_config=CodeFormatConfig.PEP8.value)
style_config=mindspore_yapf_config())


report_generator = ReportGenerator() report_generator = ReportGenerator()
report = report_generator.gen_report(formatted_code) report = report_generator.gen_report(formatted_code)
@@ -589,7 +590,7 @@ class Generator:
output_obj.idx_in_ms_user[nd_struct.identifier] = idx output_obj.idx_in_ms_user[nd_struct.identifier] = idx


# set this output to be returned to external # set this output to be returned to external
output_obj.to_external = not(nd_struct.check_target_node_internal(
output_obj.to_external = not (nd_struct.check_target_node_internal(
self._global_context.outputs_storage.onnx_name(inp) self._global_context.outputs_storage.onnx_name(inp)
)) ))




+ 7
- 13
mindinsight/mindconverter/graph_based_converter/generator/module_struct.py View File

@@ -31,11 +31,10 @@ class ModuleStruct:
Define a module struct which stores all info. to generate statement. Define a module struct which stores all info. to generate statement.


Args: Args:
args (list): A list of node structs.
nd_struct_list (list): A list of node structs.
init_as_parent (bool): Control init method if the ModuleStruct be init as a parent module struct. init_as_parent (bool): Control init method if the ModuleStruct be init as a parent module struct.
parent_base (ModuleStruct): The base ModuleStruct the current ModuleStruct to be init as. parent_base (ModuleStruct): The base ModuleStruct the current ModuleStruct to be init as.
""" """
GLOBAL_CONTEXT_MGR = GlobalContext()


def __init__(self, nd_struct_list, init_as_parent=False, parent_base=None): def __init__(self, nd_struct_list, init_as_parent=False, parent_base=None):
"""Init. a module by NodeStructs.""" """Init. a module by NodeStructs."""
@@ -247,7 +246,7 @@ class ModuleStruct:
self._module_structs += md_structs self._module_structs += md_structs
tail_md = md_structs[-1] tail_md = md_structs[-1]
else: else:
raise TypeError("ModuleStruct cannot add an unsupport Type {} to module_structs list.".format(
raise TypeError("ModuleStruct cannot add an unsupported Type {} to module_structs list.".format(
type(md_structs))) type(md_structs)))
# update tail node and index # update tail node and index
if self.tail_nd_struct_index < tail_md.tail_nd_struct_index: if self.tail_nd_struct_index < tail_md.tail_nd_struct_index:
@@ -318,12 +317,7 @@ class ModuleStruct:
return ret return ret


def code_line_in_init(self): def code_line_in_init(self):
"""
Initialization line of code in module init block.

Args:
override_formal_val (dict): Indicate which args should be renamed for passing value from upper level.
"""
"""Initialization line of code in module init block."""
left = "self.{}".format(self.ms_var_name) left = "self.{}".format(self.ms_var_name)
args_list = list() args_list = list()
# Load args in init statement. # Load args in init statement.
@@ -338,7 +332,7 @@ class ModuleStruct:
else: else:
args_list += self._fragment.actual_args args_list += self._fragment.actual_args
right = f"{self.class_name}({', '.join(args_list)})" right = f"{self.class_name}({', '.join(args_list)})"
return (left, right)
return left, right


def code_line_in_construct(self, inputs=None): def code_line_in_construct(self, inputs=None):
"""Construct line of code in module construct block.""" """Construct line of code in module construct block."""
@@ -356,7 +350,7 @@ class ModuleStruct:
if isinstance(inputs, str): if isinstance(inputs, str):
inputs = [inputs] inputs = [inputs]
right = f"self.{self.ms_var_name}({', '.join(inputs)})" right = f"self.{self.ms_var_name}({', '.join(inputs)})"
return (left, right)
return left, right


@property @property
def node_structs(self): def node_structs(self):
@@ -463,8 +457,8 @@ class ModuleStruct:
"""Return the class name for generating code of this module.""" """Return the class name for generating code of this module."""
if self.pattern_id == -1: if self.pattern_id == -1:
return "Model" return "Model"
if self.GLOBAL_CONTEXT_MGR.known_module_name.get("Module{}".format(self.pattern_id)) is not None:
class_name = self.GLOBAL_CONTEXT_MGR.known_module_name.get("Module{}".format(self.pattern_id))
if GlobalContext().known_module_name.get("Module{}".format(self.pattern_id)) is not None:
class_name = GlobalContext().known_module_name.get("Module{}".format(self.pattern_id))
else: else:
class_name = "Module{}".format(self.pattern_id) class_name = "Module{}".format(self.pattern_id)
return class_name return class_name


+ 16
- 16
mindinsight/mindconverter/graph_based_converter/generator/node_struct.py View File

@@ -23,6 +23,7 @@ from ..third_party_graph.onnx_graph_node import OnnxGraphNode
from ..common.global_context import GlobalContext from ..common.global_context import GlobalContext
from ...common.exceptions import GeneratorError from ...common.exceptions import GeneratorError



class NodeStruct: class NodeStruct:
""" """
Define a node struct which stores all info. to generate statement. Define a node struct which stores all info. to generate statement.
@@ -34,10 +35,10 @@ class NodeStruct:
You can pass as many args as possible and the Node Struct will update You can pass as many args as possible and the Node Struct will update
by arguments order. by arguments order.
""" """
GLOBAL_CONTEXT_MGR = GlobalContext()


def __init__(self, args): def __init__(self, args):
# define attributes here # define attributes here
self.global_context_mgr = GlobalContext()
self._identifier = None self._identifier = None
self._fragment = None self._fragment = None
self._args_translator = None self._args_translator = None
@@ -74,7 +75,7 @@ class NodeStruct:
"""Get the original topological index in the onnx graph.""" """Get the original topological index in the onnx graph."""
ori_name = self._fragment.metadata.get('source') ori_name = self._fragment.metadata.get('source')
self.onnx_name = ori_name self.onnx_name = ori_name
return self.GLOBAL_CONTEXT_MGR.onnx_node_name_to_topo_idx.get(ori_name)
return GlobalContext().onnx_node_name_to_topo_idx.get(ori_name)


def update_var_name(self, idx=None): def update_var_name(self, idx=None):
""" """
@@ -83,6 +84,7 @@ class NodeStruct:
Args: Args:
idx (int): The index of the node in this module. idx (int): The index of the node in this module.
""" """

def _remove_op_header(op_name): def _remove_op_header(op_name):
"""Remove op header which indicating their sources of op set.""" """Remove op header which indicating their sources of op set."""
op_name = op_name.replace('nn.', '') op_name = op_name.replace('nn.', '')
@@ -112,7 +114,7 @@ class NodeStruct:
self._fragment = FragmentHandler(frag) self._fragment = FragmentHandler(frag)


if self.ms_op: if self.ms_op:
idx = self.GLOBAL_CONTEXT_MGR.latest_node_struct_count
idx = GlobalContext().latest_node_struct_count
self.update_var_name(idx=idx) self.update_var_name(idx=idx)


def _set_scope_from_identifier(self): def _set_scope_from_identifier(self):
@@ -142,9 +144,7 @@ class NodeStruct:


Args: Args:
arg (Union[PyTorchGraphNode, OnnxGraphNode, dict]): Node related obj. arg (Union[PyTorchGraphNode, OnnxGraphNode, dict]): Node related obj.
force_ready (bool): Force this NodeStruct is ready to generate.
""" """

if isinstance(arg, OnnxGraphNode): if isinstance(arg, OnnxGraphNode):
self._update_from_onnx_gn(arg) self._update_from_onnx_gn(arg)
elif isinstance(arg, NewFragment): elif isinstance(arg, NewFragment):
@@ -168,7 +168,7 @@ class NodeStruct:
self._identifier = s self._identifier = s
self._set_scope_from_identifier() self._set_scope_from_identifier()
self.topo_idx = self.ori_topo_idx() self.topo_idx = self.ori_topo_idx()
self.GLOBAL_CONTEXT_MGR.onnx_node_name_to_node_struct_map[self.onnx_name] = self
GlobalContext().onnx_node_name_to_node_struct_map[self.onnx_name] = self


@property @property
def fragment(self): def fragment(self):
@@ -181,7 +181,7 @@ class NodeStruct:
Set the Node fragment. Set the Node fragment.


Args: Args:
s (NodeFragment): The node identifier string.
frag (NodeFragment): The node identifier string.
""" """
self._fragment = frag self._fragment = frag


@@ -198,7 +198,7 @@ class NodeStruct:
@property @property
def onnx_node(self): def onnx_node(self):
"""Return the original onnx node reference.""" """Return the original onnx node reference."""
return self.GLOBAL_CONTEXT_MGR.onnx_nodes_collection.get(self.onnx_name)
return GlobalContext().onnx_nodes_collection.get(self.onnx_name)


@property @property
def ms_op(self): def ms_op(self):
@@ -241,7 +241,7 @@ class NodeStruct:
ret = [] ret = []
precursor_nodes_names = self.precursor_nodes_names precursor_nodes_names = self.precursor_nodes_names
for pre_node_name in precursor_nodes_names: for pre_node_name in precursor_nodes_names:
nd_struct = self.GLOBAL_CONTEXT_MGR.onnx_node_name_to_node_struct_map.get(pre_node_name)
nd_struct = GlobalContext().onnx_node_name_to_node_struct_map.get(pre_node_name)
ret.append(nd_struct) ret.append(nd_struct)
return ret return ret


@@ -255,7 +255,7 @@ class NodeStruct:
"""Return the node struct instances of successor nodes.""" """Return the node struct instances of successor nodes."""
ret = [] ret = []
for pre_node_name in self.successor_nodes_names: for pre_node_name in self.successor_nodes_names:
nd_struct = self.GLOBAL_CONTEXT_MGR.onnx_node_name_to_node_struct_map.get(pre_node_name)
nd_struct = GlobalContext().onnx_node_name_to_node_struct_map.get(pre_node_name)
ret.append(nd_struct) ret.append(nd_struct)
return ret return ret


@@ -312,11 +312,11 @@ class NodeStruct:
inputs = self.matched_inputs inputs = self.matched_inputs


# Check original onnx node's input to ensure double inputs are not ignored # Check original onnx node's input to ensure double inputs are not ignored
original_inputs = self.GLOBAL_CONTEXT_MGR.onnx_node_inputs.get(self.onnx_name)
original_inputs = GlobalContext().onnx_node_inputs.get(self.onnx_name)
new_inputs = [] new_inputs = []
for idx, prec_node in enumerate(self.precursor_nodes_names): for idx, prec_node in enumerate(self.precursor_nodes_names):
occurence = original_inputs.count(prec_node)
for _ in range(occurence):
occurrence = original_inputs.count(prec_node)
for _ in range(occurrence):
new_inputs.append(inputs[idx]) new_inputs.append(inputs[idx])
inputs = new_inputs inputs = new_inputs


@@ -360,12 +360,12 @@ class NodeStruct:
Args: Args:
name (str): Can accept both node identifier or original onnx node name. name (str): Can accept both node identifier or original onnx node name.
""" """
target_nd_struct = self.GLOBAL_CONTEXT_MGR.node_struct_collections.get(name) \
or self.GLOBAL_CONTEXT_MGR.onnx_node_name_to_node_struct_map.get(name)
target_nd_struct = GlobalContext().node_struct_collections.get(name) \
or GlobalContext().onnx_node_name_to_node_struct_map.get(name)
if target_nd_struct is None and self.topo_idx == 0: # First node always has external input if target_nd_struct is None and self.topo_idx == 0: # First node always has external input
return False return False


if target_nd_struct is None and (name in self.GLOBAL_CONTEXT_MGR.onnx_graph_info.get('graph_inputs')):
if target_nd_struct is None and (name in GlobalContext().onnx_graph_info.get('graph_inputs')):
return False return False


if target_nd_struct is None: if target_nd_struct is None:


+ 1
- 1
mindinsight/mindconverter/graph_based_converter/sub_graph_searcher/pattern.py View File

@@ -34,7 +34,7 @@ class Pattern:
# If pattern in BUILD_IN_MODULE_NAME or BUILD_IN_PATTERN, # If pattern in BUILD_IN_MODULE_NAME or BUILD_IN_PATTERN,
# the pattern will get additional score. # the pattern will get additional score.
self.additional_score = 0 self.additional_score = 0
self.know_module_name = None
self.known_module_name = None


def insert(self, idx, seq_len): def insert(self, idx, seq_len):
""" """


Loading…
Cancel
Save