Browse Source

!1132 Add trainable weights transformer module in MindConverter

From: @moran3
Reviewed-by: 
Signed-off-by:
tags/v1.2.0-rc1
mindspore-ci-bot Gitee 4 years ago
parent
commit
22a1b9e652
37 changed files with 478 additions and 175 deletions
  1. +31
    -1
      mindinsight/mindconverter/common/exceptions.py
  2. +3
    -3
      mindinsight/mindconverter/graph_based_converter/__init__.py
  3. +6
    -1
      mindinsight/mindconverter/graph_based_converter/common/code_fragment.py
  4. +1
    -1
      mindinsight/mindconverter/graph_based_converter/common/global_context.py
  5. +33
    -3
      mindinsight/mindconverter/graph_based_converter/common/utils.py
  6. +1
    -0
      mindinsight/mindconverter/graph_based_converter/constant.py
  7. +48
    -30
      mindinsight/mindconverter/graph_based_converter/framework.py
  8. +4
    -4
      mindinsight/mindconverter/graph_based_converter/generator/__init__.py
  9. +74
    -2
      mindinsight/mindconverter/graph_based_converter/generator/generator.py
  10. +7
    -7
      mindinsight/mindconverter/graph_based_converter/generator/module_struct.py
  11. +5
    -5
      mindinsight/mindconverter/graph_based_converter/generator/node_struct.py
  12. +2
    -2
      mindinsight/mindconverter/graph_based_converter/mapper/__init__.py
  13. +14
    -7
      mindinsight/mindconverter/graph_based_converter/mapper/base.py
  14. +11
    -6
      mindinsight/mindconverter/graph_based_converter/mapper/impl/nn/batch_norm_mapper.py
  15. +19
    -20
      mindinsight/mindconverter/graph_based_converter/mapper/impl/nn/conv_mapper.py
  16. +7
    -6
      mindinsight/mindconverter/graph_based_converter/mapper/impl/nn/dense_mapper.py
  17. +4
    -3
      mindinsight/mindconverter/graph_based_converter/mapper/impl/nn/mat_mul_mapper.py
  18. +1
    -6
      mindinsight/mindconverter/graph_based_converter/mapper/impl/nn/pad_mapper.py
  19. +103
    -10
      mindinsight/mindconverter/graph_based_converter/mapper/impl/nn/pool_mapper.py
  20. +1
    -2
      mindinsight/mindconverter/graph_based_converter/mapper/impl/ops/add_mapper.py
  21. +1
    -2
      mindinsight/mindconverter/graph_based_converter/mapper/impl/ops/mul_mapper.py
  22. +1
    -1
      mindinsight/mindconverter/graph_based_converter/mapper/impl/ops/resize_mapper.py
  23. +1
    -1
      mindinsight/mindconverter/graph_based_converter/mapper/impl/ops/slice_mapper.py
  24. +2
    -2
      mindinsight/mindconverter/graph_based_converter/sub_graph_searcher/__init__.py
  25. +3
    -3
      mindinsight/mindconverter/graph_based_converter/sub_graph_searcher/built_in_pattern.py
  26. +10
    -7
      mindinsight/mindconverter/graph_based_converter/sub_graph_searcher/search_path.py
  27. +9
    -6
      mindinsight/mindconverter/graph_based_converter/sub_graph_searcher/searcher.py
  28. +1
    -1
      mindinsight/mindconverter/graph_based_converter/third_party_graph/__init__.py
  29. +3
    -4
      mindinsight/mindconverter/graph_based_converter/third_party_graph/input_node.py
  30. +9
    -9
      mindinsight/mindconverter/graph_based_converter/third_party_graph/onnx_graph.py
  31. +6
    -7
      mindinsight/mindconverter/graph_based_converter/third_party_graph/onnx_graph_node.py
  32. +20
    -7
      mindinsight/mindconverter/graph_based_converter/third_party_graph/onnx_utils.py
  33. +1
    -1
      mindinsight/mindconverter/graph_based_converter/third_party_graph/optimizer.py
  34. +1
    -0
      mindinsight/mindconverter/graph_based_converter/third_party_graph/pytorch_graph_parser.py
  35. +3
    -3
      mindinsight/mindconverter/graph_based_converter/third_party_graph/tf_graph_parser.py
  36. +6
    -2
      tests/ut/mindconverter/graph_based_converter/mapper/test_mapper.py
  37. +26
    -0
      tests/utils/mindspore/train/serialization.py

+ 31
- 1
mindinsight/mindconverter/common/exceptions.py View File

@@ -1,4 +1,4 @@
# Copyright 2019 Huawei Technologies Co., Ltd
# Copyright 2020-2021 Huawei Technologies Co., Ltd
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@@ -301,6 +301,8 @@ class SourceFilesSaveError(MindConverterException):
NODE_INPUT_TYPE_NOT_SUPPORT = 1 NODE_INPUT_TYPE_NOT_SUPPORT = 1
SCRIPT_GENERATE_FAIL = 2 SCRIPT_GENERATE_FAIL = 2
REPORT_GENERATE_FAIL = 3 REPORT_GENERATE_FAIL = 3
CKPT_GENERATE_FAIL = 4
MAP_GENERATE_FAIL = 5


BASE_ERROR_CODE = ConverterErrors.SOURCE_FILES_SAVE_FAIL.value BASE_ERROR_CODE = ConverterErrors.SOURCE_FILES_SAVE_FAIL.value
ERROR_CODE = ErrCode.UNKNOWN_ERROR.value ERROR_CODE = ErrCode.UNKNOWN_ERROR.value
@@ -315,6 +317,8 @@ class SourceFilesSaveError(MindConverterException):
except_source = (NodeInputTypeNotSupportError, except_source = (NodeInputTypeNotSupportError,
ScriptGenerationError, ScriptGenerationError,
ReportGenerationError, ReportGenerationError,
CheckPointGenerationError,
WeightMapGenerationError,
IOError, cls) IOError, cls)
return except_source return except_source


@@ -437,6 +441,32 @@ class ReportGenerationError(SourceFilesSaveError):
return ZeroDivisionError, cls return ZeroDivisionError, cls




class CheckPointGenerationError(SourceFilesSaveError):
"""The checkpoint generate fail error."""
ERROR_CODE = SourceFilesSaveError.ErrCode.CKPT_GENERATE_FAIL.value

def __init__(self, msg):
super(CheckPointGenerationError, self).__init__(msg=msg)

@classmethod
def raise_from(cls):
"""Raise from exceptions below."""
return cls


class WeightMapGenerationError(SourceFilesSaveError):
"""The weight names map generate fail error."""
ERROR_CODE = SourceFilesSaveError.ErrCode.MAP_GENERATE_FAIL.value

def __init__(self, msg):
super(WeightMapGenerationError, self).__init__(msg=msg)

@classmethod
def raise_from(cls):
"""Raise from exception below."""
return cls


class SubGraphSearchingError(MindConverterException): class SubGraphSearchingError(MindConverterException):
"""Sub-graph searching exception.""" """Sub-graph searching exception."""




+ 3
- 3
mindinsight/mindconverter/graph_based_converter/__init__.py View File

@@ -1,4 +1,4 @@
# Copyright 2020 Huawei Technologies Co., Ltd.All Rights Reserved.
# Copyright 2020-2021 Huawei Technologies Co., Ltd.All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@@ -15,5 +15,5 @@
"""Graph based scripts converter definition.""" """Graph based scripts converter definition."""
__all__ = ["graph_based_converter_pytorch_to_ms", "graph_based_converter_tf_to_ms"] __all__ = ["graph_based_converter_pytorch_to_ms", "graph_based_converter_tf_to_ms"]


from .framework import graph_based_converter_pytorch_to_ms
from .framework import graph_based_converter_tf_to_ms
from mindinsight.mindconverter.graph_based_converter.framework import graph_based_converter_pytorch_to_ms
from mindinsight.mindconverter.graph_based_converter.framework import graph_based_converter_tf_to_ms

+ 6
- 1
mindinsight/mindconverter/graph_based_converter/common/code_fragment.py View File

@@ -191,18 +191,23 @@ class CodeFragment(Fragment):
""" """


def __init__(self, operation, actual_args, settings, input_shape, output_shape, def __init__(self, operation, actual_args, settings, input_shape, output_shape,
trainable_params=None):
trainable_params=None, trainable_weights=None):
super(CodeFragment, self).__init__(operation=operation, actual_args=actual_args, super(CodeFragment, self).__init__(operation=operation, actual_args=actual_args,
input_shape=input_shape, output_shape=output_shape, input_shape=input_shape, output_shape=output_shape,
settings=settings) settings=settings)
self._trainable_params = dict() # External weights, like Matmul. self._trainable_params = dict() # External weights, like Matmul.
self._init_trainable_params = trainable_params # Can put into operation init method, like Conv2d. self._init_trainable_params = trainable_params # Can put into operation init method, like Conv2d.
self._trainable_weights = trainable_weights


@property @property
def trainable_params(self): def trainable_params(self):
"""Return the trainable parameters.""" """Return the trainable parameters."""
return self._trainable_params return self._trainable_params


@property
def trainable_weights(self):
return self._trainable_weights



class ModuleFragment(Fragment): class ModuleFragment(Fragment):
"""Manage module type code variables.""" """Manage module type code variables."""


+ 1
- 1
mindinsight/mindconverter/graph_based_converter/common/global_context.py View File

