You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

grad_implementations.py 8.7 kB

5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293
  1. # Copyright 2020-2021 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """bprop primitives"""
  16. from mindspore.ops import _constants
  17. from ..operations import _grad_ops as G
  18. from .. import functional as F
  19. from .. import operations as P
  20. from ..composite import multitype_ops as C
  21. from .grad_base import bprops
  22. get_dtype = P.DType()
  23. # Unused parameters are placeholders.
  24. @bprops.register("MaximumGrad")
  25. def bprop_maximum_grad_grad(x, y, z, out, dout):
  26. """Backpropagator for primitive `MaximumGrad`."""
  27. out0 = F.cast(out[0] != 0, get_dtype(dout[0]))
  28. out1 = F.cast(out[1] != 0, get_dtype(dout[1]))
  29. dz = out0 * dout[0] + out1 * dout[1]
  30. return F.zeros_like(x), F.zeros_like(y), dz
  31. @bprops.register("MinimumGrad")
  32. def bprop_minimum_grad_grad(x, y, z, out, dout):
  33. """Backpropagator for primitive `MinimumGrad`."""
  34. out0 = F.cast(out[0] != 0, get_dtype(dout[0]))
  35. out1 = F.cast(out[1] != 0, get_dtype(dout[1]))
  36. dz = out0 * dout[0] + out1 * dout[1]
  37. return F.zeros_like(x), F.zeros_like(y), dz
  38. @bprops.register("ReluGrad")
  39. def bprop_relu_grad_grad(x, y, out, dout):
  40. """Backpropagator for primitive `ReluGrad`."""
  41. input_grad = G.ReluGrad()
  42. dy = input_grad(dout, y)
  43. return dy, F.zeros_like(y)
  44. @bprops.register(_constants.kScalarAdd)
  45. def bprop_scalar_add(x, y, out, dout):
  46. """Backpropagator for primitive `scalar_add`."""
  47. return dout, dout
  48. @bprops.register(_constants.kScalarMul)
  49. def bprop_scalar_mul(x, y, out, dout):
  50. """Backpropagator for primitive `scalar_mul`."""
  51. return dout*y, dout*x
  52. @bprops.register(_constants.kScalarSub)
  53. def bprop_scalar_sub(x, y, out, dout):
  54. """Backpropagator for primitive `scalar_sub`."""
  55. return dout, -dout
  56. @bprops.register(_constants.kScalarDiv)
  57. def bprop_scalar_div(x, y, out, dout):
  58. """Backpropagator for primitive `scalar_div`."""
  59. return dout/y, (-dout) * (out/y)
  60. @bprops.register(_constants.kScalarPow)
  61. def bprop_scalar_pow(x, y, out, dout):
  62. """Backpropagator for primitive `scalar_pow`."""
  63. return dout * (y * (x ** (y-1))), dout * (F.scalar_log(x) * out)
  64. @bprops.register("scalar_exp")
  65. def bprop_scalar_exp(x, out, dout):
  66. """Backpropagator for primitive `scalar_exp`."""
  67. return (dout * out,)
  68. @bprops.register(_constants.kScalarUadd)
  69. def bprop_scalar_uadd(x, out, dout):
  70. """Backpropagator for primitive `scalar_uadd`."""
  71. return (dout,)
  72. @bprops.register(_constants.kScalarUsub)
  73. def bprop_scalar_usub(x, out, dout):
  74. """Backpropagator for primitive `scalar_usub`."""
  75. return (-dout,)
  76. @bprops.register("scalar_gt")
  77. def bprop_scalar_gt(x, y, out, dout):
  78. """Backpropagator for primitive `scalar_gt`."""
  79. return C.zeros_like(x), C.zeros_like(y)
  80. @bprops.register("scalar_lt")
  81. def bprop_scalar_lt(x, y, out, dout):
  82. """Backpropagator for primitive `scalar_lt`."""
  83. return C.zeros_like(x), C.zeros_like(y)
  84. @bprops.register("scalar_ge")
  85. def bprop_scalar_ge(x, y, out, dout):
  86. """Backpropagator for primitive `scalar_ge`."""
  87. return C.zeros_like(x), C.zeros_like(y)
  88. @bprops.register("scalar_le")
  89. def bprop_scalar_le(x, y, out, dout):
  90. """Backpropagator for primitive `scalar_le`."""
  91. return C.zeros_like(x), C.zeros_like(y)
  92. @bprops.register("scalar_eq")
  93. def bprop_scalar_eq(x, y, out, dout):
  94. """Backpropagator for primitive `scalar_eq`."""
  95. return C.zeros_like(x), C.zeros_like(y)
  96. @bprops.register("scalar_ne")
  97. def bprop_scalar_ne(x, y, out, dout):
  98. """Backpropagator for primitive `scalar_eq`."""
  99. return C.zeros_like(x), C.zeros_like(y)
  100. @bprops.register("scalar_cast")
  101. def bprop_scalar_cast(x, t, out, dout):
  102. """Backpropagator for primitive `scalar_cast`."""
  103. return F.scalar_cast(dout, F.typeof(x)), t
  104. @bprops.register(_constants.kTupleGetItem)
  105. def bprop_tuple_getitem(data, idx, out, dout):
  106. """Backpropagator for primitive `tuple_getitem`."""
  107. return F.tuple_setitem(C.zeros_like(data), idx, dout), C.zeros_like(idx)
  108. @bprops.register("list_getitem")
  109. def bprop_list_getitem(data, idx, out, dout):
  110. """Backpropagator for primitive `list_getitem`."""
  111. return F.list_setitem(C.zeros_like(data), idx, dout), C.zeros_like(idx)
  112. @bprops.register("identity")
  113. def bprop_identity(x, out, dout):
  114. """Backpropagator for primitive `identity`."""
  115. return (dout,)
  116. @bprops.register("make_ref")
  117. def bprop_make_ref(key, x, y, out, dout):
  118. """Backpropagator for primitive `make_ref`."""
  119. return (C.zeros_like(key), dout, C.zeros_like(y))
  120. @bprops.register("get_ref_value")
  121. def bprop_get_ref_value(x, out, dout):
  122. """Backpropagator for primitive `get_ref_value`."""
  123. return (dout,)
  124. @bprops.register("get_ref_key")
  125. def bprop_get_ref_key(x, out, dout):
  126. """Backpropagator for primitive `get_ref_key`."""
  127. return (C.zeros_like(x),)
  128. @bprops.register("scalar_to_array")
  129. def bprop_scalar_to_array(x, out, dout):
  130. """Backpropagator for primitive `scalar_to_array`."""
  131. return (F.array_to_scalar(dout),)
  132. @bprops.register("array_to_scalar")
  133. def bprop_array_to_scalar(x, out, dout):
  134. """Backpropagator for primitive `array_to_scalar`."""
  135. return (F.scalar_to_array(dout),)
  136. @bprops.register("reshape")
  137. def bprop_reshape(xs, shp, out, dout):
  138. """Backpropagator for primitive `reshape`."""
  139. return F.reshape(dout, F.shape(xs)), C.zeros_like(shp)
  140. @bprops.register("distribute")
  141. def bprop_distribute(arr, shp, out, dout):
  142. """Backpropagator for primitive `distribute`."""
  143. return F.array_reduce(F.scalar_add, dout, F.shape(arr)), C.zeros_like(shp)
  144. @bprops.register("shape")
  145. def bprop_shape(arr, out, dout):
  146. """Backpropagator for primitive `shape`."""
  147. return (C.zeros_like(arr),)
  148. @bprops.register("broadcast_shape")
  149. def bprop_broadcast_shape(shp1, shp2, out, dout):
  150. """Backpropagator for primitive `broadcast_shape`."""
  151. return C.zeros_like(shp1), C.zeros_like(shp2)
  152. @bprops.register("J")
  153. def bprop_j(x, out, dout):
  154. """Backpropagator for primitive `J`."""
  155. return (F.jinv(dout),)
  156. @bprops.register("array_reduce")
  157. def bprop_array_reduce(fn, x, shp, out, dout):
  158. """Backpropagator for primitive `array_reduce`."""
  159. return F.distribute(dout, F.shape(x)), C.zeros_like(shp)
  160. @bprops.register("Depend")
  161. def bprop_depend(x, y, out, dout):
  162. """Backpropagator for primitive `depend`."""
  163. return dout, C.zeros_like(y)
  164. @bprops.register("embed")
  165. def bprop_embed(x, out, dout):
  166. """Backpropagator for primitive `embed`."""
  167. return (C.zeros_like(x),)
  168. @bprops.register("bool_not")
  169. def bprop_bool_not(x, out, dout):
  170. """Backpropagator for primitive `bool_not`."""
  171. return (C.zeros_like(x),)
  172. @bprops.register("bool_or")
  173. def bprop_bool_or(x, y, out, dout):
  174. """Backpropagator for primitive `bool_or`."""
  175. return C.zeros_like(x), C.zeros_like(y)
  176. @bprops.register("stop_gradient")
  177. def bprop_stop_gradient(x, out, dout):
  178. """Backpropagator for primitive `stop_gradient`."""
  179. return (C.zeros_like(x),)
  180. @bprops.register("bool_and")
  181. def bprop_bool_and(x, y, out, dout):
  182. """Backpropagator for primitive `bool_and`."""
  183. return C.zeros_like(x), C.zeros_like(y)
  184. @bprops.register("ControlDepend")
  185. def bprop_control_depend(x, y, out, dout):
  186. """Backpropagator for primitive `Control_depend`."""
  187. return C.zeros_like(x), C.zeros_like(y)
  188. @bprops.register("switch")
  189. def bprop_switch(cond, tb, fb, out, dout):
  190. """Backpropagator for primitive `switch`."""
  191. return C.zeros_like(cond), F.switch(cond, dout, C.zeros_like(tb)), \
  192. F.switch(cond, C.zeros_like(fb), dout)
  193. def _fprop_switch_layer(index, layers):
  194. """Backpropagator for primitive `switch_layer`."""
  195. def _bprop_switch_layer(dout):
  196. return dout, C.zeros_like(index), ()
  197. return F.switch_layer(index, layers), _bprop_switch_layer
  198. @bprops.register("UpdateState")
  199. def bprop_update_state(u_monad, x, out, dout):
  200. """Backpropagator for primitive `UpdateState`."""
  201. return C.zeros_like(u_monad), C.zeros_like(x)
  202. @bprops.register("Load")
  203. def bprop_load(param, u_monad, out, dout):
  204. """Backpropagator for primitive `load`."""
  205. return dout, C.zeros_like(u_monad)