You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

grad_implementations.py 7.0 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """bprop primitives"""
  16. from .. import functional as F
  17. from ..composite import multitype_ops as C
  18. from .grad_base import bprops
  19. # Unused parameters are placeholders.
  20. @bprops.register("scalar_add")
  21. def bprop_scalar_add(x, y, out, dout):
  22. """Backpropagator for primitive `scalar_add`."""
  23. return dout, dout
  24. @bprops.register("scalar_mul")
  25. def bprop_scalar_mul(x, y, out, dout):
  26. """Backpropagator for primitive `scalar_mul`."""
  27. return dout*y, dout*x
  28. @bprops.register("scalar_sub")
  29. def bprop_scalar_sub(x, y, out, dout):
  30. """Backpropagator for primitive `scalar_sub`."""
  31. return dout, -dout
  32. @bprops.register("scalar_div")
  33. def bprop_scalar_div(x, y, out, dout):
  34. """Backpropagator for primitive `scalar_div`."""
  35. return dout/y, (-dout) * (out/y)
  36. @bprops.register("scalar_pow")
  37. def bprop_scalar_pow(x, y, out, dout):
  38. """Backpropagator for primitive `scalar_pow`."""
  39. return dout * (y * (x ** (y-1))), dout * (F.scalar_log(x) * out)
  40. @bprops.register("scalar_exp")
  41. def bprop_scalar_exp(x, out, dout):
  42. """Backpropagator for primitive `scalar_exp`."""
  43. return (dout * out,)
  44. @bprops.register("scalar_uadd")
  45. def bprop_scalar_uadd(x, out, dout):
  46. """Backpropagator for primitive `scalar_uadd`."""
  47. return (dout,)
  48. @bprops.register("scalar_usub")
  49. def bprop_scalar_usub(x, out, dout):
  50. """Backpropagator for primitive `scalar_usub`."""
  51. return (-dout,)
  52. @bprops.register("scalar_gt")
  53. def bprop_scalar_gt(x, y, out, dout):
  54. """Backpropagator for primitive `scalar_gt`."""
  55. return C.zeros_like(x), C.zeros_like(y)
  56. @bprops.register("scalar_lt")
  57. def bprop_scalar_lt(x, y, out, dout):
  58. """Backpropagator for primitive `scalar_lt`."""
  59. return C.zeros_like(x), C.zeros_like(y)
  60. @bprops.register("scalar_ge")
  61. def bprop_scalar_ge(x, y, out, dout):
  62. """Backpropagator for primitive `scalar_ge`."""
  63. return C.zeros_like(x), C.zeros_like(y)
  64. @bprops.register("scalar_le")
  65. def bprop_scalar_le(x, y, out, dout):
  66. """Backpropagator for primitive `scalar_le`."""
  67. return C.zeros_like(x), C.zeros_like(y)
  68. @bprops.register("scalar_eq")
  69. def bprop_scalar_eq(x, y, out, dout):
  70. """Backpropagator for primitive `scalar_eq`."""
  71. return C.zeros_like(x), C.zeros_like(y)
  72. @bprops.register("scalar_ne")
  73. def bprop_scalar_ne(x, y, out, dout):
  74. """Backpropagator for primitive `scalar_eq`."""
  75. return C.zeros_like(x), C.zeros_like(y)
  76. @bprops.register("scalar_cast")
  77. def bprop_scalar_cast(x, t, out, dout):
  78. """Backpropagator for primitive `scalar_cast`."""
  79. return F.scalar_cast(dout, F.typeof(x)), t
  80. @bprops.register("tuple_getitem")
  81. def bprop_tuple_getitem(data, idx, out, dout):
  82. """Backpropagator for primitive `tuple_getitem`."""
  83. return F.tuple_setitem(C.zeros_like(data), idx, dout), C.zeros_like(idx)
  84. @bprops.register("identity")
  85. def bprop_identity(x, out, dout):
  86. """Backpropagator for primitive `identity`."""
  87. return (dout,)
  88. @bprops.register("make_ref")
  89. def bprop_make_ref(key, x, y, out, dout):
  90. """Backpropagator for primitive `make_ref`."""
  91. return (C.zeros_like(key), dout, C.zeros_like(y))
  92. @bprops.register("get_ref_value")
  93. def bprop_get_ref_value(x, out, dout):
  94. """Backpropagator for primitive `get_ref_value`."""
  95. return (dout,)
  96. @bprops.register("get_ref_key")
  97. def bprop_get_ref_key(x, out, dout):
  98. """Backpropagator for primitive `get_ref_key`."""
  99. return (C.zeros_like(x),)
  100. @bprops.register("scalar_to_array")
  101. def bprop_scalar_to_array(x, out, dout):
  102. """Backpropagator for primitive `scalar_to_array`."""
  103. return (F.array_to_scalar(dout),)
  104. @bprops.register("array_to_scalar")
  105. def bprop_array_to_scalar(x, out, dout):
  106. """Backpropagator for primitive `array_to_scalar`."""
  107. return (F.scalar_to_array(dout),)
  108. @bprops.register("dot")
  109. def bprop_dot(x, y, out, dout):
  110. """Backpropagator for primitive `dot`."""
  111. return F.dot(dout, F.transpose(y, (1, 0))), F.dot(F.transpose(x, (1, 0)), dout)
  112. @bprops.register("reshape")
  113. def bprop_reshape(xs, shp, out, dout):
  114. """Backpropagator for primitive `reshape`."""
  115. return F.reshape(dout, F.shape(xs)), C.zeros_like(shp)
  116. @bprops.register("distribute")
  117. def bprop_distribute(arr, shp, out, dout):
  118. """Backpropagator for primitive `distribute`."""
  119. return F.array_reduce(F.scalar_add, dout, F.shape(arr)), C.zeros_like(shp)
  120. @bprops.register("shape")
  121. def bprop_shape(arr, out, dout):
  122. """Backpropagator for primitive `shape`."""
  123. return (C.zeros_like(arr),)
  124. @bprops.register("broadcast_shape")
  125. def bprop_broadcast_shape(shp1, shp2, out, dout):
  126. """Backpropagator for primitive `broadcast_shape`."""
  127. return C.zeros_like(shp1), C.zeros_like(shp2)
  128. @bprops.register("J")
  129. def bprop_j(x, out, dout):
  130. """Backpropagator for primitive `J`."""
  131. return (F.jinv(dout),)
  132. @bprops.register("array_reduce")
  133. def bprop_array_reduce(fn, x, shp, out, dout):
  134. """Backpropagator for primitive `array_reduce`."""
  135. return F.distribute(dout, F.shape(x)), C.zeros_like(shp)
  136. @bprops.register("depend")
  137. def bprop_depend(x, y, out, dout):
  138. """Backpropagator for primitive `depend`."""
  139. return dout, C.zeros_like(y)
  140. @bprops.register("embed")
  141. def bprop_embed(x, out, dout):
  142. """Backpropagator for primitive `embed`."""
  143. return (C.zeros_like(x),)
  144. @bprops.register("bool_not")
  145. def bprop_bool_not(x, out, dout):
  146. """Backpropagator for primitive `bool_not`."""
  147. return (C.zeros_like(x),)
  148. @bprops.register("bool_or")
  149. def bprop_bool_or(x, y, out, dout):
  150. """Backpropagator for primitive `bool_or`."""
  151. return C.zeros_like(x), C.zeros_like(y)
  152. @bprops.register("stop_gradient")
  153. def bprop_stop_gradient(x, out, dout):
  154. """Backpropagator for primitive `stop_gradient`."""
  155. return (C.zeros_like(x),)
  156. @bprops.register("bool_and")
  157. def bprop_bool_and(x, y, out, dout):
  158. """Backpropagator for primitive `bool_and`."""
  159. return C.zeros_like(x), C.zeros_like(y)
  160. @bprops.register("ControlDepend")
  161. def bprop_control_depend(x, y, out, dout):
  162. """Backpropagator for primitive `Control_depend`."""
  163. return C.zeros_like(x), C.zeros_like(y)
  164. @bprops.register("switch")
  165. def bprop_switch(cond, tb, fb, out, dout):
  166. """Backpropagator for primitive `switch`."""
  167. return C.zeros_like(cond), F.switch(cond, dout, C.zeros_like(tb)), \
  168. F.switch(cond, C.zeros_like(fb), dout)