You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

grad_array_ops.py 17 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """array_ops"""
  16. from .. import operations as P
  17. from ..operations import _grad_ops as G
  18. from ..composite.multitype_ops.zeros_like_impl import zeros_like
  19. from ..functional import broadcast_gradient_args
  20. from .. import functional as F
  21. from .grad_base import bprop_getters
  22. from ..primitive import constexpr
  23. from ... import context
  24. from ...common import dtype as mstype
  25. reduce_sum = P.ReduceSum()
  26. unsorted_segment_sum = P.UnsortedSegmentSum()
  27. transpose = P.Transpose()
  28. shape_op = P.Shape()
  29. reshape = P.Reshape()
  30. size_op = P.Size()
  31. invert_permutation = P.InvertPermutation()
  32. logical_and = P.LogicalAnd()
  33. @bprop_getters.register(P.Fill)
  34. def get_bprop_fill(self):
  35. """Generate bprop for Fill"""
  36. def bprop(dtype, dims, x, out, dout):
  37. return zeros_like(dims), zeros_like(x)
  38. return bprop
  39. @bprop_getters.register(P.DType)
  40. def get_bprop_dtype(self):
  41. """Generate bprop for DType"""
  42. def bprop(x, out, dout):
  43. return (zeros_like(x),)
  44. return bprop
  45. @bprop_getters.register(P.Cast)
  46. def get_bprop_cast(self):
  47. """Generate bprop for Cast"""
  48. cast = P.Cast()
  49. get_dtype = P.DType()
  50. def bprop(x, t, out, dout):
  51. dx = cast(dout, get_dtype(x))
  52. return dx, zeros_like(t)
  53. return bprop
  54. @bprop_getters.register(P.Shape)
  55. def get_bprop_shape(self):
  56. """Generate bprop for Shape"""
  57. def bprop(x, out, dout):
  58. return (zeros_like(x),)
  59. return bprop
  60. @bprop_getters.register(P.Split)
  61. def get_bprop_split(self):
  62. """Generate bprop for Split"""
  63. axis = self.axis
  64. def bprop(x, out, dout):
  65. concat_op = P.Concat(axis)
  66. dx = concat_op(dout)
  67. return (dx,)
  68. return bprop
  69. @bprop_getters.register(P.Rank)
  70. def get_bprop_rank(self):
  71. """Generate bprop for Rank"""
  72. def bprop(x, out, dout):
  73. return (zeros_like(x),)
  74. return bprop
  75. @bprop_getters.register(P.Reshape)
  76. def get_bprop_reshape(self):
  77. """Generate bprop for Reshape"""
  78. def bprop(x, shp, out, dout):
  79. shapex = shape_op(x)
  80. return reshape(dout, shapex), zeros_like(shp)
  81. return bprop
  82. @bprop_getters.register(P.ExpandDims)
  83. def get_bprop_expand_dims(self):
  84. """Generate bprop for ExpandDims"""
  85. def bprop(x, axis, out, dout):
  86. shapex = shape_op(x)
  87. return reshape(dout, shapex), zeros_like(axis)
  88. return bprop
  89. @bprop_getters.register(P.Squeeze)
  90. def get_bprop_squeeze(self):
  91. """Generate bprop for Squeeze"""
  92. def bprop(x, out, dout):
  93. shapex = shape_op(x)
  94. return (reshape(dout, shapex),)
  95. return bprop
  96. @bprop_getters.register(P.Flatten)
  97. def get_bprop_flatten(self):
  98. """Generate bprop for Flatten"""
  99. flatten_grad = G.FlattenGrad()
  100. def bprop(x, out, dout):
  101. dx = flatten_grad(dout, shape_op(x))
  102. return (dx,)
  103. return bprop
  104. @constexpr
  105. def _tile_shape(multiples, shapex):
  106. """Calculate [1,2], [3, 4] -> [1,3,2,4]."""
  107. len_muli = len(multiples)
  108. rank = len(shapex)
  109. len_cmp = len_muli - rank
  110. max_len = max(len_muli, rank)
  111. i = 0
  112. j = 0
  113. ret = []
  114. while (i < max_len) and (j < max_len):
  115. if len_cmp == 0:
  116. ret.append(multiples[i])
  117. ret.append(shapex[j])
  118. i += 1
  119. j += 1
  120. elif len_cmp > 0:
  121. ret.append(multiples[i])
  122. ret.append(1)
  123. i += 1
  124. len_cmp -= 1
  125. else:
  126. ret.append(1)
  127. ret.append(shapex[j])
  128. len_cmp += 1
  129. return tuple(ret)
  130. @bprop_getters.register(P.Tile)
  131. def get_bprop_tile(self):
  132. """Generate bprop for Tile"""
  133. def bprop(x, multiples, out, dout):
  134. shapex = shape_op(x)
  135. r_shape = _tile_shape(multiples, shapex)
  136. # 0 represents the start index, and 2 represents the step
  137. axis = F.make_range(0, len(r_shape), 2)
  138. dx = reduce_sum(reshape(dout, r_shape), axis)
  139. dx = reshape(dx, shapex)
  140. return dx, zeros_like(multiples)
  141. return bprop
  142. @bprop_getters.register(P.Transpose)
  143. def get_bprop_transpose(self):
  144. """Generate bprop for Transpose"""
  145. def bprop(x, perm, out, dout):
  146. return transpose(dout, invert_permutation(perm)), zeros_like(perm)
  147. return bprop
  148. @bprop_getters.register(P.Concat)
  149. def get_bprop_concat(self):
  150. """Generate bprop for Concat"""
  151. axis = self.axis
  152. def bprop(x, out, dout):
  153. dx = ()
  154. out_offset = G.ConcatOffset(F.tuple_len(x), axis)(x)
  155. for i in range(F.tuple_len(x)):
  156. slice_out = P.Slice()(dout, out_offset[i], shape_op(x[i]))
  157. dx = dx + (slice_out,)
  158. return (dx,)
  159. return bprop
  160. @constexpr
  161. def _slice_grad_pad(begins, sizes, shapes):
  162. pads = tuple((begin, shape - begin - size) for begin, size, shape in zip(begins, sizes, shapes))
  163. return pads
  164. @bprop_getters.register(P.Slice)
  165. def get_bprop_slice(self):
  166. """Generate bprop for Slice"""
  167. def bprop(x, begin, size, out, dout):
  168. dx = P.Pad(_slice_grad_pad(begin, size, shape_op(x)))(dout)
  169. return (dx, zeros_like(begin), zeros_like(size))
  170. def bprop_grad(x, begin, size, out, dout):
  171. dx = dx = G.SliceGrad()(dout, x, begin, size)
  172. return (dx, zeros_like(begin), zeros_like(size))
  173. if context.get_context('device_target') == "GPU" or context.get_context('device_target') == "CPU":
  174. return bprop_grad
  175. return bprop
  176. @constexpr
  177. def _generate_shape_index(out_shape, indices_shape, axis):
  178. out_rank = len(out_shape)
  179. ind_rank = len(indices_shape)
  180. if axis < 0:
  181. axis += out_rank - ind_rank + 1
  182. perm_part1 = tuple(range(axis, axis + ind_rank))
  183. index = tuple(range(out_rank))
  184. perm = perm_part1 + index[:axis] + index[axis + ind_rank:]
  185. return perm
  186. @constexpr
  187. def _generate_inverse_index(x_shape, axis):
  188. x_rank = len(x_shape)
  189. index = tuple(range(x_rank))
  190. if axis < 0:
  191. axis += x_rank
  192. perm = index[1:1 + axis] + (0,) + index[1 + axis:]
  193. return perm
  194. @bprop_getters.register(P.GatherV2)
  195. def get_bprop_gather_v2(self):
  196. """Generate bprop for GatherV2"""
  197. def bprop(x, indices, axis, out, dout):
  198. if F.rank(dout) == 0:
  199. dout = P.ExpandDims()(dout, -1)
  200. if F.rank(indices) == 0:
  201. indices = P.ExpandDims()(indices, -1)
  202. x_shp = shape_op(x)
  203. out_shp = shape_op(dout)
  204. ind_shp = shape_op(indices)
  205. # Example: out_shape:(3,2,3) axis 1 -> (1,0,2)
  206. perm_1 = _generate_shape_index(out_shp, ind_shp, axis)
  207. values_transpose = transpose(dout, perm_1)
  208. params_grad = unsorted_segment_sum(values_transpose, indices, shape_op(x)[axis])
  209. # Example: out_shape:(3,2,3) axis 2 -> (1,2,0)
  210. perm_2 = _generate_inverse_index(x_shp, axis)
  211. params_grad = transpose(params_grad, perm_2)
  212. return params_grad, zeros_like(indices), zeros_like(axis)
  213. return bprop
  214. @bprop_getters.register(P.SparseGatherV2)
  215. def get_bprop_sparse_gather_v2(self):
  216. """Generate bprop for SparseGatherV2"""
  217. def bprop(x, indices, axis, out, dout):
  218. x_shp = shape_op(x)
  219. if axis == 0:
  220. indices_size = (size_op(indices),)
  221. x_tail_shp = x_shp[1:]
  222. values_shape = indices_size + x_tail_shp
  223. values = reshape(dout, values_shape)
  224. indices = reshape(indices, indices_size)
  225. return (indices, values, x_shp), zeros_like(indices), zeros_like(axis)
  226. if F.rank(dout) == 0:
  227. dout = P.ExpandDims()(dout, -1)
  228. if F.rank(indices) == 0:
  229. indices = P.ExpandDims()(indices, -1)
  230. out_shp = shape_op(dout)
  231. ind_shp = shape_op(indices)
  232. # Example: out_shape:(3,2,3) axis 1 -> (1,0,2)
  233. perm_1 = _generate_shape_index(out_shp, ind_shp, axis)
  234. values_transpose = transpose(dout, perm_1)
  235. params_grad = unsorted_segment_sum(values_transpose, indices, shape_op(x)[axis])
  236. # Example: out_shape:(3,2,3) axis 2 -> (1,2,0)
  237. perm_2 = _generate_inverse_index(x_shp, axis)
  238. params_grad = transpose(params_grad, perm_2)
  239. return params_grad, zeros_like(indices), zeros_like(axis)
  240. return bprop
  241. @bprop_getters.register(P.Range)
  242. def get_bprop_range(self):
  243. """Generate bprop for Range"""
  244. def bprop(x, out, dout):
  245. return (zeros_like(x),)
  246. return bprop
  247. @bprop_getters.register(P.Pack)
  248. def get_bprop_pack(self):
  249. """Generate bprop for Pack"""
  250. axis = self.axis
  251. def bprop(x, out, dout):
  252. pack_grad = P.Unpack(axis)
  253. out = pack_grad(dout)
  254. return (out,)
  255. return bprop
  256. @bprop_getters.register(P.Unpack)
  257. def get_bprop_unpack(self):
  258. """Generate bprop for Unpack"""
  259. axis = self.axis
  260. def bprop(x, out, dout):
  261. unpack_grad = P.Pack(axis)
  262. out = unpack_grad(dout)
  263. return (out,)
  264. return bprop
  265. @bprop_getters.register(P.StridedSlice)
  266. def get_bprop_strided_slice(self):
  267. """Generate bprop for StridedSlice"""
  268. input_grad = G.StridedSliceGrad(self.begin_mask,
  269. self.end_mask,
  270. self.ellipsis_mask,
  271. self.new_axis_mask,
  272. self.shrink_axis_mask)
  273. def bprop(x, begin, end, strides, out, dout):
  274. dx = input_grad(dout, shape_op(x), begin, end, strides)
  275. return dx, zeros_like(begin), zeros_like(end), zeros_like(strides)
  276. return bprop
  277. @bprop_getters.register(P.Eye)
  278. def get_bprop_eye(self):
  279. """Generate bprop for Eye"""
  280. def bprop(n, m, t, out, dout):
  281. return zeros_like(n), zeros_like(m), zeros_like(t)
  282. return bprop
  283. @bprop_getters.register(P.Select)
  284. def get_bprop_select(self):
  285. """Generate bprop for Select"""
  286. select = P.Select()
  287. def bprop(cond, x, y, out, dout):
  288. return zeros_like(cond), select(cond, dout, zeros_like(x)), select(cond, zeros_like(y), dout)
  289. return bprop
  290. @bprop_getters.register(P.OnesLike)
  291. def get_bprop_oneslike(self):
  292. """Generate bprop for OnesLike"""
  293. def bprop(x, out, dout):
  294. return (zeros_like(x),)
  295. return bprop
  296. @bprop_getters.register(P.ZerosLike)
  297. def get_bprop_zeroslike(self):
  298. """Generate bprop for OnesLike"""
  299. def bprop(x, out, dout):
  300. return (zeros_like(x),)
  301. return bprop
  302. @bprop_getters.register(P.ResizeNearestNeighbor)
  303. def get_bprop_resize_nearest_neighbor(self):
  304. """Generate bprop for ResizeNearestNeighbor"""
  305. op = G.ResizeNearestNeighborGrad(self.align_corners)
  306. def bprop(inputs, out, dout):
  307. shp = shape_op(inputs)
  308. # 2 and 3 represent the height and width
  309. shp = (shp[2], shp[3])
  310. return (op(dout, shp),)
  311. return bprop
  312. @bprop_getters.register(P.GatherNd)
  313. def get_bprop_gather_nd(self):
  314. """Generate bprop for GatherNd"""
  315. op = P.ScatterNd()
  316. def bprop(x, indices, out, dout):
  317. shp = shape_op(x)
  318. return op(indices, dout, shp), zeros_like(indices)
  319. return bprop
  320. @bprop_getters.register(P.ScatterNd)
  321. def get_bprop_scatter_nd(self):
  322. """Generate bprop for ScatterNd"""
  323. op = P.GatherNd()
  324. def bprop(indices, x, shape, out, dout):
  325. return zeros_like(indices), op(dout, indices), zeros_like(shape)
  326. return bprop
  327. @bprop_getters.register(P.ScatterNdUpdate)
  328. def get_bprop_scatter_nd_update(self):
  329. """Generate bprop for ScatterNdUpdate"""
  330. op = P.GatherNd()
  331. def bprop(x, indices, update, out, dout):
  332. return dout, zeros_like(indices), op(dout, indices)
  333. return bprop
  334. @bprop_getters.register(P.Argmax)
  335. def get_bprop_argmax(self):
  336. """Generate bprop for Argmax"""
  337. def bprop(x, out, dout):
  338. return (zeros_like(x),)
  339. return bprop
  340. @bprop_getters.register(P.Argmin)
  341. def get_bprop_argmin(self):
  342. """Generate bprop for Argmin"""
  343. def bprop(x, out, dout):
  344. return (zeros_like(x),)
  345. return bprop
  346. @bprop_getters.register(P.SpaceToDepth)
  347. def get_bprop_space_to_depth(self):
  348. """Generate bprop for SpaceToDepth"""
  349. op = P.DepthToSpace(self.block_size)
  350. def bprop(x, out, dout):
  351. return (op(dout),)
  352. return bprop
  353. @bprop_getters.register(P.DepthToSpace)
  354. def get_bprop_depth_to_space(self):
  355. """Generate bprop for DepthToSpace"""
  356. op = P.SpaceToDepth(self.block_size)
  357. def bprop(x, out, dout):
  358. return (op(dout),)
  359. return bprop
  360. @bprop_getters.register(P.Diag)
  361. def get_bprop_diag(self):
  362. """Generate bprop for Diag"""
  363. op = P.DiagPart()
  364. def bprop(x, out, dout):
  365. return (op(dout),)
  366. return bprop
  367. @bprop_getters.register(P.DiagPart)
  368. def get_bprop_diag_part(self):
  369. """Generate bprop for DiagPart"""
  370. op = P.Diag()
  371. def bprop(x, out, dout):
  372. return (op(dout),)
  373. return bprop
  374. def _GatherDropNegatives(params,
  375. ids,
  376. zero_clipped_indices=None,
  377. is_positive=None):
  378. """Helper function for unsorted segment ops."""
  379. maximum = P.Maximum()
  380. gather = P.GatherV2()
  381. greater_equal = P.GreaterEqual()
  382. rank = P.Rank()
  383. fill = P.Fill()
  384. select = P.Select()
  385. if zero_clipped_indices is None:
  386. zero_clipped_indices = maximum(ids, zeros_like(ids))
  387. gathered = gather(params, zero_clipped_indices, 0)
  388. if is_positive is None:
  389. is_positive = greater_equal(ids, 0)
  390. is_positive_shape = shape_op(is_positive)
  391. broadcastable_shape = is_positive_shape
  392. for _ in range(rank(gathered) - rank(is_positive)):
  393. broadcastable_shape += (1,)
  394. is_positive = reshape(is_positive, broadcastable_shape)
  395. gathered_shape = shape_op(gathered)
  396. is_positive = logical_and(is_positive, fill(mstype.bool_, gathered_shape, 1))
  397. zero_slice = zeros_like(gathered)
  398. return (select(is_positive, gathered, zero_slice), zero_clipped_indices, is_positive)
  399. @bprop_getters.register(P.UnsortedSegmentMin)
  400. def get_bprop_unsorted_segment_min(self):
  401. """Generate bprop for UnsortedSegmentMin"""
  402. equal = P.Equal()
  403. cast = P.Cast()
  404. divide = P.RealDiv()
  405. get_dtype = P.DType()
  406. select = P.Select()
  407. def bprop(x, segment_ids, num_segments, out, dout):
  408. gathered_outputs, zero_clipped_indices, is_positive = _GatherDropNegatives(out, segment_ids)
  409. is_selected = equal(x, gathered_outputs)
  410. is_selected = logical_and(is_selected, is_positive)
  411. num_selected = unsorted_segment_sum(cast(is_selected, get_dtype(dout)),
  412. segment_ids, num_segments)
  413. weighted_grads = divide(dout, num_selected)
  414. gathered_grads, _, _ = _GatherDropNegatives(weighted_grads, None,
  415. zero_clipped_indices, is_positive)
  416. zeros = zeros_like(gathered_grads)
  417. return select(is_selected, gathered_grads, zeros), zeros_like(segment_ids), zeros_like(num_segments)
  418. return bprop
  419. @bprop_getters.register(P.SpaceToBatch)
  420. def get_bprop_space_to_batch(self):
  421. """Generate bprop for SpaceToBatch"""
  422. space_to_batch_grad = P.BatchToSpace(self.block_size, self.paddings)
  423. def bprop(x, out, dout):
  424. dx = space_to_batch_grad(dout)
  425. return (dx,)
  426. return bprop
  427. @bprop_getters.register(P.BatchToSpace)
  428. def get_bprop_batch_to_space(self):
  429. """Generate bprop for BatchToSpace"""
  430. batch_to_space_grad = P.SpaceToBatch(self.block_size, self.crops)
  431. def bprop(x, out, dout):
  432. dx = batch_to_space_grad(dout)
  433. return (dx,)
  434. return bprop
  435. @bprop_getters.register(P.SpaceToBatchND)
  436. def get_bprop_space_to_batch_nd(self):
  437. """Generate bprop for SpaceToBatchND"""
  438. space_to_batch_nd_grad = P.BatchToSpaceND(self.block_shape, self.paddings)
  439. def bprop(x, out, dout):
  440. dx = space_to_batch_nd_grad(dout)
  441. return (dx,)
  442. return bprop
  443. @bprop_getters.register(P.BatchToSpaceND)
  444. def get_bprop_batch_to_space_nd(self):
  445. """Generate bprop for BatchToSpaceND"""
  446. batch_to_space_nd_grad = P.SpaceToBatchND(self.block_shape, self.crops)
  447. def bprop(x, out, dout):
  448. dx = batch_to_space_nd_grad(dout)
  449. return (dx,)
  450. return bprop
  451. @bprop_getters.register(P.BroadcastTo)
  452. def get_bprop_broadcast_to(self):
  453. """Generate bprop for BroadcastTo"""
  454. reduce_keep_dim = P.ReduceSum(keep_dims=True)
  455. broadcast_shape = self.shape
  456. def bprop(x, out, dout):
  457. x_shape = shape_op(x)
  458. _, reduction_axes = broadcast_gradient_args(broadcast_shape, x_shape)
  459. reduced_grad = reduce_keep_dim(dout, reduction_axes)
  460. dx = reshape(reduced_grad, x_shape)
  461. return (dx,)
  462. return bprop