You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

grad_array_ops.py 15 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """array_ops"""
  16. from .. import operations as P
  17. from ..operations import _grad_ops as G
  18. from ..composite.multitype_ops.zeros_like_impl import zeros_like
  19. from .. import functional as F
  20. from .grad_base import bprop_getters
  21. from ..primitive import constexpr
  22. from ... import context
  23. from ...common import dtype as mstype
  24. reduce_sum = P.ReduceSum()
  25. unsorted_segment_sum = P.UnsortedSegmentSum()
  26. transpose = P.Transpose()
  27. shape_op = P.Shape()
  28. reshape = P.Reshape()
  29. invert_permutation = P.InvertPermutation()
  30. logical_and = P.LogicalAnd()
  31. @bprop_getters.register(P.Fill)
  32. def get_bprop_fill(self):
  33. """Generate bprop for Fill"""
  34. def bprop(dtype, dims, x, out, dout):
  35. return zeros_like(dims), zeros_like(x)
  36. return bprop
  37. @bprop_getters.register(P.DType)
  38. def get_bprop_dtype(self):
  39. """Generate bprop for DType"""
  40. def bprop(x, out, dout):
  41. return (zeros_like(x),)
  42. return bprop
  43. @bprop_getters.register(P.Cast)
  44. def get_bprop_cast(self):
  45. """Generate bprop for Cast"""
  46. cast = P.Cast()
  47. get_dtype = P.DType()
  48. def bprop(x, t, out, dout):
  49. dx = cast(dout, get_dtype(x))
  50. return dx, zeros_like(t)
  51. return bprop
  52. @bprop_getters.register(P.Shape)
  53. def get_bprop_shape(self):
  54. """Generate bprop for Shape"""
  55. def bprop(x, out, dout):
  56. return (zeros_like(x),)
  57. return bprop
  58. @bprop_getters.register(P.Split)
  59. def get_bprop_split(self):
  60. """Generate bprop for Split"""
  61. axis = self.axis
  62. def bprop(x, out, dout):
  63. concat_op = P.Concat(axis)
  64. dx = concat_op(dout)
  65. return (dx,)
  66. return bprop
  67. @bprop_getters.register(P.Rank)
  68. def get_bprop_rank(self):
  69. """Generate bprop for Rank"""
  70. def bprop(x, out, dout):
  71. return (zeros_like(x),)
  72. return bprop
  73. @bprop_getters.register(P.Reshape)
  74. def get_bprop_reshape(self):
  75. """Generate bprop for Reshape"""
  76. def bprop(x, shp, out, dout):
  77. shapex = shape_op(x)
  78. return reshape(dout, shapex), zeros_like(shp)
  79. return bprop
  80. @bprop_getters.register(P.ExpandDims)
  81. def get_bprop_expand_dims(self):
  82. """Generate bprop for ExpandDims"""
  83. def bprop(x, axis, out, dout):
  84. shapex = shape_op(x)
  85. return reshape(dout, shapex), zeros_like(axis)
  86. return bprop
  87. @bprop_getters.register(P.Squeeze)
  88. def get_bprop_squeeze(self):
  89. """Generate bprop for Squeeze"""
  90. def bprop(x, out, dout):
  91. shapex = shape_op(x)
  92. return (reshape(dout, shapex),)
  93. return bprop
  94. @bprop_getters.register(P.Flatten)
  95. def get_bprop_flatten(self):
  96. """Generate bprop for Flatten"""
  97. flatten_grad = G.FlattenGrad()
  98. def bprop(x, out, dout):
  99. dx = flatten_grad(dout, shape_op(x))
  100. return (dx,)
  101. return bprop
  102. @constexpr
  103. def _tile_shape(multiples, shapex):
  104. """Calculate [1,2], [3, 4] -> [1,3,2,4]."""
  105. len_muli = len(multiples)
  106. rank = len(shapex)
  107. len_cmp = len_muli - rank
  108. max_len = max(len_muli, rank)
  109. i = 0
  110. j = 0
  111. ret = []
  112. while (i < max_len) and (j < max_len):
  113. if len_cmp == 0:
  114. ret.append(multiples[i])
  115. ret.append(shapex[j])
  116. i += 1
  117. j += 1
  118. elif len_cmp > 0:
  119. ret.append(multiples[i])
  120. ret.append(1)
  121. i += 1
  122. len_cmp -= 1
  123. else:
  124. ret.append(1)
  125. ret.append(shapex[j])
  126. len_cmp += 1
  127. return tuple(ret)
  128. @bprop_getters.register(P.Tile)
  129. def get_bprop_tile(self):
  130. """Generate bprop for Tile"""
  131. def bprop(x, multiples, out, dout):
  132. shapex = shape_op(x)
  133. r_shape = _tile_shape(multiples, shapex)
  134. # 0 represents the start index, and 2 represents the step
  135. axis = F.make_range(0, len(r_shape), 2)
  136. dx = reduce_sum(reshape(dout, r_shape), axis)
  137. dx = reshape(dx, shapex)
  138. return dx, zeros_like(multiples)
  139. return bprop
  140. @bprop_getters.register(P.Transpose)
  141. def get_bprop_transpose(self):
  142. """Generate bprop for Transpose"""
  143. def bprop(x, perm, out, dout):
  144. return transpose(dout, invert_permutation(perm)), zeros_like(perm)
  145. return bprop
  146. @bprop_getters.register(P.Concat)
  147. def get_bprop_concat(self):
  148. """Generate bprop for Concat"""
  149. axis = self.axis
  150. def bprop(x, out, dout):
  151. dx = ()
  152. out_offset = G.ConcatOffset(F.tuple_len(x), axis)(x)
  153. for i in range(F.tuple_len(x)):
  154. slice_out = P.Slice()(dout, out_offset[i], shape_op(x[i]))
  155. dx = dx + (slice_out,)
  156. return (dx,)
  157. return bprop
  158. @constexpr
  159. def _slice_grad_pad(begins, sizes, shapes):
  160. pads = tuple((begin, shape - begin - size) for begin, size, shape in zip(begins, sizes, shapes))
  161. return pads
  162. @bprop_getters.register(P.Slice)
  163. def get_bprop_slice(self):
  164. """Generate bprop for Slice"""
  165. def bprop(x, begin, size, out, dout):
  166. dx = P.Pad(_slice_grad_pad(begin, size, shape_op(x)))(dout)
  167. return (dx, zeros_like(begin), zeros_like(size))
  168. def bprop_gpu(x, begin, size, out, dout):
  169. dx = dx = G.SliceGrad()(dout, x, begin, size)
  170. return (dx, zeros_like(begin), zeros_like(size))
  171. if context.get_context('device_target') == "GPU":
  172. return bprop_gpu
  173. return bprop
  174. @constexpr
  175. def _generate_shape_index(out_shape, indices_shape, axis):
  176. out_rank = len(out_shape)
  177. ind_rank = len(indices_shape)
  178. if axis < 0:
  179. axis += out_rank - ind_rank + 1
  180. perm_part1 = tuple(range(axis, axis + ind_rank))
  181. index = tuple(range(out_rank))
  182. perm = perm_part1 + index[:axis] + index[axis + ind_rank:]
  183. return perm
  184. @constexpr
  185. def _generate_inverse_index(x_shape, axis):
  186. x_rank = len(x_shape)
  187. index = tuple(range(x_rank))
  188. if axis < 0:
  189. axis += x_rank
  190. perm = index[1:1 + axis] + (0,) + index[1 + axis:]
  191. return perm
  192. @bprop_getters.register(P.GatherV2)
  193. def get_bprop_gather_v2(self):
  194. """Generate bprop for GatherV2"""
  195. def bprop(x, indices, axis, out, dout):
  196. if F.rank(dout) == 0:
  197. dout = P.ExpandDims()(dout, -1)
  198. if F.rank(indices) == 0:
  199. indices = P.ExpandDims()(indices, -1)
  200. x_shp = shape_op(x)
  201. out_shp = shape_op(dout)
  202. ind_shp = shape_op(indices)
  203. # Example: out_shape:(3,2,3) axis 1 -> (1,0,2)
  204. perm_1 = _generate_shape_index(out_shp, ind_shp, axis)
  205. values_transpose = transpose(dout, perm_1)
  206. params_grad = unsorted_segment_sum(values_transpose, indices, shape_op(x)[axis])
  207. # Example: out_shape:(3,2,3) axis 2 -> (1,2,0)
  208. perm_2 = _generate_inverse_index(x_shp, axis)
  209. params_grad = transpose(params_grad, perm_2)
  210. return params_grad, zeros_like(indices), zeros_like(axis)
  211. return bprop
  212. @bprop_getters.register(P.Pack)
  213. def get_bprop_pack(self):
  214. """Generate bprop for Pack"""
  215. axis = self.axis
  216. def bprop(x, out, dout):
  217. pack_grad = P.Unpack(axis)
  218. out = pack_grad(dout)
  219. return (out,)
  220. return bprop
  221. @bprop_getters.register(P.Unpack)
  222. def get_bprop_unpack(self):
  223. """Generate bprop for Unpack"""
  224. axis = self.axis
  225. def bprop(x, out, dout):
  226. unpack_grad = P.Pack(axis)
  227. out = unpack_grad(dout)
  228. return (out,)
  229. return bprop
  230. @bprop_getters.register(P.StridedSlice)
  231. def get_bprop_strided_slice(self):
  232. """Generate bprop for StridedSlice"""
  233. input_grad = G.StridedSliceGrad(self.begin_mask,
  234. self.end_mask,
  235. self.ellipsis_mask,
  236. self.new_axis_mask,
  237. self.shrink_axis_mask)
  238. def bprop(x, begin, end, strides, out, dout):
  239. dx = input_grad(dout, shape_op(x), begin, end, strides)
  240. return dx, zeros_like(begin), zeros_like(end), zeros_like(strides)
  241. return bprop
  242. @bprop_getters.register(P.Eye)
  243. def get_bprop_eye(self):
  244. """Generate bprop for Eye"""
  245. def bprop(n, m, t, out, dout):
  246. return zeros_like(n), zeros_like(m), zeros_like(t)
  247. return bprop
  248. @bprop_getters.register(P.Select)
  249. def get_bprop_select(self):
  250. """Generate bprop for Select"""
  251. select = P.Select()
  252. def bprop(cond, x, y, out, dout):
  253. return zeros_like(cond), select(cond, dout, zeros_like(x)), select(cond, zeros_like(y), dout)
  254. return bprop
  255. @bprop_getters.register(P.OnesLike)
  256. def get_bprop_oneslike(self):
  257. """Generate bprop for OnesLike"""
  258. def bprop(x, out, dout):
  259. return (zeros_like(x),)
  260. return bprop
  261. @bprop_getters.register(P.ZerosLike)
  262. def get_bprop_zeroslike(self):
  263. """Generate bprop for OnesLike"""
  264. def bprop(x, out, dout):
  265. return (zeros_like(x),)
  266. return bprop
  267. @bprop_getters.register(P.ResizeNearestNeighbor)
  268. def get_bprop_resize_nearest_neighbor(self):
  269. """Generate bprop for ResizeNearestNeighbor"""
  270. op = G.ResizeNearestNeighborGrad(self.align_corners)
  271. def bprop(inputs, out, dout):
  272. shp = shape_op(inputs)
  273. # 2 and 3 represent the height and width
  274. shp = (shp[2], shp[3])
  275. return (op(dout, shp),)
  276. return bprop
  277. @bprop_getters.register(P.GatherNd)
  278. def get_bprop_gather_nd(self):
  279. """Generate bprop for GatherNd"""
  280. op = P.ScatterNd()
  281. def bprop(x, indices, out, dout):
  282. shp = shape_op(x)
  283. return op(indices, dout, shp), zeros_like(indices)
  284. return bprop
  285. @bprop_getters.register(P.ScatterNd)
  286. def get_bprop_scatter_nd(self):
  287. """Generate bprop for ScatterNd"""
  288. op = P.GatherNd()
  289. def bprop(indices, x, shape, out, dout):
  290. return zeros_like(indices), op(dout, indices), zeros_like(shape)
  291. return bprop
  292. @bprop_getters.register(P.ScatterNdUpdate)
  293. def get_bprop_scatter_nd_update(self):
  294. """Generate bprop for ScatterNdUpdate"""
  295. op = P.GatherNd()
  296. def bprop(x, indices, update, out, dout):
  297. return dout, zeros_like(indices), op(dout, indices)
  298. return bprop
  299. @bprop_getters.register(P.Argmax)
  300. def get_bprop_argmax(self):
  301. """Generate bprop for Argmax"""
  302. def bprop(x, out, dout):
  303. return (zeros_like(x),)
  304. return bprop
  305. @bprop_getters.register(P.Argmin)
  306. def get_bprop_argmin(self):
  307. """Generate bprop for Argmin"""
  308. def bprop(x, out, dout):
  309. return (zeros_like(x),)
  310. return bprop
  311. @bprop_getters.register(P.SpaceToDepth)
  312. def get_bprop_space_to_depth(self):
  313. """Generate bprop for SpaceToDepth"""
  314. op = P.DepthToSpace(self.block_size)
  315. def bprop(x, out, dout):
  316. return (op(dout),)
  317. return bprop
  318. @bprop_getters.register(P.DepthToSpace)
  319. def get_bprop_depth_to_space(self):
  320. """Generate bprop for DepthToSpace"""
  321. op = P.SpaceToDepth(self.block_size)
  322. def bprop(x, out, dout):
  323. return (op(dout),)
  324. return bprop
  325. @bprop_getters.register(P.Diag)
  326. def get_bprop_diag(self):
  327. """Generate bprop for Diag"""
  328. op = P.DiagPart()
  329. def bprop(x, out, dout):
  330. return (op(dout),)
  331. return bprop
  332. @bprop_getters.register(P.DiagPart)
  333. def get_bprop_diag_part(self):
  334. """Generate bprop for DiagPart"""
  335. op = P.Diag()
  336. def bprop(x, out, dout):
  337. return (op(dout),)
  338. return bprop
  339. def _GatherDropNegatives(params,
  340. ids,
  341. zero_clipped_indices=None,
  342. is_positive=None):
  343. """Helper function for unsorted segment ops."""
  344. maximum = P.Maximum()
  345. gather = P.GatherV2()
  346. greater_equal = P.GreaterEqual()
  347. rank = P.Rank()
  348. fill = P.Fill()
  349. select = P.Select()
  350. if zero_clipped_indices is None:
  351. zero_clipped_indices = maximum(ids, zeros_like(ids))
  352. gathered = gather(params, zero_clipped_indices, 0)
  353. if is_positive is None:
  354. is_positive = greater_equal(ids, 0)
  355. is_positive_shape = shape_op(is_positive)
  356. broadcastable_shape = is_positive_shape
  357. for _ in range(rank(gathered) - rank(is_positive)):
  358. broadcastable_shape += (1,)
  359. is_positive = reshape(is_positive, broadcastable_shape)
  360. gathered_shape = shape_op(gathered)
  361. is_positive = logical_and(is_positive, fill(mstype.bool_, gathered_shape, 1))
  362. zero_slice = zeros_like(gathered)
  363. return (select(is_positive, gathered, zero_slice), zero_clipped_indices, is_positive)
  364. @bprop_getters.register(P.UnsortedSegmentMin)
  365. def get_bprop_unsorted_segment_min(self):
  366. """Generate bprop for UnsortedSegmentMin"""
  367. equal = P.Equal()
  368. cast = P.Cast()
  369. divide = P.RealDiv()
  370. get_dtype = P.DType()
  371. select = P.Select()
  372. def bprop(x, segment_ids, num_segments, out, dout):
  373. gathered_outputs, zero_clipped_indices, is_positive = _GatherDropNegatives(out, segment_ids)
  374. is_selected = equal(x, gathered_outputs)
  375. is_selected = logical_and(is_selected, is_positive)
  376. num_selected = unsorted_segment_sum(cast(is_selected, get_dtype(dout)),
  377. segment_ids, num_segments)
  378. weighted_grads = divide(dout, num_selected)
  379. gathered_grads, _, _ = _GatherDropNegatives(weighted_grads, None,
  380. zero_clipped_indices, is_positive)
  381. zeros = zeros_like(gathered_grads)
  382. return select(is_selected, gathered_grads, zeros), zeros_like(segment_ids), zeros_like(num_segments)
  383. return bprop
  384. @bprop_getters.register(P.SpaceToBatch)
  385. def get_bprop_space_to_batch(self):
  386. """Generate bprop for SpaceToBatch"""
  387. space_to_batch_grad = P.BatchToSpace(self.block_size, self.paddings)
  388. def bprop(x, out, dout):
  389. dx = space_to_batch_grad(dout)
  390. return (dx,)
  391. return bprop
  392. @bprop_getters.register(P.BatchToSpace)
  393. def get_bprop_batch_to_space(self):
  394. """Generate bprop for BatchToSpace"""
  395. batch_to_space_grad = P.SpaceToBatch(self.block_size, self.crops)
  396. def bprop(x, out, dout):
  397. dx = batch_to_space_grad(dout)
  398. return (dx,)
  399. return bprop