You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

api.py 21 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537
  1. # This is the Python adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
  2. #
  3. # Copyright 2020 Huawei Technologies Co., Ltd
  4. #
  5. # Licensed under the Apache License, Version 2.0 (the "License");
  6. # you may not use this file except in compliance with the License.
  7. # You may obtain a copy of the License at
  8. #
  9. # http://www.apache.org/licenses/LICENSE-2.0
  10. #
  11. # Unless required by applicable law or agreed to in writing, software
  12. # distributed under the License is distributed on an "AS IS" BASIS,
  13. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. # See the License for the specific language governing permissions and
  15. # limitations under the License.
  16. # ============================================================================
  17. """Providing interface methods."""
  18. import types
  19. from collections import OrderedDict
  20. from functools import wraps
  21. from mindspore import context
  22. from mindspore import log as logger
  23. from .._c_expression import generate_key, Executor_, Tensor, MetaTensor, PynativeExecutor_
  24. from .._c_expression import verify_inputs_signature, init_exec_dataset, _set_dataset_mode_config, init_backend
  25. from .tensor import Tensor as MsTensor
  26. from ..parallel._utils import _get_device_num, _get_global_rank, _need_to_full, _check_full_batch, _to_full_tensor
  27. from ..parallel._ps_context import _is_role_pserver
  28. # store ms_function class compiled pipeline cache
  29. ms_compile_cache = {}
  30. def _convert_function_arguments(fn, *args):
  31. """
  32. Process the fn default parameters.
  33. Args:
  34. fn (Function): The function to be parsed.
  35. args (tuple): The parameters of the function.
  36. """
  37. arguments_dict = OrderedDict()
  38. parse_method = None
  39. if isinstance(fn, (types.FunctionType, types.MethodType)):
  40. parse_method = fn.__name__
  41. index = 0
  42. for value in args:
  43. arguments_dict[f'arg{index}'] = value
  44. index = index + 1
  45. logger.debug("fn(%r) full parameters dict is: %r", fn, arguments_dict)
  46. converted = True
  47. else:
  48. logger.warning("Find error: fn isn't function or method")
  49. converted = False
  50. return converted, arguments_dict, parse_method
  51. def _wrap_func(fn):
  52. """
  53. Wrapper function, convert return data to tensor or tuple of tensor.
  54. Args:
  55. fn (Function): The function need be wrapped.
  56. Returns:
  57. Function, a new function with return suitable format data.
  58. """
  59. @wraps(fn)
  60. def wrapper(*arg, **kwargs):
  61. results = fn(*arg, **kwargs)
  62. def _convert_data(data):
  63. if isinstance(data, Tensor) and not isinstance(data, MsTensor):
  64. return MsTensor(data)
  65. if isinstance(data, tuple):
  66. return tuple(_convert_data(x) for x in data)
  67. if isinstance(data, list):
  68. return list(_convert_data(x) for x in data)
  69. return data
  70. return _convert_data(results)
  71. return wrapper
  72. def _exec_init_graph(obj, init_phase):
  73. """Execute the parameter initializer graph."""
  74. inst_executor = Executor_.get_instance()
  75. param_dict = OrderedDict()
  76. for name, param in obj.parameters_dict().items():
  77. if not param.is_init:
  78. param_dict[name] = param
  79. param.is_init = True
  80. param.data.init_flag = True
  81. if param_dict:
  82. inst_executor.run_init_graph(param_dict, init_phase)
  83. class _MindSporeFunction:
  84. """
  85. Represents a function compiled by mind expression.
  86. _MindSporeFunction will compile the original function for every combination
  87. of argument types and shapes it is given (as well as their values, optionally).
  88. Args:
  89. fn (Function): The root function to compile.
  90. input_signature (Function): User defines signature to verify input.
  91. obj (Object): If function is a method, obj is the owner of function,
  92. else, obj is none.
  93. """
  94. def __init__(self, fn, input_signature=None, obj=None):
  95. self.fn = fn
  96. self.save_graphs = context.get_context("save_graphs")
  97. self.save_graphs_path = context.get_context("save_graphs_path")
  98. self.input_signature = input_signature
  99. self.obj = None
  100. self.identify_obj = None
  101. if hasattr(obj, fn.__name__):
  102. self.obj = obj
  103. elif obj is not None:
  104. self.identify_obj = obj
  105. self._executor = Executor_.get_instance()
  106. def build_data_init_graph(self, graph_name):
  107. """Build GE data graph and init graph for the given graph name."""
  108. if self.obj is None:
  109. logger.warning("Make sure parameter should not be used in function")
  110. para_dict = OrderedDict()
  111. self._executor.build_data_graph(para_dict, graph_name)
  112. return
  113. self._executor.build_data_graph(self.obj.parameters_dict(), graph_name, self.obj.parameters_broadcast_dict())
  114. init_phase = "init_subgraph" + graph_name[graph_name.find("."):]
  115. _exec_init_graph(self.obj, init_phase)
  116. def compile(self, arguments_dict, method_name):
  117. """Returns pipeline for the given args."""
  118. args_list = tuple(arguments_dict.values())
  119. arg_names = tuple(arguments_dict.keys())
  120. # remove first self parameter when fn is a method
  121. if self.obj is not None:
  122. args_list = args_list[1:]
  123. arg_names = arg_names[1:]
  124. # verify the signature for both function and method
  125. if self.input_signature is not None:
  126. signatures = []
  127. for sig_spec in self.input_signature:
  128. if not isinstance(sig_spec, MetaTensor):
  129. raise TypeError("Input_signature is not MetaTensor")
  130. signatures.append(sig_spec)
  131. is_valid_input = verify_inputs_signature(signatures, args_list)
  132. if not is_valid_input:
  133. raise ValueError("Inputs is incompatible with input signature!")
  134. dic = dict(zip(arg_names, args_list))
  135. generate_name = self.fn.__module__ + "." + self.fn.__name__
  136. self.fn.__parse_method__ = method_name
  137. # replace key with obj info and object ext info when fn is a method
  138. if self.obj is not None:
  139. self.obj.__parse_method__ = method_name
  140. generate_name = self.obj.__module__ + "."
  141. if self.obj.__class__.__name__ != "ClipByNorm":
  142. generate_name = generate_name + str(self.obj.create_time)
  143. if self.identify_obj is not None:
  144. generate_name = generate_name + str(id(self.identify_obj))
  145. key = generate_key(generate_name, dic)
  146. phase = str(key[1]) + generate_name
  147. if key not in ms_compile_cache.keys():
  148. is_compile = False
  149. if self.obj is None:
  150. is_compile = self._executor.compile(self.fn, args_list, phase, True)
  151. else:
  152. is_compile = self._executor.compile(self.obj, args_list, phase, True)
  153. if not is_compile:
  154. raise RuntimeError("Executor compile failed.")
  155. if context.get_context("enable_ge"):
  156. self.build_data_init_graph(phase)
  157. # since function can be redefined, we only cache class method pipeline
  158. if self.obj is not None or self.identify_obj is not None:
  159. ms_compile_cache[key] = phase
  160. return phase
  161. return ms_compile_cache[key]
  162. @_wrap_func
  163. def __call__(self, *args):
  164. init_backend()
  165. converted, arguments_dict, parse_method = _convert_function_arguments(self.fn, *args)
  166. if not converted:
  167. raise RuntimeError('Process function parameter is failure')
  168. args_list = tuple(arguments_dict.values())
  169. if self.obj is not None:
  170. args_list = args_list[1:]
  171. phase = self.compile(arguments_dict, parse_method)
  172. if context.get_context("precompile_only"):
  173. return None
  174. return self._executor(args_list, phase)
  175. def ms_function(fn=None, obj=None, input_signature=None):
  176. """
  177. Create a callable MindSpore graph from a python function.
  178. This allows the MindSpore runtime to apply optimizations based on graph.
  179. Args:
  180. fn (Function): The Python function that will be run as a graph. Default: None.
  181. obj (Object): The Python Object that provides the information for identifying the compiled function.Default:
  182. None.
  183. input_signature (MetaTensor): The MetaTensor which describes the input arguments. The MetaTensor specifies
  184. the shape and dtype of the Tensor and they will be supplied to this function. If input_signature
  185. is specified, each input to `fn` must be a `Tensor`. And the input parameters of `fn` cannot accept
  186. `**kwargs`. The shape and dtype of actual inputs should keep the same as input_signature. Otherwise,
  187. TypeError will be raised. Default: None.
  188. Returns:
  189. Function, if `fn` is not None, returns a callable function that will execute the compiled function; If `fn` is
  190. None, returns a decorator and when this decorator invokes with a single `fn` argument, the callable function is
  191. equal to the case when `fn` is not None.
  192. Examples:
  193. >>> def tensor_add(x, y):
  194. >>> z = F.tensor_add(x, y)
  195. >>> return z
  196. >>>
  197. >>> @ms_function
  198. >>> def tensor_add_with_dec(x, y):
  199. >>> z = F.tensor_add(x, y)
  200. >>> return z
  201. >>>
  202. >>> @ms_function(input_signature=(MetaTensor(mindspore.float32, (1, 1, 3, 3)),
  203. >>> MetaTensor(mindspore.float32, (1, 1, 3, 3))))
  204. >>> def tensor_add_with_sig(x, y):
  205. >>> z = F.tensor_add(x, y)
  206. >>> return z
  207. >>>
  208. >>> x = Tensor(np.ones([1, 1, 3, 3]).astype(np.float32))
  209. >>> y = Tensor(np.ones([1, 1, 3, 3]).astype(np.float32))
  210. >>>
  211. >>> tensor_add_graph = ms_function(fn=tensor_add)
  212. >>> out = tensor_add_graph(x, y)
  213. >>> out = tensor_add_with_dec(x, y)
  214. >>> out = tensor_add_with_sig(x, y)
  215. """
  216. def wrap_mindspore(func):
  217. @wraps(func)
  218. def staging_specialize(*args):
  219. process_obj = obj
  220. if args and not isinstance(args[0], MsTensor) and hasattr(args[0], func.__name__):
  221. process_obj = args[0]
  222. return _MindSporeFunction(func, input_signature, process_obj)(*args)
  223. return staging_specialize
  224. if fn is not None:
  225. return wrap_mindspore(fn)
  226. return wrap_mindspore
  227. def _generate_pip_args(obj, *args, method="construct"):
  228. """Generate arguments for pipeline."""
  229. if hasattr(obj, method):
  230. fn = getattr(obj, method)
  231. else:
  232. raise AttributeError('The process method is not exist')
  233. converted, arguments_dict, parse_method = _convert_function_arguments(fn, *args)
  234. if not converted:
  235. raise RuntimeError('Process method parameter is failure')
  236. args_list = tuple(arguments_dict.values())
  237. args_names = tuple(arguments_dict.keys())
  238. obj.__parse_method__ = parse_method
  239. return args_names, args_list
  240. class _PynativeExecutor:
  241. """
  242. An pynative executor used to compile/manage/run graph.
  243. Returns:
  244. Graph, return the result of pipeline running.
  245. """
  246. def __init__(self):
  247. self._executor = PynativeExecutor_.get_instance()
  248. def new_graph(self, obj, *args, **kwargs):
  249. self._executor.new_graph(obj, *args, *(kwargs.values()))
  250. def end_graph(self, obj, output, *args, **kwargs):
  251. self._executor.end_graph(obj, output, *args, *(kwargs.values()))
  252. def grad(self, grad, obj, weights, *args, **kwargs):
  253. self._executor.grad_net(grad, obj, weights, *args, *(kwargs.values()))
  254. def clear(self, flag=""):
  255. self._executor.clear(flag)
  256. def set_grad_flag(self, flag):
  257. self._executor.set_grad_flag(flag)
  258. def __call__(self, *args, **kwargs):
  259. args = args + tuple(kwargs.values())
  260. return self._executor(args, "")
  261. class _Executor:
  262. """
  263. An executor used to compile/manage/run graph.
  264. Including data_graph, train_graph, eval_graph and predict graph.
  265. Returns:
  266. Graph, return the result of pipeline running.
  267. """
  268. def __init__(self):
  269. # create needed graph by lazy mode
  270. self.is_init = False
  271. self._executor = Executor_.get_instance()
  272. self.compile_cache = {}
  273. self.phase_prefix = ""
  274. def init_dataset(self, queue_name, dataset_size, batch_size, dataset_types, dataset_shapes,
  275. input_indexs, phase='dataset'):
  276. """
  277. Initialization interface for calling data subgraph.
  278. Args:
  279. queue_name (str): The name of tdt queue on the device.
  280. dataset_size (int): The size of dataset.
  281. batch_size (int): The size of batch.
  282. dataset_types (list): The output types of element in dataset.
  283. dataset_shapes (list): The output shapes of element in dataset.
  284. input_indexs (list): The index of data with net.
  285. phase (str): The name of phase, e.g., train_dataset/eval_dataset. Default: 'dataset'.
  286. Returns:
  287. bool, specifies whether the data subgraph was initialized successfully.
  288. """
  289. if not init_exec_dataset(queue_name=queue_name,
  290. size=dataset_size,
  291. batch_size=batch_size,
  292. types=dataset_types,
  293. shapes=dataset_shapes,
  294. input_indexs=input_indexs,
  295. phase=phase):
  296. raise RuntimeError("Failure to init and dataset subgraph!")
  297. return True
  298. def _build_data_graph(self, obj, phase):
  299. self._executor.build_data_graph(obj.parameters_dict(), phase, obj.parameters_broadcast_dict())
  300. def _set_dataset_mode(self, args_list):
  301. """set dataset mode."""
  302. # decide whether to sink based on whether the inputs is virtual or args_list is ()
  303. if (args_list and isinstance(args_list[0], Tensor) and args_list[0].virtual_flag) or \
  304. (args_list is not None and args_list == ()):
  305. _set_dataset_mode_config('sink')
  306. else:
  307. _set_dataset_mode_config('normal')
  308. def compile(self, obj, *args, phase='predict', do_convert=True, auto_parallel_mode=False):
  309. """
  310. Compiles graph.
  311. Args:
  312. obj (Function/Cell): The function or cell instance need compile.
  313. args (tuple): Function or cell input arguments.
  314. phase (str): The name of compile phase. Default: 'predict'.
  315. do_convert (bool): When set to True, convert ME graph to GE graph after compiling graph.
  316. auto_parallel_mode: When set to True, use auto parallel mode to compile graph.
  317. Return:
  318. Str, the full phase of the cell.
  319. Bool, if the graph has been compiled before, return False, else return True.
  320. """
  321. args_names, args_list = _generate_pip_args(obj, *args)
  322. dic = dict(zip(args_names, args_list))
  323. key = generate_key(phase, dic)
  324. self.phase_prefix = str(key[1])
  325. if 'export' in phase:
  326. phase = phase + '.' + self.phase_prefix + '.' + str(obj.create_time)
  327. else:
  328. phase = self.phase_prefix + phase + '.' + str(obj.create_time)
  329. if phase in self.compile_cache.keys():
  330. logger.debug("%r graph has existed.", phase)
  331. return phase, False
  332. obj.check_names()
  333. _check_full_batch()
  334. self._set_dataset_mode(args_list)
  335. is_sink_mode = args and isinstance(args[0], Tensor) and args[0].virtual_flag
  336. if auto_parallel_mode and _need_to_full() and not is_sink_mode and obj.auto_parallel_compile_and_run():
  337. args_full = _to_full_tensor(args, _get_device_num(), _get_global_rank())
  338. _, args_list = _generate_pip_args(obj, *args_full)
  339. enable_debug_runtime = context.get_context("enable_debug_runtime")
  340. enable_ge = context.get_context("enable_ge")
  341. use_vm = not enable_ge or (enable_debug_runtime and context.get_context("mode") == context.PYNATIVE_MODE)
  342. result = self._executor.compile(obj, args_list, phase, use_vm)
  343. self.compile_cache[phase] = phase
  344. if not result:
  345. raise RuntimeError("Executor compile failed.")
  346. graph = self._executor.get_func_graph(phase)
  347. if graph is None:
  348. logger.error("%r graph compile failed.", phase)
  349. if not do_convert:
  350. return phase, True
  351. if auto_parallel_mode:
  352. obj.parameter_layout_dict = self._executor.get_parameter_layout(phase)
  353. replace = obj.init_parameters_data(auto_parallel_mode=auto_parallel_mode)
  354. if not enable_debug_runtime or enable_ge:
  355. if auto_parallel_mode:
  356. obj.load_parameter_slice(None)
  357. self._updata_param_node_default_input(phase, replace)
  358. # set parallel inputs in sink mode
  359. if auto_parallel_mode and is_sink_mode:
  360. obj.set_parallel_input_with_inputs(*args)
  361. # the following GE init process is not needed when use vm or ms backend
  362. if enable_ge:
  363. self._build_data_graph(obj, phase)
  364. if "export" not in phase:
  365. init_phase = "init_subgraph" + "." + str(obj.create_time)
  366. _exec_init_graph(obj, init_phase)
  367. elif not enable_ge and "export" in phase:
  368. self._build_data_graph(obj, phase)
  369. return phase, True
  370. def _updata_param_node_default_input(self, phase, replace):
  371. new_param = {x.name: replace[x] for x in replace if id(x) != id(replace[x])}
  372. return self._executor.updata_param_node_default_input(phase, new_param)
  373. def _get_shard_strategy(self, obj):
  374. real_phase = self.phase_prefix + obj.phase + '.' + str(obj.create_time)
  375. return self._executor.get_strategy(real_phase)
  376. def _get_allreduce_fusion(self, obj):
  377. real_phase = self.phase_prefix + obj.phase + '.' + str(obj.create_time)
  378. return self._executor.get_allreduce_fusion(real_phase)
  379. def has_compiled(self, phase='predict'):
  380. """
  381. Specify whether have been compiled.
  382. Args:
  383. phase (str): The phase name. Default: 'predict'.
  384. Returns:
  385. bool, specifies whether the specific graph has been compiled.
  386. """
  387. return self._executor.has_compiled(phase)
  388. def __call__(self, obj, *args, phase='predict'):
  389. if context.get_context("precompile_only") or _is_role_pserver():
  390. return None
  391. return self.run(obj, *args, phase=phase)
  392. @_wrap_func
  393. def _exec_pip(self, obj, *args, phase=''):
  394. """Execute the generated pipeline."""
  395. fn = obj.construct
  396. converted, arguments_dict, parse_method = _convert_function_arguments(fn, *args)
  397. if not converted:
  398. raise RuntimeError('Process method parameter is failure')
  399. args_list = tuple(arguments_dict.values())
  400. obj.__parse_method__ = parse_method
  401. return self._executor(args_list, phase)
  402. def run(self, obj, *args, phase='predict'):
  403. """
  404. Run the specific graph.
  405. Args:
  406. phase (str): The phase name. Default: 'predict'.
  407. Returns:
  408. Tensor/Tuple, return execute result.
  409. """
  410. if phase == 'save':
  411. return self._executor((), phase + '.' + str(obj.create_time))
  412. phase_real = self.phase_prefix + phase + '.' + str(obj.create_time)
  413. if self.has_compiled(phase_real):
  414. return self._exec_pip(obj, *args, phase=phase_real)
  415. raise KeyError('{} graph is not exist.'.format(phase_real))
  416. def del_net_res(self, net_id):
  417. self._executor.del_net_res(net_id)
  418. def _get_func_graph_proto(self, exec_id, ir_type="onnx_ir", use_prefix=False):
  419. """Get graph proto from pipeline."""
  420. if use_prefix:
  421. exec_id = self.phase_prefix + exec_id
  422. if self._executor.has_compiled(exec_id) is False:
  423. return None
  424. return self._executor.get_func_graph_proto(exec_id, ir_type)
  425. def export(self, file_name, graph_id):
  426. """
  427. Export graph.
  428. Args:
  429. file_name (str): File name of model to export
  430. graph_id (str): id of graph to be exported
  431. """
  432. from .._c_expression import export_graph
  433. export_graph(file_name, 'AIR', graph_id)
  434. def fetch_info_for_quant_export(self, exec_id):
  435. """Get graph proto from pipeline."""
  436. if self._executor.has_compiled(exec_id) is False:
  437. return None
  438. return self._executor.fetch_info_for_quant_export(exec_id)
  439. _executor = _Executor()
  440. _pynative_exec = _PynativeExecutor()
  441. __all__ = ['ms_function']