Browse Source

Fix global context reference problem.

tags/v1.2.0-rc1
liuchongming 4 years ago
parent
commit
e089c974a3
8 changed files with 95 additions and 54 deletions
  1. +12
    -2
      mindinsight/mindconverter/graph_based_converter/common/global_context.py
  2. +37
    -0
      mindinsight/mindconverter/graph_based_converter/common/yapf_config.py
  3. +0
    -5
      mindinsight/mindconverter/graph_based_converter/constant.py
  4. +5
    -1
      mindinsight/mindconverter/graph_based_converter/framework.py
  5. +17
    -16
      mindinsight/mindconverter/graph_based_converter/generator/generator.py
  6. +7
    -13
      mindinsight/mindconverter/graph_based_converter/generator/module_struct.py
  7. +16
    -16
      mindinsight/mindconverter/graph_based_converter/generator/node_struct.py
  8. +1
    -1
      mindinsight/mindconverter/graph_based_converter/sub_graph_searcher/pattern.py

+ 12
- 2
mindinsight/mindconverter/graph_based_converter/common/global_context.py View File

@@ -26,6 +26,11 @@ class Singleton(type):
cls._instances[cls] = super(Singleton, cls).__call__(*args, **kwargs)
return cls._instances[cls]

@classmethod
def release(mcs):
"""Clear singleton object."""
mcs._instances.clear()


class GlobalContext(metaclass=Singleton):
"""
@@ -110,7 +115,7 @@ class GlobalContext(metaclass=Singleton):
if isinstance(arg, OrderedDict):
self._onnx_nodes_collection = arg # arg must be nodes_dict in OnnxDataLoader
else:
raise TypeError("GlobalContext received an unsupport variable to assign to onnx_nodes_collection.")
raise TypeError("GlobalContext received an unsupported variable to assign to onnx_nodes_collection.")

@property
def onnx_nodes_topo_index(self) -> dict:
@@ -149,7 +154,7 @@ class GlobalContext(metaclass=Singleton):
if isinstance(arg, dict):
self._onnx_tensors_collection = arg # arg must be tensors_dict in OnnxDataLoader
else:
raise TypeError("GlobalContext received an unsupport variable to assign to onnx_tensors_collection.")
raise TypeError("GlobalContext received an unsupported variable to assign to onnx_tensors_collection.")

@property
def latest_node_struct_count(self):
@@ -237,3 +242,8 @@ class GlobalContext(metaclass=Singleton):
self.module_structs[pattern_id] = [module_struct]
else:
self.module_structs[pattern_id].append(module_struct)

@classmethod
def release(cls):
"""Clear singleton object."""
Singleton.release()

+ 37
- 0
mindinsight/mindconverter/graph_based_converter/common/yapf_config.py View File

@@ -0,0 +1,37 @@
# Copyright 2021 Huawei Technologies Co., Ltd.All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Define code format configuration."""
from yapf.yapflib.style import CreatePEP8Style


def mindspore_yapf_config():
"""Create the PEP8 formatting style."""
style = CreatePEP8Style()
style['ALLOW_SPLIT_BEFORE_DEFAULT_OR_NAMED_ASSIGNS'] = False
style['ALLOW_MULTILINE_LAMBDAS'] = True
style['ALLOW_SPLIT_BEFORE_DICT_VALUE'] = False
style['COLUMN_LIMIT'] = 120
style['COALESCE_BRACKETS'] = True
style['FORCE_MULTILINE_DICT'] = True
style['DISABLE_ENDING_COMMA_HEURISTIC'] = True
style['INDENT_DICTIONARY_VALUE'] = True
style['JOIN_MULTIPLE_LINES'] = False
style['SPACES_BEFORE_COMMENT'] = 2
style['SPLIT_PENALTY_AFTER_OPENING_BRACKET'] = 30
style['SPLIT_PENALTY_BEFORE_IF_EXPR'] = 30
style['SPLIT_PENALTY_FOR_ADDED_LINE_SPLIT'] = 30
style['SPLIT_BEFORE_LOGICAL_OPERATOR'] = False
style['SPLIT_BEFORE_BITWISE_OPERATOR'] = False
return style

+ 0
- 5
mindinsight/mindconverter/graph_based_converter/constant.py View File

@@ -104,11 +104,6 @@ NO_CONVERTED_OPERATORS = [
]


@unique
class CodeFormatConfig(Enum):
PEP8 = "pep8"


@unique
class NodeType(Enum):
MODULE = "module"


+ 5
- 1
mindinsight/mindconverter/graph_based_converter/framework.py View File

@@ -20,6 +20,7 @@ from importlib import import_module
from importlib.util import find_spec