@@ -14,7 +14,7 @@
# ============================================================================== # ==============================================================================
"""Define GlobalContext class to save required resources during whole conversion procedure.""" """Define GlobalContext class to save required resources during whole conversion procedure."""
from collections import OrderedDict from collections import OrderedDict
from .outputs import OutputStorage
from mindinsight.mindconverter.graph_based_converter.common.outputs import OutputStorage




class Singleton(type): class Singleton(type):


+ 33
- 3
mindinsight/mindconverter/graph_based_converter/common/utils.py View File

@@ -13,16 +13,21 @@
# limitations under the License. # limitations under the License.
# ============================================================================ # ============================================================================
"""Define common utils.""" """Define common utils."""
import json
import os import os
import stat import stat
from importlib import import_module from importlib import import_module
from importlib.util import find_spec
from typing import List, Tuple, Mapping from typing import List, Tuple, Mapping


from mindinsight.mindconverter.common.exceptions import ScriptGenerationError, ReportGenerationError, UnknownModelError
from mindinsight.mindconverter.common.exceptions import ScriptGenerationError, ReportGenerationError, \
UnknownModelError, CheckPointGenerationError, WeightMapGenerationError
from mindinsight.mindconverter.common.log import logger as log from mindinsight.mindconverter.common.log import logger as log
from mindinsight.mindconverter.graph_based_converter.constant import SEPARATOR_IN_ONNX_OP, BINARY_HEADER_PYTORCH_BITS, \ from mindinsight.mindconverter.graph_based_converter.constant import SEPARATOR_IN_ONNX_OP, BINARY_HEADER_PYTORCH_BITS, \
FrameworkType, BINARY_HEADER_PYTORCH_FILE, TENSORFLOW_MODEL_SUFFIX FrameworkType, BINARY_HEADER_PYTORCH_FILE, TENSORFLOW_MODEL_SUFFIX


from mindspore.train.serialization import save_checkpoint



def is_converted(operation: str): def is_converted(operation: str):
""" """
@@ -96,7 +101,6 @@ def save_code_file_and_report(model_name: str, code_lines: Mapping[str, Tuple],
code_lines (dict): Code lines. code_lines (dict): Code lines.
out_folder (str): Output folder. out_folder (str): Output folder.
report_folder (str): Report output folder. report_folder (str): Report output folder.

""" """
flags = os.O_WRONLY | os.O_CREAT | os.O_EXCL flags = os.O_WRONLY | os.O_CREAT | os.O_EXCL
modes = stat.S_IRUSR | stat.S_IWUSR modes = stat.S_IRUSR | stat.S_IWUSR
@@ -114,7 +118,7 @@ def save_code_file_and_report(model_name: str, code_lines: Mapping[str, Tuple],
os.makedirs(report_folder, modes_usr) os.makedirs(report_folder, modes_usr)


for file_name in code_lines: for file_name in code_lines:
code, report = code_lines[file_name]
code, report, trainable_weights, weight_map = code_lines[file_name]
code_file_path = os.path.realpath(os.path.join(out_folder, f"{model_name}.py")) code_file_path = os.path.realpath(os.path.join(out_folder, f"{model_name}.py"))
report_file_path = os.path.realpath(os.path.join(report_folder, f"report_of_{model_name}.txt")) report_file_path = os.path.realpath(os.path.join(report_folder, f"report_of_{model_name}.txt"))
try: try:
@@ -133,6 +137,31 @@ def save_code_file_and_report(model_name: str, code_lines: Mapping[str, Tuple],
except (IOError, FileExistsError) as error: except (IOError, FileExistsError) as error:
raise ReportGenerationError(str(error)) raise ReportGenerationError(str(error))


ckpt_file_path = os.path.realpath(os.path.join(out_folder, f"{model_name}.ckpt"))
try:
if os.path.exists(ckpt_file_path):
raise CheckPointGenerationError("Checkpoint file with the same name already exists.")
save_checkpoint(trainable_weights, ckpt_file_path)
except TypeError as error:
raise CheckPointGenerationError(str(error))

weight_map_path = os.path.realpath(os.path.join(out_folder, f"weight_map_of_{model_name}.json"))
try:
if os.path.exists(weight_map_path):
raise WeightMapGenerationError("Weight map file with the same name already exists.")
with os.fdopen(os.open(weight_map_path, flags, stat.S_IRUSR), 'w') as map_f:
weight_map_json = {f"{model_name}": weight_map}
json.dump(weight_map_json, map_f)
except (IOError, FileExistsError) as error:
raise WeightMapGenerationError(str(error))


def onnx_satisfied():
"""Validate ONNX , ONNXRUNTIME, ONNXOPTIMIZER installation."""
if not find_spec("onnx") or not find_spec("onnxruntime") or not find_spec("onnxoptimizer"):
return False
return True



def lib_version_satisfied(current_ver: str, mini_ver_limited: str, def lib_version_satisfied(current_ver: str, mini_ver_limited: str,
newest_ver_limited: str = ""): newest_ver_limited: str = ""):
@@ -220,6 +249,7 @@ def reset_init_or_construct(template, variable_slot, new_data, scope):
template[variable_slot][scope] += new_data template[variable_slot][scope] += new_data
return template return template



def replace_string_in_list(str_list: list, original_str: str, target_str: str): def replace_string_in_list(str_list: list, original_str: str, target_str: str):
""" """
Replace a string in a list by provided string. Replace a string in a list by provided string.


+ 1
- 0
mindinsight/mindconverter/graph_based_converter/constant.py View File

@@ -41,6 +41,7 @@ UNKNOWN_DIM_VAL = "unk__001"
ONNX_MIN_VER = "1.8.0" ONNX_MIN_VER = "1.8.0"
TF2ONNX_MIN_VER = "1.7.1" TF2ONNX_MIN_VER = "1.7.1"
ONNXRUNTIME_MIN_VER = "1.5.2" ONNXRUNTIME_MIN_VER = "1.5.2"
ONNXOPTIMIZER_MIN_VER = "0.1.2"




@unique @unique


+ 48
- 30
mindinsight/mindconverter/graph_based_converter/framework.py View File

@@ -21,10 +21,10 @@ from importlib.util import find_spec


import mindinsight import mindinsight
from mindinsight.mindconverter.graph_based_converter.common.global_context import GlobalContext from mindinsight.mindconverter.graph_based_converter.common.global_context import GlobalContext
from mindinsight.mindconverter.graph_based_converter.common.utils import lib_version_satisfied, \
from mindinsight.mindconverter.graph_based_converter.common.utils import lib_version_satisfied, onnx_satisfied, \
save_code_file_and_report, get_framework_type save_code_file_and_report, get_framework_type
from mindinsight.mindconverter.graph_based_converter.constant import FrameworkType, \ from mindinsight.mindconverter.graph_based_converter.constant import FrameworkType, \
ONNX_MIN_VER, TF2ONNX_MIN_VER, ONNXRUNTIME_MIN_VER
ONNX_MIN_VER, TF2ONNX_MIN_VER, ONNXRUNTIME_MIN_VER, ONNXOPTIMIZER_MIN_VER
from mindinsight.mindconverter.graph_based_converter.generator import batch_add_nodes from mindinsight.mindconverter.graph_based_converter.generator import batch_add_nodes
from mindinsight.mindconverter.graph_based_converter.mapper import ONNXToMindSporeMapper from mindinsight.mindconverter.graph_based_converter.mapper import ONNXToMindSporeMapper
from mindinsight.mindconverter.common.log import logger as log, logger_console as log_console from mindinsight.mindconverter.common.log import logger as log, logger_console as log_console
@@ -53,6 +53,18 @@ parser.add_argument("--report", type=str, required=False,
help="Generated reports output folder path.") help="Generated reports output folder path.")




def onnx_lib_version_satisfied():
"""Check onnx libs version whether is satisfied."""
onnx = import_module("onnx")
ort = import_module("onnxruntime")
optimizer = import_module("onnxoptimizer.version")
if not lib_version_satisfied(getattr(onnx, "__version__"), ONNX_MIN_VER) \
or not lib_version_satisfied(getattr(ort, "__version__"), ONNXRUNTIME_MIN_VER) \
or not lib_version_satisfied(getattr(optimizer, "version"), ONNXOPTIMIZER_MIN_VER):
return False
return True


def torch_installation_validation(func): def torch_installation_validation(func):
""" """
Validate args of func. Validate args of func.
@@ -68,26 +80,33 @@ def torch_installation_validation(func):
input_nodes: str, output_nodes: str, input_nodes: str, output_nodes: str,
output_folder: str, report_folder: str = None): output_folder: str, report_folder: str = None):
# Check whether pytorch is installed. # Check whether pytorch is installed.
if not find_spec("torch") or not find_spec("onnx") or not find_spec("onnxruntime"):
error = RuntimeIntegrityError(f"PyTorch, onnx(>={ONNX_MIN_VER}) and "
f"onnxruntime(>={ONNXRUNTIME_MIN_VER}) "
f"are required when using graph based "
f"scripts converter, and PyTorch version must "
f"be consisted with model generation runtime.")
error_info = None
if graph_path.endswith('.onnx'):
if not onnx_satisfied():
error_info = f"onnx(>={ONNX_MIN_VER}, onnxruntime(>={ONNXRUNTIME_MIN_VER}) and " \
f"onnxoptimizer(>={ONNXOPTIMIZER_MIN_VER}) " \
f"are required when using graph based scripts converter."
else:
if not find_spec("torch") or not onnx_satisfied():
error_info = f"PyTorch, onnx(>={ONNX_MIN_VER}), " \
f"onnxruntime(>={ONNXRUNTIME_MIN_VER}) and " \
f"onnxoptimizer(>={ONNXOPTIMIZER_MIN_VER}) " \
f"are required when using graph based " \
f"scripts converter, and PyTorch version must " \
f"be consisted with model generation runtime."
if error_info:
error = RuntimeIntegrityError(error_info)
log.error(error) log.error(error)
log_console.error("\n") log_console.error("\n")
log_console.error(str(error)) log_console.error(str(error))
log_console.error("\n") log_console.error("\n")
sys.exit(0) sys.exit(0)


