You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

base.py 34 kB

4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
5 years ago
5 years ago
4 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799
  1. # This is the Python adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
  2. #
  3. # Copyright 2020-2021 Huawei Technologies Co., Ltd
  4. #
  5. # Licensed under the Apache License, Version 2.0 (the "License");
  6. # you may not use this file except in compliance with the License.
  7. # You may obtain a copy of the License at
  8. #
  9. # http://www.apache.org/licenses/LICENSE-2.0
  10. #
  11. # Unless required by applicable law or agreed to in writing, software
  12. # distributed under the License is distributed on an "AS IS" BASIS,
  13. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. # See the License for the specific language governing permissions and
  15. # limitations under the License.
  16. # ============================================================================
  17. """Basic composite operations."""
  18. from functools import partial
  19. from types import FunctionType
  20. from mindspore import context
  21. from ..._c_expression import EnvInstance_, GradOperation_, HyperMap_, Map_, MultitypeFuncGraph_, Tail_, \
  22. TupleAdd_, TupleSlice_, UnpackCall_, ZipOperation_, ListAppend_, TupleGetItemTensor_
  23. from ...common import dtype as mstype
  24. from ...common.api import ms_function, _pynative_executor, _wrap_func
  25. from ..primitive import Primitive
  26. from ..operations import _grad_ops
  27. from .. import operations as P
  28. from .. import signature as sig
  29. __all__ = [EnvInstance_, TupleAdd_, TupleSlice_, UnpackCall_, TupleGetItemTensor_]
  30. def add_flags(fn=None, **flags):
  31. """
  32. A decorator that adds a flag to the function.
  33. Note:
  34. Only supports bool value.
  35. Args:
  36. fn (Function): Function or cell to add flag. Default: None.
  37. flags (dict): Flags use kwargs. Default: None.
  38. Returns:
  39. Function, the function with added flags.
  40. Examples:
  41. >>> net = Net();
  42. >>> net = add_flags(net, predit=True)
  43. >>> print(hasattr(net, '_mindspore_flags'))
  44. True
  45. """
  46. def deco(fn):
  47. # need set the attr and access on c++
  48. if not hasattr(fn, "_mindspore_flags"):
  49. fn._mindspore_flags = {}
  50. fn._mindspore_flags.update({**flags})
  51. return fn
  52. ret = deco
  53. if fn is not None:
  54. ret = deco(fn)
  55. return ret
  56. def core(fn=None, **flags):
  57. """
  58. A decorator that adds a flag to the function.
  59. By default, the function is marked as True, enabling to use this decorator to
  60. set flag to a graph.
  61. Args:
  62. fn (Function): Function to add flag. Default: None.
  63. flags (dict): The following flags can be set core, which indicates that this is a core function or
  64. other flag. Default: None.
  65. Supported Platforms:
  66. ``Ascend`` ``GPU`` ``CPU``
  67. Examples:
  68. >>> net = Net()
  69. >>> net = core(net, predit=True)
  70. >>> print(hasattr(net, '_mindspore_flags'))
  71. True
  72. """
  73. # need set the attr and access on c++
  74. def deco(fn):
  75. fn._mindspore_flags = {
  76. 'core': True,
  77. **flags,
  78. }
  79. return fn
  80. if fn is not None:
  81. ret = deco(fn)
  82. else:
  83. ret = deco
  84. return ret
  85. class GradOperation(GradOperation_):
  86. """
  87. A higher-order function which is used to generate the gradient function for the input function.
  88. The gradient function generated by `GradOperation` higher-order function can be customized by
  89. construction arguments.
  90. Given an input function `net = Net()` that takes `x` and `y` as inputs, and has a parameter `z`,
  91. see `Net` in Examples.
  92. To generate a gradient function that returns gradients with respect to the first input
  93. (see `GradNetWrtX` in Examples).
  94. 1. Construct a `GradOperation` higher-order function with default arguments:
  95. `grad_op = GradOperation()`.
  96. 2. Call it with input function as argument to get the gradient function: `gradient_function = grad_op(net)`.
  97. 3. Call the gradient function with input function's inputs to get the gradients with respect to the first input:
  98. `grad_op(net)(x, y)`.
  99. To generate a gradient function that returns gradients with respect to all inputs (see `GradNetWrtXY` in Examples).
  100. 1. Construct a `GradOperation` higher-order function with `get_all=True` which
  101. indicates getting gradients with respect to all inputs, they are `x` and `y` in example function `Net()`:
  102. `grad_op = GradOperation(get_all=True)`.
  103. 2. Call it with input function as argument to get the gradient function: `gradient_function = grad_op(net)`.
  104. 3. Call the gradient function with input function's inputs to get the gradients with respect to all inputs:
  105. `gradient_function(x, y)`.
  106. To generate a gradient function that returns gradients with respect to given parameters
  107. (see `GradNetWithWrtParams` in Examples).
  108. 1. Construct a `GradOperation` higher-order function with `get_by_list=True`:
  109. `grad_op = GradOperation(get_by_list=True)`.
  110. 2. Construct a `ParameterTuple` that will be passed to the input function when constructing
  111. `GradOperation` higher-order function, it will be used as a parameter filter that determine
  112. which gradient to return: `params = ParameterTuple(net.trainable_params())`.
  113. 3. Call it with input function and `params` as arguments to get the gradient function:
  114. `gradient_function = grad_op(net, params)`.
  115. 4. Call the gradient function with input function's inputs to get the gradients with
  116. respect to given parameters: `gradient_function(x, y)`.
  117. To generate a gradient function that returns gradients with respect to all inputs and given parameters
  118. in the format of ((dx, dy), (dz))(see `GradNetWrtInputsAndParams` in Examples).
  119. 1. Construct a `GradOperation` higher-order function with `get_all=True` and `get_by_list=True`:
  120. `grad_op = GradOperation(get_all=True, get_by_list=True)`.
  121. 2. Construct a `ParameterTuple` that will be passed along input function when constructing
  122. `GradOperation` higher-order function: `params = ParameterTuple(net.trainable_params())`.
  123. 3. Call it with input function and `params` as arguments to get the gradient function:
  124. `gradient_function = grad_op(net, params)`.
  125. 4. Call the gradient function with input function's inputs
  126. to get the gradients with respect to all inputs and given parameters: `gradient_function(x, y)`.
  127. We can configure the sensitivity(gradient with respect to output) by setting `sens_param` as True and
  128. passing an extra sensitivity input to the gradient function, the sensitivity input should has the
  129. same shape and type with input function's output(see `GradNetWrtXYWithSensParam` in Examples).
  130. 1. Construct a `GradOperation` higher-order function with `get_all=True` and `sens_param=True`:
  131. `grad_op = GradOperation(get_all=True, sens_param=True)`.
  132. 2. Define `grad_wrt_output` as `sens_param` which works as the gradient with respect to output:
  133. `grad_wrt_output = Tensor(np.ones([2, 2]).astype(np.float32))`.
  134. 3. Call it with input function as argument to get the gradient function:
  135. `gradient_function = grad_op(net)`.
  136. 4. Call the gradient function with input function's inputs and `sens_param` to
  137. get the gradients with respect to all inputs:
  138. `gradient_function(x, y, grad_wrt_output)`.
  139. Args:
  140. get_all (bool): If True, get all the gradients with respect to inputs. Default: False.
  141. get_by_list (bool): If True, get all the gradients with respect to Parameter variables.
  142. If get_all and get_by_list are both False, get the gradient with respect to first input.
  143. If get_all and get_by_list are both True, get the gradients with respect to inputs and Parameter variables
  144. at the same time in the form of ((gradients with respect to inputs),
  145. (gradients with respect to parameters)). Default: False.
  146. sens_param (bool): Whether to append sensitivity (gradient with respect to output) as input.
  147. If sens_param is False, a 'ones_like(outputs)' sensitivity will be attached automatically.
  148. Default: False.
  149. If the sensor_param is True, a sensitivity (gradient with respect to output) needs to be transferred
  150. through the location parameter or key-value pair parameter. If the value is transferred through
  151. the key-value pair parameter, the key must be sens.
  152. Returns:
  153. The higher-order function which takes a function as argument and returns gradient function for it.
  154. Raises:
  155. TypeError: If `get_all`, `get_by_list` or `sens_param` is not a bool.
  156. Supported Platforms:
  157. ``Ascend`` ``GPU`` ``CPU``
  158. Examples:
  159. >>> from mindspore import ParameterTuple
  160. >>> from mindspore.ops.composite import GradOperation
  161. >>> class Net(nn.Cell):
  162. ... def __init__(self):
  163. ... super(Net, self).__init__()
  164. ... self.matmul = P.MatMul()
  165. ... self.z = Parameter(Tensor(np.array([1.0], np.float32)), name='z')
  166. ... def construct(self, x, y):
  167. ... x = x * self.z
  168. ... out = self.matmul(x, y)
  169. ... return out
  170. ...
  171. >>> class GradNetWrtX(nn.Cell):
  172. ... def __init__(self, net):
  173. ... super(GradNetWrtX, self).__init__()
  174. ... self.net = net
  175. ... self.grad_op = GradOperation()
  176. ... def construct(self, x, y):
  177. ... gradient_function = self.grad_op(self.net)
  178. ... return gradient_function(x, y)
  179. ...
  180. >>> x = Tensor([[0.5, 0.6, 0.4], [1.2, 1.3, 1.1]], dtype=mstype.float32)
  181. >>> y = Tensor([[0.01, 0.3, 1.1], [0.1, 0.2, 1.3], [2.1, 1.2, 3.3]], dtype=mstype.float32)
  182. >>> output = GradNetWrtX(Net())(x, y)
  183. >>> print(output)
  184. [[1.4100001 1.5999999 6.6 ]
  185. [1.4100001 1.5999999 6.6 ]]
  186. >>>
  187. >>> class GradNetWrtXY(nn.Cell):
  188. ... def __init__(self, net):
  189. ... super(GradNetWrtXY, self).__init__()
  190. ... self.net = net
  191. ... self.grad_op = GradOperation(get_all=True)
  192. ... def construct(self, x, y):
  193. ... gradient_function = self.grad_op(self.net)
  194. ... return gradient_function(x, y)
  195. >>>
  196. >>> x = Tensor([[0.8, 0.6, 0.2], [1.8, 1.3, 1.1]], dtype=mstype.float32)
  197. >>> y = Tensor([[0.1, 3.3, 1.1], [1.1, 0.2, 1.4], [1.1, 2.2, 0.3]], dtype=mstype.float32)
  198. >>> output = GradNetWrtXY(Net())(x, y)
  199. >>> print(output)
  200. (Tensor(shape=[2, 3], dtype=Float32, value=
  201. [[ 4.50000000e+00, 2.70000005e+00, 3.60000014e+00],
  202. [ 4.50000000e+00, 2.70000005e+00, 3.60000014e+00]]), Tensor(shape=[3, 3], dtype=Float32, value=
  203. [[ 2.59999990e+00, 2.59999990e+00, 2.59999990e+00],
  204. [ 1.89999998e+00, 1.89999998e+00, 1.89999998e+00],
  205. [ 1.30000007e+00, 1.30000007e+00, 1.30000007e+00]]))
  206. >>>
  207. >>> class GradNetWrtXYWithSensParam(nn.Cell):
  208. ... def __init__(self, net):
  209. ... super(GradNetWrtXYWithSensParam, self).__init__()
  210. ... self.net = net
  211. ... self.grad_op = GradOperation(get_all=True, sens_param=True)
  212. ... self.grad_wrt_output = Tensor([[0.1, 0.6, 0.2], [0.8, 1.3, 1.1]], dtype=mstype.float32)
  213. ... def construct(self, x, y):
  214. ... gradient_function = self.grad_op(self.net)
  215. ... return gradient_function(x, y, self.grad_wrt_output)
  216. >>>
  217. >>> x = Tensor([[0.8, 0.6, 0.2], [1.8, 1.3, 1.1]], dtype=mstype.float32)
  218. >>> y = Tensor([[0.11, 3.3, 1.1], [1.1, 0.2, 1.4], [1.1, 2.2, 0.3]], dtype=mstype.float32)
  219. >>> output = GradNetWrtXYWithSensParam(Net())(x, y)
  220. >>> print(output)
  221. (Tensor(shape=[2, 3], dtype=Float32, value=
  222. [[ 2.21099997e+00, 5.09999990e-01, 1.49000001e+00],
  223. [ 5.58800030e+00, 2.68000007e+00, 4.07000017e+00]]), Tensor(shape=[3, 3], dtype=Float32, value=
  224. [[ 1.51999998e+00, 2.81999993e+00, 2.14000010e+00],
  225. [ 1.09999990e+00, 2.04999995e+00, 1.54999995e+00],
  226. [ 9.00000036e-01, 1.54999995e+00, 1.25000000e+00]]))
  227. >>>
  228. >>> class GradNetWithWrtParams(nn.Cell):
  229. ... def __init__(self, net):
  230. ... super(GradNetWithWrtParams, self).__init__()
  231. ... self.net = net
  232. ... self.params = ParameterTuple(net.trainable_params())
  233. ... self.grad_op = GradOperation(get_by_list=True)
  234. ... def construct(self, x, y):
  235. ... gradient_function = self.grad_op(self.net, self.params)
  236. ... return gradient_function(x, y)
  237. >>>
  238. >>> x = Tensor([[0.8, 0.6, 0.2], [1.8, 1.3, 1.1]], dtype=mstype.float32)
  239. >>> y = Tensor([[0.11, 3.3, 1.1], [1.1, 0.2, 1.4], [1.1, 2.2, 0.3]], dtype=mstype.float32)
  240. >>> output = GradNetWithWrtParams(Net())(x, y)
  241. >>> print(output)
  242. (Tensor(shape=[1], dtype=Float32, value= [ 2.15359993e+01]),)
  243. >>>
  244. >>> class GradNetWrtInputsAndParams(nn.Cell):
  245. ... def __init__(self, net):
  246. ... super(GradNetWrtInputsAndParams, self).__init__()
  247. ... self.net = net
  248. ... self.params = ParameterTuple(net.trainable_params())
  249. ... self.grad_op = GradOperation(get_all=True, get_by_list=True)
  250. ... def construct(self, x, y):
  251. ... gradient_function = self.grad_op(self.net, self.params)
  252. ... return gradient_function(x, y)
  253. >>>
  254. >>> x = Tensor([[0.1, 0.6, 1.2], [0.5, 1.3, 0.1]], dtype=mstype.float32)
  255. >>> y = Tensor([[0.12, 2.3, 1.1], [1.3, 0.2, 2.4], [0.1, 2.2, 0.3]], dtype=mstype.float32)
  256. >>> output = GradNetWrtInputsAndParams(Net())(x, y)
  257. >>> print(output)
  258. ((Tensor(shape=[2, 3], dtype=Float32, value=
  259. [[ 3.51999998e+00, 3.90000010e+00, 2.59999990e+00],
  260. [ 3.51999998e+00, 3.90000010e+00, 2.59999990e+00]]), Tensor(shape=[3, 3], dtype=Float32, value=
  261. [[ 6.00000024e-01, 6.00000024e-01, 6.00000024e-01],
  262. [ 1.89999998e+00, 1.89999998e+00, 1.89999998e+00],
  263. [ 1.30000007e+00, 1.30000007e+00, 1.30000007e+00]])), (Tensor(shape=[1], dtype=Float32, value=
  264. [ 1.29020004e+01]),))
  265. """
  266. def __init__(self, get_all=False, get_by_list=False, sens_param=False):
  267. """Initialize GradOperation."""
  268. if not isinstance(get_all, bool):
  269. raise TypeError(f"For 'GradOperation', the 'get_all' should be bool, but got {type(get_all).__name__}")
  270. if not isinstance(get_by_list, bool):
  271. raise TypeError(f"For 'GradOperation', the 'get_by_list' should be bool, "
  272. f"but got {type(get_by_list).__name__}")
  273. if not isinstance(sens_param, bool):
  274. raise TypeError(f"For 'GradOperation', the 'sens_param' should be bool, "
  275. f"but got {type(sens_param).__name__}")
  276. self.get_all = get_all
  277. self.get_by_list = get_by_list
  278. self.sens_param = sens_param
  279. GradOperation_.__init__(self, 'grad', get_all, get_by_list, sens_param, False)
  280. self.grad_fn = None
  281. self.fn = None
  282. self.pynative_ = False
  283. def _pynative_forward_run(self, grad, args, kwargs, fn):
  284. """ Pynative forward run to build grad graph. """
  285. new_kwargs = kwargs
  286. if self.sens_param:
  287. if not 'sens' in kwargs.keys():
  288. args = args[:-1]
  289. else:
  290. new_kwargs = kwargs.copy()
  291. new_kwargs.pop('sens')
  292. if isinstance(fn, FunctionType):
  293. if not _pynative_executor.check_run(grad, fn, *args, **new_kwargs):
  294. _pynative_executor.set_grad_flag(True)
  295. _pynative_executor.new_graph(fn, *args, **new_kwargs)
  296. output = fn(*args, **new_kwargs)
  297. _pynative_executor.end_graph(fn, output, *args, **new_kwargs)
  298. else:
  299. # Check if fn have run already
  300. if not _pynative_executor.check_run(grad, fn, *args, **new_kwargs):
  301. fn.set_grad()
  302. fn(*args, **new_kwargs)
  303. fn.set_grad(False)
  304. def __call__(self, fn, weights=None):
  305. if self.grad_fn is not None and self.fn == fn:
  306. return self.grad_fn
  307. grad_ = GradOperation(self.get_all, self.get_by_list, self.sens_param)
  308. # If calling Grad in GRAPH_MODE or calling Grad in ms_function, do grad in GRAPH_MODE
  309. # If calling Grad in pure PYNATIVE_MODE do grad in PYNATIVE_MODE
  310. # In pure PYNATIVE_MODE the out layer after_grad just used to set pynative flag for inner GradOperation.
  311. # In PYNATIVE_MODE calling Grad from ms_function, use the out layer after_grad do grad in GRAPH_MODE.
  312. if context.get_context("mode") == context.GRAPH_MODE:
  313. if self.get_by_list:
  314. @ms_function
  315. def after_grad(*args):
  316. return grad_(fn, weights)(*args)
  317. else:
  318. @ms_function
  319. def after_grad(*args):
  320. return grad_(fn)(*args)
  321. elif self.pynative_:
  322. @_wrap_func
  323. def after_grad(*args, **kwargs):
  324. if _pynative_executor.check_graph(fn, *args, **kwargs):
  325. print("Another grad step is running")
  326. self._pynative_forward_run(grad_, args, kwargs, fn)
  327. _pynative_executor.grad(grad_, fn, weights, (0,), *args, **kwargs)
  328. out = _pynative_executor(fn, *args, **kwargs)
  329. _pynative_executor.clear_grad(fn, *args, **kwargs)
  330. return out
  331. else:
  332. grad_.pynative_ = True
  333. # after_grad of this branch can't use @ms_function, just directly call grad_
  334. if self.get_by_list:
  335. def after_grad(*args, **kwargs):
  336. return grad_(fn, weights)(*args, **kwargs)
  337. else:
  338. def after_grad(*args, **kwargs):
  339. return grad_(fn)(*args, **kwargs)
  340. self.grad_fn = after_grad
  341. self.fn = fn
  342. return self.grad_fn
  343. class _Grad(GradOperation_):
  344. """
  345. A higher-order function which is used to generate the gradient function by position for the input function.
  346. """
  347. def __init__(self, get_by_list=False, sens_param=False, get_by_position=False):
  348. """Initialize _Grad."""
  349. if not isinstance(get_by_position, bool):
  350. raise TypeError(f"For '_Grad', the 'get_by_position' should be bool, "
  351. f"but got {type(get_by_position).__name__}")
  352. if not isinstance(get_by_list, bool):
  353. raise TypeError(f"For '_Grad', the 'get_by_list' should be bool, "
  354. f"but got {type(get_by_list).__name__}")
  355. if not isinstance(sens_param, bool):
  356. raise TypeError(f"For '_Grad', the 'sens_param' should be bool, "
  357. f"but got {type(sens_param).__name__}")
  358. self.get_by_position = get_by_position
  359. self.get_by_list = get_by_list
  360. self.sens_param = sens_param
  361. GradOperation_.__init__(self, 'grad', False, get_by_list, sens_param, get_by_position)
  362. self.grad_fn = None
  363. self.fn = None
  364. self.pynative_ = False
  365. self.grad_position = None
  366. def _pynative_forward_run(self, grad, args, kwargs, fn):
  367. """ Pynative forward run to build grad graph. """
  368. new_kwargs = kwargs
  369. if self.sens_param:
  370. if not 'sens' in kwargs.keys():
  371. args = args[:-1]
  372. else:
  373. new_kwargs = kwargs.copy()
  374. new_kwargs.pop('sens')
  375. if isinstance(fn, FunctionType):
  376. if not _pynative_executor.check_run(grad, fn, *args, **new_kwargs):
  377. _pynative_executor.set_grad_flag(True)
  378. _pynative_executor.new_graph(fn, *args, **new_kwargs)
  379. output = fn(*args, **new_kwargs)
  380. _pynative_executor.end_graph(fn, output, *args, **new_kwargs)
  381. else:
  382. # Check if fn have run already
  383. if not _pynative_executor.check_run(grad, fn, *args, **new_kwargs):
  384. fn.set_grad()
  385. fn(*args, **new_kwargs)
  386. fn.set_grad(False)
  387. def __call__(self, fn, weights=None, grad_position=0):
  388. if self.grad_fn is not None and self.fn == fn and self.grad_position == grad_position:
  389. return self.grad_fn
  390. grad_ = _Grad(self.get_by_list, self.sens_param, self.get_by_position)
  391. # If calling Grad in GRAPH_MODE or calling Grad in ms_function, do grad in GRAPH_MODE
  392. # If calling Grad in pure PYNATIVE_MODE do grad in PYNATIVE_MODE
  393. # In pure PYNATIVE_MODE the out layer after_grad just used to set pynative flag for inner GradOperation.
  394. # In PYNATIVE_MODE calling Grad from ms_function, use the out layer after_grad do grad in GRAPH_MODE.
  395. if context.get_context("mode") == context.GRAPH_MODE:
  396. if self.get_by_position:
  397. @ms_function
  398. def after_grad(*args):
  399. return grad_(fn, weights, grad_position)(*args)
  400. else:
  401. if self.get_by_list:
  402. @ms_function
  403. def after_grad(*args):
  404. return grad_(fn, weights)(*args)
  405. else:
  406. @ms_function
  407. def after_grad(*args):
  408. return grad_(fn)(*args)
  409. elif self.pynative_:
  410. @_wrap_func
  411. def after_grad(*args, **kwargs):
  412. if _pynative_executor.check_graph(fn, *args, **kwargs):
  413. print("Another grad step is running")
  414. self._pynative_forward_run(grad_, args, kwargs, fn)
  415. _pynative_executor.grad(grad_, fn, weights, grad_position, *args, **kwargs)
  416. out = _pynative_executor(fn, *args, **kwargs)
  417. _pynative_executor.clear_grad(fn, *args, **kwargs)
  418. return out
  419. else:
  420. grad_.pynative_ = True
  421. # after_grad of this branch can't use @ms_function, just directly call grad_
  422. if self.get_by_position:
  423. def after_grad(*args, **kwargs):
  424. return grad_(fn, weights, grad_position)(*args, **kwargs)
  425. else:
  426. if self.get_by_list:
  427. def after_grad(*args, **kwargs):
  428. return grad_(fn, weights)(*args, **kwargs)
  429. else:
  430. def after_grad(*args, **kwargs):
  431. return grad_(fn)(*args, **kwargs)
  432. self.grad_fn = after_grad
  433. self.fn = fn
  434. self.grad_position = grad_position
  435. return self.grad_fn
  436. class MultitypeFuncGraph(MultitypeFuncGraph_):
  437. """
  438. Generates overloaded functions.
  439. MultitypeFuncGraph is a class used to generate overloaded functions, considering different types as inputs.
  440. Initialize an `MultitypeFuncGraph` object with name, and use `register` with input types as the decorator
  441. for the function to be registered. And the object can be called with different types of inputs,
  442. and work with `HyperMap` and `Map`.
  443. Args:
  444. name (str): Operator name.
  445. read_value (bool): If the registered function not need to set value on Parameter,
  446. and all inputs will pass by value, set `read_value` to True. Default: False.
  447. Raises:
  448. ValueError: If failed to find a matching function for the given arguments.
  449. Supported Platforms:
  450. ``Ascend`` ``GPU`` ``CPU``
  451. Examples:
  452. >>> # `add` is a metagraph object which will add two objects according to
  453. >>> # input type using ".register" decorator.
  454. >>> from mindspore import Tensor
  455. >>> from mindspore import ops
  456. >>> from mindspore import dtype as mstype
  457. >>>
  458. >>> tensor_add = ops.Add()
  459. >>> add = MultitypeFuncGraph('add')
  460. >>> @add.register("Number", "Number")
  461. ... def add_scala(x, y):
  462. ... return x + y
  463. >>> @add.register("Tensor", "Tensor")
  464. ... def add_tensor(x, y):
  465. ... return tensor_add(x, y)
  466. >>> output = add(1, 2)
  467. >>> print(output)
  468. 3
  469. >>> output = add(Tensor([0.1, 0.6, 1.2], dtype=mstype.float32), Tensor([0.1, 0.6, 1.2], dtype=mstype.float32))
  470. >>> print(output)
  471. [0.2 1.2 2.4]
  472. """
  473. def __init__(self, name, read_value=False):
  474. """Initialize MultitypeFuncGraph."""
  475. MultitypeFuncGraph_.__init__(self, name)
  476. self.entries = list()
  477. if read_value:
  478. self.set_signatures((
  479. sig.make_sig('args', sig.sig_rw.RW_READ, sig.sig_kind.KIND_VAR_POSITIONAL),))
  480. def __call__(self, *args):
  481. if len(self.entries) == 1:
  482. output = self.entries[0][1](*args)
  483. return output
  484. types = tuple(map(mstype.get_py_obj_dtype, args))
  485. for sigs, fn in self.entries:
  486. if len(sigs) != len(types):
  487. continue
  488. if any(not mstype.issubclass_(type_, sig) for sig, type_ in zip(sigs, types)):
  489. continue
  490. output = fn(*args)
  491. return output
  492. raise ValueError(f"For 'MultitypeFuncGraph', cannot find fn match given args. Got (sigs, fn): {self.entries}, "
  493. f"and (dtype, args): {types}.")
  494. def register(self, *type_names):
  495. """
  496. Register a function for the given type string.
  497. Args:
  498. type_names (Union[str, :class:`mindspore.dtype`]): Inputs type names or types list.
  499. Return:
  500. decorator, a decorator to register the function to run, when called under the
  501. types described in `type_names`.
  502. """
  503. def deco(fn):
  504. def convert_type(type_input):
  505. if isinstance(type_input, str):
  506. return mstype.typing.str_to_type(type_input)
  507. if not isinstance(type_input, mstype.Type):
  508. raise TypeError(f"For 'MultitypeFuncGraph', register only support str or {mstype.Type}, but got "
  509. f"'type_input': {type_input}.")
  510. return type_input
  511. types = tuple(map(convert_type, type_names))
  512. self.register_fn(type_names, fn)
  513. self.entries.append((types, fn))
  514. return fn
  515. return deco
  516. class HyperMap(HyperMap_):
  517. """
  518. Hypermap will apply the set operation to input sequences.
  519. Apply the operations to every elements of the sequence or nested sequence. Different
  520. from `Map`, the `HyperMap` supports to apply on nested structure.
  521. Args:
  522. ops (Union[MultitypeFuncGraph, None]): `ops` is the operation to apply. If `ops` is `None`,
  523. the operations should be put in the first input of the instance. Default is None.
  524. reverse (bool): The optimizer needs to be inverted in some scenarios to improve parallel performance,
  525. general users please ignore. `reverse` is the flag to decide if apply the operation reversely.
  526. Only supported in graph mode. Default is False.
  527. Inputs:
  528. - **args** (Tuple[sequence]) - If `ops` is not `None`, all the inputs should be sequences with the same length.
  529. And each row of the sequences will be the inputs of the operation.
  530. If `ops` is `None`, the first input is the operation, and the others are inputs.
  531. Outputs:
  532. Sequence or nested sequence, the sequence of output after applying the function.
  533. e.g. `operation(args[0][i], args[1][i])`.
  534. Raises:
  535. TypeError: If `ops` is neither MultitypeFuncGraph nor None.
  536. TypeError: If `args` is not a Tuple.
  537. Supported Platforms:
  538. ``Ascend`` ``GPU`` ``CPU``
  539. Examples:
  540. >>> from mindspore import dtype as mstype
  541. >>> nest_tensor_list = ((Tensor(1, mstype.float32), Tensor(2, mstype.float32)),
  542. ... (Tensor(3, mstype.float32), Tensor(4, mstype.float32)))
  543. >>> # square all the tensor in the nested list
  544. >>>
  545. >>> square = MultitypeFuncGraph('square')
  546. >>> @square.register("Tensor")
  547. ... def square_tensor(x):
  548. ... return ops.square(x)
  549. >>>
  550. >>> common_map = HyperMap()
  551. >>> output = common_map(square, nest_tensor_list)
  552. >>> print(output)
  553. ((Tensor(shape=[], dtype=Float32, value= 1), Tensor(shape=[], dtype=Float32, value= 4)),
  554. (Tensor(shape=[], dtype=Float32, value= 9), Tensor(shape=[], dtype=Float32, value= 16)))
  555. >>> square_map = HyperMap(square, False)
  556. >>> output = square_map(nest_tensor_list)
  557. >>> print(output)
  558. ((Tensor(shape=[], dtype=Float32, value= 1), Tensor(shape=[], dtype=Float32, value= 4)),
  559. (Tensor(shape=[], dtype=Float32, value= 9), Tensor(shape=[], dtype=Float32, value= 16)))
  560. """
  561. def __init__(self, ops=None, reverse=False):
  562. """Initialize HyperMap."""
  563. self.ops = ops
  564. if ops:
  565. HyperMap_.__init__(self, reverse, ops)
  566. else:
  567. HyperMap_.__init__(self, reverse)
  568. def __call__(self, *args):
  569. func = self.ops
  570. args_list = args
  571. hypermap = self
  572. if self.ops is None:
  573. func = args[0]
  574. args_list = args[1:]
  575. hypermap = partial(self, func)
  576. # is leaf
  577. if not isinstance(args_list[0], (tuple, list)):
  578. return func(*args_list)
  579. return tuple(map(hypermap, *args_list))
  580. class Map(Map_):
  581. """
  582. Map will apply the set operation on input sequences.
  583. Apply the operations to every element of the sequence.
  584. Args:
  585. ops (Union[MultitypeFuncGraph, None]): `ops` is the operation to apply. If `ops` is `None`,
  586. the operations should be put in the first input of the instance. Default: None
  587. reverse (bool): The optimizer needs to be inverted in some scenarios to improve parallel performance,
  588. general users please ignore. `Reverse` is the flag to decide if apply the operation reversely.
  589. Only supported in graph mode. Default is False.
  590. Inputs:
  591. - **args** (Tuple[sequence]) - If `ops` is not `None`, all the inputs should be the same length sequences,
  592. and each row of the sequences. e.g. If args length is 2, and for `i` in length of each sequence
  593. `(args[0][i], args[1][i])` will be the input of the operation.
  594. If `ops` is `None`, the first input is the operation, and the other is inputs.
  595. Outputs:
  596. Sequence, the sequence of output after applying the function. e.g. `operation(args[0][i], args[1][i])`.
  597. Supported Platforms:
  598. ``Ascend`` ``GPU`` ``CPU``
  599. Examples:
  600. >>> from mindspore import dtype as mstype
  601. >>> tensor_list = (Tensor(1, mstype.float32), Tensor(2, mstype.float32), Tensor(3, mstype.float32))
  602. >>> # square all the tensor in the list
  603. >>>
  604. >>> square = MultitypeFuncGraph('square')
  605. >>> @square.register("Tensor")
  606. ... def square_tensor(x):
  607. ... return ops.square(x)
  608. >>>
  609. >>> common_map = Map()
  610. >>> output = common_map(square, tensor_list)
  611. >>> print(output)
  612. (Tensor(shape=[], dtype=Float32, value= 1), Tensor(shape=[], dtype=Float32, value= 4),
  613. Tensor(shape=[], dtype=Float32, value= 9))
  614. >>> square_map = Map(square, False)
  615. >>> output = square_map(tensor_list)
  616. >>> print(output)
  617. (Tensor(shape=[], dtype=Float32, value= 1), Tensor(shape=[], dtype=Float32, value= 4),
  618. Tensor(shape=[], dtype=Float32, value= 9))
  619. """
  620. def __init__(self, ops=None, reverse=False):
  621. """Initialize Map."""
  622. self.ops = ops
  623. if ops:
  624. Map_.__init__(self, reverse, ops)
  625. else:
  626. Map_.__init__(self, reverse)
  627. def __call__(self, *args):
  628. func = self.ops
  629. args_list = args
  630. if self.ops is None:
  631. func = args[0]
  632. args_list = args[1:]
  633. return tuple(map(func, *args_list))
  634. class _ListAppend(ListAppend_):
  635. """
  636. A metafuncgraph class that append one element to list.
  637. Args:
  638. name (str): The name of the metafuncgraph object.
  639. """
  640. def __init__(self, name):
  641. """Initialize _ListAppend."""
  642. ListAppend_.__init__(self, name)
  643. def __call__(self, *args):
  644. pass
  645. _append = _ListAppend("append")
  646. class _Tail(Tail_):
  647. """
  648. A metafuncgraph class that generates tail elements of the tuple.
  649. Args:
  650. name (str): The name of the metafuncgraph object.
  651. """
  652. def __init__(self, name):
  653. """Initialize _Tail."""
  654. Tail_.__init__(self, name)
  655. def __call__(self, *args):
  656. pass
  657. tail = _Tail('tail')
  658. class _ZipOperation(ZipOperation_):
  659. """Generates a tuple of zip iterations for inputs."""
  660. def __init__(self, name):
  661. """Initialize _ZipOperation."""
  662. ZipOperation_.__init__(self, name)
  663. def __call__(self, *args):
  664. pass
  665. zip_operation = _ZipOperation('zip_operation')
  666. """`zip_operation` will generate a tuple of zip iterations of inputs."""
  667. env_get = MultitypeFuncGraph("env_get")
  668. env_getitem = Primitive('env_getitem')
  669. ref_to_embed = _grad_ops.RefToEmbed()
  670. zeros_like = P.ZerosLike()
  671. @env_get.register("EnvType", "Tensor")
  672. def _tensor_env_get(env, parameter):
  673. """Used to get env."""
  674. return env_getitem(env, ref_to_embed(parameter), zeros_like(parameter))