import mindinsight
from mindinsight.mindconverter.graph_based_converter.common.global_context import GlobalContext
from mindinsight.mindconverter.graph_based_converter.common.utils import lib_version_satisfied, \
save_code_file_and_report, get_framework_type
from mindinsight.mindconverter.graph_based_converter.constant import FrameworkType, \
@@ -199,13 +200,14 @@ def graph_based_converter_pytorch_to_ms(graph_path: str, sample_shape: tuple,
output_folder (str): Output folder.
report_folder (str): Report output folder path.
"""

graph_obj = GraphFactory.init(graph_path, sample_shape=sample_shape,
input_nodes=input_nodes, output_nodes=output_nodes)
generator_inst = batch_add_nodes(graph_obj, ONNXToMindSporeMapper)
model_name = _extract_model_name(graph_path)
code_fragments = generator_inst.generate()
save_code_file_and_report(model_name, code_fragments, output_folder, report_folder)
# Release global context.
GlobalContext.release()


@tf_installation_validation
@@ -238,6 +240,8 @@ def graph_based_converter_tf_to_ms(graph_path: str, sample_shape: tuple,
model_name = _extract_model_name(graph_path)
code_fragments = generator_inst.generate()
save_code_file_and_report(model_name, code_fragments, output_folder, report_folder)
# Release global context.
GlobalContext.release()


@BaseConverterError.uniform_catcher()


+ 17
- 16
mindinsight/mindconverter/graph_based_converter/generator/generator.py View File

@@ -18,16 +18,18 @@ from collections import OrderedDict

from yapf.yapflib.yapf_api import FormatCode

from .scope_utils import Scope
from .node_struct import NodeStruct
from .module_struct import ModuleStruct
from .args_translator import ArgsTranslationHelper
from ..common.global_context import GlobalContext
from ..common.outputs import BaseOutput, ModuleOutputManager
from ...common.exceptions import GeneratorError
from ..common.name_mgr import GlobalVarNameMgr
from ..constant import NEW_LINE, SECOND_LEVEL_INDENT, FIRST_LEVEL_INDENT, CodeFormatConfig, get_imported_module
from ..report_generator import ReportGenerator
from mindinsight.mindconverter.common.exceptions import GeneratorError
from mindinsight.mindconverter.graph_based_converter.generator.scope_utils import Scope
from mindinsight.mindconverter.graph_based_converter.generator.node_struct import NodeStruct
from mindinsight.mindconverter.graph_based_converter.generator.module_struct import ModuleStruct
from mindinsight.mindconverter.graph_based_converter.generator.args_translator import ArgsTranslationHelper
from mindinsight.mindconverter.graph_based_converter.common.global_context import GlobalContext
from mindinsight.mindconverter.graph_based_converter.common.outputs import BaseOutput, ModuleOutputManager
from mindinsight.mindconverter.graph_based_converter.common.yapf_config import mindspore_yapf_config
from mindinsight.mindconverter.graph_based_converter.common.name_mgr import GlobalVarNameMgr
from mindinsight.mindconverter.graph_based_converter.constant import NEW_LINE, SECOND_LEVEL_INDENT, FIRST_LEVEL_INDENT, \
get_imported_module
from mindinsight.mindconverter.graph_based_converter.report_generator import ReportGenerator


class CodeStruct:
@@ -35,7 +37,6 @@ class CodeStruct:
Define the Code template for each module generated in the final output.
Each module has only one CodeStruct to its pattern.
"""
GLOBAL_CONTEXT = GlobalContext()
NOT_IN_SCOPE_OPT = dict()

def __init__(self, struct, repeated_submodules=None):
@@ -104,7 +105,7 @@ class CodeStruct:

elif isinstance(struct, ModuleStruct):
# check if this instance generated CodeStruct
if self.GLOBAL_CONTEXT.code_structs.get(struct.pattern_id) is None:
if GlobalContext().code_structs.get(struct.pattern_id) is None:
CodeStruct(struct, repeated_submodules)

code_line_init = struct.code_line_in_init()
@@ -138,10 +139,10 @@ class CodeStruct:
returns = list(set(returns))
else:
returns = [code_line_construct[0]] if isinstance(code_line_construct, tuple) \
else [code_line_construct[-1].replace(' ', '').split('=')[0]]
else [code_line_construct[-1].replace(' ', '').split('=')[0]]
self.new_line = f"{SECOND_LEVEL_INDENT}return {', '.join(returns)}"
self.new_line = f"{NEW_LINE * 2}"
self.GLOBAL_CONTEXT.code_structs[md_struct.pattern_id] = self
GlobalContext().code_structs[md_struct.pattern_id] = self


class Generator:
@@ -482,7 +483,7 @@ class Generator:
outputs.append(line)

formatted_code, _ = FormatCode("\n".join(outputs),
style_config=CodeFormatConfig.PEP8.value)
style_config=mindspore_yapf_config())