onnx = import_module("onnx")
ort = import_module("onnxruntime")

if not lib_version_satisfied(getattr(onnx, "__version__"), ONNX_MIN_VER) \
or not lib_version_satisfied(getattr(ort, "__version__"), ONNXRUNTIME_MIN_VER):
if not onnx_lib_version_satisfied():
error = RuntimeIntegrityError( error = RuntimeIntegrityError(
f"onnx(>={ONNX_MIN_VER}) and "
f"onnxruntime(>={ONNXRUNTIME_MIN_VER}) are required when using graph "
f"onnx(>={ONNX_MIN_VER}), "
f"onnxruntime(>={ONNXRUNTIME_MIN_VER}) and "
f"onnxoptimizer(>={ONNXOPTIMIZER_MIN_VER}) are required when using graph "
f"based scripts converter for Pytorch conversion." f"based scripts converter for Pytorch conversion."
) )
log.error(error) log.error(error)
@@ -128,11 +147,11 @@ def tf_installation_validation(func):
output_folder: str, report_folder: str = None, output_folder: str, report_folder: str = None,
input_nodes: str = None, output_nodes: str = None): input_nodes: str = None, output_nodes: str = None):
# Check whether tensorflow is installed. # Check whether tensorflow is installed.
if not _check_tf_installation() or not find_spec("tf2onnx") \
or not find_spec("onnx") or not find_spec("onnxruntime"):
if not _check_tf_installation() or not onnx_satisfied():
error = RuntimeIntegrityError( error = RuntimeIntegrityError(
f"TensorFlow, tf2onnx(>={TF2ONNX_MIN_VER}), onnx(>={ONNX_MIN_VER}) and "
f"onnxruntime(>={ONNXRUNTIME_MIN_VER}) are required when using graph "
f"TensorFlow, tf2onnx(>={TF2ONNX_MIN_VER}), onnx(>={ONNX_MIN_VER}), "
f"onnxruntime(>={ONNXRUNTIME_MIN_VER}) and onnxoptimizer(>={ONNXOPTIMIZER_MIN_VER}) "
f"are required when using graph "
f"based scripts converter for TensorFlow conversion." f"based scripts converter for TensorFlow conversion."
) )
log.error(error) log.error(error)
@@ -141,15 +160,14 @@ def tf_installation_validation(func):
log_console.error("\n") log_console.error("\n")
sys.exit(0) sys.exit(0)


onnx, tf2onnx = import_module("onnx"), import_module("tf2onnx")
ort = import_module("onnxruntime")
tf2onnx = import_module("tf2onnx")


if not lib_version_satisfied(getattr(onnx, "__version__"), ONNX_MIN_VER) \
or not lib_version_satisfied(getattr(ort, "__version__"), ONNXRUNTIME_MIN_VER) \
or not lib_version_satisfied(getattr(tf2onnx, "__version__"), TF2ONNX_MIN_VER):
if not lib_version_satisfied(getattr(tf2onnx, "__version__"), TF2ONNX_MIN_VER) \
or not onnx_lib_version_satisfied():
error = RuntimeIntegrityError( error = RuntimeIntegrityError(
f"TensorFlow, tf2onnx(>={TF2ONNX_MIN_VER}), onnx(>={ONNX_MIN_VER}) and "
f"onnxruntime(>={ONNXRUNTIME_MIN_VER}) are required when using graph "
f"TensorFlow, tf2onnx(>={TF2ONNX_MIN_VER}), onnx(>={ONNX_MIN_VER}), "
f"onnxruntime(>={ONNXRUNTIME_MIN_VER}) and onnxoptimizer(>={ONNXOPTIMIZER_MIN_VER}) "
f"are required when using graph "
f"based scripts converter for TensorFlow conversion." f"based scripts converter for TensorFlow conversion."
) )
log.error(error) log.error(error)
@@ -258,12 +276,12 @@ def main_graph_base_converter(file_config):
raise ParamMissingError("Param missing, `--shape` is required when using graph mode.") raise ParamMissingError("Param missing, `--shape` is required when using graph mode.")


if frame_type == FrameworkType.PYTORCH.value: if frame_type == FrameworkType.PYTORCH.value:
check_params = ['input_nodes', 'output_nodes']
check_params_exist(check_params, file_config)
graph_based_converter_pytorch_to_ms(graph_path=graph_path, graph_based_converter_pytorch_to_ms(graph_path=graph_path,
sample_shape=file_config['shape'], sample_shape=file_config['shape'],
input_nodes=file_config['input_nodes'],
output_nodes=file_config['output_nodes'],
input_nodes=file_config['input_nodes'] if file_config['input_nodes']
else 'input.1',
output_nodes=file_config['output_nodes'] if file_config['output_nodes']
else '',
output_folder=file_config['outfile_dir'], output_folder=file_config['outfile_dir'],
report_folder=file_config['report_dir']) report_folder=file_config['report_dir'])
elif frame_type == FrameworkType.TENSORFLOW.value: elif frame_type == FrameworkType.TENSORFLOW.value:


+ 4
- 4
mindinsight/mindconverter/graph_based_converter/generator/__init__.py View File

@@ -18,10 +18,10 @@ __all__ = ["batch_add_nodes"]
import re import re
import copy import copy


from .generator import Generator, CodeStruct
from ..common.code_fragment import CodeFragment, NewFragment
from ..common.outputs import NodeOutputManager
from ..constant import ExchangeMessageKeywords
from mindinsight.mindconverter.graph_based_converter.generator.generator import Generator, CodeStruct
from mindinsight.mindconverter.graph_based_converter.common.code_fragment import NewFragment
from mindinsight.mindconverter.graph_based_converter.common.outputs import NodeOutputManager
from mindinsight.mindconverter.graph_based_converter.constant import ExchangeMessageKeywords




def _tf_model_node_name_reformat(node, node_name): def _tf_model_node_name_reformat(node, node_name):


+ 74
- 2
mindinsight/mindconverter/graph_based_converter/generator/generator.py View File

@@ -16,6 +16,7 @@
import copy import copy
from collections import OrderedDict from collections import OrderedDict


from mindspore import Tensor
from yapf.yapflib.yapf_api import FormatCode from yapf.yapflib.yapf_api import FormatCode


from mindinsight.mindconverter.common.exceptions import GeneratorError from mindinsight.mindconverter.common.exceptions import GeneratorError
@@ -28,7 +29,7 @@ from mindinsight.mindconverter.graph_based_converter.common.outputs import BaseO
from mindinsight.mindconverter.graph_based_converter.common.yapf_config import mindspore_yapf_config from mindinsight.mindconverter.graph_based_converter.common.yapf_config import mindspore_yapf_config
from mindinsight.mindconverter.graph_based_converter.common.name_mgr import GlobalVarNameMgr from mindinsight.mindconverter.graph_based_converter.common.name_mgr import GlobalVarNameMgr
from mindinsight.mindconverter.graph_based_converter.constant import NEW_LINE, SECOND_LEVEL_INDENT, \ from mindinsight.mindconverter.graph_based_converter.constant import NEW_LINE, SECOND_LEVEL_INDENT, \
FIRST_LEVEL_INDENT, get_imported_module
FIRST_LEVEL_INDENT, get_imported_module, SEPARATOR_BTW_NAME_AND_ID
from mindinsight.mindconverter.graph_based_converter.report_generator import ReportGenerator from mindinsight.mindconverter.graph_based_converter.report_generator import ReportGenerator
from mindinsight.mindconverter.graph_based_converter.common.utils import replace_string_in_list from mindinsight.mindconverter.graph_based_converter.common.utils import replace_string_in_list


@@ -469,6 +470,74 @@ class Generator:
"""Return all ModuleStructs in this model.""" """Return all ModuleStructs in this model."""
return self._module_struct_collections return self._module_struct_collections


def generate_weight_scope_name(self, node):
"""Generate weight scope name for checkpoint."""
replaced_module_dict = self.node_structs[node].global_context_mgr.known_module_name
scope_list = self.node_structs[node].scope.scope_list
ms_var_name = self.node_structs[node].ms_var_name

weight_scope_name = None
for scope in scope_list[1:]:
replaced_module = replaced_module_dict.get(scope.split(SEPARATOR_BTW_NAME_AND_ID)[0])
if replaced_module:
scope = scope.replace(scope.split(SEPARATOR_BTW_NAME_AND_ID)[0], replaced_module)
if not weight_scope_name:
weight_scope_name = scope
else:
weight_scope_name = '.'.join((weight_scope_name, scope))

if not weight_scope_name:
weight_scope_name = ms_var_name
else:
weight_scope_name = '.'.join((weight_scope_name, ms_var_name))

return weight_scope_name.lower()

def generate_checkpoint(self):
"""Generate checkpoint."""

