You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

grad_array_ops.py 12 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """array_ops"""
  16. from .. import operations as P
  17. from ..operations import _grad_ops as G
  18. from ..composite.multitype_ops.zeros_like_impl import zeros_like
  19. from .. import functional as F
  20. from .grad_base import bprop_getters
  21. from ..primitive import constexpr
  22. from ... import context
  23. reduce_sum = P.ReduceSum()
  24. unsorted_segment_sum = P.UnsortedSegmentSum()
  25. transpose = P.Transpose()
  26. shape_op = P.Shape()
  27. reshape = P.Reshape()
  28. invert_permutation = P.InvertPermutation()
  29. @bprop_getters.register(P.Fill)
  30. def get_bprop_fill(self):
  31. """Generate bprop for Fill"""
  32. def bprop(dtype, dims, x, out, dout):
  33. return zeros_like(dims), zeros_like(x)
  34. return bprop
  35. @bprop_getters.register(P.DType)
  36. def get_bprop_dtype(self):
  37. """Generate bprop for DType"""
  38. def bprop(x, out, dout):
  39. return (zeros_like(x),)
  40. return bprop
  41. @bprop_getters.register(P.Cast)
  42. def get_bprop_cast(self):
  43. """Generate bprop for Cast"""
  44. cast = P.Cast()
  45. get_dtype = P.DType()
  46. def bprop(x, t, out, dout):
  47. dx = cast(dout, get_dtype(x))
  48. return dx, zeros_like(t)
  49. return bprop
  50. @bprop_getters.register(P.Shape)
  51. def get_bprop_shape(self):
  52. """Generate bprop for Shape"""
  53. def bprop(x, out, dout):
  54. return (zeros_like(x),)
  55. return bprop
  56. @bprop_getters.register(P.Split)
  57. def get_bprop_split(self):
  58. """Generate bprop for Split"""
  59. axis = self.axis
  60. def bprop(x, out, dout):
  61. concat_op = P.Concat(axis)
  62. dx = concat_op(dout)
  63. return (dx,)
  64. return bprop
  65. @bprop_getters.register(P.Rank)
  66. def get_bprop_rank(self):
  67. """Generate bprop for Rank"""
  68. def bprop(x, out, dout):
  69. return (zeros_like(x),)
  70. return bprop
  71. @bprop_getters.register(P.Reshape)
  72. def get_bprop_reshape(self):
  73. """Generate bprop for Reshape"""
  74. def bprop(x, shp, out, dout):
  75. shapex = shape_op(x)
  76. return reshape(dout, shapex), zeros_like(shp)
  77. return bprop
  78. @bprop_getters.register(P.ExpandDims)
  79. def get_bprop_expand_dims(self):
  80. """Generate bprop for ExpandDims"""
  81. def bprop(x, axis, out, dout):
  82. shapex = shape_op(x)
  83. return reshape(dout, shapex), zeros_like(axis)
  84. return bprop
  85. @bprop_getters.register(P.Squeeze)
  86. def get_bprop_squeeze(self):
  87. """Generate bprop for Squeeze"""
  88. def bprop(x, out, dout):
  89. shapex = shape_op(x)
  90. return (reshape(dout, shapex),)
  91. return bprop
  92. @bprop_getters.register(P.Flatten)
  93. def get_bprop_flatten(self):
  94. """Generate bprop for Flatten"""
  95. flatten_grad = G.FlattenGrad()
  96. def bprop(x, out, dout):
  97. dx = flatten_grad(dout, shape_op(x))
  98. return (dx,)
  99. return bprop
  100. @constexpr
  101. def _tile_shape(multiples, shapex):
  102. """Calculate [1,2], [3, 4] -> [1,3,2,4]."""
  103. len_muli = len(multiples)
  104. rank = len(shapex)
  105. len_cmp = len_muli - rank
  106. max_len = max(len_muli, rank)
  107. i = 0
  108. j = 0
  109. ret = []
  110. while (i < max_len) and (j < max_len):
  111. if len_cmp == 0:
  112. ret.append(multiples[i])
  113. ret.append(shapex[j])
  114. i += 1
  115. j += 1
  116. elif len_cmp > 0:
  117. ret.append(multiples[i])
  118. ret.append(1)
  119. i += 1
  120. len_cmp -= 1
  121. else:
  122. ret.append(1)
  123. ret.append(shapex[j])
  124. len_cmp += 1
  125. return tuple(ret)
  126. @bprop_getters.register(P.Tile)
  127. def get_bprop_tile(self):
  128. """Generate bprop for Tile"""
  129. def bprop(x, multiples, out, dout):
  130. shapex = shape_op(x)
  131. r_shape = _tile_shape(multiples, shapex)
  132. # 0 represents the start index, and 2 represents the step
  133. axis = F.make_range(0, len(r_shape), 2)
  134. dx = reduce_sum(reshape(dout, r_shape), axis)
  135. dx = reshape(dx, shapex)
  136. return dx, zeros_like(multiples)
  137. return bprop
  138. @bprop_getters.register(P.Transpose)
  139. def get_bprop_transpose(self):
  140. """Generate bprop for Transpose"""
  141. def bprop(x, perm, out, dout):
  142. return transpose(dout, invert_permutation(perm)), zeros_like(perm)
  143. return bprop
  144. @bprop_getters.register(P.Concat)
  145. def get_bprop_concat(self):
  146. """Generate bprop for Concat"""
  147. axis = self.axis
  148. def bprop(x, out, dout):
  149. dx = ()
  150. out_offset = G.ConcatOffset(F.tuple_len(x), axis)(x)
  151. for i in range(F.tuple_len(x)):
  152. slice_out = P.Slice()(dout, out_offset[i], shape_op(x[i]))
  153. dx = dx + (slice_out,)
  154. return (dx,)
  155. return bprop
  156. @constexpr
  157. def _slice_grad_pad(begins, sizes, shapes):
  158. pads = tuple((begin, shape - begin - size) for begin, size, shape in zip(begins, sizes, shapes))
  159. return pads
  160. @bprop_getters.register(P.Slice)
  161. def get_bprop_slice(self):
  162. """Generate bprop for Slice"""
  163. def bprop(x, begin, size, out, dout):
  164. dx = P.Pad(_slice_grad_pad(begin, size, shape_op(x)))(dout)
  165. return (dx, zeros_like(begin), zeros_like(size))
  166. def bprop_gpu(x, begin, size, out, dout):
  167. dx = dx = G.SliceGrad()(dout, x, begin, size)
  168. return (dx, zeros_like(begin), zeros_like(size))
  169. if context.get_context('device_target') == "GPU":
  170. return bprop_gpu
  171. return bprop
  172. @constexpr
  173. def _generate_shape_index(out_shape, indices_shape, axis):
  174. out_rank = len(out_shape)
  175. ind_rank = len(indices_shape)
  176. if axis < 0:
  177. axis += out_rank - ind_rank + 1
  178. perm_part1 = tuple(range(axis, axis + ind_rank))
  179. index = tuple(range(out_rank))
  180. perm = perm_part1 + index[:axis] + index[axis + ind_rank:]
  181. return perm
  182. @constexpr
  183. def _generate_inverse_index(x_shape, axis):
  184. x_rank = len(x_shape)
  185. index = tuple(range(x_rank))
  186. if axis < 0:
  187. axis += x_rank
  188. perm = index[1:1 + axis] + (0,) + index[1 + axis:]
  189. return perm
  190. @bprop_getters.register(P.GatherV2)
  191. def get_bprop_gather_v2(self):
  192. """Generate bprop for GatherV2"""
  193. def bprop(x, indices, axis, out, dout):
  194. if F.rank(dout) == 0:
  195. dout = P.ExpandDims()(dout, -1)
  196. if F.rank(indices) == 0:
  197. indices = P.ExpandDims()(indices, -1)
  198. x_shp = shape_op(x)
  199. out_shp = shape_op(dout)
  200. ind_shp = shape_op(indices)
  201. # Example: out_shape:(3,2,3) axis 1 -> (1,0,2)
  202. perm_1 = _generate_shape_index(out_shp, ind_shp, axis)
  203. values_transpose = transpose(dout, perm_1)
  204. params_grad = unsorted_segment_sum(values_transpose, indices, shape_op(x)[axis])
  205. # Example: out_shape:(3,2,3) axis 2 -> (1,2,0)
  206. perm_2 = _generate_inverse_index(x_shp, axis)
  207. params_grad = transpose(params_grad, perm_2)
  208. return params_grad, zeros_like(indices), zeros_like(axis)
  209. return bprop
  210. @bprop_getters.register(P.Pack)
  211. def get_bprop_pack(self):
  212. """Generate bprop for Pack"""
  213. axis = self.axis
  214. def bprop(x, out, dout):
  215. pack_grad = P.Unpack(axis)
  216. out = pack_grad(dout)
  217. return (out,)
  218. return bprop
  219. @bprop_getters.register(P.Unpack)
  220. def get_bprop_unpack(self):
  221. """Generate bprop for Unpack"""
  222. axis = self.axis
  223. def bprop(x, out, dout):
  224. unpack_grad = P.Pack(axis)
  225. out = unpack_grad(dout)
  226. return (out,)
  227. return bprop
  228. @bprop_getters.register(P.StridedSlice)
  229. def get_bprop_strided_slice(self):
  230. """Generate bprop for StridedSlice"""
  231. input_grad = G.StridedSliceGrad(self.begin_mask,
  232. self.end_mask,
  233. self.ellipsis_mask,
  234. self.new_axis_mask,
  235. self.shrink_axis_mask)
  236. def bprop(x, begin, end, strides, out, dout):
  237. dx = input_grad(dout, shape_op(x), begin, end, strides)
  238. return dx, zeros_like(begin), zeros_like(end), zeros_like(strides)
  239. return bprop
  240. @bprop_getters.register(P.Eye)
  241. def get_bprop_eye(self):
  242. """Generate bprop for Eye"""
  243. def bprop(n, m, t, out, dout):
  244. return zeros_like(n), zeros_like(m), zeros_like(t)
  245. return bprop
  246. @bprop_getters.register(P.Select)
  247. def get_bprop_select(self):
  248. """Generate bprop for Select"""
  249. select = P.Select()
  250. def bprop(cond, x, y, out, dout):
  251. return zeros_like(cond), select(cond, dout, zeros_like(x)), select(cond, zeros_like(y), dout)
  252. return bprop
  253. @bprop_getters.register(P.OnesLike)
  254. def get_bprop_oneslike(self):
  255. """Generate bprop for OnesLike"""
  256. def bprop(x, out, dout):
  257. return (zeros_like(x),)
  258. return bprop
  259. @bprop_getters.register(P.ZerosLike)
  260. def get_bprop_zeroslike(self):
  261. """Generate bprop for OnesLike"""
  262. def bprop(x, out, dout):
  263. return (zeros_like(x),)
  264. return bprop
  265. @bprop_getters.register(P.ResizeNearestNeighbor)
  266. def get_bprop_resize_nearest_neighbor(self):
  267. """Generate bprop for ResizeNearestNeighbor"""
  268. op = G.ResizeNearestNeighborGrad(self.align_corners)
  269. def bprop(inputs, out, dout):
  270. shp = shape_op(inputs)
  271. # 2 and 3 represent the height and width
  272. shp = (shp[2], shp[3])
  273. return (op(dout, shp),)
  274. return bprop
  275. @bprop_getters.register(P.GatherNd)
  276. def get_bprop_gather_nd(self):
  277. """Generate bprop for GatherNd"""
  278. op = P.ScatterNd()
  279. def bprop(x, indices, out, dout):
  280. shp = shape_op(x)
  281. return op(indices, dout, shp), zeros_like(indices)
  282. return bprop
  283. @bprop_getters.register(P.ScatterNd)
  284. def get_bprop_scatter_nd(self):
  285. """Generate bprop for ScatterNd"""
  286. op = P.GatherNd()
  287. def bprop(indices, x, shape, out, dout):
  288. return zeros_like(indices), op(dout, indices), zeros_like(shape)
  289. return bprop
  290. @bprop_getters.register(P.ScatterNdUpdate)
  291. def get_bprop_scatter_nd_update(self):
  292. """Generate bprop for ScatterNdUpdate"""
  293. op = P.GatherNd()
  294. def bprop(x, indices, update, out, dout):
  295. return dout, zeros_like(indices), op(dout, indices)
  296. return bprop
  297. @bprop_getters.register(P.Argmax)
  298. def get_bprop_argmax(self):
  299. """Generate bprop for Argmax"""
  300. def bprop(x, out, dout):
  301. return (zeros_like(x),)
  302. return bprop
  303. @bprop_getters.register(P.Argmin)
  304. def get_bprop_argmin(self):
  305. """Generate bprop for Argmin"""
  306. def bprop(x, out, dout):
  307. return (zeros_like(x),)
  308. return bprop
  309. @bprop_getters.register(P.SpaceToDepth)
  310. def get_bprop_space_to_depth(self):
  311. """Generate bprop for SpaceToDepth"""
  312. op = P.DepthToSpace(self.block_size)
  313. def bprop(x, out, dout):
  314. return (op(dout),)
  315. return bprop
  316. @bprop_getters.register(P.DepthToSpace)
  317. def get_bprop_depth_to_space(self):
  318. """Generate bprop for DepthToSpace"""
  319. op = P.SpaceToDepth(self.block_size)
  320. def bprop(x, out, dout):
  321. return (op(dout),)
  322. return bprop
  323. @bprop_getters.register(P.Diag)
  324. def get_bprop_diag(self):
  325. """Generate bprop for Diag"""
  326. op = P.DiagPart()
  327. def bprop(x, out, dout):
  328. return (op(dout),)
  329. return bprop
  330. @bprop_getters.register(P.DiagPart)
  331. def get_bprop_diag_part(self):
  332. """Generate bprop for DiagPart"""
  333. op = P.Diag()
  334. def bprop(x, out, dout):
  335. return (op(dout),)
  336. return bprop
  337. @bprop_getters.register(P.SpaceToBatch)
  338. def get_bprop_space_to_batch(self):
  339. """Generate bprop for SpaceToBatch"""
  340. space_to_batch_grad = P.BatchToSpace(self.block_size, self.paddings)
  341. def bprop(x, out, dout):
  342. dx = space_to_batch_grad(dout)
  343. return (dx,)
  344. return bprop
  345. @bprop_getters.register(P.BatchToSpace)
  346. def get_bprop_batch_to_space(self):
  347. """Generate bprop for BatchToSpace"""
  348. batch_to_space_grad = P.SpaceToBatch(self.block_size, self.crops)
  349. def bprop(x, out, dout):
  350. dx = batch_to_space_grad(dout)
  351. return (dx,)
  352. return bprop