You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

grad_implementations.py 8.1 kB

5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273
  1. # Copyright 2020-2021 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """bprop primitives"""
  16. from mindspore.ops import _constants
  17. from ..operations import _grad_ops as G
  18. from .. import functional as F
  19. from .. import operations as P
  20. from ..composite import multitype_ops as C
  21. from .grad_base import bprops
  22. get_dtype = P.DType()
  23. # Unused parameters are placeholders.
  24. @bprops.register("MaximumGrad")
  25. @bprops.register("MinimumGrad")
  26. def bprop_max_and_minimum_grad_grad(x, y, z, out, dout):
  27. """Backpropagator for primitive `MaximumGrad` and `MinimumGrad`."""
  28. out0 = F.cast(out[0] != 0, get_dtype(dout[0]))
  29. out1 = F.cast(out[1] != 0, get_dtype(dout[1]))
  30. dz = out0 * dout[0] + out1 * dout[1]
  31. return F.zeros_like(x), F.zeros_like(y), dz
  32. @bprops.register("ReluGrad")
  33. def bprop_relu_grad_grad(x, y, out, dout):
  34. """Backpropagator for primitive `ReluGrad`."""
  35. input_grad = G.ReluGrad()
  36. dy = input_grad(dout, y)
  37. return dy, F.zeros_like(y)
  38. @bprops.register(_constants.kScalarAdd)
  39. def bprop_scalar_add(x, y, out, dout):
  40. """Backpropagator for primitive `scalar_add`."""
  41. return dout, dout
  42. @bprops.register(_constants.kScalarMul)
  43. def bprop_scalar_mul(x, y, out, dout):
  44. """Backpropagator for primitive `scalar_mul`."""
  45. return dout*y, dout*x
  46. @bprops.register(_constants.kScalarSub)
  47. def bprop_scalar_sub(x, y, out, dout):
  48. """Backpropagator for primitive `scalar_sub`."""
  49. return dout, -dout
  50. @bprops.register(_constants.kScalarDiv)
  51. def bprop_scalar_div(x, y, out, dout):
  52. """Backpropagator for primitive `scalar_div`."""
  53. return dout/y, (-dout) * (out/y)
  54. @bprops.register(_constants.kScalarPow)
  55. def bprop_scalar_pow(x, y, out, dout):
  56. """Backpropagator for primitive `scalar_pow`."""
  57. return dout * (y * (x ** (y-1))), dout * (F.scalar_log(x) * out)
  58. @bprops.register("scalar_exp")
  59. def bprop_scalar_exp(x, out, dout):
  60. """Backpropagator for primitive `scalar_exp`."""
  61. return (dout * out,)
  62. @bprops.register(_constants.kScalarUadd)
  63. def bprop_scalar_uadd(x, out, dout):
  64. """Backpropagator for primitive `scalar_uadd`."""
  65. return (dout,)
  66. @bprops.register(_constants.kScalarUsub)
  67. def bprop_scalar_usub(x, out, dout):
  68. """Backpropagator for primitive `scalar_usub`."""
  69. return (-dout,)
  70. @bprops.register("scalar_gt")
  71. def bprop_scalar_gt(x, y, out, dout):
  72. """Backpropagator for primitive `scalar_gt`."""
  73. return C.zeros_like(x), C.zeros_like(y)
  74. @bprops.register("scalar_lt")
  75. def bprop_scalar_lt(x, y, out, dout):
  76. """Backpropagator for primitive `scalar_lt`."""
  77. return C.zeros_like(x), C.zeros_like(y)
  78. @bprops.register("scalar_ge")
  79. def bprop_scalar_ge(x, y, out, dout):
  80. """Backpropagator for primitive `scalar_ge`."""
  81. return C.zeros_like(x), C.zeros_like(y)
  82. @bprops.register("scalar_le")
  83. def bprop_scalar_le(x, y, out, dout):
  84. """Backpropagator for primitive `scalar_le`."""
  85. return C.zeros_like(x), C.zeros_like(y)
  86. @bprops.register("scalar_eq")
  87. def bprop_scalar_eq(x, y, out, dout):
  88. """Backpropagator for primitive `scalar_eq`."""
  89. return C.zeros_like(x), C.zeros_like(y)
  90. @bprops.register("scalar_ne")
  91. def bprop_scalar_ne(x, y, out, dout):
  92. """Backpropagator for primitive `scalar_eq`."""
  93. return C.zeros_like(x), C.zeros_like(y)
  94. @bprops.register("scalar_cast")
  95. def bprop_scalar_cast(x, t, out, dout):
  96. """Backpropagator for primitive `scalar_cast`."""
  97. return F.scalar_cast(dout, F.typeof(x)), t
  98. @bprops.register(_constants.kTupleGetItem)
  99. def bprop_tuple_getitem(data, idx, out, dout):
  100. """Backpropagator for primitive `tuple_getitem`."""
  101. return F.tuple_setitem(C.zeros_like(data), idx, dout), C.zeros_like(idx)
  102. @bprops.register("list_getitem")
  103. def bprop_list_getitem(data, idx, out, dout):
  104. """Backpropagator for primitive `list_getitem`."""
  105. return F.list_setitem(C.zeros_like(data), idx, dout), C.zeros_like(idx)
  106. @bprops.register("identity")
  107. def bprop_identity(x, out, dout):
  108. """Backpropagator for primitive `identity`."""
  109. return (dout,)
  110. @bprops.register("make_ref")
  111. def bprop_make_ref(key, x, y, out, dout):
  112. """Backpropagator for primitive `make_ref`."""
  113. return (C.zeros_like(key), dout, C.zeros_like(y))
  114. @bprops.register("get_ref_value")
  115. def bprop_get_ref_value(x, out, dout):
  116. """Backpropagator for primitive `get_ref_value`."""
  117. return (dout,)
  118. @bprops.register("get_ref_key")
  119. def bprop_get_ref_key(x, out, dout):
  120. """Backpropagator for primitive `get_ref_key`."""
  121. return (C.zeros_like(x),)
  122. @bprops.register("scalar_to_array")
  123. def bprop_scalar_to_array(x, out, dout):
  124. """Backpropagator for primitive `scalar_to_array`."""
  125. return (F.array_to_scalar(dout),)
  126. @bprops.register("array_to_scalar")
  127. def bprop_array_to_scalar(x, out, dout):
  128. """Backpropagator for primitive `array_to_scalar`."""
  129. return (F.scalar_to_array(dout),)
  130. @bprops.register("reshape")
  131. def bprop_reshape(xs, shp, out, dout):
  132. """Backpropagator for primitive `reshape`."""
  133. return F.reshape(dout, F.shape(xs)), C.zeros_like(shp)
  134. @bprops.register("distribute")
  135. def bprop_distribute(arr, shp, out, dout):
  136. """Backpropagator for primitive `distribute`."""
  137. return F.array_reduce(F.scalar_add, dout, F.shape(arr)), C.zeros_like(shp)
  138. @bprops.register("shape")
  139. def bprop_shape(arr, out, dout):
  140. """Backpropagator for primitive `shape`."""
  141. return (C.zeros_like(arr),)
  142. @bprops.register("broadcast_shape")
  143. def bprop_broadcast_shape(shp1, shp2, out, dout):
  144. """Backpropagator for primitive `broadcast_shape`."""
  145. return C.zeros_like(shp1), C.zeros_like(shp2)
  146. @bprops.register("array_reduce")
  147. def bprop_array_reduce(fn, x, shp, out, dout):
  148. """Backpropagator for primitive `array_reduce`."""
  149. return F.distribute(dout, F.shape(x)), C.zeros_like(shp)
  150. @bprops.register("Depend")
  151. def bprop_depend(x, y, out, dout):
  152. """Backpropagator for primitive `depend`."""
  153. return dout, C.zeros_like(y)
  154. @bprops.register("embed")
  155. def bprop_embed(x, out, dout):
  156. """Backpropagator for primitive `embed`."""
  157. return (C.zeros_like(x),)
  158. @bprops.register("bool_not")
  159. def bprop_bool_not(x, out, dout):
  160. """Backpropagator for primitive `bool_not`."""
  161. return (C.zeros_like(x),)
  162. @bprops.register("bool_or")
  163. def bprop_bool_or(x, y, out, dout):
  164. """Backpropagator for primitive `bool_or`."""
  165. return C.zeros_like(x), C.zeros_like(y)
  166. @bprops.register("stop_gradient")
  167. def bprop_stop_gradient(x, out, dout):
  168. """Backpropagator for primitive `stop_gradient`."""
  169. return (C.zeros_like(x),)
  170. @bprops.register("bool_and")
  171. def bprop_bool_and(x, y, out, dout):
  172. """Backpropagator for primitive `bool_and`."""
  173. return C.zeros_like(x), C.zeros_like(y)
  174. @bprops.register("Switch")
  175. def bprop_switch(cond, tb, fb, out, dout):
  176. """Backpropagator for primitive `switch`."""
  177. return C.zeros_like(cond), F.switch(cond, dout, C.zeros_like(tb)), \
  178. F.switch(cond, C.zeros_like(fb), dout)
  179. def _fprop_switch_layer(index, layers):
  180. """Backpropagator for primitive `switch_layer`."""
  181. def _bprop_switch_layer(dout):
  182. return dout, C.zeros_like(index), ()
  183. return F.switch_layer(index, layers), _bprop_switch_layer
  184. @bprops.register("UpdateState")
  185. def bprop_update_state(u_monad, x, out, dout):
  186. """Backpropagator for primitive `UpdateState`."""
  187. return C.zeros_like(u_monad), C.zeros_like(x)
  188. @bprops.register("Load")
  189. def bprop_load(param, u_monad, out, dout):
  190. """Backpropagator for primitive `load`."""
  191. return dout, C.zeros_like(u_monad)