trainable_weights_dict = dict()
weight_map = list()
for node_name, node_inst in self.node_structs.items():
if node_inst.fragment.exchange_msg['var_0']['trainable_params']:
weights_scope_name = self.generate_weight_scope_name(node_name)
onnx_weight_inst = node_inst.fragment.exchange_msg['var_0']['weights']
for idx, (weight_key, weight_value) in \
enumerate(node_inst.fragment.exchange_msg['var_0']['trainable_params'].items()):
weight_name = '.'.join((weights_scope_name, weight_key))
weight_shape = Tensor(weight_value).shape
data_type = Tensor(weight_value).dtype
trainable_weights_dict[weight_name] = weight_value

onnx_weight_name = onnx_weight_inst[idx].name
onnx_weight_shape = onnx_weight_inst[idx].value.shape
onnx_data_type = onnx_weight_inst[idx].value.dtype

weight_map.append(
{
'converted_weight': {
'name': weight_name,
'shape': weight_shape,
'data_type': str(data_type)
},
'source_weight': {
'name': onnx_weight_name,
'shape': onnx_weight_shape,
'data_type': str(onnx_data_type)
}
}
)

save_obj = list()
for weight_name, weight_value in trainable_weights_dict.items():
obj = {
'name': weight_name,
'data': Tensor(weight_value)
}
save_obj.append(obj)

return save_obj, weight_map

@GeneratorError.check_except("Generator occurs an error when generating code statements.") @GeneratorError.check_except("Generator occurs an error when generating code statements.")
def generate(self): def generate(self):
""" """
@@ -479,6 +548,9 @@ class Generator:
""" """
self._form_bottom_submodule() self._form_bottom_submodule()
self._recursive_form_module() self._recursive_form_module()

ckpt_data_list, weight_map = self.generate_checkpoint()

CodeStruct(self.module_structs.get('[]'), self._repeated_submodules) CodeStruct(self.module_structs.get('[]'), self._repeated_submodules)


outputs = [get_imported_module()] outputs = [get_imported_module()]
@@ -494,7 +566,7 @@ class Generator:
report = report_generator.gen_report(formatted_code) report = report_generator.gen_report(formatted_code)
del self._global_context del self._global_context


return {"model": (formatted_code, report)}
return {"model": (formatted_code, report, ckpt_data_list, weight_map)}


def get_node_struct(self, node_identifier): def get_node_struct(self, node_identifier):
""" """


+ 7
- 7
mindinsight/mindconverter/graph_based_converter/generator/module_struct.py View File

@@ -17,13 +17,13 @@
import copy import copy
from collections import OrderedDict from collections import OrderedDict


from .node_struct import NodeStruct
from .scope_utils import Scope
from ..common.utils import get_dict_key_by_value
from .args_translator import ArgsTranslation
from ..common.code_fragment import ModuleFragment
from ..common.global_context import GlobalContext
from ..common.name_mgr import LocalVarNameMgr
from mindinsight.mindconverter.graph_based_converter.generator.node_struct import NodeStruct
from mindinsight.mindconverter.graph_based_converter.generator.scope_utils import Scope
from mindinsight.mindconverter.graph_based_converter.common.utils import get_dict_key_by_value
from mindinsight.mindconverter.graph_based_converter.generator.args_translator import ArgsTranslation
from mindinsight.mindconverter.graph_based_converter.common.code_fragment import ModuleFragment
from mindinsight.mindconverter.graph_based_converter.common.global_context import GlobalContext
from mindinsight.mindconverter.graph_based_converter.common.name_mgr import LocalVarNameMgr




class ModuleStruct: class ModuleStruct:


+ 5
- 5
mindinsight/mindconverter/graph_based_converter/generator/node_struct.py View File

@@ -17,11 +17,11 @@ from collections import OrderedDict


from mindinsight.mindconverter.graph_based_converter.common.code_fragment import NewFragment from mindinsight.mindconverter.graph_based_converter.common.code_fragment import NewFragment
from mindinsight.mindconverter.graph_based_converter.generator.fragment_utils import FragmentHandler from mindinsight.mindconverter.graph_based_converter.generator.fragment_utils import FragmentHandler
from .scope_utils import Scope
from .args_translator import ArgsTranslation
from ..third_party_graph.onnx_graph_node import OnnxGraphNode
from ..common.global_context import GlobalContext
from ...common.exceptions import GeneratorError
from mindinsight.mindconverter.graph_based_converter.generator.scope_utils import Scope
from mindinsight.mindconverter.graph_based_converter.generator.args_translator import ArgsTranslation
from mindinsight.mindconverter.graph_based_converter.third_party_graph.onnx_graph_node import OnnxGraphNode
from mindinsight.mindconverter.graph_based_converter.common.global_context import GlobalContext
from mindinsight.mindconverter.common.exceptions import GeneratorError




class NodeStruct: class NodeStruct:


+ 2
- 2
mindinsight/mindconverter/graph_based_converter/mapper/__init__.py View File

@@ -1,4 +1,4 @@
# Copyright 2020 Huawei Technologies Co., Ltd.All Rights Reserved.
# Copyright 2020-2021 Huawei Technologies Co., Ltd.All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@@ -16,4 +16,4 @@


__all__ = ["ONNXToMindSporeMapper"] __all__ = ["ONNXToMindSporeMapper"]


from .base import ONNXToMindSporeMapper
from mindinsight.mindconverter.graph_based_converter.mapper.base import ONNXToMindSporeMapper

+ 14
- 7
mindinsight/mindconverter/graph_based_converter/mapper/base.py View File

@@ -108,18 +108,21 @@ class ONNXToMindSporeMapper(Mapper, abc.ABC):
try: try:
converter_name = op_name_converter(params=params, weights=weights, op_name=op_name) converter_name = op_name_converter(params=params, weights=weights, op_name=op_name)
converted_params = params_converter(params=params, weights=weights) converted_params = params_converter(params=params, weights=weights)

if "input_shape" in converted_params: if "input_shape" in converted_params:
converted_params.pop("input_shape") converted_params.pop("input_shape")
if "output_shape" in converted_params: if "output_shape" in converted_params:
converted_params.pop("output_shape") converted_params.pop("output_shape")
# set to converted_weights to enable weight migration # set to converted_weights to enable weight migration
_ = weights_converter(weights=weights) if weights else dict()
converted_weights = weights_converter(weights=weights) if weights else dict()
code_template, exchange_msg, outputs_list, outputs_mapping = template_generator( code_template, exchange_msg, outputs_list, outputs_mapping = template_generator(
operation=converter_name, operation=converter_name,
converted_params=converted_params, converted_params=converted_params,
raw_params=params, raw_params=params,
weights=weights
weights=weights,
trainable_params=converted_weights
) )

except (AttributeError, KeyError, ValueError, TypeError, IndexError) as e: except (AttributeError, KeyError, ValueError, TypeError, IndexError) as e:
err_msg = f"Converting {op_name} failed, see {str(e)}" err_msg = f"Converting {op_name} failed, see {str(e)}"
log.error(err_msg) log.error(err_msg)
@@ -148,6 +151,7 @@ class ONNXToMindSporeMapper(Mapper, abc.ABC):
op = kwargs.get("operation") op = kwargs.get("operation")
args = kwargs.get("converted_params", dict()) args = kwargs.get("converted_params", dict())
weights = kwargs.get("weights") weights = kwargs.get("weights")
trainable_params = kwargs.get("trainable_params", dict())
if not op: if not op:
raise ValueError("Can not get MindSpore operation name.") raise ValueError("Can not get MindSpore operation name.")
variable_slot = "var_0" variable_slot = "var_0"
@@ -169,7 +173,7 @@ class ONNXToMindSporeMapper(Mapper, abc.ABC):
ExchangeMessageKeywords.VariableScope.value.INPUTS.value: [], ExchangeMessageKeywords.VariableScope.value.INPUTS.value: [],
ExchangeMessageKeywords.VariableScope.value.ARGS.value: args, ExchangeMessageKeywords.VariableScope.value.ARGS.value: args,
ExchangeMessageKeywords.VariableScope.value.WEIGHTS.value: weights, ExchangeMessageKeywords.VariableScope.value.WEIGHTS.value: weights,
ExchangeMessageKeywords.VariableScope.value.TRAINABLE_PARAMS.value: {}
ExchangeMessageKeywords.VariableScope.value.TRAINABLE_PARAMS.value: trainable_params
} }
} }
outputs_list = [f"opt_{{{variable_slot}}}"] outputs_list = [f"opt_{{{variable_slot}}}"]
@@ -177,11 +181,14 @@ class ONNXToMindSporeMapper(Mapper, abc.ABC):
return template, exchange_msg, outputs_list, outputs_mapping return template, exchange_msg, outputs_list, outputs_mapping


@staticmethod @staticmethod
def _find_val_by_index(loc_index, values_dict):
"""Find value by location index of values_dict."""
def _find_val_by_index(loc_index, weights_list):
"""Find value by location index of weights_list."""
result = None result = None
for idx, dict_val in enumerate(values_dict.values()):
if loc_index < 0:
return weights_list[loc_index].value

for idx, weight in enumerate(weights_list):
if idx == loc_index: if idx == loc_index:
result = dict_val
result = weight.value
break break
return result return result

+ 11
- 6
mindinsight/mindconverter/graph_based_converter/mapper/impl/nn/batch_norm_mapper.py View File

@@ -14,7 +14,6 @@
# ============================================================================== # ==============================================================================
"""Mapper module.""" """Mapper module."""
from mindinsight.mindconverter.graph_based_converter.mapper.base import ONNXToMindSporeMapper from mindinsight.mindconverter.graph_based_converter.mapper.base import ONNXToMindSporeMapper
from mindinsight.mindconverter.graph_based_converter.mapper.gen_setting import Setting




