You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

framework.py 12 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292
  1. # Copyright 2020-2021 Huawei Technologies Co., Ltd.All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. """Graph based scripts converter workflow."""
  16. import os
  17. import argparse
  18. import sys
  19. from importlib import import_module
  20. from importlib.util import find_spec
  21. import mindinsight
  22. from mindinsight.mindconverter.graph_based_converter.common.global_context import GlobalContext
  23. from mindinsight.mindconverter.graph_based_converter.common.utils import lib_version_satisfied, \
  24. save_code_file_and_report, get_framework_type
  25. from mindinsight.mindconverter.graph_based_converter.constant import FrameworkType, \
  26. ONNX_MIN_VER, TF2ONNX_MIN_VER, ONNXRUNTIME_MIN_VER
  27. from mindinsight.mindconverter.graph_based_converter.generator import batch_add_nodes
  28. from mindinsight.mindconverter.graph_based_converter.mapper import ONNXToMindSporeMapper
  29. from mindinsight.mindconverter.common.log import logger as log, logger_console as log_console
  30. from mindinsight.mindconverter.common.exceptions import GraphInitError, TreeCreationError, SourceFilesSaveError, \
  31. BaseConverterError, UnknownModelError, GeneratorError, TfRuntimeError, RuntimeIntegrityError, ParamMissingError
  32. from mindinsight.mindconverter.graph_based_converter.third_party_graph import GraphFactory
  33. permissions = os.R_OK | os.W_OK | os.X_OK
  34. os.umask(permissions << 3 | permissions)
  35. parser = argparse.ArgumentParser(
  36. prog="MindConverter",
  37. description="Graph based MindConverter CLI entry point (version: {})".format(
  38. mindinsight.__version__)
  39. )
  40. parser.add_argument("--graph", type=str, required=True,
  41. help="Third party framework's graph path.")
  42. parser.add_argument("--sample_shape", nargs='+', type=int, required=True,
  43. help="Input shape of the model.")
  44. parser.add_argument("--ckpt", type=str, required=False,
  45. help="Third party framework's checkpoint path.")
  46. parser.add_argument("--output", type=str, required=True,
  47. help="Generated scripts output folder path.")
  48. parser.add_argument("--report", type=str, required=False,
  49. help="Generated reports output folder path.")
  50. def torch_installation_validation(func):
  51. """
  52. Validate args of func.
  53. Args:
  54. func (type): Function.
  55. Returns:
  56. type, inner function.
  57. """
  58. def _f(graph_path: str, sample_shape: tuple,
  59. input_nodes: str, output_nodes: str,
  60. output_folder: str, report_folder: str = None):
  61. # Check whether pytorch is installed.
  62. if not find_spec("torch") or not find_spec("onnx") or not find_spec("onnxruntime"):
  63. error = RuntimeIntegrityError(f"PyTorch, onnx(>={ONNX_MIN_VER}) and "
  64. f"onnxruntime(>={ONNXRUNTIME_MIN_VER}) "
  65. f"are required when using graph based "
  66. f"scripts converter, and PyTorch version must "
  67. f"be consisted with model generation runtime.")
  68. log.error(error)
  69. log_console.error("\n")
  70. log_console.error(str(error))
  71. log_console.error("\n")
  72. sys.exit(0)
  73. onnx = import_module("onnx")
  74. ort = import_module("onnxruntime")
  75. if not lib_version_satisfied(getattr(onnx, "__version__"), ONNX_MIN_VER) \
  76. or not lib_version_satisfied(getattr(ort, "__version__"), ONNXRUNTIME_MIN_VER):
  77. error = RuntimeIntegrityError(
  78. f"onnx(>={ONNX_MIN_VER}) and "
  79. f"onnxruntime(>={ONNXRUNTIME_MIN_VER}) are required when using graph "
  80. f"based scripts converter for Pytorch conversion."
  81. )
  82. log.error(error)
  83. log_console.error("\n")
  84. log_console.error(str(error))
  85. log_console.error("\n")
  86. sys.exit(0)
  87. func(graph_path=graph_path, sample_shape=sample_shape,
  88. input_nodes=input_nodes, output_nodes=output_nodes,
  89. output_folder=output_folder, report_folder=report_folder)
  90. return _f
  91. def _check_tf_installation():
  92. """
  93. Check whether TensorFlow was installed.
  94. Returns:
  95. bool, true or false.
  96. """
  97. return find_spec("tensorflow") or find_spec("tensorflow-gpu")
  98. def tf_installation_validation(func):
  99. """
  100. Validate args of func.
  101. Args:
  102. func(type): Function.
  103. Returns:
  104. type, inner function.
  105. """
  106. def _f(graph_path: str, sample_shape: tuple,
  107. output_folder: str, report_folder: str = None,
  108. input_nodes: str = None, output_nodes: str = None):
  109. # Check whether tensorflow is installed.
  110. if not _check_tf_installation() or not find_spec("tf2onnx") \
  111. or not find_spec("onnx") or not find_spec("onnxruntime"):
  112. error = RuntimeIntegrityError(
  113. f"TensorFlow, tf2onnx(>={TF2ONNX_MIN_VER}), onnx(>={ONNX_MIN_VER}) and "
  114. f"onnxruntime(>={ONNXRUNTIME_MIN_VER}) are required when using graph "
  115. f"based scripts converter for TensorFlow conversion."
  116. )
  117. log.error(error)
  118. log_console.error("\n")
  119. log_console.error(str(error))
  120. log_console.error("\n")
  121. sys.exit(0)
  122. onnx, tf2onnx = import_module("onnx"), import_module("tf2onnx")
  123. ort = import_module("onnxruntime")
  124. if not lib_version_satisfied(getattr(onnx, "__version__"), ONNX_MIN_VER) \
  125. or not lib_version_satisfied(getattr(ort, "__version__"), ONNXRUNTIME_MIN_VER) \
  126. or not lib_version_satisfied(getattr(tf2onnx, "__version__"), TF2ONNX_MIN_VER):
  127. error = RuntimeIntegrityError(
  128. f"TensorFlow, tf2onnx(>={TF2ONNX_MIN_VER}), onnx(>={ONNX_MIN_VER}) and "
  129. f"onnxruntime(>={ONNXRUNTIME_MIN_VER}) are required when using graph "
  130. f"based scripts converter for TensorFlow conversion."
  131. )
  132. log.error(error)
  133. log_console.error("\n")
  134. log_console.error(str(error))
  135. log_console.error("\n")
  136. sys.exit(0)
  137. func(graph_path=graph_path, sample_shape=sample_shape,
  138. output_folder=output_folder, report_folder=report_folder,
  139. input_nodes=input_nodes, output_nodes=output_nodes)
  140. return _f
  141. def _extract_model_name(model_path):
  142. """
  143. Extract model name from model path.
  144. Args:
  145. model_path(str): Path of Converted model.
  146. Returns:
  147. str: Name of Converted model.
  148. """
  149. base_path = os.path.basename(model_path)
  150. model_name = '.'.join(base_path.split('.')[:-1])
  151. return model_name
  152. @torch_installation_validation
  153. @GraphInitError.uniform_catcher()
  154. @TreeCreationError.uniform_catcher()
  155. @SourceFilesSaveError.uniform_catcher()
  156. @GeneratorError.uniform_catcher()
  157. def graph_based_converter_pytorch_to_ms(graph_path: str, sample_shape: tuple,
  158. input_nodes: str, output_nodes: str,
  159. output_folder: str, report_folder: str = None):
  160. """
  161. PyTorch to MindSpore based on Graph.
  162. Args:
  163. graph_path (str): Graph file path.
  164. sample_shape (tuple): Input shape of the model.
  165. input_nodes (str): Input node(s) of the model.
  166. output_nodes (str): Output node(s) of the model.
  167. output_folder (str): Output folder.
  168. report_folder (str): Report output folder path.
  169. """
  170. graph_obj = GraphFactory.init(graph_path, sample_shape=sample_shape,
  171. input_nodes=input_nodes, output_nodes=output_nodes)
  172. generator_inst = batch_add_nodes(graph_obj, ONNXToMindSporeMapper)
  173. model_name = _extract_model_name(graph_path)
  174. code_fragments = generator_inst.generate()
  175. save_code_file_and_report(model_name, code_fragments, output_folder, report_folder)
  176. # Release global context.
  177. GlobalContext.release()
  178. @tf_installation_validation
  179. @GraphInitError.uniform_catcher()
  180. @TfRuntimeError.uniform_catcher()
  181. @TreeCreationError.uniform_catcher()
  182. @SourceFilesSaveError.uniform_catcher()
  183. @GeneratorError.uniform_catcher()
  184. def graph_based_converter_tf_to_ms(graph_path: str, sample_shape: tuple,
  185. input_nodes: str, output_nodes: str,
  186. output_folder: str, report_folder: str = None):
  187. """
  188. Tensorflow to MindSpore based on Graph.
  189. Args:
  190. graph_path(str): Graph file path.
  191. sample_shape(tuple): Input shape of the model.
  192. input_nodes(str): Input node(s) of the model.
  193. output_nodes(str): Output node(s) of the model.
  194. output_folder(str): Output folder.
  195. report_folder(str): Report output folder path.
  196. """
  197. # Close unnecessary log.
  198. os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
  199. graph_obj = GraphFactory.init(graph_path, sample_shape=sample_shape,
  200. input_nodes=input_nodes, output_nodes=output_nodes)
  201. generator_inst = batch_add_nodes(graph_obj, ONNXToMindSporeMapper)
  202. model_name = _extract_model_name(graph_path)
  203. code_fragments = generator_inst.generate()
  204. save_code_file_and_report(model_name, code_fragments, output_folder, report_folder)
  205. # Release global context.
  206. GlobalContext.release()
  207. @BaseConverterError.uniform_catcher()
  208. def main_graph_base_converter(file_config):
  209. """
  210. The entrance for converter, script files will be converted.
  211. Args:
  212. file_config (dict): The config of file which to convert.
  213. """
  214. graph_path = file_config['model_file']
  215. frame_type = get_framework_type(graph_path)
  216. if not file_config.get("shape"):
  217. raise ParamMissingError("Param missing, `--shape` is required when using graph mode.")
  218. if frame_type == FrameworkType.PYTORCH.value:
  219. check_params = ['input_nodes', 'output_nodes']
  220. check_params_exist(check_params, file_config)
  221. graph_based_converter_pytorch_to_ms(graph_path=graph_path,
  222. sample_shape=file_config['shape'],
  223. input_nodes=file_config['input_nodes'],
  224. output_nodes=file_config['output_nodes'],
  225. output_folder=file_config['outfile_dir'],
  226. report_folder=file_config['report_dir'])
  227. elif frame_type == FrameworkType.TENSORFLOW.value:
  228. check_params = ['input_nodes', 'output_nodes']
  229. check_params_exist(check_params, file_config)
  230. graph_based_converter_tf_to_ms(graph_path=graph_path,
  231. sample_shape=file_config['shape'],
  232. input_nodes=file_config['input_nodes'],
  233. output_nodes=file_config['output_nodes'],
  234. output_folder=file_config['outfile_dir'],
  235. report_folder=file_config['report_dir'])
  236. else:
  237. error_msg = "Get UNSUPPORTED model."
  238. error = UnknownModelError(error_msg)
  239. raise error
  240. def check_params_exist(params: list, config):
  241. """Check params exist."""
  242. miss_param_list = ''
  243. for param in params:
  244. if not config.get(param) or not config[param]:
  245. miss_param_list = ', '.join((miss_param_list, param)) if miss_param_list else param
  246. if miss_param_list:
  247. raise ParamMissingError(f"Param(s) missing, {miss_param_list} is(are) required when using graph mode.")