You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

api.py 34 kB

code check for master # Conflicts: # mindspore/common/initializer.py # mindspore/nn/cell.py # # 似乎您正在做一个拣选提交。如果不对,请删除文件 # .git/CHERRY_PICK_HEAD # 然后重试。 # 请为您的变更输入提交说明。以 '#' 开始的行将被忽略,而一个空的提交 # 说明将会终止提交。 # # 日期: Fri Aug 13 18:40:19 2021 +0800 # # 位于分支 code_review_r1.3 # 您的分支与上游分支 'ma/r1.3' 一致。 # # 您在执行拣选提交 ffda6be35c 的操作。 # # 要提交的变更: # 修改: mindspore/common/__init__.py # 修改: mindspore/common/_register_for_tensor.py # 修改: mindspore/common/api.py # 修改: mindspore/common/dtype.py # 修改: mindspore/common/initializer.py # 修改: mindspore/common/monad.py # 修改: mindspore/common/parameter.py # 修改: mindspore/common/seed.py # 修改: mindspore/common/tensor.py # 修改: mindspore/nn/cell.py # 修改: mindspore/nn/metrics/__init__.py # 修改: mindspore/nn/metrics/confusion_matrix.py # 修改: mindspore/nn/metrics/error.py # 修改: mindspore/nn/metrics/fbeta.py # 修改: mindspore/nn/metrics/loss.py # 修改: mindspore/nn/metrics/metric.py # 修改: mindspore/nn/metrics/precision.py # 修改: mindspore/nn/metrics/recall.py # 修改: mindspore/nn/metrics/topk.py # 修改: mindspore/train/callback/_checkpoint.py # 修改: mindspore/train/model.py # 修改: mindspore/train/serialization.py # # Conflicts: # mindspore/common/api.py # mindspore/common/initializer.py # mindspore/nn/metrics/confusion_matrix.py # # 似乎您正在做一个拣选提交。如果不对,请删除文件 # .git/CHERRY_PICK_HEAD # 然后重试。 # 请为您的变更输入提交说明。以 '#' 开始的行将被忽略,而一个空的提交 # 说明将会终止提交。 # # 日期: Fri Aug 13 18:40:19 2021 +0800 # # 位于分支 code_review_master # 您的分支与上游分支 'ma/master' 一致。 # # 您在执行拣选提交 743f9fbff3 的操作。 # # 要提交的变更: # 修改: mindspore/common/__init__.py # 修改: mindspore/common/_monad.py # 修改: mindspore/common/_register_for_tensor.py # 修改: mindspore/common/api.py # 修改: mindspore/common/dtype.py # 修改: mindspore/common/initializer.py # 修改: mindspore/common/parameter.py # 修改: mindspore/common/seed.py # 修改: mindspore/common/tensor.py # 修改: mindspore/nn/cell.py # 修改: mindspore/nn/metrics/__init__.py # 修改: mindspore/nn/metrics/confusion_matrix.py # 修改: mindspore/nn/metrics/error.py # 修改: mindspore/nn/metrics/fbeta.py # 修改: mindspore/nn/metrics/loss.py # 修改: mindspore/nn/metrics/metric.py # 修改: mindspore/nn/metrics/precision.py # 修改: mindspore/nn/metrics/recall.py # 修改: mindspore/nn/metrics/topk.py # 修改: mindspore/train/callback/_checkpoint.py # 修改: mindspore/train/model.py # 修改: mindspore/train/serialization.py #
4 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
4 years ago
5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828
  1. # This is the Python adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
  2. #
  3. # Copyright 2020-2021 Huawei Technologies Co., Ltd
  4. #
  5. # Licensed under the Apache License, Version 2.0 (the "License");
  6. # you may not use this file except in compliance with the License.
  7. # You may obtain a copy of the License at
  8. #
  9. # http://www.apache.org/licenses/LICENSE-2.0
  10. #
  11. # Unless required by applicable law or agreed to in writing, software
  12. # distributed under the License is distributed on an "AS IS" BASIS,
  13. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. # See the License for the specific language governing permissions and
  15. # limitations under the License.
  16. # ============================================================================
  17. """Providing interface methods."""
  18. import types
  19. import sys
  20. import os
  21. import time
  22. import traceback
  23. import ast
  24. import importlib
  25. from collections import OrderedDict
  26. from functools import wraps
  27. from mindspore import context
  28. from mindspore import log as logger
  29. from mindspore._extends.remote import kernel_build_server
  30. from .tensor import Tensor as MsTensor
  31. from .tensor import CSRTensor as MsCSRTensor
  32. from .._c_expression import generate_arguments_key, GraphExecutor_, Tensor, MetaTensor, CSRTensor, PynativeExecutor_
  33. from .._c_expression import verify_inputs_signature, init_exec_dataset, _set_dataset_mode_config, init_pipeline
  34. from ..parallel._ps_context import _is_role_pserver
  35. from ..parallel._utils import _get_device_num, _get_global_rank, _need_to_full, _check_full_batch, _to_full_tensor, \
  36. _get_parameter_broadcast, _get_pipeline_stages
  37. from .._checkparam import Validator
  38. # store ms_function class compiled pipeline cache
  39. ms_compile_cache = {}
  40. BROADCAST_PHASE = "_broadcast_"
  41. def _convert_function_arguments(fn, *args):
  42. """
  43. Process the fn default parameters.
  44. Args:
  45. fn (Function): The function to be parsed.
  46. args (tuple): The parameters of the function.
  47. """
  48. arguments_dict = OrderedDict()
  49. parse_method = None
  50. if isinstance(fn, (types.FunctionType, types.MethodType)):
  51. parse_method = fn.__name__
  52. index = 0
  53. for value in args:
  54. arguments_dict[f'arg{index}'] = value
  55. index = index + 1
  56. logger.debug("fn(%r) full parameters dict is: %r", fn, arguments_dict)
  57. converted = True
  58. else:
  59. logger.warning("Find error: fn isn't function or method")
  60. converted = False
  61. return converted, arguments_dict, parse_method
  62. def _wrap_func(fn):
  63. """
  64. Wrapper function, convert return data to tensor or tuple of tensor.
  65. Args:
  66. fn (Function): The function need be wrapped.
  67. Returns:
  68. Function, a new function with return suitable format data.
  69. """
  70. @wraps(fn)
  71. def wrapper(*arg, **kwargs):
  72. results = fn(*arg, **kwargs)
  73. def _convert_data(data):
  74. if isinstance(data, Tensor) and not isinstance(data, MsTensor):
  75. return MsTensor(data)
  76. if isinstance(data, CSRTensor) and not isinstance(data, MsCSRTensor):
  77. return MsCSRTensor(csr_tensor=data)
  78. if isinstance(data, tuple):
  79. return tuple(_convert_data(x) for x in data)
  80. if isinstance(data, list):
  81. return list(_convert_data(x) for x in data)
  82. return data
  83. return _convert_data(results)
  84. return wrapper
  85. def _exec_init_graph(obj, init_phase):
  86. """Execute the parameter initializer graph."""
  87. inst_executor = GraphExecutor_.get_instance()
  88. param_dict = OrderedDict()
  89. for name, param in obj.parameters_dict().items():
  90. if not param.is_init:
  91. param_dict[name] = param
  92. param.is_init = True
  93. param.data.init_flag = True
  94. if param_dict:
  95. inst_executor.run_init_graph(param_dict, init_phase)
  96. def _check_all_tensor(sequence):
  97. for element in sequence:
  98. if not isinstance(element, Tensor) and not (isinstance(element, tuple) and _check_all_tensor(element)):
  99. return False
  100. return True
  101. def _get_filename_from_trace(trace):
  102. # format: File "xxx.py", line x, in <module>
  103. strings = trace.strip().split(' ')
  104. filename = strings[1].rstrip(',').strip('"')
  105. return filename
  106. def __get_compile_cache_dep_files(file_path, python_bin_dir, compile_cache_dep_files):
  107. """Get the dependency files of the network"""
  108. with open(file_path) as fh:
  109. root = ast.parse(fh.read(), file_path)
  110. for node in ast.iter_child_nodes(root):
  111. module_name = ""
  112. if isinstance(node, ast.ImportFrom):
  113. module_name = node.module
  114. elif not isinstance(node, ast.Import):
  115. continue
  116. # Do not care the files in mindspore package
  117. if module_name.startswith("mindspore"):
  118. continue
  119. for n in node.names:
  120. if n.name.startswith("mindspore"):
  121. continue
  122. if module_name == "":
  123. whole_module = n.name
  124. else:
  125. whole_module = module_name
  126. if not n.name is None:
  127. whole_module += "." + n.name
  128. try:
  129. module_spec = importlib.util.find_spec(whole_module)
  130. except (ModuleNotFoundError, ValueError):
  131. whole_module = whole_module[0:whole_module.rfind('.')]
  132. module_spec = importlib.util.find_spec(whole_module)
  133. if module_spec is None:
  134. continue
  135. module = importlib.util.module_from_spec(module_spec)
  136. if hasattr(module, '__file__'):
  137. dep_file_path = module.__file__
  138. else:
  139. continue
  140. if not dep_file_path.startswith(python_bin_dir) and not dep_file_path in compile_cache_dep_files:
  141. logger.debug(f"dependent file path: {dep_file_path}")
  142. compile_cache_dep_files.append(dep_file_path)
  143. __get_compile_cache_dep_files(dep_file_path, python_bin_dir, compile_cache_dep_files)
  144. def _get_compile_cache_dep_files():
  145. """Get the dependency files of the network"""
  146. python_bin_path = sys.executable
  147. if python_bin_path.endswith('bin/python'):
  148. python_bin_dir = python_bin_path[:-10]
  149. else:
  150. return []
  151. tb = traceback.format_stack()
  152. compile_cache_dep_files = []
  153. filename = None
  154. # Get the entry script file.
  155. entry_id = 0
  156. while entry_id < len(tb) and _get_filename_from_trace(tb[entry_id]).startswith(python_bin_dir):
  157. logger.debug(f"trace: {tb[entry_id]}")
  158. entry_id += 1
  159. if entry_id < len(tb):
  160. filename = _get_filename_from_trace(tb[entry_id])
  161. if filename is None:
  162. return []
  163. file_path = os.path.realpath(filename)
  164. logger.debug(f"entry script file path: {file_path}")
  165. compile_cache_dep_files.append(file_path)
  166. __get_compile_cache_dep_files(file_path, python_bin_dir, compile_cache_dep_files)
  167. return compile_cache_dep_files
  168. class _MindsporeFunctionExecutor:
  169. """
  170. Represents a function compiled by graph compiler.
  171. _MindsporeFunctionExecutor will compile the original function for every combination
  172. of argument types and shapes it is given (as well as their values, optionally).
  173. Args:
  174. fn (Function): The root function to compile.
  175. input_signature (Function): User defines signature to verify input.
  176. ms_create_time(TimeStamp): The time ms_function created
  177. obj (Object): If function is a method, obj is the owner of function,
  178. else, obj is none.
  179. Returns:
  180. The result of pipeline running in graph mode.
  181. """
  182. def __init__(self, fn, ms_create_time, input_signature=None, obj=None):
  183. self.fn = fn
  184. self.input_signature = input_signature
  185. self.obj = None
  186. if hasattr(obj, fn.__name__):
  187. self.obj = obj
  188. self._graph_executor = GraphExecutor_.get_instance()
  189. self._create_time = ms_create_time
  190. def build_data_init_graph(self, graph_name):
  191. """Build GE data graph and init graph for the given graph name."""
  192. if self.obj is None:
  193. logger.warning("Make sure parameter should not be used in function")
  194. para_dict = OrderedDict()
  195. self._graph_executor.build_data_graph(para_dict, graph_name)
  196. return
  197. self._graph_executor.build_data_graph(self.obj.parameters_dict(), graph_name,
  198. self.obj.parameters_broadcast_dict())
  199. init_phase = "init_subgraph" + graph_name[graph_name.find("."):]
  200. _exec_init_graph(self.obj, init_phase)
  201. def compile(self, args_list, arg_names, method_name):
  202. """Returns pipeline for the given args."""
  203. # Verify the signature for both function and method
  204. if self.input_signature is not None:
  205. signatures = []
  206. for sig_spec in self.input_signature:
  207. if not isinstance(sig_spec, MetaTensor):
  208. raise TypeError("Input_signature is not MetaTensor")
  209. signatures.append(sig_spec)
  210. is_valid_input = verify_inputs_signature(signatures, args_list)
  211. if not is_valid_input:
  212. raise ValueError("Inputs is incompatible with input signature!")
  213. dic = dict(zip(arg_names, args_list))
  214. generate_name = self.fn.__module__ + "." + self.fn.__name__ + "." + self.fn.__code__.co_filename + "." + \
  215. str(self.fn.__code__.co_firstlineno) + '.' + str(id(self.fn))
  216. if _pynative_executor.grad_flag():
  217. generate_name = generate_name + ".grad"
  218. self.fn.__parse_method__ = method_name
  219. # Add key with obj
  220. if self.obj is not None:
  221. if self.obj.__module__ != self.fn.__module__:
  222. logger.error(f'`obj` module not equal to `fn` module: {self.obj.__module__}, {self.fn.__module__}')
  223. self.obj.__parse_method__ = method_name
  224. generate_name = generate_name + '.' + str(self.obj.create_time) + '.' + str(id(self.obj))
  225. else:
  226. # Different instance of same class may use same memory(means same obj_id) at diff times.
  227. # To avoid unexpected phase matched, add create_time to generate_name.
  228. generate_name = generate_name + '.' + str(self._create_time)
  229. if hasattr(self.obj, "enable_tuple_broaden"):
  230. self.enable_tuple_broaden = self.obj.enable_tuple_broaden
  231. else:
  232. self.enable_tuple_broaden = False
  233. self._graph_executor.set_enable_tuple_broaden(self.enable_tuple_broaden)
  234. key = generate_arguments_key(dic, self.enable_tuple_broaden)
  235. phase = generate_name + '.' + str(key)
  236. if phase in ms_compile_cache.keys():
  237. return phase
  238. if self.obj is None:
  239. is_compile = self._graph_executor.compile(self.fn, args_list, phase, True)
  240. else:
  241. self._graph_executor.set_weights_values(self.obj.parameters_dict())
  242. is_compile = self._graph_executor.compile(self.obj, args_list, phase, True)
  243. if not is_compile:
  244. raise RuntimeError("Executor compile failed.")
  245. if context.get_context("enable_ge"):
  246. self.build_data_init_graph(phase)
  247. ms_compile_cache[phase] = phase
  248. return phase
  249. @_wrap_func
  250. def __call__(self, *args):
  251. init_pipeline()
  252. converted, arguments_dict, parse_method = _convert_function_arguments(self.fn, *args)
  253. if not converted:
  254. raise RuntimeError('Process function parameter is failure')
  255. args_list = tuple(arguments_dict.values())
  256. arg_names = tuple(arguments_dict.keys())
  257. if self.obj is not None:
  258. args_list = args_list[1:]
  259. arg_names = arg_names[1:]
  260. phase = self.compile(args_list, arg_names, parse_method)
  261. if context.get_context("precompile_only"):
  262. return None
  263. new_inputs = []
  264. for i in args_list:
  265. if isinstance(i, Tensor):
  266. new_inputs.append(i)
  267. elif context.get_context("grad_for_scalar") and isinstance(i, (int, float)):
  268. new_inputs.append(i)
  269. elif self.enable_tuple_broaden and isinstance(i, tuple) and _check_all_tensor(i):
  270. new_inputs.append(i)
  271. output = self._graph_executor(tuple(new_inputs), phase)
  272. if context.get_context("mode") == context.PYNATIVE_MODE:
  273. _pynative_executor.set_graph_phase(phase)
  274. output = _pynative_executor.grad_ms_function(output, *new_inputs)
  275. return output
  276. def ms_function(fn=None, obj=None, input_signature=None):
  277. """
  278. Create a callable MindSpore graph from a Python function.
  279. This allows the MindSpore runtime to apply optimizations based on graph.
  280. Args:
  281. fn (Function): The Python function that will be run as a graph. Default: None.
  282. obj (Object): The Python object is used to distinguish the compiled function. Default: None.
  283. input_signature (Tensor): The Tensor which describes the input arguments. The shape and dtype of the Tensor
  284. will be supplied to this function. If input_signature is specified, each input to `fn` must be a `Tensor`.
  285. And the input parameters of `fn` cannot accept `**kwargs`. The shape and dtype of actual inputs should
  286. keep the same as input_signature. Otherwise, TypeError will be raised. Default: None.
  287. Returns:
  288. Function, if `fn` is not None, returns a callable function that will execute the compiled function; If `fn` is
  289. None, returns a decorator and when this decorator invokes with a single `fn` argument, the callable function is
  290. equal to the case when `fn` is not None.
  291. Supported Platforms:
  292. ``Ascend`` ``GPU`` ``CPU``
  293. Examples:
  294. >>> import numpy as np
  295. >>> from mindspore import Tensor
  296. >>> from mindspore import ms_function
  297. ...
  298. >>> x = Tensor(np.ones([1, 1, 3, 3]).astype(np.float32))
  299. >>> y = Tensor(np.ones([1, 1, 3, 3]).astype(np.float32))
  300. ...
  301. >>> # create a callable MindSpore graph by calling ms_function
  302. >>> def tensor_add(x, y):
  303. ... z = x + y
  304. ... return z
  305. ...
  306. >>> tensor_add_graph = ms_function(fn=tensor_add)
  307. >>> out = tensor_add_graph(x, y)
  308. ...
  309. >>> # create a callable MindSpore graph through decorator @ms_function
  310. >>> @ms_function
  311. ... def tensor_add_with_dec(x, y):
  312. ... z = x + y
  313. ... return z
  314. ...
  315. >>> out = tensor_add_with_dec(x, y)
  316. ...
  317. >>> # create a callable MindSpore graph through decorator @ms_function with input_signature parameter
  318. >>> @ms_function(input_signature=(Tensor(np.ones([1, 1, 3, 3]).astype(np.float32)),
  319. ... Tensor(np.ones([1, 1, 3, 3]).astype(np.float32))))
  320. ... def tensor_add_with_sig(x, y):
  321. ... z = x + y
  322. ... return z
  323. ...
  324. >>> out = tensor_add_with_sig(x, y)
  325. """
  326. def wrap_mindspore(func):
  327. ms_create_time = int(time.time() * 1e9)
  328. @wraps(func)
  329. def staging_specialize(*args):
  330. if obj is not None:
  331. logger.warning("Obj is no longer in use, and the function's own object has been used to \
  332. distinguish whether it has been compiled.")
  333. process_obj = None
  334. if args and not isinstance(args[0], MsTensor) and hasattr(args[0], func.__name__):
  335. process_obj = args[0]
  336. out = _MindsporeFunctionExecutor(func, ms_create_time, input_signature, process_obj)(*args)
  337. return out
  338. return staging_specialize
  339. if fn is not None:
  340. return wrap_mindspore(fn)
  341. return wrap_mindspore
  342. def _generate_pip_args(obj, *args, method="construct"):
  343. """Generate arguments for pipeline."""
  344. if hasattr(obj, method):
  345. fn = getattr(obj, method)
  346. else:
  347. raise AttributeError('The process method is not exist')
  348. converted, arguments_dict, parse_method = _convert_function_arguments(fn, *args)
  349. if not converted:
  350. raise RuntimeError('Process method parameter is failure')
  351. args_list = tuple(arguments_dict.values())
  352. args_names = tuple(arguments_dict.keys())
  353. obj.__parse_method__ = parse_method
  354. return args_names, args_list
  355. def _get_auto_split_param_names(parameter_layout_dict):
  356. auto_split_param_names = []
  357. for key, value in parameter_layout_dict.items():
  358. for dim in value[1]:
  359. if dim != -1:
  360. auto_split_param_names.append(key)
  361. break
  362. return auto_split_param_names
  363. def _build_broadcast_graph(broadcast_params_dict, broadcast_phase):
  364. """Build broadcast graph."""
  365. from mindspore.nn.wrap.cell_wrapper import _BroadCastCell
  366. if not broadcast_params_dict:
  367. broadcast_params_dict = {}
  368. broadcast_params = []
  369. for param in broadcast_params_dict.values():
  370. broadcast_params.append(Tensor(param.asnumpy()))
  371. _broadcast_net = _BroadCastCell(broadcast_params)
  372. _broadcast_net.phase = broadcast_phase
  373. broadcasted_params = _broadcast_net()
  374. for param_name, param in zip(broadcast_params_dict.keys(), broadcasted_params):
  375. broadcast_params_dict[param_name].set_data(param)
  376. def _parameter_broadcast(obj, auto_parallel_mode):
  377. """Parameter broadcast."""
  378. auto_split_param_names = []
  379. if auto_parallel_mode:
  380. auto_split_param_names = _get_auto_split_param_names(obj.parameter_layout_dict)
  381. broadcast_params_dict = obj.parameters_broadcast_dict()
  382. if auto_split_param_names and broadcast_params_dict:
  383. broadcast_params_dict = OrderedDict()
  384. for param_name, param in obj.parameters_broadcast_dict().items():
  385. if param_name not in auto_split_param_names:
  386. broadcast_params_dict[param_name] = param
  387. broadcast_phase = "_broadcast_subgraph"
  388. _build_broadcast_graph(broadcast_params_dict, broadcast_phase)
  389. class _PynativeExecutor:
  390. """
  391. A pynative executor used to compile/manage/run single op.
  392. The main functions include: construct op graph, compile op graph, auto grad and run op graph.
  393. Args:
  394. obj (Object): The python network that will be run in pynative mode.
  395. args (Tuple(Tensor...)): The inputs of network in tuple form.
  396. Returns:
  397. gradients (Tuple(Tensor...)): The gradients of network parameters and inputs.
  398. Supported Platforms:
  399. ``Ascend`` ``GPU`` ``CPU``
  400. """
  401. def __init__(self):
  402. self._executor = PynativeExecutor_.get_instance()
  403. self._executor.set_py_exe_path(sys.executable)
  404. self._executor.set_kernel_build_server_dir(os.path.split(kernel_build_server.__file__)[0] + os.sep)
  405. def new_graph(self, obj, *args, **kwargs):
  406. self._executor.new_graph(obj, *args, *(kwargs.values()))
  407. def end_graph(self, obj, output, *args, **kwargs):
  408. self._executor.end_graph(obj, output, *args, *(kwargs.values()))
  409. def check_graph(self, obj, *args, **kwargs):
  410. return self._executor.check_graph(obj, *args, *(kwargs.values()))
  411. def check_run(self, grad, obj, *args, **kwargs):
  412. return self._executor.check_run(grad, obj, *args, *(kwargs.values()))
  413. def grad(self, grad, obj, weights, grad_position, *args, **kwargs):
  414. self._executor.grad_net(grad, obj, weights, grad_position, *args, *(kwargs.values()))
  415. def del_cell(self, cell_id=""):
  416. self._executor.clear_cell(cell_id)
  417. def clear_res(self):
  418. return self._executor.clear_res()
  419. def clear_grad(self, obj, *args, **kwargs):
  420. self._executor.clear_grad(obj, *args, *(kwargs.values()))
  421. def sync(self):
  422. self._executor.sync()
  423. def set_lazy_build(self, enable):
  424. self._executor.set_lazy_build(enable)
  425. def execute_all_task(self):
  426. self._executor.execute_all_task()
  427. def grad_ms_function(self, output, *args):
  428. return self._executor.grad_ms_function(output, *args)
  429. def set_graph_phase(self, phase):
  430. self._executor.set_graph_phase(phase)
  431. def grad_flag(self):
  432. return self._executor.grad_flag()
  433. def set_grad_flag(self, flag):
  434. self._executor.set_grad_flag(flag)
  435. def parameter_broadcast(self, obj, phase, auto_parallel_mode):
  436. if BROADCAST_PHASE not in phase and _get_parameter_broadcast():
  437. _parameter_broadcast(obj, auto_parallel_mode)
  438. def enter_cell(self):
  439. self._executor.enter_cell()
  440. def exit_cell(self):
  441. self._executor.exit_cell()
  442. def is_top_cell(self):
  443. return self._executor.is_top_cell()
  444. def __call__(self, obj, *args, **kwargs):
  445. args = args + tuple(kwargs.values())
  446. return self._executor(obj, args)
  447. class _CellGraphExecutor:
  448. """
  449. An executor used to compile/manage/run graph for a Cell.
  450. Including data_graph, train_graph, eval_graph and predict graph.
  451. Args:
  452. obj (Function/Cell): The function or cell instance need compile.
  453. args (tuple): Function or cell input arguments.
  454. Returns:
  455. Graph, return the result of pipeline running.
  456. """
  457. VALID_JIT_CONFIG_PARAM = ["jit_level"]
  458. VALID_JIT_CONFIG_PARAM_VALUE = {
  459. "jit_level": ["o0", "o1"]
  460. }
  461. def __init__(self):
  462. # create needed graph by lazy mode
  463. self.is_init = False
  464. self._graph_executor = GraphExecutor_.get_instance()
  465. self._graph_executor.set_py_exe_path(sys.executable)
  466. self._graph_executor.set_kernel_build_server_dir(os.path.split(kernel_build_server.__file__)[0] + os.sep)
  467. def init_dataset(self, queue_name, dataset_size, batch_size, dataset_types, dataset_shapes,
  468. input_indexs, phase='dataset'):
  469. """
  470. Initialization interface for calling data subgraph.
  471. Args:
  472. queue_name (str): The name of tdt queue on the device.
  473. dataset_size (int): The size of dataset.
  474. batch_size (int): The size of batch.
  475. dataset_types (list): The output types of element in dataset.
  476. dataset_shapes (list): The output shapes of element in dataset.
  477. input_indexs (list): The index of data with net.
  478. phase (str): The name of phase, e.g., train_dataset/eval_dataset. Default: 'dataset'.
  479. Returns:
  480. bool, specifies whether the data subgraph was initialized successfully.
  481. """
  482. if not init_exec_dataset(queue_name=queue_name,
  483. size=dataset_size,
  484. batch_size=batch_size,
  485. types=dataset_types,
  486. shapes=dataset_shapes,
  487. input_indexs=input_indexs,
  488. phase=phase):
  489. raise RuntimeError("Failure to init and dataset subgraph!")
  490. self._graph_executor.set_queue_name(queue_name)
  491. return True
  492. def _build_data_graph(self, obj, phase):
  493. self._graph_executor.build_data_graph(obj.parameters_dict(), phase, obj.parameters_broadcast_dict())
  494. def _set_dataset_mode(self, args_list):
  495. """set dataset mode."""
  496. # decide whether to sink based on whether the inputs is virtual or args_list is ()
  497. if (args_list and isinstance(args_list[0], Tensor) and args_list[0].virtual_flag) or \
  498. (args_list is not None and args_list == ()):
  499. _set_dataset_mode_config('sink')
  500. else:
  501. _set_dataset_mode_config('normal')
  502. @staticmethod
  503. def _use_vm_mode():
  504. enable_ge = context.get_context("enable_ge")
  505. enable_debug_runtime = context.get_context("enable_debug_runtime")
  506. exe_mode = context.get_context("mode") == context.PYNATIVE_MODE
  507. return not enable_ge or (enable_debug_runtime and exe_mode)
  508. def _set_compile_cache_dep_files(self, phase):
  509. # If enable compile cache, get the dependency files list
  510. enable_compile_cache = context.get_context("enable_compile_cache")
  511. if enable_compile_cache is None:
  512. enable_compile_cache = os.getenv('MS_COMPILER_CACHE_ENABLE')
  513. if "train" in phase and (enable_compile_cache is True or enable_compile_cache == "1"):
  514. self._graph_executor.set_compile_cache_dep_files(_get_compile_cache_dep_files())
  515. def compile(self, obj, *args, phase='predict', do_convert=True, auto_parallel_mode=False):
  516. """
  517. Compiles graph.
  518. Args:
  519. obj (Function/Cell): The function or cell instance need compile.
  520. args (tuple): Function or cell input arguments.
  521. phase (str): The name of compile phase. Default: 'predict'.
  522. do_convert (bool): When set to True, convert ME graph to GE graph after compiling graph.
  523. auto_parallel_mode: When set to True, use auto parallel mode to compile graph.
  524. Return:
  525. Str, the full phase of the cell.
  526. Bool, if the graph has been compiled before, return False, else return True.
  527. """
  528. args_names, args_list = _generate_pip_args(obj, *args)
  529. dic = dict(zip(args_names, args_list))
  530. if hasattr(obj, "enable_tuple_broaden"):
  531. self.enable_tuple_broaden = obj.enable_tuple_broaden
  532. else:
  533. self.enable_tuple_broaden = False
  534. self._graph_executor.set_enable_tuple_broaden(self.enable_tuple_broaden)
  535. key = generate_arguments_key(dic, self.enable_tuple_broaden)
  536. obj.arguments_key = str(key)
  537. phase = phase + '.' + str(obj.create_time) + '.' + str(id(obj)) + '.' + obj.arguments_key
  538. if phase in obj.compile_cache and self.has_compiled(phase):
  539. logger.debug("%r graph has existed.", phase)
  540. return phase, False
  541. obj.check_names()
  542. _check_full_batch()
  543. self._set_dataset_mode(args_list)
  544. self._set_compile_cache_dep_files(phase)
  545. is_sink_mode = args and isinstance(args[0], Tensor) and args[0].virtual_flag
  546. if auto_parallel_mode and _need_to_full() and not is_sink_mode and obj.auto_parallel_compile_and_run():
  547. args_full = _to_full_tensor(args, _get_device_num(), _get_global_rank())
  548. _, args_list = _generate_pip_args(obj, *args_full)
  549. enable_ge = context.get_context("enable_ge")
  550. use_vm = self._use_vm_mode()
  551. self._graph_executor.set_weights_values(obj.parameters_dict())
  552. result = self._graph_executor.compile(obj, args_list, phase, use_vm)
  553. obj.compile_cache.add(phase)
  554. if not result:
  555. raise RuntimeError("Executor compile failed.")
  556. graph = self._graph_executor.get_func_graph(phase)
  557. if graph is None:
  558. raise RuntimeError("Compile graph failed for phase {}.".format(phase))
  559. self._auto_parallel_process(obj, phase, is_sink_mode, auto_parallel_mode, *args)
  560. if not do_convert:
  561. return phase, True
  562. # the following GE init process is not needed when use vm or ms backend
  563. if enable_ge:
  564. self._build_data_graph(obj, phase)
  565. if "export" not in phase:
  566. init_phase = "init_subgraph." + str(obj.create_time) + "." + str(id(obj))
  567. _exec_init_graph(obj, init_phase)
  568. elif "export" in phase:
  569. self._build_data_graph(obj, phase)
  570. elif BROADCAST_PHASE not in phase and _get_parameter_broadcast():
  571. _parameter_broadcast(obj, auto_parallel_mode)
  572. return phase, True
  573. def _auto_parallel_process(self, obj, phase, is_sink_mode, auto_parallel_mode, *args):
  574. """compile graph in auto parallel mode."""
  575. if not auto_parallel_mode:
  576. replace = obj.init_parameters_data(auto_parallel_mode=auto_parallel_mode)
  577. self._update_param_node_default_input(phase, replace)
  578. return
  579. obj.parameter_layout_dict = self._graph_executor.get_parameter_layout(phase)
  580. obj.parallel_parameter_name_list = self._graph_executor.get_parallel_parameter_name_list(phase)
  581. replace = obj.init_parameters_data(auto_parallel_mode=True)
  582. if _get_pipeline_stages() > 1 and (not hasattr(obj, "is_first_iteration") or not obj.is_first_iteration):
  583. obj.remove_redundant_parameters()
  584. if not context.get_context("enable_debug_runtime") or context.get_context("enable_ge"):
  585. obj.load_parameter_slice(None)
  586. self._update_param_node_default_input(phase, replace)
  587. # set parallel inputs in sink mode
  588. if is_sink_mode:
  589. obj.set_parallel_input_with_inputs(*args)
  590. def _update_param_node_default_input(self, phase, replace):
  591. new_param = {x.name: replace[x] for x in replace if id(x) != id(replace[x])}
  592. return self._graph_executor.updata_param_node_default_input(phase, new_param)
  593. def _get_shard_strategy(self, obj):
  594. real_phase = obj.phase + '.' + str(obj.create_time) + '.' + str(id(obj)) + '.' + obj.arguments_key
  595. return self._graph_executor.get_strategy(real_phase)
  596. def _get_num_parallel_ops(self, obj):
  597. real_phase = obj.phase + '.' + str(obj.create_time) + '.' + str(id(obj)) + '.' + obj.arguments_key
  598. return self._graph_executor.get_num_parallel_ops(real_phase)
  599. def _get_allreduce_fusion(self, obj):
  600. real_phase = obj.phase + '.' + str(obj.create_time) + '.' + str(id(obj)) + '.' + obj.arguments_key
  601. return self._graph_executor.get_allreduce_fusion(real_phase)
  602. def has_compiled(self, phase='predict'):
  603. """
  604. Specify whether have been compiled.
  605. Args:
  606. phase (str): The phase name. Default: 'predict'.
  607. Returns:
  608. bool, specifies whether the specific graph has been compiled.
  609. """
  610. return self._graph_executor.has_compiled(phase)
  611. def __call__(self, obj, *args, phase='predict'):
  612. if context.get_context("precompile_only") or _is_role_pserver():
  613. return None
  614. return self.run(obj, *args, phase=phase)
  615. @_wrap_func
  616. def _exec_pip(self, obj, *args, phase=''):
  617. """Execute the generated pipeline."""
  618. fn = obj.construct
  619. converted, arguments_dict, parse_method = _convert_function_arguments(fn, *args)
  620. if not converted:
  621. raise RuntimeError('Process method parameter is failure')
  622. args_list = tuple(arguments_dict.values())
  623. obj.__parse_method__ = parse_method
  624. return self._graph_executor(args_list, phase)
  625. def run(self, obj, *args, phase='predict'):
  626. """
  627. Run the specific graph.
  628. Args:
  629. phase (str): The phase name. Default: 'predict'.
  630. Returns:
  631. Tensor/Tuple, return execute result.
  632. """
  633. if phase == 'save':
  634. return self._graph_executor((), phase + '.' + str(obj.create_time) + '.' + str(id(obj)))
  635. phase_real = phase + '.' + str(obj.create_time) + '.' + str(id(obj)) + '.' + obj.arguments_key
  636. if self.has_compiled(phase_real):
  637. return self._exec_pip(obj, *args, phase=phase_real)
  638. raise KeyError('{} graph is not exist.'.format(phase_real))
  639. def del_net_res(self, net_id):
  640. self._graph_executor.del_net_res(net_id)
  641. def _get_func_graph_proto(self, obj, exec_id, ir_type="onnx_ir", use_prefix=False):
  642. """Get graph proto from pipeline."""
  643. if use_prefix:
  644. exec_id = exec_id + '.' + obj.arguments_key
  645. if self._graph_executor.has_compiled(exec_id) is False:
  646. return None
  647. return self._graph_executor.get_func_graph_proto(exec_id, ir_type)
  648. def get_optimize_graph_proto(self, obj):
  649. """Return optimize graph binary proto."""
  650. exec_id = obj.phase + "." + str(obj.create_time) + '.' + str(id(obj)) + '.' + obj.arguments_key
  651. if self._graph_executor.has_compiled(exec_id) is False:
  652. return None
  653. graph_proto = self._graph_executor.get_optimize_graph_proto(exec_id)
  654. if isinstance(graph_proto, str) and graph_proto == "":
  655. logger.warning("Can not get optimize graph proto. Instead, try to find function graph.")
  656. graph_proto = obj.get_func_graph_proto()
  657. return graph_proto
  658. def export(self, file_name, graph_id):
  659. """
  660. Export graph.
  661. Args:
  662. file_name (str): File name of model to export
  663. graph_id (str): id of graph to be exported
  664. """
  665. from .._c_expression import export_graph
  666. export_graph(file_name, 'AIR', graph_id)
  667. def fetch_info_for_quant_export(self, exec_id):
  668. """Get graph proto from pipeline."""
  669. if self._graph_executor.has_compiled(exec_id) is False:
  670. return None
  671. return self._graph_executor.fetch_info_for_quant_export(exec_id)
  672. def set_jit_config(self, jit_config):
  673. """Set jit config."""
  674. self._check_jit_config(jit_config)
  675. self._graph_executor.set_jit_config(jit_config)
  676. def _check_jit_config(self, jit_config):
  677. """Check the value of jit config."""
  678. if not isinstance(jit_config, dict):
  679. raise ValueError("The jit_config should be a string.")
  680. for param_name, param_value in jit_config.items():
  681. Validator.check_string(param_name, self.VALID_JIT_CONFIG_PARAM, "jit_config")
  682. Validator.check_string(param_value, self.VALID_JIT_CONFIG_PARAM_VALUE.get(param_name), param_name,
  683. "jit_config")
  684. _cell_graph_executor = _CellGraphExecutor()
  685. _pynative_executor = _PynativeExecutor()
  686. __all__ = ['ms_function']