class BatchNormMapper(ONNXToMindSporeMapper): class BatchNormMapper(ONNXToMindSporeMapper):
@@ -36,8 +35,14 @@ class BatchNormMapper(ONNXToMindSporeMapper):


@staticmethod @staticmethod
def _convert_trained_weights(**kwargs): def _convert_trained_weights(**kwargs):
return dict()

@staticmethod
def _convert_settings(**kwargs):
return Setting()
weights = kwargs['weights']
gamma = BatchNormMapper._find_val_by_index(0, weights)
beta = BatchNormMapper._find_val_by_index(1, weights)
moving_mean = BatchNormMapper._find_val_by_index(2, weights)
moving_variance = BatchNormMapper._find_val_by_index(3, weights)
return {
'gamma': gamma,
'beta': beta,
'moving_mean': moving_mean,
'moving_variance': moving_variance
}

+ 19
- 20
mindinsight/mindconverter/graph_based_converter/mapper/impl/nn/conv_mapper.py View File

@@ -14,7 +14,6 @@
# ============================================================================== # ==============================================================================
"""Mapper module.""" """Mapper module."""
import numpy as np import numpy as np
from mindinsight.mindconverter.graph_based_converter.mapper.gen_setting import Setting
from mindinsight.mindconverter.graph_based_converter.mapper.base import ONNXToMindSporeMapper from mindinsight.mindconverter.graph_based_converter.mapper.base import ONNXToMindSporeMapper
from mindinsight.mindconverter.graph_based_converter.common.utils import convert_bytes_string_to_string from mindinsight.mindconverter.graph_based_converter.common.utils import convert_bytes_string_to_string


@@ -42,7 +41,7 @@ class ConvMapper(ONNXToMindSporeMapper):
"""Convert params from PyTorch to MindSpore""" """Convert params from PyTorch to MindSpore"""
weights = kwargs['weights'] weights = kwargs['weights']
params = kwargs['params'] params = kwargs['params']
weight = weights['weight']
weight = ConvMapper._find_val_by_index(0, weights)
weight = np.transpose(weight, list(range(2, weight.ndim)) + [1, 0]) weight = np.transpose(weight, list(range(2, weight.ndim)) + [1, 0])
if isinstance(params['dilations'], list): if isinstance(params['dilations'], list):
dilation = tuple(params['dilations']) dilation = tuple(params['dilations'])
@@ -76,11 +75,13 @@ class ConvMapper(ONNXToMindSporeMapper):
"""Convert params from Tensorflow to MindSpore""" """Convert params from Tensorflow to MindSpore"""
weights = kwargs['weights'] weights = kwargs['weights']
params = kwargs['params'] params = kwargs['params']
# regex to find Conv weight
weight = list(weights.values())[0]
weight = ConvMapper._find_val_by_index(0, weights)
bias = ConvMapper._find_val_by_index(1, weights)
if weight is None: if weight is None:
raise ValueError("Conv. Mapper cannot get the weight.") raise ValueError("Conv. Mapper cannot get the weight.")


has_bias = isinstance(bias, np.ndarray)

auto_pad = None auto_pad = None
if params.get("auto_pad") is not None: if params.get("auto_pad") is not None:
auto_pad = convert_bytes_string_to_string(params.get("auto_pad")) auto_pad = convert_bytes_string_to_string(params.get("auto_pad"))
@@ -119,18 +120,14 @@ class ConvMapper(ONNXToMindSporeMapper):
'padding': padding, 'padding': padding,
'pad_mode': pad_mode, 'pad_mode': pad_mode,
'dilation': dilation, 'dilation': dilation,
'group': params.get('group', 1)}
'group': params.get('group', 1),
'has_bias': has_bias
}


@staticmethod @staticmethod
def _operation_name_in_ms(*args, **kwargs): def _operation_name_in_ms(*args, **kwargs):
weight = kwargs['weights'].get('weight', 'empty')

if weight == 'empty': # is from tf
kernel_size = kwargs['params'].get('kernel_shape')
dim = len(kernel_size)
return f"nn.Conv{dim}d"

dim = weight.ndim - 2
kernel_size = kwargs['params'].get('kernel_shape')
dim = len(kernel_size)
return f"nn.Conv{dim}d" return f"nn.Conv{dim}d"


@staticmethod @staticmethod
@@ -138,14 +135,16 @@ class ConvMapper(ONNXToMindSporeMapper):
weights = kwargs['weights'] weights = kwargs['weights']
params = kwargs['params'] params = kwargs['params']


if weights.get('weight', 'empty') == 'empty': # is from tf
return ConvMapper.convert_params_tf(params=params, weights=weights)
return ConvMapper.convert_params_torch(params=params, weights=weights)
return ConvMapper.convert_params_tf(params=params, weights=weights)


@staticmethod @staticmethod
def _convert_trained_weights(**kwargs): def _convert_trained_weights(**kwargs):
return dict()
weights = kwargs['weights']
weight = ConvMapper._find_val_by_index(0, weights)
bias = ConvMapper._find_val_by_index(1, weights)


@staticmethod
def _convert_settings(**kwargs):
return Setting()
converted_weights = {'weight': weight}
if isinstance(bias, np.ndarray):
converted_weights['bias'] = bias

return converted_weights

+ 7
- 6
mindinsight/mindconverter/graph_based_converter/mapper/impl/nn/dense_mapper.py View File

@@ -15,7 +15,6 @@
"""Mapper module.""" """Mapper module."""
import numpy as np import numpy as np
from mindinsight.mindconverter.graph_based_converter.mapper.base import ONNXToMindSporeMapper from mindinsight.mindconverter.graph_based_converter.mapper.base import ONNXToMindSporeMapper
from mindinsight.mindconverter.graph_based_converter.mapper.gen_setting import Setting




class DenseMapper(ONNXToMindSporeMapper): class DenseMapper(ONNXToMindSporeMapper):
@@ -42,8 +41,10 @@ class DenseMapper(ONNXToMindSporeMapper):


@staticmethod @staticmethod
def _convert_trained_weights(**kwargs): def _convert_trained_weights(**kwargs):
return dict()

@staticmethod
def _convert_settings(**kwargs):
return Setting()
weights = kwargs['weights']
weight = DenseMapper._find_val_by_index(0, weights)
bias = DenseMapper._find_val_by_index(1, weights)
return {
'weight': weight,
'bias': bias
}

+ 4
- 3
mindinsight/mindconverter/graph_based_converter/mapper/impl/nn/mat_mul_mapper.py View File

@@ -30,7 +30,9 @@ class MatMulMapper(ONNXToMindSporeMapper):


@staticmethod @staticmethod
def _convert_trained_weights(**kwargs): def _convert_trained_weights(**kwargs):
return dict()
weights = kwargs['weights']
weight = MatMulMapper._find_val_by_index(0, weights)
return {'weight': weight}


@staticmethod @staticmethod
def _generate_snippet_template(**kwargs): def _generate_snippet_template(**kwargs):
@@ -44,8 +46,7 @@ class MatMulMapper(ONNXToMindSporeMapper):
if not weights: if not weights:
return template, exchange_msg, outputs_list, outputs_mapping return template, exchange_msg, outputs_list, outputs_mapping


weight = list(weights.items())[0]
_, tensor = weight
tensor = MatMulMapper._find_val_by_index(0, weights)


variable_slot = "var_0" variable_slot = "var_0"
init_template = f"self.{{{variable_slot}}} = {op}({', '.join(['%s={%s}' % (p, p) for p in args])})" init_template = f"self.{{{variable_slot}}} = {op}({', '.join(['%s={%s}' % (p, p) for p in args])})"


+ 1
- 6
mindinsight/mindconverter/graph_based_converter/mapper/impl/nn/pad_mapper.py View File

@@ -15,7 +15,6 @@
"""Mapper module.""" """Mapper module."""
from mindinsight.mindconverter.graph_based_converter.common.utils import convert_bytes_string_to_string from mindinsight.mindconverter.graph_based_converter.common.utils import convert_bytes_string_to_string
from mindinsight.mindconverter.graph_based_converter.mapper.base import ONNXToMindSporeMapper from mindinsight.mindconverter.graph_based_converter.mapper.base import ONNXToMindSporeMapper
from mindinsight.mindconverter.graph_based_converter.mapper.gen_setting import Setting




def _padding_format_convert(padding: list): def _padding_format_convert(padding: list):
@@ -49,7 +48,7 @@ class PadMapper(ONNXToMindSporeMapper):
weights = kwargs.get("weights") weights = kwargs.get("weights")
params = kwargs.get("params") params = kwargs.get("params")
mode = convert_bytes_string_to_string(params.get('mode', 'constant')) mode = convert_bytes_string_to_string(params.get('mode', 'constant'))
pads_onnx = params.get("pads") if params.get("pads") else list(weights.values())[0].tolist()
pads_onnx = params.get("pads") if params.get("pads") else PadMapper._find_val_by_index(0, weights).tolist()
if mode == 'constant' and params.get('value') is None: if mode == 'constant' and params.get('value') is None:
if params.get('pads') or weights: if params.get('pads') or weights:
if isinstance(pads_onnx, list): if isinstance(pads_onnx, list):
@@ -76,7 +75,3 @@ class PadMapper(ONNXToMindSporeMapper):
@staticmethod @staticmethod
def _convert_trained_weights(**kwargs): def _convert_trained_weights(**kwargs):
return dict() return dict()

@staticmethod
def _convert_settings(**kwargs):
return Setting()

+ 103
- 10
mindinsight/mindconverter/graph_based_converter/mapper/impl/nn/pool_mapper.py View File