report_generator = ReportGenerator()
report = report_generator.gen_report(formatted_code)
@@ -589,7 +590,7 @@ class Generator:
output_obj.idx_in_ms_user[nd_struct.identifier] = idx

# set this output to be returned to external
output_obj.to_external = not(nd_struct.check_target_node_internal(
output_obj.to_external = not (nd_struct.check_target_node_internal(
self._global_context.outputs_storage.onnx_name(inp)
))



+ 7
- 13
mindinsight/mindconverter/graph_based_converter/generator/module_struct.py View File

@@ -31,11 +31,10 @@ class ModuleStruct:
Define a module struct which stores all info. to generate statement.

Args:
args (list): A list of node structs.
nd_struct_list (list): A list of node structs.
init_as_parent (bool): Control init method if the ModuleStruct be init as a parent module struct.
parent_base (ModuleStruct): The base ModuleStruct the current ModuleStruct to be init as.
"""
GLOBAL_CONTEXT_MGR = GlobalContext()

def __init__(self, nd_struct_list, init_as_parent=False, parent_base=None):
"""Init. a module by NodeStructs."""
@@ -247,7 +246,7 @@ class ModuleStruct:
self._module_structs += md_structs
tail_md = md_structs[-1]
else:
raise TypeError("ModuleStruct cannot add an unsupport Type {} to module_structs list.".format(
raise TypeError("ModuleStruct cannot add an unsupported Type {} to module_structs list.".format(
type(md_structs)))
# update tail node and index
if self.tail_nd_struct_index < tail_md.tail_nd_struct_index:
@@ -318,12 +317,7 @@ class ModuleStruct:
return ret

def code_line_in_init(self):
"""
Initialization line of code in module init block.

Args:
override_formal_val (dict): Indicate which args should be renamed for passing value from upper level.
"""
"""Initialization line of code in module init block."""
left = "self.{}".format(self.ms_var_name)
args_list = list()
# Load args in init statement.
@@ -338,7 +332,7 @@ class ModuleStruct:
else:
args_list += self._fragment.actual_args
right = f"{self.class_name}({', '.join(args_list)})"
return (left, right)
return left, right

def code_line_in_construct(self, inputs=None):
"""Construct line of code in module construct block."""
@@ -356,7 +350,7 @@ class ModuleStruct:
if isinstance(inputs, str):
inputs = [inputs]
right = f"self.{self.ms_var_name}({', '.join(inputs)})"
return (left, right)
return left, right

@property
def node_structs(self):
@@ -463,8 +457,8 @@ class ModuleStruct:
"""Return the class name for generating code of this module."""
if self.pattern_id == -1:
return "Model"
if self.GLOBAL_CONTEXT_MGR.known_module_name.get("Module{}".format(self.pattern_id)) is not None:
class_name = self.GLOBAL_CONTEXT_MGR.known_module_name.get("Module{}".format(self.pattern_id))
if GlobalContext().known_module_name.get("Module{}".format(self.pattern_id)) is not None:
class_name = GlobalContext().known_module_name.get("Module{}".format(self.pattern_id))
else:
class_name = "Module{}".format(self.pattern_id)
return class_name


+ 16
- 16
mindinsight/mindconverter/graph_based_converter/generator/node_struct.py View File

@@ -23,6 +23,7 @@ from ..third_party_graph.onnx_graph_node import OnnxGraphNode
from ..common.global_context import GlobalContext
from ...common.exceptions import GeneratorError


class NodeStruct:
"""
Define a node struct which stores all info. to generate statement.
@@ -34,10 +35,10 @@ class NodeStruct:
You can pass as many args as possible and the Node Struct will update
by arguments order.
"""
GLOBAL_CONTEXT_MGR = GlobalContext()

def __init__(self, args):
# define attributes here
self.global_context_mgr = GlobalContext()
self._identifier = None
self._fragment = None
self._args_translator = None
@@ -74,7 +75,7 @@ class NodeStruct:
"""Get the original topological index in the onnx graph."""
ori_name = self._fragment.metadata.get('source')
self.onnx_name = ori_name
return self.GLOBAL_CONTEXT_MGR.onnx_node_name_to_topo_idx.get(ori_name)
return GlobalContext().onnx_node_name_to_topo_idx.get(ori_name)

def update_var_name(self, idx=None):
"""
@@ -83,6 +84,7 @@ class NodeStruct:
Args:
idx (int): The index of the node in this module.
"""

def _remove_op_header(op_name):
"""Remove op header which indicating their sources of op set."""
op_name = op_name.replace('nn.', '')
@@ -112,7 +114,7 @@ class NodeStruct:
self._fragment = FragmentHandler(frag)

if self.ms_op:
idx = self.GLOBAL_CONTEXT_MGR.latest_node_struct_count
idx = GlobalContext().latest_node_struct_count
self.update_var_name(idx=idx)

def _set_scope_from_identifier(self):
@@ -142,9 +144,7 @@ class NodeStruct:

Args:
arg (Union[PyTorchGraphNode, OnnxGraphNode, dict]): Node related obj.
force_ready (bool): Force this NodeStruct is ready to generate.
"""

if isinstance(arg, OnnxGraphNode):
self._update_from_onnx_gn(arg)
elif isinstance(arg, NewFragment):
@@ -168,7 +168,7 @@ class NodeStruct:
self._identifier = s
self._set_scope_from_identifier()
self.topo_idx = self.ori_topo_idx()
self.GLOBAL_CONTEXT_MGR.onnx_node_name_to_node_struct_map[self.onnx_name] = self
GlobalContext().onnx_node_name_to_node_struct_map[self.onnx_name] = self

@property
def fragment(self):
@@ -181,7 +181,7 @@ class NodeStruct:
Set the Node fragment.

Args:
s (NodeFragment): The node identifier string.
frag (NodeFragment): The node identifier string.
"""
self._fragment = frag

@@ -198,7 +198,7 @@ class NodeStruct:
@property
def onnx_node(self):
"""Return the original onnx node reference."""
return self.GLOBAL_CONTEXT_MGR.onnx_nodes_collection.get(self.onnx_name)
return GlobalContext().onnx_nodes_collection.get(self.onnx_name)

@property
def ms_op(self):
@@ -241,7 +241,7 @@ class NodeStruct:
ret = []
precursor_nodes_names = self.precursor_nodes_names
for pre_node_name in precursor_nodes_names:
nd_struct = self.GLOBAL_CONTEXT_MGR.onnx_node_name_to_node_struct_map.get(pre_node_name)
nd_struct = GlobalContext().onnx_node_name_to_node_struct_map.get(pre_node_name)
ret.append(nd_struct)
return ret

@@ -255,7 +255,7 @@ class NodeStruct:
"""Return the node struct instances of successor nodes."""
ret = []
for pre_node_name in self.successor_nodes_names:
nd_struct = self.GLOBAL_CONTEXT_MGR.onnx_node_name_to_node_struct_map.get(pre_node_name)
nd_struct = GlobalContext().onnx_node_name_to_node_struct_map.get(pre_node_name)
ret.append(nd_struct)
return ret

@@ -312,11 +312,11 @@ class NodeStruct:
inputs = self.matched_inputs

# Check original onnx node's input to ensure double inputs are not ignored
original_inputs = self.GLOBAL_CONTEXT_MGR.onnx_node_inputs.get(self.onnx_name)
original_inputs = GlobalContext().onnx_node_inputs.get(self.onnx_name)
new_inputs = []
for idx, prec_node in enumerate(self.precursor_nodes_names):
occurence = original_inputs.count(prec_node)
for _ in range(occurence):
occurrence = original_inputs.count(prec_node)
for _ in range(occurrence):
new_inputs.append(inputs[idx])
inputs = new_inputs

@@ -360,12 +360,12 @@ class NodeStruct:
Args:
name (str): Can accept both node identifier or original onnx node name.
"""
target_nd_struct = self.GLOBAL_CONTEXT_MGR.node_struct_collections.get(name) \
or self.GLOBAL_CONTEXT_MGR.onnx_node_name_to_node_struct_map.get(name)
target_nd_struct = GlobalContext().node_struct_collections.get(name) \
or GlobalContext().onnx_node_name_to_node_struct_map.get(name)
if target_nd_struct is None and self.topo_idx == 0: # First node always has external input
return False

if target_nd_struct is None and (name in self.GLOBAL_CONTEXT_MGR.onnx_graph_info.get('graph_inputs')):
if target_nd_struct is None and (name in GlobalContext().onnx_graph_info.get('graph_inputs')):
return False

if target_nd_struct is None:


+ 1
- 1
mindinsight/mindconverter/graph_based_converter/sub_graph_searcher/pattern.py View File

@@ -34,7 +34,7 @@ class Pattern:
# If pattern in BUILD_IN_MODULE_NAME or BUILD_IN_PATTERN,
# the pattern will get additional score.
self.additional_score = 0
self.know_module_name = None
self.known_module_name = None

def insert(self, idx, seq_len):
"""


Loading…
Cancel
Save