You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

functional.py 17 kB

4 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474
  1. # This is the Python adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
  2. #
  3. # Copyright 2021 Huawei Technologies Co., Ltd
  4. #
  5. # Licensed under the Apache License, Version 2.0 (the "License");
  6. # you may not use this file except in compliance with the License.
  7. # You may obtain a copy of the License at
  8. #
  9. # http://www.apache.org/licenses/LICENSE-2.0
  10. #
  11. # Unless required by applicable law or agreed to in writing, software
  12. # distributed under the License is distributed on an "AS IS" BASIS,
  13. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. # See the License for the specific language governing permissions and
  15. # limitations under the License.
  16. # ============================================================================
  17. """The names of functional part are summarized here."""
  18. from mindspore.common._register_for_tensor import tensor_operator_registry
  19. from mindspore.common import ms_function
  20. from mindspore.common import Tensor
  21. from mindspore.nn.grad.cell_grad import _JvpInner
  22. from mindspore.nn.grad.cell_grad import _VjpInner
  23. from mindspore.ops import _constants
  24. from mindspore.ops.primitive import constexpr
  25. from .primitive import Primitive
  26. from . import operations as P
  27. from .operations import _grad_ops
  28. from .composite import _Grad
  29. from .._c_expression import security
  30. typeof = Primitive('typeof')
  31. hastype = Primitive('hastype')
  32. cast = P.Cast()
  33. dtype = P.DType()
  34. isconstant = Primitive('is_constant')
  35. isconstant.set_const_prim(True)
  36. issubclass_ = P.IsSubClass()
  37. isinstance_ = P.IsInstance()
  38. eye = P.Eye()
  39. fill = P.Fill()
  40. tile = P.Tile()
  41. select = P.Select()
  42. size = P.Size()
  43. ones_like = P.OnesLike()
  44. shape = P.Shape()
  45. dyn_shape = P.DynamicShape()
  46. rank = P.Rank()
  47. reshape = P.Reshape()
  48. merge = P.Merge()
  49. geswitch = P.GeSwitch()
  50. addn = P.AddN()
  51. absolute = P.Abs()
  52. tensor_add = P.Add()
  53. add = tensor_add
  54. neg_tensor = P.Neg()
  55. tensor_lt = P.Less()
  56. less = tensor_lt
  57. tensor_le = P.LessEqual()
  58. le = tensor_le
  59. tensor_gt = P.Greater()
  60. gt = tensor_gt
  61. tensor_ge = P.GreaterEqual()
  62. ge = tensor_ge
  63. tensor_sub = P.Sub()
  64. sub = tensor_sub
  65. tensor_mul = P.Mul()
  66. mul = tensor_mul
  67. tensor_div = P.RealDiv()
  68. div = tensor_div
  69. tensor_floordiv = P.FloorDiv()
  70. floordiv = tensor_floordiv
  71. tensor_pow = P.Pow()
  72. pows = tensor_pow
  73. tensor_mod = P.FloorMod()
  74. floormod = tensor_mod
  75. tensor_exp = P.Exp()
  76. exp = tensor_exp
  77. tensor_expm1 = P.Expm1()
  78. tensor_slice = P.Slice()
  79. strided_slice = P.StridedSlice()
  80. same_type_shape = P.SameTypeShape()
  81. check_bprop = P.CheckBprop()
  82. equal = P.Equal()
  83. not_equal = P.NotEqual()
  84. isfinite = P.IsFinite()
  85. isnan = P.IsNan()
  86. assign_sub = P.AssignSub()
  87. assign_add = P.AssignAdd()
  88. assign = P.Assign()
  89. square = P.Square()
  90. sqrt = P.Sqrt()
  91. log = P.Log()
  92. reduce_sum = P.ReduceSum()
  93. reduce_max = P.ReduceMax()
  94. reduce_min = P.ReduceMin()
  95. reduce_mean = P.ReduceMean()
  96. reduce_prod = P.ReduceProd()
  97. tensor_slice = P.Slice()
  98. maximum = P.Maximum()
  99. minimum = P.Minimum()
  100. floor = P.Floor()
  101. logical_not = P.LogicalNot()
  102. logical_or = P.LogicalOr()
  103. logical_and = P.LogicalAnd()
  104. sin = P.Sin()
  105. cos = P.Cos()
  106. tan = P.Tan()
  107. asin = P.Asin()
  108. acos = P.ACos()
  109. atan = P.Atan()
  110. sinh = P.Sinh()
  111. cosh = P.Cosh()
  112. tanh = P.Tanh()
  113. asinh = P.Asinh()
  114. acosh = P.Acosh()
  115. atanh = P.Atanh()
  116. atan2 = P.Atan2()
  117. bitwise_and = P.BitwiseAnd()
  118. bitwise_or = P.BitwiseOr()
  119. bitwise_xor = P.BitwiseXor()
  120. invert = P.Invert()
  121. erf = P.Erf()
  122. erfc = P.Erfc()
  123. sort = P.Sort()
  124. tensor_range = P.Range()
  125. scalar_to_array = P.ScalarToArray()
  126. scalar_to_tensor = P.ScalarToTensor()
  127. tuple_to_array = P.TupleToArray()
  128. scalar_cast = P.ScalarCast()
  129. if not security.enable_security():
  130. print_ = P.Print()
  131. expand_dims = P.ExpandDims()
  132. transpose = P.Transpose()
  133. squeeze = P.Squeeze()
  134. scatter_nd = P.ScatterNd()
  135. gather = P.Gather()
  136. gather_d = P.GatherD()
  137. gather_nd = P.GatherNd()
  138. scatter_update = P.ScatterUpdate()
  139. tensor_scatter_update = P.TensorScatterUpdate()
  140. scatter_nd_update = P.ScatterNdUpdate()
  141. stack = P.Stack()
  142. def pack(x):
  143. """Call stack in this pack function."""
  144. print("WARNING: 'pack' is deprecated from version 1.1 and will be removed in a future version, use 'stack' instead"
  145. ".")
  146. return stack(x)
  147. partial = P.Partial()
  148. # depend: mount a node to another node
  149. depend = P.Depend()
  150. identity = P.identity()
  151. @constexpr
  152. def _convert_grad_position_type(grad_position):
  153. """Check and convert the type and size of grad position index."""
  154. if isinstance(grad_position, tuple):
  155. for gp in grad_position:
  156. if not isinstance(gp, int):
  157. raise TypeError(f"For 'F.grad', the element in 'grad_position' should be int, "
  158. f"but got {type(gp).__name__}")
  159. if gp < 0:
  160. raise ValueError("The element in grad_position must be >= 0.")
  161. elif isinstance(grad_position, int):
  162. if grad_position < 0:
  163. raise ValueError("grad_position must be >= 0.")
  164. grad_position = (grad_position,)
  165. else:
  166. raise TypeError(f"For 'F.grad', the 'grad_position' should be int or tuple, "
  167. f"but got {type(grad_position).__name__}")
  168. return grad_position
  169. grad_by_position = _Grad(get_by_list=False, sens_param=False, get_by_position=True)
  170. grad_by_position_with_sens = _Grad(get_by_list=False, sens_param=True, get_by_position=True)
  171. def grad(fn, grad_position=0, sens_param=False):
  172. r"""
  173. A wrapper function to generate the gradient function for the input function.
  174. Args:
  175. fn (Union(Cell, function)): Function to do GradOperation.
  176. grad_position (Union(int, tuple[int])): If int, get the gradient with respect to single input.
  177. If tuple, get the gradients with respect to selected inputs. 'grad_position' begins with 0. Default: 0.
  178. sens_param (bool): Whether to append sensitivity (gradient with respect to output) as input.
  179. If sens_param is False, a 'ones_like(outputs)' sensitivity will be attached automatically. Default: False.
  180. Returns:
  181. Function, returns the gradient function for the input function or cell.
  182. Supported Platforms:
  183. ``Ascend`` ``GPU`` ``CPU``
  184. Examples:
  185. >>> import numpy as np
  186. >>> import mindspore.nn as nn
  187. >>> import mindspore.context as context
  188. >>> from mindspore import Tensor
  189. >>> from mindspore.ops.functional import grad
  190. >>> context.set_context(mode=context.GRAPH_MODE)
  191. >>> class Net(nn.Cell):
  192. ... def construct(self, x, y, z):
  193. ... return x*y*z
  194. >>> x = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
  195. >>> y = Tensor(np.array([[-2, 3], [-1, 2]]).astype(np.float32))
  196. >>> z = Tensor(np.array([[0, 3], [5, -1]]).astype(np.float32))
  197. >>> net = Net()
  198. >>> output = grad(net, grad_position=(1, 2))(x, y, z)
  199. >>> print(output)
  200. (Tensor(shape=[2, 2], dtype=Float32, value=
  201. [[ 0.00000000e+00, 6.00000000e+00],
  202. [ 1.50000000e+01, -4.00000000e+00]]), Tensor(shape=[2, 2], dtype=Float32, value=
  203. [[-2.00000000e+00, 6.00000000e+00],
  204. [-3.00000000e+00, 8.00000000e+00]]))
  205. """
  206. grad_position = _convert_grad_position_type(grad_position)
  207. if sens_param:
  208. return grad_by_position_with_sens(fn, None, grad_position)
  209. return grad_by_position(fn, None, grad_position)
  210. def jvp(fn, inputs, v):
  211. """
  212. Compute the jacobian-vector-product of the given network.
  213. Args:
  214. fn (Function or Cell): The function or net that takes Tensor inputs and returns a tensor or tuple of Tensors.
  215. inputs (Tensor or tuple or list): The inputs to `fn`.
  216. v (Tensor or tuple or list): The shape and type of v should be the same as inputs.
  217. Returns:
  218. Tuple, tuple of output and jvp.
  219. - netout(Tensors or Tuple of Tensors), the output of "fn(inputs)".
  220. - jvp(Tensors or Tuple of Tensors), the result of the dot product.
  221. Raises:
  222. TypeError: If the input is not a tensor or tuple or list of tensors.
  223. Supported Platforms:
  224. ``Ascend`` ``GPU`` ``CPU``
  225. Examples:
  226. >>> from mindspore.ops import functional as F
  227. >>> from mindspore import Tensor
  228. >>> class Net(nn.Cell):
  229. ... def construct(self, x, y):
  230. ... return x**3 + y
  231. >>> x = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
  232. >>> y = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
  233. >>> v = Tensor(np.array([[1, 1], [1, 1]]).astype(np.float32))
  234. >>> output = F.jvp(Net(), (x, y), (v, v))
  235. >>> print(output[0])
  236. [[ 2. 10.]
  237. [30. 68.]]
  238. >>> print(output[1])
  239. [[ 4. 13.]
  240. [28. 49.]]
  241. """
  242. jvp_inner = _JvpInner()
  243. @ms_function
  244. def _wrap_container(*arg):
  245. args = arg[1:]
  246. vectors = arg[0]
  247. return jvp_inner(fn, vectors, *args)
  248. if not isinstance(inputs, (Tensor, tuple, list)):
  249. _raise_type_error()
  250. if isinstance(inputs, (tuple, list)):
  251. return _wrap_container(v, *inputs)
  252. return _wrap_container(v, inputs)
  253. def vjp(fn, inputs, v):
  254. """
  255. Compute the vector-jacobian-product of the given network.
  256. Args:
  257. fn (Function or Cell): The function or net that takes Tensor inputs and returns a tensor or tuple of Tensors.
  258. inputs (Tensor or tuple or list): The inputs to `fn`.
  259. v (Tensor or tuple or list): The shape and type of v should be the same as outputs.
  260. Returns:
  261. Tuple, tuple of output and jvp.
  262. - netout(Tensors or Tuple of Tensors), the output of "fn(inputs)".
  263. - vjp(Tensors or Tuple of Tensors), the result of the dot product.
  264. Raises:
  265. TypeError: If the input is not a tensor or tuple or list of tensors.
  266. Supported Platforms:
  267. ``Ascend`` ``GPU`` ``CPU``
  268. Examples:
  269. >>> from mindspore.ops import functional as F
  270. >>> from mindspore import Tensor
  271. >>> class Net(nn.Cell):
  272. ... def construct(self, x, y):
  273. ... return x**3 + y
  274. >>> x = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
  275. >>> y = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
  276. >>> v = Tensor(np.array([[1, 1], [1, 1]]).astype(np.float32))
  277. >>> output = F.vjp(Net(), (x, y), v)
  278. >>> print(output[0])
  279. [[ 2. 10.]
  280. [30. 68.]]
  281. >>> print(output[1])
  282. (Tensor(shape=[2, 2], dtype=Float32, value=
  283. [[ 3.00000000e+00, 1.20000000e+01],
  284. [ 2.70000000e+01, 4.80000000e+01]]), Tensor(shape=[2, 2], dtype=Float32, value=
  285. [[ 1.00000000e+00, 1.00000000e+00],
  286. [ 1.00000000e+00, 1.00000000e+00]]))
  287. """
  288. vjp_inner = _VjpInner()
  289. @ms_function
  290. def wrap_container(*arg):
  291. args = arg[:-1]
  292. vectors = arg[-1]
  293. return vjp_inner(fn, *args, vectors)
  294. if not isinstance(inputs, (Tensor, tuple, list)):
  295. _raise_type_error()
  296. if isinstance(inputs, (tuple, list)):
  297. return wrap_container(*inputs, v)
  298. return wrap_container(inputs, v)
  299. @constexpr
  300. def _raise_type_error():
  301. raise TypeError("The inputs type should be a Tensor, tuple or list of Tensor.")
  302. tuple_setitem = Primitive('tuple_setitem')
  303. tuple_getitem = Primitive(_constants.kTupleGetItem)
  304. list_getitem = Primitive('list_getitem')
  305. list_setitem = Primitive('list_setitem')
  306. dict_getitem = Primitive('dict_getitem')
  307. dict_setitem = Primitive('dict_setitem')
  308. tuple_div = Primitive("tuple_div")
  309. tuple_len = Primitive("tuple_len")
  310. list_len = Primitive("list_len")
  311. tuple_reversed = Primitive("tuple_reversed")
  312. make_range = Primitive("make_range")
  313. make_tuple = Primitive('MakeTuple')
  314. make_dict = Primitive('make_dict')
  315. make_list = Primitive('make_list')
  316. make_slice = Primitive('make_slice')
  317. tuple_equal = Primitive("tuple_equal")
  318. list_equal = Primitive("list_equal")
  319. make_ref = Primitive("make_ref")
  320. scalar_add = Primitive(_constants.kScalarAdd)
  321. scalar_mul = Primitive(_constants.kScalarMul)
  322. scalar_sub = Primitive(_constants.kScalarSub)
  323. scalar_div = Primitive(_constants.kScalarDiv)
  324. scalar_floordiv = Primitive(_constants.kScalarFloordiv)
  325. scalar_log = Primitive('scalar_log')
  326. scalar_pow = Primitive(_constants.kScalarPow)
  327. scalar_gt = Primitive('scalar_gt')
  328. scalar_ge = Primitive('scalar_ge')
  329. scalar_le = Primitive('scalar_le')
  330. scalar_lt = Primitive('scalar_lt')
  331. scalar_eq = Primitive('scalar_eq')
  332. scalar_ne = Primitive('scalar_ne')
  333. scalar_uadd = Primitive(_constants.kScalarUadd)
  334. scalar_usub = Primitive(_constants.kScalarUsub)
  335. scalar_mod = Primitive(_constants.kScalarMod)
  336. string_eq = Primitive('string_equal')
  337. string_concat = Primitive('string_concat')
  338. bool_not = Primitive("bool_not")
  339. bool_or = Primitive("bool_or")
  340. bool_and = Primitive("bool_and")
  341. bool_eq = Primitive("bool_eq")
  342. logical_and = P.LogicalAnd()
  343. logical_or = P.LogicalOr()
  344. logical_not = P.LogicalNot()
  345. cumsum = P.CumSum()
  346. cumprod = P.CumProd()
  347. tensor_scatter_add = P.TensorScatterAdd()
  348. array_to_scalar = Primitive('array_to_scalar')
  349. is_ = Primitive("is_")
  350. is_not = Primitive("is_not")
  351. in_dict = Primitive("in_dict")
  352. not_in_dict = Primitive("not_in_dict")
  353. mixed_precision_cast = Primitive("mixed_precision_cast")
  354. broadcast_gradient_args = Primitive('BroadcastGradientArgs')
  355. array_reduce = Primitive('array_reduce')
  356. zeros_like = P.ZerosLike()
  357. distribute = Primitive('distribute')
  358. embed = Primitive('embed')
  359. ref_to_embed = _grad_ops.RefToEmbed()
  360. env_setitem = Primitive('env_setitem')
  361. env_getitem = Primitive('env_getitem')
  362. env_add = Primitive('env_add')
  363. J = Primitive('J')
  364. SliceGetItem = Primitive("SliceGetItem")
  365. switch = Primitive('Switch')
  366. switch_layer = Primitive('switch_layer')
  367. # for sum bprop
  368. reduced_shape = Primitive("reduced_shape")
  369. # shape_mul:input must be shape multiply elements in tuple(shape)
  370. shape_mul = Primitive("shape_mul")
  371. # a primitive to compare between tuple.
  372. stop_gradient = Primitive("stop_gradient")
  373. make_row_tensor = Primitive('MakeRowTensor')
  374. row_tensor_get_values = Primitive('RowTensorGetValues')
  375. row_tensor_get_indices = Primitive('RowTensorGetIndices')
  376. row_tensor_get_dense_shape = Primitive('RowTensorGetDenseShape')
  377. row_tensor_add = Primitive('RowTensorAdd')
  378. make_sparse_tensor = Primitive('MakeSparseTensor')
  379. sparse_tensor_get_values = Primitive('SparseTensorGetValues')
  380. sparse_tensor_get_indices = Primitive('SparseTensorGetIndices')
  381. sparse_tensor_get_dense_shape = Primitive('SparseTensorGetDenseShape')
  382. make_csr_tensor = Primitive('MakeCSRTensor')
  383. csr_tensor_get_values = Primitive('CSRTensorGetValues')
  384. csr_tensor_get_indices = Primitive('CSRTensorGetIndices')
  385. csr_tensor_get_indptr = Primitive('CSRTensorGetIndptr')
  386. csr_tensor_get_shape = Primitive('CSRTensorGetDenseShape')
  387. tensor_operator_registry.register('all', P.ReduceAll)
  388. tensor_operator_registry.register('any', P.ReduceAny)
  389. tensor_operator_registry.register('abs', P.Abs)
  390. tensor_operator_registry.register('mean', P.ReduceMean)
  391. tensor_operator_registry.register('reshape', P.Reshape)
  392. tensor_operator_registry.register('transpose', P.Transpose)
  393. tensor_operator_registry.register('broadcast_to', P.BroadcastTo)
  394. tensor_operator_registry.register('matmul', P.MatMul)
  395. tensor_operator_registry.register('argmax', P.Argmax)
  396. tensor_operator_registry.register('cumsum', P.CumSum)
  397. tensor_operator_registry.register('reduce_max', P.ReduceMax)
  398. tensor_operator_registry.register('reduce_min', P.ReduceMin)
  399. tensor_operator_registry.register('maximum', P.Maximum)
  400. tensor_operator_registry.register('minimum', P.Minimum)
  401. tensor_operator_registry.register('fill', P.Fill)
  402. tensor_operator_registry.register('tile', P.Tile)
  403. tensor_operator_registry.register('logical_not', P.LogicalNot)
  404. tensor_operator_registry.register('sum', P.ReduceSum)
  405. tensor_operator_registry.register('split', P.Split)
  406. # ms cannot support Tensor(True) compare
  407. tensor_operator_registry.register('__eq__', equal)
  408. tensor_operator_registry.register('__ne__', not_equal)
  409. tensor_operator_registry.register('__neg__', neg_tensor)
  410. tensor_operator_registry.register('__lt__', tensor_lt)
  411. tensor_operator_registry.register('__le__', tensor_le)
  412. tensor_operator_registry.register('__gt__', tensor_gt)
  413. tensor_operator_registry.register('__ge__', tensor_ge)
  414. tensor_operator_registry.register('__logical_not__', logical_not)
  415. tensor_operator_registry.register('shape', shape)
  416. tensor_operator_registry.register('squeeze', squeeze)
  417. # support GE backend for no compare operators
  418. tensor_operator_registry.register('cast', cast)
  419. tensor_operator_registry.register('shape_mul', shape_mul)
  420. tensor_operator_registry.register('fill', fill)
  421. tensor_operator_registry.register('concatenate', P.Concat)
  422. tensor_operator_registry.register('eye', eye)
  423. tensor_operator_registry.register('reduce_sum', reduce_sum)
  424. tensor_operator_registry.register('tensor_slice', tensor_slice)
  425. tensor_operator_registry.register('select', select)
  426. tensor_operator_registry.register('gather_d', gather_d)
  427. tensor_operator_registry.register('gather_nd', gather_nd)
  428. tensor_operator_registry.register('stack', P.Stack)
  429. tensor_operator_registry.register('log', log)
  430. tensor_operator_registry.register('floor', floor)
  431. __all__ = [name for name in dir() if name[0] != "_"]
  432. __all__.remove('Primitive')