@@ -13,12 +13,16 @@
# limitations under the License. # limitations under the License.
# ============================================================================== # ==============================================================================
"""Mapper module.""" """Mapper module."""
import math

import numpy as np

from mindinsight.mindconverter.graph_based_converter.mapper.base import ONNXToMindSporeMapper from mindinsight.mindconverter.graph_based_converter.mapper.base import ONNXToMindSporeMapper
from mindinsight.mindconverter.graph_based_converter.mapper.gen_setting import Setting
from mindinsight.mindconverter.graph_based_converter.constant import ExchangeMessageKeywords, TemplateKeywords




class PoolMapper(ONNXToMindSporeMapper): class PoolMapper(ONNXToMindSporeMapper):
"""MaxPool mapper."""
"""Pool mapper."""


@staticmethod @staticmethod
def _operation_name_in_ms(*args, **kwargs): def _operation_name_in_ms(*args, **kwargs):
@@ -35,12 +39,6 @@ class PoolMapper(ONNXToMindSporeMapper):
transformed_params = dict() transformed_params = dict()
transformed_params["kernel_size"] = tuple(params['kernel_shape']) transformed_params["kernel_size"] = tuple(params['kernel_shape'])
transformed_params["stride"] = tuple(params['strides']) transformed_params["stride"] = tuple(params['strides'])
if "pads" in params:
if sum(params['pads']) == 0 and not params.get('ceil_mode', None):
pad_mode = '\"valid\"'
else:
pad_mode = '\"same\"'
transformed_params["pad_mode"] = pad_mode


return transformed_params return transformed_params


@@ -49,5 +47,100 @@ class PoolMapper(ONNXToMindSporeMapper):
return dict() return dict()


@staticmethod @staticmethod
def _convert_settings(**kwargs):
return Setting()
def _get_ms_opt_shape(**kwargs):
"""Get output shape in MindSpore."""
params = kwargs['raw_params']
input_shape = params['input_shape']
kernel_shape = params['kernel_shape']
strides = params['strides']
dilations = params.get('dilations', (1, 1))
# For mindspore,
# output_shape[i] = ceil((input_shape[i] - ((kernel_shape[i] - 1) * dilations[i] + 1) + 1) / strides[i])
ms_opt_shape = np.true_divide(np.subtract(np.array(input_shape[-len(kernel_shape):], dtype=np.float32),
((np.array(kernel_shape, dtype=np.float32) - 1) *
np.array(dilations, dtype=np.float32) + 1)) + 1,
np.array(strides, dtype=np.float32)).tolist()
ms_opt_shape_ceil = tuple(math.ceil(ms_opt_shape_axis) for ms_opt_shape_axis in ms_opt_shape)
return ms_opt_shape_ceil

@staticmethod
def _generate_snippet_template(**kwargs):
template, exchange_msg, outputs_list, outputs_mapping = ONNXToMindSporeMapper._generate_snippet_template(
**kwargs)
op = kwargs.get("operation")
args = kwargs.get("converted_params", dict())

ms_opt_shape = PoolMapper._get_ms_opt_shape(**kwargs)
tensor_opt_shape = kwargs['raw_params']['output_shape']
tensor_ipt_shape = kwargs['raw_params']['input_shape']
kernel_shape = kwargs['raw_params']['kernel_shape']
dilations = kwargs['raw_params'].get('dilations', (1, 1))
strides = kwargs['raw_params']['strides']
onnx_opt_shape = tensor_opt_shape[-len(ms_opt_shape):]

if np.all(np.array(ms_opt_shape) == np.array(onnx_opt_shape)):
return template, exchange_msg, outputs_list, outputs_mapping

variable_slot = "var_0"
init_template = f"self.{{{variable_slot}}} = {op}({', '.join(['%s={%s}' % (p, p) for p in args])})"
construct_template = f"opt_{{{variable_slot}}} = self.{{{variable_slot}}}(opt_{{{variable_slot}}})"

init_template_pad, construct_template_pad, paddings = \
PoolMapper._generate_pad_init_and_construct(tensor_opt_shape, tensor_ipt_shape,
ms_opt_shape, variable_slot,
kernel_shape, dilations, strides)

template = {
variable_slot: {
TemplateKeywords.INIT.value: [init_template_pad, init_template],
TemplateKeywords.CONSTRUCT.value: [construct_template_pad, construct_template]
}
}

args['paddings'] = paddings

exchange_msg = {
variable_slot: {
ExchangeMessageKeywords.VariableScope.value.OPERATION.value: op,
ExchangeMessageKeywords.VariableScope.value.VARIABLE_NAME.value: None,
ExchangeMessageKeywords.VariableScope.value.OUTPUT_TYPE.value:
ExchangeMessageKeywords.VariableScope.value.TSR_TYPE.value,
ExchangeMessageKeywords.VariableScope.value.INPUTS.value: [],
ExchangeMessageKeywords.VariableScope.value.ARGS.value: args,
ExchangeMessageKeywords.VariableScope.value.WEIGHTS.value: dict(),
ExchangeMessageKeywords.VariableScope.value.TRAINABLE_PARAMS.value: dict()
}
}

return template, exchange_msg, outputs_list, outputs_mapping

@staticmethod
def _generate_pad_init_and_construct(tensor_opt_shape, tensor_ipt_shape,
ms_opt_shape, variable_slot, kernel_shape, dilations, strides):
"""Generate pad code in init and construct."""
onnx_opt_shape = tensor_opt_shape[-len(ms_opt_shape):]
onnx_ipt_shape = tensor_ipt_shape[-len(ms_opt_shape):]

if np.any(np.array(ms_opt_shape) > np.array(onnx_opt_shape)):
raise ValueError(f"ms_opt_shape[{ms_opt_shape}] should be no larger than onnx_opt_shape[{onnx_opt_shape}].")

# shape_diff[i] = (onnx_opt_shape[i] - 1)*strides[i] -
# (onnx_ipt_shape[i] - ((kernel_shape[i] - 1)*dilations[i] + 1))
shape_diff = np.subtract((np.array(onnx_opt_shape) - 1)*np.array(strides),
np.subtract(np.array(onnx_ipt_shape),
(np.array(kernel_shape) - 1)*np.array(dilations) + 1)).tolist()

zero_pad_single = (0, 0)
paddings = [zero_pad_single]
num_zero_pads = len(tensor_opt_shape) - len(ms_opt_shape)
for _ in range(num_zero_pads - 1):
paddings.append(zero_pad_single)

