You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

grad_nn_ops.py 37 kB

6 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
6 years ago
6 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
6 years ago
6 years ago
6 years ago
6 years ago
5 years ago
5 years ago
5 years ago
6 years ago
6 years ago
12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232
  1. # Copyright 2020-2021 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """Define the grad rules of neural network related operations."""
  16. from mindspore.ops.primitive import constexpr
  17. from mindspore.ops.operations import nn_ops as nps
  18. from .grad_base import bprop_getters
  19. from .. import functional as F
  20. from .. import operations as P
  21. from ...common import dtype as mstype
  22. from ..composite.multitype_ops.zeros_like_impl import zeros_like
  23. from ..operations import _grad_ops as G
  24. from ..operations import _inner_ops as inner
  25. from ... import context
  26. from .._utils.utils import range_op, get_1d_shape
  27. @bprop_getters.register(P.BiasAdd)
  28. def get_bprop_bias_add(self):
  29. """Grad definition for `BiasAdd` operation."""
  30. bias_grad = G.BiasAddGrad(self.data_format)
  31. def bprop(x, w, out, dout):
  32. return dout, bias_grad(dout)
  33. return bprop
  34. @bprop_getters.register(P.Conv2D)
  35. def get_bprop_conv2d(self):
  36. """Grad definition for `Conv2D` operation."""
  37. self.out_channel = self.get_attr_dict()["out_channel"]
  38. self.pad_list = self.get_attr_dict()["pad_list"]
  39. input_grad = P.Conv2DBackpropInput(
  40. self.out_channel, self.kernel_size, self.pad_mode, self.pad, self.pad_list, mode=self.mode,
  41. dilation=self.dilation, stride=self.stride, group=self.group, data_format=self.format
  42. )
  43. filter_grad = G.Conv2DBackpropFilter(
  44. self.out_channel, self.kernel_size, self.pad_mode, self.pad, self.pad_list, mode=self.mode,
  45. dilation=self.dilation, stride=self.stride, group=self.group, data_format=self.format
  46. )
  47. get_shape = P.Shape()
  48. get_dyn_shape = P.DynamicShape()
  49. def bprop(x, w, out, dout):
  50. x_shape = get_shape(x)
  51. w_shape = get_shape(w)
  52. if -1 in x_shape:
  53. x_shape = get_dyn_shape(x)
  54. if -1 in w_shape:
  55. w_shape = get_dyn_shape(w)
  56. dx = input_grad(dout, w, x_shape)
  57. dw = filter_grad(dout, x, w_shape)
  58. return dx, dw
  59. return bprop
  60. @bprop_getters.register(nps.Conv3D)
  61. def get_bprop_conv3d(self):
  62. """Grad definition for `Conv3D` operation."""
  63. input_grad = nps.Conv3DBackpropInput(
  64. self.out_channel, self.kernel_size, self.mode, pad_mode=self.pad_mode,
  65. pad=self.pad, stride=self.stride, dilation=self.dilation, group=self.group, data_format=self.data_format
  66. )
  67. filter_grad = G.Conv3DBackpropFilter(
  68. self.out_channel, self.kernel_size, self.mode, pad_mode=self.pad_mode,
  69. pad=self.pad, stride=self.stride, dilation=self.dilation, group=self.group, data_format=self.data_format
  70. )
  71. get_shape = P.Shape()
  72. def bprop(x, w, out, dout):
  73. dx = input_grad(w, dout, get_shape(x))
  74. dw = filter_grad(x, dout, get_shape(w))
  75. return dx, dw
  76. return bprop
  77. @bprop_getters.register(nps.Conv3DTranspose)
  78. def get_bprop_conv3d_transpose(self):
  79. """Grad definition for `Conv3DTranspose` operation."""
  80. stride = (self.stride[2], self.stride[3], self.stride[4])
  81. dilation = (self.dilation[2], self.dilation[3], self.dilation[4])
  82. input_grad = nps.Conv3D(
  83. out_channel=self.in_channel, kernel_size=self.kernel_size, mode=self.mode, pad_mode="pad",
  84. pad=self.pad_list, stride=stride, dilation=dilation, group=self.group, data_format=self.data_format
  85. )
  86. filter_grad = G.Conv3DBackpropFilter(
  87. out_channel=self.in_channel, kernel_size=self.kernel_size, mode=self.mode, pad_mode="pad",
  88. pad=self.pad_list, stride=self.stride, dilation=self.dilation, group=self.group, data_format=self.data_format
  89. )
  90. def bprop(x, w, out, dout):
  91. dx = input_grad(dout, w)
  92. dw = filter_grad(dout, x, F.shape(w))
  93. return dx, dw
  94. return bprop
  95. @bprop_getters.register(inner.ExtractImagePatches)
  96. def get_bprop_extract_image_patches(self):
  97. """Grad definition for `ExtractImagePatches` operation."""
  98. get_shape = P.Shape()
  99. reshape = P.Reshape()
  100. extract_image_patches = inner.ExtractImagePatches(ksizes=self.ksizes,
  101. strides=self.strides,
  102. rates=self.rates,
  103. padding=self.padding)
  104. concat = P.Concat(axis=-1)
  105. expand_dims = P.ExpandDims()
  106. scatter_nd = P.ScatterNd()
  107. dtype = P.DType()
  108. fill = P.Fill()
  109. slice_op = P.Slice()
  110. transpose = P.Transpose()
  111. cast = P.Cast()
  112. matmul = P.MatMul()
  113. _, _, ksizes_row, ksizes_col = self.ksizes
  114. def bprop(x, out, dout):
  115. x_shape = get_shape(x)
  116. x_batch, x_depth, x_row, x_col = x_shape
  117. x_indices_num = x_row * x_col + 1
  118. x_idx = cast(F.tuple_to_array(range(1, x_indices_num)), mstype.float32)
  119. x_idx = reshape(x_idx, (1, 1, x_row, x_col))
  120. x_idx_patch = cast(extract_image_patches(x_idx), mstype.int32)
  121. x_idx_patch = transpose(x_idx_patch, (0, 2, 3, 1))
  122. out_shape = get_shape(out)
  123. _, _, out_row, out_col = out_shape
  124. out_indices_num = out_row * out_col * ksizes_row * ksizes_col
  125. out_idx = F.tuple_to_array(range(out_indices_num))
  126. out_idx = reshape(out_idx, (1, out_row, out_col, ksizes_row * ksizes_col))
  127. idx_tensor = concat((expand_dims(x_idx_patch, -1), expand_dims(out_idx, -1)))
  128. idx_tensor = reshape(idx_tensor, (-1, 2))
  129. sp_shape = (x_indices_num, out_indices_num)
  130. sp_tensor = scatter_nd(idx_tensor, fill(dtype(dout), (out_indices_num,), 1), sp_shape)
  131. sp_tensor = slice_op(sp_tensor, (1, 0), (x_indices_num - 1, out_indices_num))
  132. grad = transpose(dout, (0, 2, 3, 1))
  133. grad = reshape(grad, (x_batch, out_row, out_col, ksizes_row, ksizes_col, x_depth))
  134. grad = transpose(grad, (1, 2, 3, 4, 0, 5))
  135. grad = reshape(grad, (-1, x_batch * x_depth))
  136. jac = matmul(sp_tensor, grad)
  137. dx = reshape(jac, (x_row, x_col, x_batch, x_depth))
  138. dx = transpose(dx, (2, 3, 0, 1))
  139. return (dx,)
  140. return bprop
  141. @bprop_getters.register(P.DepthwiseConv2dNative)
  142. def get_bprop_depthwise_conv2d_native(self):
  143. """Grad definition for `DepthwiseConv2dNative` operation."""
  144. input_grad = G.DepthwiseConv2dNativeBackpropInput(
  145. self.channel_multiplier, self.kernel_size, self.pad_mode, self.pad, self.pad_list, self.mode, self.stride,
  146. self.dilation, self.group
  147. )
  148. filter_grad = G.DepthwiseConv2dNativeBackpropFilter(
  149. self.channel_multiplier, self.kernel_size, self.pad_mode, self.pad, self.pad_list, self.mode, self.stride,
  150. self.dilation, self.group
  151. )
  152. get_shape = P.Shape()
  153. def bprop(x, w, out, dout):
  154. dx = input_grad(get_shape(x), w, dout)
  155. dw = filter_grad(x, get_shape(w), dout)
  156. return dx, dw
  157. return bprop
  158. @bprop_getters.register(P.MaxPoolWithArgmax)
  159. def get_bprop_max_pool_with_argmax(self):
  160. """Grad definition for `MaxPoolWithArgmax` operation."""
  161. maxpool_grad = G.MaxPoolGradWithArgmax(
  162. kernel_size=self.kernel_size,
  163. strides=self.strides,
  164. pad_mode=self.pad_mode)
  165. def bprop(x, out, dout):
  166. dx = maxpool_grad(x, dout[0], out[1])
  167. return (dx,)
  168. return bprop
  169. @bprop_getters.register(G.MaxPoolGrad)
  170. def get_bprop_max_pool_grad_grad(self):
  171. """Grad definition for `MaxPoolGrad` operation."""
  172. maxpool_grad_grad = G.MaxPoolGradGrad(
  173. kernel_size=self.kernel_size,
  174. strides=self.strides,
  175. pad_mode=self.pad_mode)
  176. def bprop(x1, x2, grad, out, dout):
  177. dx1 = zeros_like(x1)
  178. dx2 = zeros_like(x2)
  179. dgrad = maxpool_grad_grad(x1, x2, dout)
  180. return (dx1, dx2, dgrad)
  181. return bprop
  182. @bprop_getters.register(G.MaxPoolGradGrad)
  183. def get_bprop_max_pool_grad_grad_grad(self):
  184. """Grad definition for `MaxPoolGradGrad` operation."""
  185. maxpool_grad = G.MaxPoolGrad(
  186. kernel_size=self.kernel_size,
  187. strides=self.strides,
  188. pad_mode=self.pad_mode)
  189. def bprop(x1, x2, grad, out, dout):
  190. dx1 = zeros_like(x1)
  191. dx2 = zeros_like(x2)
  192. dgrad = maxpool_grad(x1, x2, dout)
  193. return (dx1, dx2, dgrad)
  194. return bprop
  195. @bprop_getters.register(P.MaxPool)
  196. def get_bprop_max_pool_grad(self):
  197. """Grad definition for `MaxPool` operation."""
  198. maxpool_grad = G.MaxPoolGrad(
  199. kernel_size=self.kernel_size,
  200. strides=self.strides,
  201. pad_mode=self.pad_mode,
  202. data_format=self.format)
  203. def bprop(x, out, dout):
  204. dx = maxpool_grad(x, out, dout)
  205. return (dx,)
  206. return bprop
  207. @bprop_getters.register(P.MaxPool3D)
  208. def get_bprop_max_pool3d_grad(self):
  209. """Grad definition for `MaxPool3D` operation."""
  210. max_pool3d_grad = G.MaxPool3DGrad(
  211. kernel_size=self.kernel_size,
  212. strides=self.strides,
  213. pad_mode=self.pad_mode,
  214. pad_list=self.pad_list,
  215. data_format=self.data_format)
  216. def bprop(x, out, dout):
  217. dx = max_pool3d_grad(x, out, dout)
  218. return (dx,)
  219. return bprop
  220. @bprop_getters.register(G.MaxPool3DGrad)
  221. def get_bprop_max_pool3d_grad_grad(self):
  222. """Grad definition for `MaxPool3Grad` operation."""
  223. max_pool3d_grad_grad = G.MaxPool3DGradGrad(
  224. kernel_size=self.kernel_size,
  225. strides=self.strides,
  226. pad_mode=self.pad_mode,
  227. data_format=self.data_format)
  228. def bprop(x, y, grad, out, dout):
  229. dgrad = max_pool3d_grad_grad(x, y, dout)
  230. return zeros_like(x), zeros_like(y), dgrad
  231. return bprop
  232. @bprop_getters.register(G.MaxPool3DGradGrad)
  233. def get_bprop_max_pool3d_grad_grad_grad(self):
  234. """Grad definition for `MaxPool3GradGrad` operation."""
  235. max_pool3d_grad = G.MaxPool3DGrad(
  236. kernel_size=self.kernel_size,
  237. strides=self.strides,
  238. pad_mode=self.pad_mode,
  239. data_format=self.data_format)
  240. def bprop(x, y, grad, out, dout):
  241. dgrad = max_pool3d_grad(x, y, dout)
  242. return zeros_like(x), zeros_like(y), dgrad
  243. return bprop
  244. @bprop_getters.register(P.AvgPool)
  245. def get_bprop_avg_pool_grad(self):
  246. """Grad definition for `AvgPool` operation."""
  247. avgpool_grad = G.AvgPoolGrad(
  248. kernel_size=self.kernel_size,
  249. strides=self.strides,
  250. pad_mode=self.pad_mode,
  251. data_format=self.format)
  252. def bprop(x, out, dout):
  253. dx = avgpool_grad(x, out, dout)
  254. return (dx,)
  255. return bprop
  256. @bprop_getters.register(P.AdaptiveAvgPool2D)
  257. def get_bprop_adaptive_avg_pool2d_grad(self):
  258. """Grad definition for `AdaptiveAvgPool2D` operation."""
  259. adaptive_avgpool_grad = G.AdaptiveAvgPool2DGrad()
  260. def bprop(x, out, dout):
  261. dx = adaptive_avgpool_grad(x, dout)
  262. return (dx,)
  263. return bprop
  264. @bprop_getters.register(P.AvgPool3D)
  265. def get_bprop_avg_pool_3d_grad(self):
  266. """Grad definition for `AvgPool3D` operation."""
  267. pad_list = self.get_attr_dict()['pad_list']
  268. count_include_pad = self.get_attr_dict()['count_include_pad']
  269. avgpool3d_grad = G.AvgPool3DGrad(kernel_size=self.kernel_size,
  270. strides=self.strides,
  271. pads=pad_list,
  272. ceil_mode=self.ceil_mode,
  273. count_include_pad=count_include_pad,
  274. divisor_override=self.divisor_override,
  275. data_format=self.data_format)
  276. def bprop(x, out, dout):
  277. x_shape = F.shape(x)
  278. dx = avgpool3d_grad(x_shape, dout)
  279. return (dx,)
  280. return bprop
  281. @bprop_getters.register(P.DropoutGenMask)
  282. def get_bprop_dropout_gen_mask(self):
  283. """Grad definition for `DropoutGenMask` operation."""
  284. def bprop(shape, keep_prob, out, dout):
  285. return (zeros_like(shape), zeros_like(keep_prob))
  286. return bprop
  287. @bprop_getters.register(P.DropoutDoMask)
  288. def get_bprop_dropout_do_mask(self):
  289. """Grad definition for `DropoutDoMask` operation."""
  290. do_mask = P.DropoutDoMask()
  291. def bprop(x, y, keep_prob, out, dout):
  292. return (do_mask(dout, y, keep_prob), zeros_like(y), zeros_like(keep_prob))
  293. return bprop
  294. @bprop_getters.register(P.Mish)
  295. def get_bprop_mish(self):
  296. """Grad definition for `Mish` operation."""
  297. tanh = P.Tanh()
  298. tanh_grad = G.TanhGrad()
  299. softplus = P.Softplus()
  300. softplus_grad = G.SoftplusGrad()
  301. def bprop(x, out, dout):
  302. dx1 = tanh(softplus(x))
  303. dx2 = softplus_grad(tanh_grad(dx1, x * dout), x)
  304. dx = (dx1 * dout + dx2)
  305. return (dx,)
  306. return bprop
  307. @bprop_getters.register(P.SeLU)
  308. def get_bprop_selu(self):
  309. """Grad definition for `SeLU` operation."""
  310. scale = 1.0507009873554804934193349852946
  311. elu_grad = G.EluGrad()
  312. def bprop(x, out, dout):
  313. dx = elu_grad(dout, out) * scale
  314. return (dx,)
  315. return bprop
  316. @bprop_getters.register(P.MulNoNan)
  317. def get_bprop_mul_no_nan(self):
  318. """Grad definition for `MulNoNan` operation."""
  319. mul_no_nan = P.MulNoNan()
  320. reduce_sum = P.ReduceSum()
  321. reshape = P.Reshape()
  322. def bprop(x, y, out, dout):
  323. x_shape = F.shape(x)
  324. y_shape = F.shape(y)
  325. dx = mul_no_nan(dout, y)
  326. dy = mul_no_nan(x, dout)
  327. broadcast_x, broadcast_y = F.broadcast_gradient_args(x_shape, y_shape)
  328. if broadcast_x != ():
  329. dx = reshape(reduce_sum(dx, broadcast_x), x_shape)
  330. if broadcast_y != ():
  331. dy = reshape(reduce_sum(dy, broadcast_y), y_shape)
  332. return dx, dy
  333. return bprop
  334. @bprop_getters.register(P.ReLU)
  335. def get_bprop_relu(self):
  336. """Grad definition for `ReLU` operation."""
  337. input_grad = G.ReluGrad()
  338. def bprop(x, out, dout):
  339. dx = input_grad(dout, out)
  340. return (dx,)
  341. return bprop
  342. @bprop_getters.register(G.ReluGrad)
  343. def get_bprop_relu_grad(self):
  344. """Grad definition for `ReLUGrad` operation."""
  345. input_grad = G.ReluGrad()
  346. def bprop(grad, y, out, dout):
  347. dgrad = input_grad(dout, y)
  348. return dgrad, zeros_like(y)
  349. return bprop
  350. @bprop_getters.register(P.ReLU6)
  351. def get_bprop_relu6(self):
  352. """Grad definition for `ReLU6` operation."""
  353. input_grad = G.ReLU6Grad()
  354. def bprop(x, out, dout):
  355. dx = input_grad(dout, x)
  356. return (dx,)
  357. return bprop
  358. @bprop_getters.register(P.ReLUV2)
  359. def get_bprop_relu_v2(self):
  360. """Grad definition for `ReLUV2` operation."""
  361. input_grad = G.ReluGradV2()
  362. def bprop(x, out, dout):
  363. mask = out[1]
  364. dx = input_grad(dout[0], mask)
  365. return (dx,)
  366. return bprop
  367. @bprop_getters.register(P.HSwish)
  368. def get_bprop_hswish(self):
  369. """Grad definition for `HSwish` operation."""
  370. input_grad = G.HSwishGrad()
  371. def bprop(x, out, dout):
  372. dx = input_grad(dout, x)
  373. return (dx,)
  374. return bprop
  375. @bprop_getters.register(P.HSigmoid)
  376. def get_bprop_hsigmoid(self):
  377. """Grad definition for `HSigmoid` operation."""
  378. input_grad = G.HSigmoidGrad()
  379. def bprop(x, out, dout):
  380. dx = input_grad(dout, x)
  381. return (dx,)
  382. return bprop
  383. @bprop_getters.register(P.Elu)
  384. def get_bprop_elu(self):
  385. """Grad definition for `Elu` operation."""
  386. input_grad = G.EluGrad()
  387. def bprop(x, out, dout):
  388. dx = input_grad(dout, out)
  389. return (dx,)
  390. return bprop
  391. @bprop_getters.register(P.Sigmoid)
  392. def get_bprop_sigmoid(self):
  393. """Grad definition for `Sigmoid` operation."""
  394. input_grad = G.SigmoidGrad()
  395. def bprop(x, out, dout):
  396. dx = input_grad(out, dout)
  397. return (dx,)
  398. return bprop
  399. @bprop_getters.register(G.SigmoidGrad)
  400. def get_bprop_sigmoid_grad(self):
  401. """Grad definition for `SigmoidGrad` operation."""
  402. sigmoid_grad = G.SigmoidGrad()
  403. def bprop(y, grad, out, dout):
  404. dy = dout * grad * (1. - 2 * y)
  405. dgrad = sigmoid_grad(y, dout)
  406. return dy, dgrad
  407. return bprop
  408. @constexpr
  409. def _get_transpose_axis(x_shp, axis):
  410. rank = len(x_shp)
  411. if axis < 0:
  412. axis += rank
  413. reverse_axis = [i for i in range(rank)]
  414. reverse_axis[axis] = rank - 1
  415. reverse_axis[rank - 1] = axis
  416. return tuple(reverse_axis)
  417. @bprop_getters.register(P.Softmax)
  418. def get_bprop_softmax(self):
  419. """Grad definition for `Softmax` operation."""
  420. sum_func = P.ReduceSum(keep_dims=True)
  421. sub = P.Sub()
  422. mul = P.Mul()
  423. get_shape = P.Shape()
  424. transpose = P.Transpose()
  425. axis = self.axis
  426. if not isinstance(axis, int):
  427. axis = axis[0]
  428. def bprop(x, out, dout):
  429. # dx = (dout - sum(dout * out)) * out
  430. # This formula is correct only when the `axis` is the last dimension.
  431. # In order to support the scenario where the `axis` is other values,
  432. # we transpose the data of the `axis` dimension to the last dimension for calculation,
  433. # and then transpose it back after the calculation.
  434. reverse_axis = _get_transpose_axis(get_shape(x), axis)
  435. out = transpose(out, reverse_axis)
  436. dout = transpose(dout, reverse_axis)
  437. dx = mul(out, sub(dout, sum_func(mul(out, dout), -1)))
  438. dx = transpose(dx, reverse_axis)
  439. return (dx,)
  440. return bprop
  441. @bprop_getters.register(P.LogSoftmax)
  442. def get_bprop_log_softmax(self):
  443. """Grad definition for `LogSoftmax` operation."""
  444. logsoftmax_grad = G.LogSoftmaxGrad(self.axis)
  445. def bprop(x, out, dout):
  446. dx = logsoftmax_grad(out, dout)
  447. return (dx,)
  448. return bprop
  449. @bprop_getters.register(P.Softplus)
  450. def get_bprop_softplus(self):
  451. """Grad definition for `Softplus` operation."""
  452. softplus_grad = G.SoftplusGrad()
  453. def bprop(x, out, dout):
  454. dx = softplus_grad(dout, x)
  455. return (dx,)
  456. return bprop
  457. @bprop_getters.register(P.Softsign)
  458. def get_bprop_softsign(self):
  459. """Grad definition for `Softsign` operation."""
  460. mul = P.Mul()
  461. absolute = P.Abs()
  462. div = P.Div()
  463. square = P.Square()
  464. def bprop(x, out, dout):
  465. dx = mul(dout, div(1, square(1 + absolute(x))))
  466. return (dx,)
  467. return bprop
  468. @bprop_getters.register(P.Tanh)
  469. def get_bprop_tanh(self):
  470. """Grad definition for `Tanh` operation."""
  471. tanh_grad = G.TanhGrad()
  472. def bprop(x, out, dout):
  473. dx = tanh_grad(out, dout)
  474. return (dx,)
  475. return bprop
  476. @bprop_getters.register(G.TanhGrad)
  477. def get_bprop_tanh_grad(self):
  478. """Grad definition for `TanhGrad` operation."""
  479. tanh_grad = G.TanhGrad()
  480. def bprop(y, grad, out, dout):
  481. dy = dout * -2.0 * grad * y
  482. dgrad = tanh_grad(y, dout)
  483. return dy, dgrad
  484. return bprop
  485. @bprop_getters.register(P.Gelu)
  486. @bprop_getters.register(P.GeLU)
  487. def get_bprop_gelu(self):
  488. """Grad definition for `GeLU` operation."""
  489. input_grad = G.GeLUGrad()
  490. def bprop(x, out, dout):
  491. dx = input_grad(dout, x, out)
  492. return (dx,)
  493. return bprop
  494. @bprop_getters.register(P.FastGeLU)
  495. def get_bprop_fast_gelu(self):
  496. """Grad definition for `FastGeLU` operation."""
  497. input_grad = G.FastGeLUGrad()
  498. def bprop(x, out, dout):
  499. dx = input_grad(dout, x)
  500. return (dx,)
  501. return bprop
  502. @bprop_getters.register(P.FastGelu)
  503. def get_bprop_fast_gelu_2(self):
  504. """Grad definition for `FastGeLU` operation."""
  505. input_grad = G.FastGeLUGrad()
  506. def bprop(x, out, dout):
  507. dx = input_grad(dout, x)
  508. return (dx,)
  509. return bprop
  510. @bprop_getters.register(P.InstanceNorm)
  511. def get_bprop_instance_norm(self):
  512. """Grad definition for `InstanceNorm` operation."""
  513. input_grad = G.InstanceNormGrad(self.epsilon, self.momentum)
  514. def bprop(x, gamma, beta, mean, variance, out, dout):
  515. saved_mean = out[1]
  516. saved_variance = out[2]
  517. out = input_grad(dout[0], x, gamma, saved_mean, saved_variance)
  518. dx = out[0]
  519. dgamma = out[1]
  520. dbeta = out[2]
  521. return dx, dgamma, dbeta, zeros_like(mean), zeros_like(variance)
  522. return bprop
  523. @bprop_getters.register(P.BatchNorm)
  524. def get_bprop_batch_norm(self):
  525. """Grad definition for `BatchNorm` operation."""
  526. is_training = self.is_training
  527. input_grad = G.BatchNormGrad(is_training, self.epsilon, self.data_format)
  528. def bprop(x, scale, b, mean, variance, out, dout):
  529. if is_training:
  530. saved_mean = out[3]
  531. saved_variance = out[4]
  532. reserve = out[2]
  533. else:
  534. saved_mean = mean
  535. saved_variance = variance
  536. reserve = out[2]
  537. out = input_grad(dout[0], x, scale, saved_mean, saved_variance, reserve)
  538. dx = out[0]
  539. dscale = out[1]
  540. dbias = out[2]
  541. return dx, dscale, dbias, zeros_like(mean), zeros_like(variance)
  542. return bprop
  543. @bprop_getters.register(P.LayerNorm)
  544. def get_bprop_layer_norm(self):
  545. """Grad definition for `LayerNorm` operation."""
  546. layer_norm_grad = G.LayerNormGrad(self.begin_norm_axis, self.begin_params_axis)
  547. def bprop(x, gamma, beta, out, dout):
  548. dx, d_gamma, d_beta = layer_norm_grad(
  549. x, dout[0], out[2], out[1], gamma)
  550. return dx, d_gamma, d_beta
  551. return bprop
  552. @bprop_getters.register(G.LayerNormGrad)
  553. def get_bprop_layer_norm_grad(self):
  554. """Grad definition for `LayerNormGrad` operation."""
  555. layer_norm_grad_grad = G.LayerNormGradGrad(self.begin_norm_axis, self.begin_params_axis)
  556. def bprop(x, dy, variance, mean, gamma, out, dout):
  557. d_x, d_dy, d_gamma = layer_norm_grad_grad(
  558. x, dy, variance, mean, gamma, dout[0], dout[1], dout[2])
  559. return d_x, d_dy, zeros_like(variance), zeros_like(mean), d_gamma
  560. return bprop
  561. @bprop_getters.register(P.L2Normalize)
  562. def get_bprop_l2normalize(self):
  563. """Grad definition for `L2Normalize` operation."""
  564. input_grad = G.L2NormalizeGrad(self.axis, self.epsilon)
  565. def bprop(x, out, dout):
  566. dx = input_grad(x, out, dout)
  567. return (dx,)
  568. return bprop
  569. @bprop_getters.register(P.SoftmaxCrossEntropyWithLogits)
  570. def get_bprop_softmax_cross_entropy_with_logits(self):
  571. """Grad definition for `SoftmaxCrossEntropyWithLogits` operation."""
  572. expand = P.ExpandDims()
  573. def bprop(logits, labels, out, dout):
  574. grad = out[1]
  575. grad = grad * expand(dout[0], -1)
  576. return grad, zeros_like(labels)
  577. return bprop
  578. @bprop_getters.register(P.NLLLoss)
  579. def get_bprop_nll_loss(self):
  580. """Grad definition for `NLLLoss` operation."""
  581. nll_loss_grad = G.NLLLossGrad(reduction=self.reduction)
  582. def bprop(x, target, weight, out, dout):
  583. total_weight = out[1]
  584. dout_x = dout[0]
  585. dx = nll_loss_grad(x, dout_x, target, weight, total_weight)
  586. return dx, zeros_like(target), zeros_like(weight)
  587. return bprop
  588. @bprop_getters.register(P.SparseSoftmaxCrossEntropyWithLogits)
  589. def get_bprop_sparse_softmax_cross_entropy_with_logits(self):
  590. """Grad definition for `SparseSoftmaxCrossEntropyWithLogits` operation."""
  591. is_grad = self.is_grad
  592. grad_op = P.SparseSoftmaxCrossEntropyWithLogits(is_grad=True)
  593. def bprop(logits, labels, out, dout):
  594. grad = out[0]
  595. if not is_grad:
  596. # if construct use loss
  597. grad = grad_op(logits, labels)
  598. grad = F.depend(grad, out)
  599. grad = grad * dout
  600. return grad, zeros_like(labels)
  601. return bprop
  602. @bprop_getters.register(P.ResizeBilinear)
  603. def get_bprop_resize_bilinear(self):
  604. """Grad definition for `ResizeBilinear` operation."""
  605. resize_grad = G.ResizeBilinearGrad(self.align_corners)
  606. def bprop(x, out, dout):
  607. dx = resize_grad(dout, x)
  608. return (dx,)
  609. return bprop
  610. @bprop_getters.register(P.OneHot)
  611. def get_bprop_onehot(self):
  612. """Grad definition for `OneHot` operation."""
  613. def bprop(indices, depth, on_value, off_value, out, dout):
  614. return zeros_like(indices), zeros_like(depth), zeros_like(on_value), zeros_like(off_value)
  615. return bprop
  616. @bprop_getters.register(P.TopK)
  617. def get_bprop_top_kv2(self):
  618. """Grad definition for `TopK` operation."""
  619. scatter = P.ScatterNd()
  620. expand_dims = P.ExpandDims()
  621. shape_op = P.Shape()
  622. reshape_op = P.Reshape()
  623. dtype = P.DType()
  624. def bprop(input_x, k, out, dout):
  625. in_shape = shape_op(input_x)
  626. in_lastdim = in_shape[-1]
  627. indices = out[1]
  628. ind_shape = shape_op(indices)
  629. ind_lastdim = ind_shape[-1]
  630. ind_2d = reshape_op(indices, (-1, ind_lastdim))
  631. outerdim = shape_op(ind_2d)[0]
  632. # [0, outterdim, 2*outerdim, ..., (k-1)*outerdim]
  633. indices_dtype = dtype(indices)
  634. range_flatten_index = range_op(0, outerdim * in_lastdim, in_lastdim, indices_dtype)
  635. # expand_dims to (k, 1), then broadcast
  636. ind = reshape_op(ind_2d + expand_dims(range_flatten_index, -1), (-1,))
  637. in_shape_1d = get_1d_shape(in_shape)
  638. out_grad = reshape_op(
  639. scatter(
  640. expand_dims(ind, -1),
  641. reshape_op(dout[0], (-1,)),
  642. in_shape_1d),
  643. in_shape)
  644. return out_grad, zeros_like(k)
  645. return bprop
  646. @bprop_getters.register(P.SmoothL1Loss)
  647. def get_bprop_smooth_l1_loss(self):
  648. """Grad definition for `SmoothL1Loss` operation."""
  649. grad = G.SmoothL1LossGrad(self.beta)
  650. def bprop(prediction, target, out, dout):
  651. dx = grad(prediction, target, dout)
  652. dy = grad(target, prediction, dout)
  653. return dx, dy
  654. return bprop
  655. @bprop_getters.register(P.L2Loss)
  656. def get_bprop_l2_loss(self):
  657. """Grad definition for `L2Loss` operation."""
  658. def bprop(x, out, dout):
  659. dx = x * dout
  660. return (dx,)
  661. return bprop
  662. @bprop_getters.register(P.RNNTLoss)
  663. def get_bprop_rnnt_loss(self):
  664. """Grad definition for `RNNTLoss` operation."""
  665. def bprop(acts, labels, act_lens, label_lens, out, dout):
  666. grad = out[1]
  667. return grad, zeros_like(labels), zeros_like(act_lens), zeros_like(label_lens)
  668. return bprop
  669. @bprop_getters.register(P.PReLU)
  670. def get_bprop_prelu(self):
  671. """Grad definition for `PReLU` operation."""
  672. grad = G.PReLUGrad()
  673. def bprop(x, w, out, dout):
  674. dx, dw = grad(dout, x, w)
  675. return dx, dw
  676. return bprop
  677. @bprop_getters.register(P.LSTM)
  678. def get_bprop_lstm(self):
  679. """Grad definition for `LSTM` operation."""
  680. lstm_grad_data = G.LSTMGradData(
  681. input_size=self.input_size,
  682. hidden_size=self.hidden_size,
  683. num_layers=self.num_layers,
  684. has_bias=self.has_bias,
  685. bidirectional=self.bidirectional,
  686. dropout=self.dropout
  687. )
  688. lstm_grad_weight = G.LSTMGradWeight(
  689. input_size=self.input_size,
  690. hidden_size=self.hidden_size,
  691. num_layers=self.num_layers,
  692. has_bias=self.has_bias,
  693. bidirectional=self.bidirectional,
  694. dropout=self.dropout
  695. )
  696. lstm_grad = G.LSTMGrad(
  697. input_size=self.input_size,
  698. hidden_size=self.hidden_size,
  699. num_layers=self.num_layers,
  700. has_bias=self.has_bias,
  701. bidirectional=self.bidirectional,
  702. dropout=self.dropout
  703. )
  704. def bprop(x, hx, cx, w, out, dout):
  705. y, _, _, reserve, state = out
  706. dy, dhy, dcy, _, _ = dout
  707. dx, dhx, dcx = lstm_grad_data(y, dy, dhy, dcy, w, hx, cx, reserve, state)
  708. dw = lstm_grad_weight(F.depend(x, dx), hx, y, reserve, state)
  709. return dx, dhx, dcx, dw
  710. #
  711. def bprop_cpu(x, hx, cx, w, out, dout):
  712. y, hy, cy, reserve, _ = out
  713. dy, dhy, dcy, _, _ = dout
  714. dx, dhx, dcx, dw = lstm_grad(x, hx, cx, w, y, hy, cy, dy, dhy, dcy, reserve)
  715. return dx, dhx, dcx, dw
  716. if context.get_context('device_target') == "CPU":
  717. return bprop_cpu
  718. return bprop
  719. @bprop_getters.register(P.DynamicRNN)
  720. def get_bprop_dynamic_rnn(self):
  721. """Grad definition for `DynamicRNN` operation."""
  722. dynamic_rnn_grad = G.DynamicRNNGrad(cell_type=self.cell_type,
  723. direction=self.direction,
  724. cell_depth=self.cell_depth,
  725. use_peephole=self.use_peephole,
  726. keep_prob=self.keep_prob,
  727. cell_clip=self.cell_clip,
  728. num_proj=self.num_proj,
  729. time_major=self.time_major,
  730. forget_bias=self.forget_bias)
  731. expand_dims = P.ExpandDims()
  732. def bprop(x, w, b, seq_length, init_h, init_c, out, dout):
  733. dy, dh, dc, _, _, _, _, _, = dout
  734. dh = dh[-1]
  735. dc = dc[-1]
  736. y, h, c, i, j, f, o, tanhct = out
  737. dw, db, dx, dh_prev, dc_prev = dynamic_rnn_grad(x, w, b, y, init_h[0], init_c[0], h,
  738. c, dy, dh, dc, i, j, f, o, tanhct)
  739. dh_prev = expand_dims(dh_prev, 0)
  740. dc_prev = expand_dims(dc_prev, 0)
  741. return dx, dw, db, (0), dh_prev, dc_prev
  742. return bprop
  743. @bprop_getters.register(P.DynamicGRUV2)
  744. def get_bprop_dynamic_gru_v2(self):
  745. """Grad definition for `DynamicGRUV2` operation."""
  746. dynamic_gru_v2_grad = G.DynamicGRUV2Grad(self.direction, self.cell_depth, self.keep_prob, self.cell_clip,
  747. self.num_proj, self.time_major, self.gate_order,
  748. self.reset_after)
  749. def bprop(x, winput, whidden, binput, bhidden, seq, init_h, out, dout):
  750. y, out_h, update, reset, new, hidden_new = out
  751. dy, dout_h, _, _, _, _ = dout
  752. dw_input, dw_hidden, db_input, db_hidden, dx, dh_prev = dynamic_gru_v2_grad(x, winput, whidden, y, init_h,
  753. out_h, dy, dout_h[-1], update,
  754. reset, new, hidden_new, None, None)
  755. return dx, dw_input, dw_hidden, db_input, db_hidden, (0), dh_prev
  756. return bprop
  757. @bprop_getters.register(P.SigmoidCrossEntropyWithLogits)
  758. def get_bprop_sigmoid_crossentropy_with_logits(self):
  759. """Grad definition for `SigmoidCrossEntropyWithLogits` operation."""
  760. op = G.SigmoidCrossEntropyWithLogitsGrad()
  761. def bprop(x, y, out, dout):
  762. dx = op(x, y, dout)
  763. return (dx, zeros_like(y))
  764. return bprop
  765. @bprop_getters.register(P.Pad)
  766. def get_bprop_pad(self):
  767. """Grad definition for `Pad` operation."""
  768. shape_op = P.Shape()
  769. paddings = self.paddings
  770. def bprop(x, out, dout):
  771. begin = ()
  772. for item in paddings:
  773. begin += (item[0],)
  774. shp = shape_op(x)
  775. dx = P.Slice()(dout, begin, shp)
  776. return (dx,)
  777. return bprop
  778. @bprop_getters.register(P.MirrorPad)
  779. def get_bprop_mirror_pad(self):
  780. """Grad definition for `MirrorPad` operation."""
  781. mirror_pad_grad = G.MirrorPadGrad(self.mode)
  782. def bprop(x, paddings, out, dout):
  783. dx = mirror_pad_grad(dout, paddings)
  784. return (dx, zeros_like(paddings))
  785. return bprop
  786. @bprop_getters.register(P.ROIAlign)
  787. def get_bprop_roi_align(self):
  788. """Grad definition for `ROIAlign` operation."""
  789. shape_op = P.Shape()
  790. pooled_height = self.pooled_height
  791. pooled_width = self.pooled_width
  792. spatial_scale = self.spatial_scale
  793. sample_num = self.sample_num
  794. def bprop(inputs, rois, out, dout):
  795. inputs_shape = shape_op(inputs)
  796. dx = G.ROIAlignGrad(inputs_shape,
  797. pooled_height,
  798. pooled_width,
  799. spatial_scale,
  800. sample_num,
  801. )(dout, rois)
  802. return dx, zeros_like(rois)
  803. return bprop
  804. @bprop_getters.register(P.Conv2DTranspose)
  805. @bprop_getters.register(P.Conv2DBackpropInput)
  806. def get_bprop_conv2d_backprop_input(self):
  807. """Grad definition for `Conv2DBackpropInput` operation."""
  808. pad_list = self.get_attr_dict()['pad_list']
  809. out_channel = self.get_attr_dict()['out_channel']
  810. filter_grad = G.Conv2DBackpropFilter(
  811. out_channel, self.kernel_size, self.pad_mode, self.pad, pad_list, mode=self.mode,
  812. dilation=self.dilation, stride=self.stride, group=self.group, data_format=self.format
  813. )
  814. input_grad = P.Conv2D(
  815. out_channel, self.kernel_size, pad_mode=self.pad_mode.lower(), pad=self.pad,
  816. dilation=self.dilation, stride=self.stride, group=self.group, data_format=self.format
  817. )
  818. get_shape = P.Shape()
  819. get_dyn_shape = P.DynamicShape()
  820. def bprop(x, w, f_sizes, out, dout):
  821. w_shape = get_shape(w)
  822. if -1 in w_shape:
  823. w_shape = get_dyn_shape(w)
  824. dx = input_grad(dout, w)
  825. dw = filter_grad(x, dout, w_shape)
  826. return dx, dw, zeros_like(f_sizes)
  827. return bprop
  828. @bprop_getters.register(P.BinaryCrossEntropy)
  829. def get_bprop_binary_cross_entropy(self):
  830. """Grad definition for `BinaryCrossEntropy` operation."""
  831. grad = G.BinaryCrossEntropyGrad(self.reduction)
  832. def bprop(x, y, weight, out, dout):
  833. dx = grad(x, y, dout, weight)
  834. return dx, zeros_like(y), zeros_like(weight)
  835. return bprop
  836. @bprop_getters.register(P.BCEWithLogitsLoss)
  837. def get_bprop_ce_with_logits_loss(self):
  838. """Grad definition for `BCEWithLogitsLoss` operation."""
  839. reduction = self.reduction
  840. mul = P.Mul()
  841. sigmoid = P.Sigmoid()
  842. add = P.Add()
  843. sub = P.Sub()
  844. size = P.Size()
  845. neg = P.Neg()
  846. log = P.Log()
  847. def bprop(predict, target, weight, pos_weight, out, dout):
  848. sigmoid_input = sigmoid(predict)
  849. if pos_weight is not None:
  850. t = mul(target, pos_weight)
  851. dx = mul(sub(mul(sub(add(t, 1), target), sigmoid_input), t), dout)
  852. grad_target = mul(sub(log(sub(1, sigmoid_input)), mul(pos_weight, log(sigmoid_input))), dout)
  853. else:
  854. dx = mul((sigmoid_input - target), dout)
  855. grad_target = mul(predict, neg(dout))
  856. if weight is not None:
  857. dx = mul(dx, weight)
  858. grad_target = mul(grad_target, weight)
  859. if reduction == 'mean':
  860. dx = dx / size(dx)
  861. grad_target = grad_target / size(target)
  862. return dx, grad_target, zeros_like(weight), zeros_like(pos_weight)
  863. return bprop
  864. @bprop_getters.register(P.KLDivLoss)
  865. def get_bprop_kl_div_loss(self):
  866. """Grad definition for `KLDivLoss` operation."""
  867. grad = G.KLDivLossGrad(self.reduction)
  868. def bprop(x, y, out, dout):
  869. dx, dy = grad(x, y, dout)
  870. return dx, dy
  871. return bprop
  872. @bprop_getters.register(P.Dropout)
  873. def get_bprop_dropout(self):
  874. """Grad definition for `Dropout` operation."""
  875. grad = G.DropoutGrad(self.keep_prob)
  876. def bprop(x, out, dout):
  877. _, mask = out
  878. dy, _ = dout
  879. dx = grad(dy, mask)
  880. return (dx,)
  881. return bprop
  882. @bprop_getters.register(P.Dropout2D)
  883. @bprop_getters.register(P.Dropout3D)
  884. def get_bprop_dropout3d(self):
  885. """Grad definition for `Dropout2D` and `Dropout3D` operation."""
  886. dtype = P.DType()
  887. cast = P.Cast()
  888. mul = P.Mul()
  889. keep_prob = self.keep_prob
  890. def bprop(x, out, dout):
  891. _, mask = dout
  892. y = cast(mask, mstype.float32)
  893. if keep_prob != 0:
  894. y = y * (1 / keep_prob)
  895. y = mul(x, y)
  896. y = cast(y, dtype(x))
  897. return (y,)
  898. return bprop
  899. @bprop_getters.register(P.CTCLoss)
  900. def get_bprop_ctc_loss(self):
  901. """Grad definition for `CTCLoss` operation"""
  902. expand = P.ExpandDims()
  903. def bprop(inputs, labels_indices, labels_values, sequence_length, out, dout):
  904. grad_loss = out[1]
  905. grad = grad_loss * expand(dout[0], -1)
  906. return grad, zeros_like(labels_indices), zeros_like(labels_values), zeros_like(sequence_length)
  907. return bprop
  908. @bprop_getters.register(P.BasicLSTMCell)
  909. def get_bprop_basic_lstm_cell(self):
  910. """Grad definition for `BasicLSTMCell` operation."""
  911. basic_lstm_cell_cstate_grad = G.BasicLSTMCellCStateGrad(
  912. forget_bias=self.forget_bias,
  913. activation=self.activation
  914. )
  915. basic_lstm_cell_weight_grad = G.BasicLSTMCellWeightGrad()
  916. basic_lstm_cell_input_grad = G.BasicLSTMCellInputGrad(keep_prob=self.keep_prob)
  917. def bprop(x, h, c, w, b, out, dout):
  918. _, _, it, jt, ft, ot, tanhct = out
  919. dct, dht, _, _, _, _, _ = dout
  920. dgate, dct_1 = basic_lstm_cell_cstate_grad(c, dht, dct, it, jt, ft, ot, tanhct)
  921. dxt, dht = basic_lstm_cell_input_grad(dgate, w)
  922. dw, db = basic_lstm_cell_weight_grad(F.depend(x, dxt), h, dgate)
  923. return dxt, dht, dct_1, dw, db
  924. return bprop
  925. @bprop_getters.register(P.LRN)
  926. def get_bprop_lrn(self):
  927. """Grad definition for `LRN` operation."""
  928. grad = G.LRNGrad(self.depth_radius, self.bias, self.alpha, self.beta)
  929. def bprop(x, out, dout):
  930. dx = grad(dout, x, out)
  931. return (dx,)
  932. return bprop