You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

api.py 21 kB

5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540
  1. # This is the Python adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
  2. #
  3. # Copyright 2020 Huawei Technologies Co., Ltd
  4. #
  5. # Licensed under the Apache License, Version 2.0 (the "License");
  6. # you may not use this file except in compliance with the License.
  7. # You may obtain a copy of the License at
  8. #
  9. # http://www.apache.org/licenses/LICENSE-2.0
  10. #
  11. # Unless required by applicable law or agreed to in writing, software
  12. # distributed under the License is distributed on an "AS IS" BASIS,
  13. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. # See the License for the specific language governing permissions and
  15. # limitations under the License.
  16. # ============================================================================
  17. """Providing interface methods."""
  18. import types
  19. from collections import OrderedDict
  20. from functools import wraps
  21. from mindspore import context
  22. from mindspore import log as logger
  23. from .._c_expression import generate_key, Executor_, Tensor, MetaTensor, PynativeExecutor_
  24. from .._c_expression import verify_inputs_signature, init_exec_dataset, _set_dataset_mode_config, init_backend
  25. from .tensor import Tensor as MsTensor
  26. # store ms_function class compiled pipeline cache
  27. ms_compile_cache = {}
  28. def _convert_function_arguments(fn, *args):
  29. """
  30. Process the fn default parameters.
  31. Args:
  32. fn (Function): The function to be parsed.
  33. args (tuple): The parameters of the function.
  34. """
  35. arguments_dict = OrderedDict()
  36. parse_method = None
  37. if isinstance(fn, (types.FunctionType, types.MethodType)):
  38. parse_method = fn.__name__
  39. index = 0
  40. for value in args:
  41. arguments_dict[f'arg{index}'] = value
  42. index = index + 1
  43. logger.debug("fn(%r) full parameters dict is: %r", fn, arguments_dict)
  44. converted = True
  45. else:
  46. logger.warning("Find error: fn isn't function or method")
  47. converted = False
  48. return converted, arguments_dict, parse_method
  49. def _wrap_func(fn):
  50. """
  51. Wrapper function, convert return data to tensor or tuple of tensor.
  52. Args:
  53. fn (Function): The function need be wrapped.
  54. Returns:
  55. Function, a new function with return suitable format data.
  56. """
  57. @wraps(fn)
  58. def wrapper(*arg, **kwargs):
  59. results = fn(*arg, **kwargs)
  60. def _convert_data(data):
  61. if isinstance(data, Tensor) and not isinstance(data, MsTensor):
  62. return MsTensor(data)
  63. if isinstance(data, tuple):
  64. return tuple(_convert_data(x) for x in data)
  65. if isinstance(data, list):
  66. return list(_convert_data(x) for x in data)
  67. return data
  68. return _convert_data(results)
  69. return wrapper
  70. def _exec_init_graph(obj, init_phase):
  71. """Execute the parameter initializer graph."""
  72. inst_executor = Executor_.get_instance()
  73. param_dict = OrderedDict()
  74. for name, param in obj.parameters_dict().items():
  75. if not param.is_init:
  76. param_dict[name] = param
  77. param.is_init = True
  78. param.data.init_flag = True
  79. if param_dict:
  80. inst_executor.run_init_graph(param_dict, init_phase)
  81. class _MindSporeFunction:
  82. """
  83. Represents a function compiled by mind expression.
  84. _MindSporeFunction will compile the original function for every combination
  85. of argument types and shapes it is given (as well as their values, optionally).
  86. Args:
  87. fn (Function): The root function to compile.
  88. input_signature (Function): User defines signature to verify input.
  89. obj (Object): If function is a method, obj is the owner of function,
  90. else, obj is none.
  91. """
  92. def __init__(self, fn, input_signature=None, obj=None):
  93. self.fn = fn
  94. self.save_graphs = context.get_context("save_graphs")
  95. self.save_graphs_path = context.get_context("save_graphs_path")
  96. self.input_signature = input_signature
  97. self.obj = None
  98. self.identify_obj = None
  99. if hasattr(obj, fn.__name__):
  100. self.obj = obj
  101. elif obj is not None:
  102. self.identify_obj = obj
  103. self._executor = Executor_.get_instance()
  104. def build_data_init_graph(self, graph_name):
  105. """Build GE data graph and init graph for the given graph name."""
  106. if self.obj is None:
  107. logger.warning("Make sure parameter should not be used in function")
  108. para_dict = OrderedDict()
  109. self._executor.build_data_graph(para_dict, graph_name)
  110. return
  111. self._executor.build_data_graph(self.obj.parameters_dict(), graph_name, self.obj.parameters_broadcast_dict())
  112. init_phase = "init_subgraph" + graph_name[graph_name.find("."):]
  113. _exec_init_graph(self.obj, init_phase)
  114. def compile(self, arguments_dict, method_name):
  115. """Returns pipline for the given args."""
  116. args_list = tuple(arguments_dict.values())
  117. arg_names = tuple(arguments_dict.keys())
  118. # remove first self parameter when fn is a method
  119. if self.obj is not None:
  120. args_list = args_list[1:]
  121. arg_names = arg_names[1:]
  122. # verify the signature for both function and method
  123. if self.input_signature is not None:
  124. signatures = []
  125. for sig_spec in self.input_signature:
  126. if not isinstance(sig_spec, MetaTensor):
  127. raise TypeError("Input_signature is not MetaTensor")
  128. signatures.append(sig_spec)
  129. is_valid_input = verify_inputs_signature(signatures, args_list)
  130. if not is_valid_input:
  131. raise ValueError("Inputs is incompatible with input signature!")
  132. dic = dict(zip(arg_names, args_list))
  133. generate_name = self.fn.__module__ + "." + self.fn.__name__
  134. self.fn.__parse_method__ = method_name
  135. # replace key with obj info and object ext info when fn is a method
  136. if self.obj is not None:
  137. self.obj.__parse_method__ = method_name
  138. generate_name = self.obj.__module__ + "."
  139. if self.obj.__class__.__name__ != "ClipByNorm":
  140. generate_name = generate_name + str(self.obj.create_time)
  141. if self.identify_obj is not None:
  142. generate_name = generate_name + str(id(self.identify_obj))
  143. key = generate_key(generate_name, dic)
  144. phase = str(key[1]) + generate_name
  145. if key not in ms_compile_cache.keys():
  146. is_compile = False
  147. if self.obj is None:
  148. is_compile = self._executor.compile(self.fn, args_list, phase, True)
  149. else:
  150. is_compile = self._executor.compile(self.obj, args_list, phase, True)
  151. if not is_compile:
  152. raise RuntimeError("Executor compile failed.")
  153. if context.get_context("enable_ge"):
  154. self.build_data_init_graph(phase)
  155. # since function can be redefined, we only cache class method pipeline
  156. if self.obj is not None or self.identify_obj is not None:
  157. ms_compile_cache[key] = phase
  158. return phase
  159. return ms_compile_cache[key]
  160. @_wrap_func
  161. def __call__(self, *args):
  162. init_backend()
  163. converted, arguments_dict, parse_method = _convert_function_arguments(self.fn, *args)
  164. if not converted:
  165. raise RuntimeError('Process function parameter is failure')
  166. args_list = tuple(arguments_dict.values())
  167. if self.obj is not None:
  168. args_list = args_list[1:]
  169. phase = self.compile(arguments_dict, parse_method)
  170. if context.get_context("precompile_only"):
  171. return None
  172. return self._executor(args_list, phase)
  173. def ms_function(fn=None, obj=None, input_signature=None):
  174. """
  175. Creates a callable MindSpore graph from a python function.
  176. This allows the MindSpore runtime to apply optimizations based on graph.
  177. Args:
  178. fn (Function): The Python function that will be run as a graph. Default: None.
  179. obj (Object): The Python Object that provide information for identify compiled function. Default: None.
  180. input_signature (MetaTensor): The MetaTensor to describe the input arguments. The MetaTensor specifies
  181. the shape and dtype of the Tensor and they will be supplied to this function. If input_signature
  182. is specified, every input to `fn` must be a `Tensor`. And the input parameters of `fn` cannot accept
  183. `**kwargs`. The shape and dtype of actual inputs should keep same with input_signature, or TypeError
  184. will be raised. Default: None.
  185. Returns:
  186. Function, if `fn` is not None, returns a callable that will execute the compiled function; If `fn` is None,
  187. returns a decorator and when this decorator invokes with a single `fn` argument, the callable is equal to the
  188. case when `fn` is not None.
  189. Examples:
  190. >>> def tensor_add(x, y):
  191. >>> z = F.tensor_add(x, y)
  192. >>> return z
  193. >>>
  194. >>> @ms_function
  195. >>> def tensor_add_with_dec(x, y):
  196. >>> z = F.tensor_add(x, y)
  197. >>> return z
  198. >>>
  199. >>> @ms_function(input_signature=(MetaTensor(mindspore.float32, (1, 1, 3, 3)),
  200. >>> MetaTensor(mindspore.float32, (1, 1, 3, 3))))
  201. >>> def tensor_add_with_sig(x, y):
  202. >>> z = F.tensor_add(x, y)
  203. >>> return z
  204. >>>
  205. >>> x = Tensor(np.ones([1, 1, 3, 3]).astype(np.float32))
  206. >>> y = Tensor(np.ones([1, 1, 3, 3]).astype(np.float32))
  207. >>>
  208. >>> tensor_add_graph = ms_function(fn=tensor_add)
  209. >>> out = tensor_add_graph(x, y)
  210. >>> out = tensor_add_with_dec(x, y)
  211. >>> out = tensor_add_with_sig(x, y)
  212. """
  213. def wrap_mindspore(func):
  214. @wraps(func)
  215. def staging_specialize(*args):
  216. process_obj = obj
  217. if args and not isinstance(args[0], MsTensor) and hasattr(args[0], func.__name__):
  218. process_obj = args[0]
  219. args = (x.default_input if hasattr(x, 'default_input') else x for x in args)
  220. return _MindSporeFunction(func, input_signature, process_obj)(*args)
  221. return staging_specialize
  222. if fn is not None:
  223. return wrap_mindspore(fn)
  224. return wrap_mindspore
  225. def _generate_pip_args(obj, *args, method="construct"):
  226. """Generate arguments for pipeline."""
  227. if hasattr(obj, method):
  228. fn = getattr(obj, method)
  229. else:
  230. raise AttributeError('The process method is not exist')
  231. converted, arguments_dict, parse_method = _convert_function_arguments(fn, *args)
  232. if not converted:
  233. raise RuntimeError('Process method parameter is failure')
  234. args_list = tuple(arguments_dict.values())
  235. args_names = tuple(arguments_dict.keys())
  236. obj.__parse_method__ = parse_method
  237. return args_names, args_list
  238. class _PynativeExecutor:
  239. """
  240. An pynative executor used to compile/manage/run graph.
  241. Returns:
  242. Graph, return the result of pipeline running.
  243. """
  244. def __init__(self):
  245. self._executor = PynativeExecutor_.get_instance()
  246. def new_graph(self, obj, *args):
  247. self._executor.new_graph(obj, *args)
  248. def end_graph(self, obj, output, *args):
  249. self._executor.end_graph(obj, output, *args)
  250. def grad(self, grad, obj, weights, *args):
  251. self._executor.grad_net(grad, obj, weights, *args)
  252. def clear(self, flag=""):
  253. self._executor.clear(flag)
  254. def set_grad_flag(self, flag):
  255. self._executor.set_grad_flag(flag)
  256. def __call__(self, *args):
  257. return self._executor(args, "")
  258. class _Executor:
  259. """
  260. An executor used to compile/manage/run graph.
  261. Including data_graph, train_graph, eval_graph and predict graph.
  262. Returns:
  263. Graph, return the result of pipeline running.
  264. """
  265. def __init__(self):
  266. # create needed graph by lazy mode
  267. self.is_init = False
  268. self._executor = Executor_.get_instance()
  269. self.compile_cache = {}
  270. self.phase_prefix = ""
  271. def init_dataset(self, queue_name, dataset_size, batch_size, dataset_types, dataset_shapes,
  272. input_indexs, phase='dataset'):
  273. """
  274. Initialization interface for calling data subgraph.
  275. Args:
  276. queue_name (str): The name of tdt queue on the device.
  277. dataset_size (int): The size of dataset.
  278. batch_size (int): The size of batch.
  279. dataset_types (list): The output types of element in dataset.
  280. dataset_shapes (list): The output shapes of element in dataset.
  281. input_indexs (list): The index of data with net.
  282. phase (str): The name of phase, e.g., train_dataset/eval_dataset. Default: 'dataset'.
  283. Returns:
  284. bool, specifies whether the data subgraph was initialized successfully.
  285. """
  286. if not init_exec_dataset(queue_name=queue_name,
  287. size=dataset_size,
  288. batch_size=batch_size,
  289. types=dataset_types,
  290. shapes=dataset_shapes,
  291. input_indexs=input_indexs,
  292. phase=phase):
  293. raise RuntimeError("Failure to init and dataset subgraph!")
  294. return True
  295. def _build_data_graph(self, obj, params, phase):
  296. if params is None:
  297. self._executor.build_data_graph(obj.parameters_dict(), phase, obj.parameters_broadcast_dict())
  298. elif isinstance(params, OrderedDict):
  299. self._executor.build_data_graph(params, phase)
  300. else:
  301. raise TypeError('Parameters need OrderedDict type, but got {}'.
  302. format(type(params)))
  303. def _params_init_data(self, obj, params, auto_parallel_mode=False):
  304. """Init parameters' data."""
  305. if params is not None:
  306. for key, param in params.items():
  307. if not auto_parallel_mode:
  308. param.init_data()
  309. elif key not in obj.parameter_layout_dict:
  310. logger.debug("Layout dict does not contain the key %s.", key)
  311. param.init_data(set_sliced=True)
  312. else:
  313. layout = obj.parameter_layout_dict[key]
  314. param.init_data(layout, set_sliced=True)
  315. obj.init_parameters_data(auto_parallel_mode=auto_parallel_mode)
  316. def _set_dataset_mode(self, args_list):
  317. """set dataset mode."""
  318. # decide whether to sink based on whether the inputs is virtual or args_list is ()
  319. if (args_list and isinstance(args_list[0], Tensor) and args_list[0].virtual_flag) or \
  320. (args_list is not None and args_list == ()):
  321. _set_dataset_mode_config('sink')
  322. else:
  323. _set_dataset_mode_config('normal')
  324. def compile(self, obj, *args, phase='predict', params=None, do_convert=True, auto_parallel_mode=False):
  325. """
  326. Compiles graph.
  327. Args:
  328. obj (Function/Cell): The function or cell instance need compile.
  329. args (tuple): Function or cell input arguments.
  330. phase (str): The name of compile phase. Default: 'predict'.
  331. params (OrderedDict): The parameters dictionary used for init data graph. Default: None.
  332. do_convert (bool): When set to True, convert ME graph to GE graph after compiling graph.
  333. auto_parallel_mode: When set to True, use auto parallel mode to compile graph.
  334. Return:
  335. Str, the full phase of the cell.
  336. Bool, if the graph has been compiled before, return False, else return True.
  337. """
  338. obj.check_names()
  339. args_names, args_list = _generate_pip_args(obj, *args)
  340. dic = dict(zip(args_names, args_list))
  341. key = generate_key(phase, dic)
  342. self.phase_prefix = str(key[1])
  343. if phase == 'export':
  344. phase = phase + '.' + self.phase_prefix + '.' + str(obj.create_time)
  345. else:
  346. phase = self.phase_prefix + phase + '.' + str(obj.create_time)
  347. enable_debug_runtime = context.get_context("enable_debug_runtime")
  348. enable_ge = context.get_context("enable_ge")
  349. use_vm = not enable_ge or (enable_debug_runtime and context.get_context("mode") == context.PYNATIVE_MODE)
  350. self._set_dataset_mode(args_list)
  351. if phase in self.compile_cache.keys():
  352. logger.debug("%r graph has existed.", phase)
  353. return phase, False
  354. result = self._executor.compile(obj, args_list, phase, use_vm)
  355. self.compile_cache[phase] = phase
  356. if not result:
  357. raise RuntimeError("Executor compile failed.")
  358. graph = self._executor.get_func_graph(phase)
  359. if graph is None:
  360. logger.error("%r graph compile failed.", phase)
  361. if not do_convert:
  362. return phase, True
  363. if auto_parallel_mode:
  364. obj.parameter_layout_dict = self._executor.get_parameter_layout(phase)
  365. self._params_init_data(obj, params, auto_parallel_mode)
  366. if not enable_debug_runtime or enable_ge:
  367. if auto_parallel_mode:
  368. obj.load_parameter_slice(params)
  369. # set parallel inputs in sink mode
  370. if auto_parallel_mode and (args and isinstance(args[0], Tensor) and args[0].virtual_flag):
  371. obj.set_parallel_input_with_inputs(*args)
  372. # the following GE init process is not needed when use vm or ms backend
  373. if enable_ge:
  374. self._build_data_graph(obj, params, phase)
  375. if "export" not in phase:
  376. init_phase = "init_subgraph" + "." + str(obj.create_time)
  377. _exec_init_graph(obj, init_phase)
  378. elif not enable_ge and "export" in phase:
  379. self._build_data_graph(obj, params, phase)
  380. return phase, True
  381. def _get_strategy(self, obj):
  382. real_phase = self.phase_prefix + obj.phase + '.' + str(obj.create_time)
  383. return self._executor.get_strategy(real_phase)
  384. def _get_allreduce_fusion(self, obj):
  385. real_phase = self.phase_prefix + obj.phase + '.' + str(obj.create_time)
  386. return self._executor.get_allreduce_fusion(real_phase)
  387. def has_compiled(self, phase='predict'):
  388. """
  389. Specify whether have been compiled.
  390. Args:
  391. phase (str): The phase name. Default: 'predict'.
  392. Returns:
  393. bool, specifies whether the specific graph has been compiled.
  394. """
  395. return self._executor.has_compiled(phase)
  396. def __call__(self, obj, *args, phase='predict'):
  397. if context.get_context("precompile_only"):
  398. return None
  399. return self.run(obj, *args, phase=phase)
  400. @_wrap_func
  401. def _exec_pip(self, obj, *args, phase=''):
  402. """Execute the generated pipeline."""
  403. fn = obj.construct
  404. converted, arguments_dict, parse_method = _convert_function_arguments(fn, *args)
  405. if not converted:
  406. raise RuntimeError('Process method parameter is failure')
  407. args_list = tuple(arguments_dict.values())
  408. obj.__parse_method__ = parse_method
  409. return self._executor(args_list, phase)
  410. def run(self, obj, *args, phase='predict'):
  411. """
  412. Run the specific graph.
  413. Args:
  414. phase (str): The phase name. Default: 'predict'.
  415. Returns:
  416. Tensor/Tuple, return execute result.
  417. """
  418. if phase == 'save':
  419. return self._executor((), phase + '.' + str(obj.create_time))
  420. phase_real = self.phase_prefix + phase + '.' + str(obj.create_time)
  421. if self.has_compiled(phase_real):
  422. return self._exec_pip(obj, *args, phase=phase_real)
  423. raise KeyError('{} graph is not exist.'.format(phase_real))
  424. def del_net_res(self, net_id):
  425. self._executor.del_net_res(net_id)
  426. def _get_func_graph_proto(self, exec_id, ir_type="onnx_ir", use_prefix=False):
  427. """Get graph proto from pipeline."""
  428. if use_prefix:
  429. exec_id = self.phase_prefix + exec_id
  430. if self._executor.has_compiled(exec_id) is False:
  431. return None
  432. return self._executor.get_func_graph_proto(exec_id, ir_type)
  433. def export(self, net, file_name, file_format='GEIR'):
  434. """
  435. Export graph.
  436. Args:
  437. net (Cell): MindSpore network
  438. file_name (str): File name of model to export
  439. file_format (str): MindSpore currently support 'GEIR' and 'ONNX' format for exported model
  440. """
  441. from .._c_expression import export_graph
  442. phase = 'export' + '.' + self.phase_prefix + '.' + str(net.create_time)
  443. export_graph(file_name, file_format, phase)
  444. def fetch_info_for_quant_export(self, exec_id):
  445. """Get graph proto from pipeline."""
  446. if self._executor.has_compiled(exec_id) is False:
  447. return None
  448. return self._executor.fetch_info_for_quant_export(exec_id)
  449. _executor = _Executor()
  450. _pynative_exec = _PynativeExecutor()
  451. __all__ = ['ms_function']