for axis_diff in shape_diff:
paddings.append((int(axis_diff//2), int(axis_diff//2 + axis_diff % 2)))

init_template_pad = f"self.pad_{{{variable_slot}}} = nn.Pad(paddings={{paddings}})"
construct_template_pad = f"opt_{{{variable_slot}}} = self.pad_{{{variable_slot}}}" \
f"({{{ExchangeMessageKeywords.VariableScope.value.INPUTS.value}}})"

return init_template_pad, construct_template_pad, tuple(paddings)

+ 1
- 2
mindinsight/mindconverter/graph_based_converter/mapper/impl/ops/add_mapper.py View File

@@ -44,8 +44,7 @@ class AddMapper(ONNXToMindSporeMapper):
if not weights: if not weights:
return template, exchange_msg, outputs_list, outputs_mapping return template, exchange_msg, outputs_list, outputs_mapping


bias = list(weights.items())[0]
_, tensor = bias
tensor = AddMapper._find_val_by_index(0, weights)


variable_slot = "var_0" variable_slot = "var_0"
init_template = f"self.{{{variable_slot}}} = {op}({', '.join(['%s={%s}' % (p, p) for p in args])})" init_template = f"self.{{{variable_slot}}} = {op}({', '.join(['%s={%s}' % (p, p) for p in args])})"


+ 1
- 2
mindinsight/mindconverter/graph_based_converter/mapper/impl/ops/mul_mapper.py View File

@@ -53,8 +53,7 @@ class MulMapper(ONNXToMindSporeMapper):
if not weights: if not weights:
return template, exchange_msg, outputs_list, outputs_mapping return template, exchange_msg, outputs_list, outputs_mapping


weight = list(weights.items())[0]
_, tensor = weight
tensor = MulMapper._find_val_by_index(0, weights)


variable_slot = "var_0" variable_slot = "var_0"
init_template = f"self.{{{variable_slot}}} = {op}({', '.join(['%s={%s}' % (p, p) for p in args])})" init_template = f"self.{{{variable_slot}}} = {op}({', '.join(['%s={%s}' % (p, p) for p in args])})"


+ 1
- 1
mindinsight/mindconverter/graph_based_converter/mapper/impl/ops/resize_mapper.py View File

@@ -60,7 +60,7 @@ class ResizeMapper(ONNXToMindSporeMapper):
align_corners = True align_corners = True


# Get requested size for resize # Get requested size for resize
size = list(weights.values())[-1][-2:].tolist()
size = ResizeMapper._find_val_by_index(-1, weights)[-2:].tolist()


return {"size": tuple(size), return {"size": tuple(size),
"align_corners": align_corners} "align_corners": align_corners}


+ 1
- 1
mindinsight/mindconverter/graph_based_converter/mapper/impl/ops/slice_mapper.py View File

@@ -48,7 +48,7 @@ class SliceMapper(ONNXToMindSporeMapper):
def _generate_snippet_template(**kwargs): def _generate_snippet_template(**kwargs):
template, exchange_msg, outputs_list, outputs_mapping = ONNXToMindSporeMapper._generate_snippet_template( template, exchange_msg, outputs_list, outputs_mapping = ONNXToMindSporeMapper._generate_snippet_template(
**kwargs) **kwargs)
weights = list(kwargs.get("weights").values()) # start, end, axis
weights = [weight.value for weight in kwargs.get('weights')] # start, end, axis
opt_shape = kwargs["raw_params"]["output_shape"] opt_shape = kwargs["raw_params"]["output_shape"]
if not weights: if not weights:
raise ValueError("Cannot get required params from slice.") raise ValueError("Cannot get required params from slice.")


+ 2
- 2
mindinsight/mindconverter/graph_based_converter/sub_graph_searcher/__init__.py View File

@@ -1,4 +1,4 @@
# Copyright 2020 Huawei Technologies Co., Ltd.All Rights Reserved.
# Copyright 2020-2021 Huawei Technologies Co., Ltd.All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@@ -15,4 +15,4 @@
"""Searcher of scope name.""" """Searcher of scope name."""
__all__ = ["generate_scope_name"] __all__ = ["generate_scope_name"]


from .searcher import generate_scope_name
from mindinsight.mindconverter.graph_based_converter.sub_graph_searcher.searcher import generate_scope_name

+ 3
- 3
mindinsight/mindconverter/graph_based_converter/sub_graph_searcher/built_in_pattern.py View File

@@ -1,4 +1,4 @@
# Copyright 2020 Huawei Technologies Co., Ltd.All Rights Reserved.
# Copyright 2020-2021 Huawei Technologies Co., Ltd.All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@@ -16,8 +16,8 @@


__all__ = ["BUILT_IN_PATTERN", "register_pattern", "is_built_in_pattern"] __all__ = ["BUILT_IN_PATTERN", "register_pattern", "is_built_in_pattern"]


from .common import cal_matching_score
from .pattern import Pattern
from mindinsight.mindconverter.graph_based_converter.sub_graph_searcher.common import cal_matching_score
from mindinsight.mindconverter.graph_based_converter.sub_graph_searcher.pattern import Pattern


BUILT_IN_PATTERN = dict() BUILT_IN_PATTERN = dict()




+ 10
- 7
mindinsight/mindconverter/graph_based_converter/sub_graph_searcher/search_path.py View File

@@ -1,4 +1,4 @@
# Copyright 2020 Huawei Technologies Co., Ltd.All Rights Reserved.
# Copyright 2020-2021 Huawei Technologies Co., Ltd.All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@@ -17,12 +17,15 @@ import copy
import uuid import uuid
from typing import Dict, List, Callable, Union from typing import Dict, List, Callable, Union
from collections import OrderedDict from collections import OrderedDict
from .common import context, gen_hash_key, DagGraph, MAX_OUT_DEGREE, cal_matching_score
from .known_module_name import BUILT_IN_MODULE_NAME
from .pattern import Pattern, scope_name_mapping
from .built_in_pattern import BUILT_IN_PATTERN, is_built_in_pattern
from .pattern_fuzzy_matching import pattern_fuzzy_matching
from ..third_party_graph.onnx_utils import OnnxNode, BaseNode
from mindinsight.mindconverter.graph_based_converter.sub_graph_searcher.common import context, gen_hash_key, DagGraph, \
MAX_OUT_DEGREE, cal_matching_score
from mindinsight.mindconverter.graph_based_converter.sub_graph_searcher.known_module_name import BUILT_IN_MODULE_NAME
from mindinsight.mindconverter.graph_based_converter.sub_graph_searcher.pattern import Pattern, scope_name_mapping
from mindinsight.mindconverter.graph_based_converter.sub_graph_searcher.built_in_pattern import BUILT_IN_PATTERN, \
is_built_in_pattern
from mindinsight.mindconverter.graph_based_converter.sub_graph_searcher.pattern_fuzzy_matching import \
pattern_fuzzy_matching
from mindinsight.mindconverter.graph_based_converter.third_party_graph.onnx_utils import OnnxNode, BaseNode


module_name_to_src = {} module_name_to_src = {}
used_module_name = dict() used_module_name = dict()


+ 9
- 6
mindinsight/mindconverter/graph_based_converter/sub_graph_searcher/searcher.py View File

@@ -16,12 +16,15 @@
from queue import PriorityQueue from queue import PriorityQueue
from typing import Dict, List from typing import Dict, List


from .common import context, DagGraph, gen_hash_key, ACCEPTABLE_RESULT_COUNT
from .common import MINI_FREQUENCY, MAX_ITERATION_DEPTH, SATISFIED_SCORE
from ..common.global_context import GlobalContext
from ..third_party_graph.onnx_utils import BaseNode
from .search_path import SearchPath, Pattern, generate_pattern, find_built_in_pattern
from ...common.exceptions import SubGraphSearchingError
from mindinsight.mindconverter.graph_based_converter.sub_graph_searcher.common import context, DagGraph, gen_hash_key, \
ACCEPTABLE_RESULT_COUNT
from mindinsight.mindconverter.graph_based_converter.sub_graph_searcher.common import MINI_FREQUENCY, \
MAX_ITERATION_DEPTH, SATISFIED_SCORE
from mindinsight.mindconverter.graph_based_converter.common.global_context import GlobalContext
from mindinsight.mindconverter.graph_based_converter.third_party_graph.onnx_utils import BaseNode
from mindinsight.mindconverter.graph_based_converter.sub_graph_searcher.search_path import SearchPath, Pattern, \
generate_pattern, find_built_in_pattern
from mindinsight.mindconverter.common.exceptions import SubGraphSearchingError




def _is_satisfied(path): def _is_satisfied(path):


+ 1
- 1
mindinsight/mindconverter/graph_based_converter/third_party_graph/__init__.py View File

@@ -17,7 +17,7 @@
__all__ = ["GraphFactory"] __all__ = ["GraphFactory"]
from importlib import import_module from importlib import import_module


from .base import Graph
from mindinsight.mindconverter.graph_based_converter.third_party_graph.base import Graph




class GraphFactory: class GraphFactory:


+ 3
- 4
mindinsight/mindconverter/graph_based_converter/third_party_graph/input_node.py View File

@@ -1,4 +1,4 @@
# Copyright 2020 Huawei Technologies Co., Ltd.All Rights Reserved.
# Copyright 2020-2021 Huawei Technologies Co., Ltd.All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@@ -15,8 +15,8 @@
"""Define PyTorch graph node.""" """Define PyTorch graph node."""
import os import os


from .base import GraphNode
from ..constant import SEPARATOR_IN_SCOPE, NodeType
from mindinsight.mindconverter.graph_based_converter.third_party_graph.base import GraphNode
from mindinsight.mindconverter.graph_based_converter.constant import SEPARATOR_IN_SCOPE, NodeType




class InputNode(GraphNode): class InputNode(GraphNode):
@@ -25,7 +25,6 @@ class InputNode(GraphNode):


Args: Args:
input_shape: Input shape of module. input_shape: Input shape of module.

""" """


def _get_arg_name(self, arg, variable_name): def _get_arg_name(self, arg, variable_name):


+ 9
- 9
mindinsight/mindconverter/graph_based_converter/third_party_graph/onnx_graph.py View File

@@ -17,12 +17,12 @@ from importlib import import_module
from typing import Dict, NoReturn from typing import Dict, NoReturn


from mindinsight.mindconverter.common.log import logger as log from mindinsight.mindconverter.common.log import logger as log
from .base import Graph
from .input_node import InputNode
from .onnx_graph_node import OnnxGraphNode
from .pytorch_graph_parser import PyTorchGraphParser
from .tf_graph_parser import TFGraphParser
from .onnx_utils import OnnxDataLoader
from mindinsight.mindconverter.graph_based_converter.third_party_graph.base import Graph
from mindinsight.mindconverter.graph_based_converter.third_party_graph.input_node import InputNode
from mindinsight.mindconverter.graph_based_converter.third_party_graph.onnx_graph_node import OnnxGraphNode
from mindinsight.mindconverter.graph_based_converter.third_party_graph.pytorch_graph_parser import PyTorchGraphParser
from mindinsight.mindconverter.graph_based_converter.third_party_graph.tf_graph_parser import TFGraphParser
from mindinsight.mindconverter.graph_based_converter.third_party_graph.onnx_utils import OnnxDataLoader, NodeWeight


NONE_SCOPE_OP = { NONE_SCOPE_OP = {
"onnx::Add": "Add", "onnx::Add": "Add",
@@ -126,7 +126,7 @@ class OnnxGraph(Graph):


self._shape_dict = model_data.node_output_shape_dict self._shape_dict = model_data.node_output_shape_dict
for ind, (node_name, node) in enumerate(model_data.nodes_dict.items()): for ind, (node_name, node) in enumerate(model_data.nodes_dict.items()):
node_weight = {}
node_weights = list()
node.scope_name = scope_name_list[ind] node.scope_name = scope_name_list[ind]
inputs = node.input_name_list inputs = node.input_name_list
# check each input from node or tensors # check each input from node or tensors
@@ -135,8 +135,8 @@ class OnnxGraph(Graph):
tensor = model_data.tensors_dict[i] tensor = model_data.tensors_dict[i]
t_name = tensor.name t_name = tensor.name
t_value = tensor.to_array() t_value = tensor.to_array()
node_weight[t_name] = t_value
self._nodes_collection[node_name] = OnnxGraphNode(node, node_weight)
node_weights.append(NodeWeight(t_name, t_value))
self._nodes_collection[node_name] = OnnxGraphNode(node, node_weights)
self._nodes_record[node_name] = node_name self._nodes_record[node_name] = node_name


for nd_ipt_name in node.precursor_onnx_node_dict: for nd_ipt_name in node.precursor_onnx_node_dict:


+ 6
- 7
mindinsight/mindconverter/graph_based_converter/third_party_graph/onnx_graph_node.py View File

@@ -1,4 +1,4 @@
# Copyright 2020 Huawei Technologies Co., Ltd.All Rights Reserved.
# Copyright 2020-2021 Huawei Technologies Co., Ltd.All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@@ -15,11 +15,11 @@
"""Define ONNX graph node.""" """Define ONNX graph node."""
from importlib import import_module from importlib import import_module


from .base import GraphNode
from ..common.utils import is_converted
from mindinsight.mindconverter.graph_based_converter.third_party_graph.base import GraphNode
from mindinsight.mindconverter.graph_based_converter.common.utils import is_converted


from ..constant import NodeType, SEPARATOR_IN_SCOPE, SEPARATOR_BTW_NAME_AND_ID, LEFT_BUCKET, RIGHT_BUCKET, \
SEPARATOR_IN_ONNX_OP
from mindinsight.mindconverter.graph_based_converter.constant import NodeType, SEPARATOR_IN_SCOPE, \
SEPARATOR_BTW_NAME_AND_ID, LEFT_BUCKET, RIGHT_BUCKET, SEPARATOR_IN_ONNX_OP




class OnnxGraphNode(GraphNode): class OnnxGraphNode(GraphNode):
@@ -28,7 +28,7 @@ class OnnxGraphNode(GraphNode):


Args: Args:
node (OnnxNode): OnnxNode Object. node (OnnxNode): OnnxNode Object.
weight (dict): Dictionary records weight and bias.
weight (list): List of recording node weights.
""" """
_type_frozen = False _type_frozen = False
_module_name_frozen = False _module_name_frozen = False
@@ -227,7 +227,6 @@ class OnnxGraphNode(GraphNode):
Args: Args:
src_arg (str): Original arg name. src_arg (str): Original arg name.
tgt_arg (str): Target arg name. tgt_arg (str): Target arg name.

""" """
self._args_in_code[src_arg] = tgt_arg self._args_in_code[src_arg] = tgt_arg




+ 20
- 7
mindinsight/mindconverter/graph_based_converter/third_party_graph/onnx_utils.py View File

@@ -22,13 +22,13 @@ from typing import Union
import numpy as np import numpy as np


from mindinsight.mindconverter.common.log import logger as log from mindinsight.mindconverter.common.log import logger as log
from ..common.utils import fetch_output_from_onnx_model
from ..common.global_context import GlobalContext
from .optimizer import OnnxSimplify
from mindinsight.mindconverter.graph_based_converter.common.utils import fetch_output_from_onnx_model
from mindinsight.mindconverter.graph_based_converter.common.global_context import GlobalContext
from mindinsight.mindconverter.graph_based_converter.third_party_graph.optimizer import OnnxSimplify


from ..constant import ONNX_TYPE_INT, ONNX_TYPE_INTS, ONNX_TYPE_STRING, \
from mindinsight.mindconverter.graph_based_converter.constant import ONNX_TYPE_INT, ONNX_TYPE_INTS, ONNX_TYPE_STRING, \
ONNX_TYPE_FLOATS, ONNX_TYPE_FLOAT, SCALAR_WITHOUT_SHAPE, DYNAMIC_SHAPE, UNKNOWN_DIM_VAL ONNX_TYPE_FLOATS, ONNX_TYPE_FLOAT, SCALAR_WITHOUT_SHAPE, DYNAMIC_SHAPE, UNKNOWN_DIM_VAL
from ...common.exceptions import GraphInitError, ModelLoadingError
from mindinsight.mindconverter.common.exceptions import GraphInitError, ModelLoadingError




def convert_tf_graph_to_onnx(model_path, model_inputs, model_outputs, opset=12): def convert_tf_graph_to_onnx(model_path, model_inputs, model_outputs, opset=12):
@@ -128,7 +128,6 @@ class ParamsAttribute:
raw_attribute (onnx.AttributeProto): onnx.AttributeProto instance. raw_attribute (onnx.AttributeProto): onnx.AttributeProto instance.
node (onnx.NodeProto): Must pass the onnx.NodeProto instance node (onnx.NodeProto): Must pass the onnx.NodeProto instance
containing the same AttributeProto. containing the same AttributeProto.

""" """


def __init__(self, raw_attribute, node): def __init__(self, raw_attribute, node):
@@ -148,7 +147,6 @@ class ParamsAttribute:


Args: Args:
attrs (onnx.AttributeProto): onnx.AttributeProto instance. attrs (onnx.AttributeProto): onnx.AttributeProto instance.

""" """
if not attrs: if not attrs:
return return
@@ -604,3 +602,18 @@ class OnnxDataLoader:
eliminated_nodes = _traceback_precursor_nodes_until_shape_op(to_shape) eliminated_nodes = _traceback_precursor_nodes_until_shape_op(to_shape)
self.dynamic_resize_node.append(nd_name) self.dynamic_resize_node.append(nd_name)
self.eliminated_nodes += eliminated_nodes self.eliminated_nodes += eliminated_nodes


class NodeWeight:
"""Node weight struct."""
def __init__(self, weight_name, weight_value):
self._weight_name = weight_name
self._weight_value = weight_value

@property
def name(self):
return self._weight_name

@property
def value(self):
return self._weight_value

+ 1
- 1
mindinsight/mindconverter/graph_based_converter/third_party_graph/optimizer.py View File

@@ -18,7 +18,7 @@ from importlib import import_module


import numpy as np import numpy as np


from ..common.utils import fetch_output_from_onnx_model
from mindinsight.mindconverter.graph_based_converter.common.utils import fetch_output_from_onnx_model




class OnnxSimplify: class OnnxSimplify:


+ 1
- 0
mindinsight/mindconverter/graph_based_converter/third_party_graph/pytorch_graph_parser.py View File

@@ -20,6 +20,7 @@ from mindinsight.mindconverter.common.log import logger as log
from mindinsight.mindconverter.graph_based_converter.third_party_graph.base import GraphParser from mindinsight.mindconverter.graph_based_converter.third_party_graph.base import GraphParser
from mindinsight.mindconverter.common.exceptions import ModelNotSupportError from mindinsight.mindconverter.common.exceptions import ModelNotSupportError



class PyTorchGraphParser(GraphParser): class PyTorchGraphParser(GraphParser):
"""Define pytorch graph parser.""" """Define pytorch graph parser."""




+ 3
- 3
mindinsight/mindconverter/graph_based_converter/third_party_graph/tf_graph_parser.py View File

@@ -1,4 +1,4 @@
# Copyright 2020 Huawei Technologies Co., Ltd.All Rights Reserved.
# Copyright 2020-2021 Huawei Technologies Co., Ltd.All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@@ -18,8 +18,8 @@ import re
from importlib import import_module from importlib import import_module


from mindinsight.mindconverter.common.log import logger as log from mindinsight.mindconverter.common.log import logger as log
from .base import GraphParser
from ...common.exceptions import ModelNotSupportError
from mindinsight.mindconverter.graph_based_converter.third_party_graph.base import GraphParser
from mindinsight.mindconverter.common.exceptions import ModelNotSupportError




class TFGraphParser(GraphParser): class TFGraphParser(GraphParser):


+ 6
- 2
tests/ut/mindconverter/graph_based_converter/mapper/test_mapper.py View File

@@ -89,7 +89,9 @@ class TestMappers:
'input': {'op_name': 'onnx::MaxPool', 'input': {'op_name': 'onnx::MaxPool',
'params': {'kernel_shape': [3, 3], 'params': {'kernel_shape': [3, 3],
'pads': [1, 1, 1, 1], 'pads': [1, 1, 1, 1],
'strides': [2, 2]},
'strides': [2, 2],
'input_shape': (1, 3, 224, 224),
'output_shape': (1, 3, 112, 112)},
'weights': dict()}, 'weights': dict()},
'expected_output': {'converter_name': 'nn.MaxPool2d', 'expected_output': {'converter_name': 'nn.MaxPool2d',
'converted_params': {'kernel_size': (3, 3), 'converted_params': {'kernel_size': (3, 3),
@@ -100,7 +102,9 @@ class TestMappers:
'input': {'op_name': 'onnx::AveragePool', 'input': {'op_name': 'onnx::AveragePool',
'params': {'kernel_shape': [3, 3], 'params': {'kernel_shape': [3, 3],
'pads': [1, 1, 1, 1], 'pads': [1, 1, 1, 1],
'strides': [2, 2]},
'strides': [2, 2],
'input_shape': (1, 3, 224, 224),
'output_shape': (1, 3, 112, 112)},
'weights': dict()}, 'weights': dict()},
'expected_output': {'converter_name': 'nn.AvgPool2d', 'expected_output': {'converter_name': 'nn.AvgPool2d',
'converted_params': {'kernel_size': (3, 3), 'converted_params': {'kernel_size': (3, 3),


+ 26
- 0
tests/utils/mindspore/train/serialization.py View File

@@ -0,0 +1,26 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Mock the MindSpore mindspore/train/serialization.py."""


def save_checkpoint(trainable_weights, ckpt_file_name):
"""
Mock save_checkpoint.

Args:
trainable_weights (list): List of weights.
ckpt_file_name (str): Path to save checkpoint file.
"""
return len(trainable_weights), ckpt_file_name

Loading…
Cancel
Save