You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

grad_comm_ops.py 14 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """Generate bprop for comm ops"""
  16. import mindspore.common.dtype as mstype
  17. from mindspore.ops import functional as F
  18. from mindspore.communication import get_rank, get_group_size
  19. from .. import operations as P
  20. from ...common.tensor import RowTensor
  21. from ..composite.multitype_ops.zeros_like_impl import zeros_like
  22. from ..operations.comm_ops import (AllGather, _HostAllGather, AllReduce, _AlltoAll, Broadcast,
  23. _GetTensorSlice, _MirrorOperator, _MirrorMiniStepOperator, ReduceOp,
  24. ReduceScatter, _HostReduceScatter, _VirtualDiv, AllSwap)
  25. from .grad_base import bprop_getters
  26. from ..operations._inner_ops import Send, Receive
  27. @bprop_getters.register(AllReduce)
  28. def get_bprop_all_reduce(self):
  29. """Generate bprop for AllReduce, do allreduce or allgather, allgather for sparse feature."""
  30. all_reduce_grad = AllReduce(ReduceOp.SUM, self.group)
  31. all_gather = AllGather(group=self.group)
  32. if self.instance_name:
  33. instance_name = "grad" + self.instance_name
  34. all_reduce_grad.set_prim_instance_name(instance_name)
  35. equal = P.Equal()
  36. cast = P.Cast()
  37. mul = P.Mul()
  38. div = P.RealDiv()
  39. dtype = P.DType()
  40. if self.op == ReduceOp.PROD:
  41. def bprop(x, out, dout):
  42. dy1 = mul(dout, out)
  43. dy2 = all_reduce_grad(dy1)
  44. dx = div(dy2, x)
  45. return (dx,)
  46. elif self.op == ReduceOp.SUM:
  47. def bprop(x, out, dout):
  48. if F.issubclass_(F.typeof(dout), mstype.tensor):
  49. dx = all_reduce_grad(dout)
  50. else:
  51. indices = all_gather(dout.indices)
  52. grad = all_gather(dout.values)
  53. dx = RowTensor(indices, grad, dout.dense_shape)
  54. return (dx,)
  55. else:
  56. def bprop(x, out, dout):
  57. if F.issubclass_(F.typeof(dout), mstype.tensor):
  58. dx = all_reduce_grad(dout)
  59. z = equal(x, out)
  60. z = cast(z, dtype(dx))
  61. dx = mul(dx, z)
  62. else:
  63. indices = all_gather(dout.indices)
  64. grad = all_gather(dout.values)
  65. z = equal(x, out)
  66. z = cast(z, dtype(grad))
  67. grad = mul(grad, z)
  68. dx = RowTensor(indices, grad, dout.dense_shape)
  69. return (dx,)
  70. return bprop
  71. @bprop_getters.register(Send)
  72. def get_bprop_send(self):
  73. """Generate bprop for Send."""
  74. shape = self.get_attr_dict()["shape"]
  75. dtype = self.get_attr_dict()["dtype"]
  76. send_grad = Receive(self.sr_tag, self.rank, shape, dtype, self.group)
  77. send_grad.add_prim_attr("backward", True)
  78. def bprop(x, out, dout):
  79. dx = send_grad()
  80. return (dx,)
  81. return bprop
  82. @bprop_getters.register(Receive)
  83. def get_bprop_receive(self):
  84. """Generate bprop for Receive."""
  85. receive_grad = Send(self.tag, self.rank, self.group)
  86. receive_grad.add_prim_attr("backward", True)
  87. depend = P.Depend()
  88. cast = P.Cast()
  89. def bprop(x, out, dout):
  90. send_out = receive_grad(dout)
  91. dx = depend(cast(zeros_like(x), F.dtype(x)), send_out)
  92. return (dx,)
  93. return bprop
  94. @bprop_getters.register(Broadcast)
  95. def get_bprop_broad_cast(self):
  96. """Generate bprop for Broadcast."""
  97. def bprop(x, out, dout):
  98. return (dout,)
  99. return bprop
  100. @bprop_getters.register(AllGather)
  101. def get_bprop_all_gather(self):
  102. """Generate bprop for AllGather"""
  103. fusion = self.get_attr_dict()["fusion"]
  104. if fusion == 0:
  105. reduce_scatter = ReduceScatter(ReduceOp.SUM, self.group)
  106. if self.instance_name:
  107. instance_name = "grad_" + self.instance_name
  108. reduce_scatter.set_prim_instance_name(instance_name)
  109. else:
  110. all_reduce = AllReduce(ReduceOp.SUM, self.group).add_prim_attr("fusion", fusion)
  111. if self.instance_name:
  112. instance_name = "grad_" + self.instance_name
  113. all_reduce.set_prim_instance_name(instance_name)
  114. rank = get_rank(self.group)
  115. dev_num = get_group_size(self.group)
  116. split = P.Split(output_num=dev_num)
  117. mean_flag = self.get_attr_dict()["mean_flag"]
  118. scale = 1/self.rank_size
  119. def bprop(x, out, dout):
  120. if fusion == 0:
  121. dx = reduce_scatter(dout)
  122. else:
  123. grad = all_reduce(dout)
  124. dx = split(grad)[rank]
  125. if mean_flag:
  126. dx = F.tensor_mul(dx, scale)
  127. return (dx,)
  128. return bprop
  129. @bprop_getters.register(_HostAllGather)
  130. def get_bprop_host_all_gather(self):
  131. """Generate bprop for _HostAllGather"""
  132. host_all_gather_grad = _HostReduceScatter(ReduceOp.SUM, self.group)
  133. if self.instance_name:
  134. instance_name = "grad" + self.instance_name
  135. host_all_gather_grad.set_prim_instance_name(instance_name)
  136. def bprop(x, out, dout):
  137. dx = host_all_gather_grad(dout)
  138. return (dx,)
  139. return bprop
  140. @bprop_getters.register(ReduceScatter)
  141. def get_bprop_reduce_scatter(self):
  142. """Generate bprop for ReduceScatter"""
  143. reduce_scatter_grad = AllGather(self.group)
  144. if self.instance_name:
  145. instance_name = "grad" + self.instance_name
  146. reduce_scatter_grad.set_prim_instance_name(instance_name)
  147. if self.op != ReduceOp.SUM:
  148. raise RuntimeError("The reducescatter bprop only support ReduceOp.SUM until now.")
  149. def bprop(x, out, dout):
  150. dx = reduce_scatter_grad(dout)
  151. return (dx,)
  152. return bprop
  153. @bprop_getters.register(AllSwap)
  154. def get_bprop_allswap(self):
  155. """Generate bprop for AllSwap."""
  156. all_swap_grad = AllSwap(self.group)
  157. if self.instance_name:
  158. instance_name = "grad" + self.instance_name
  159. all_swap_grad.set_prim_instance_name(instance_name)
  160. def bprop(x, send_size, recv_size, out, dout):
  161. dx = all_swap_grad(dout, recv_size, send_size)
  162. return (dx, zeros_like(send_size), zeros_like(recv_size))
  163. return bprop
  164. @bprop_getters.register(_HostReduceScatter)
  165. def get_bprop_host_reduce_scatter(self):
  166. """Generate bprop for _HostReduceScatter"""
  167. host_reduce_scatter_grad = _HostAllGather(self.group)
  168. if self.instance_name:
  169. instance_name = "grad" + self.instance_name
  170. host_reduce_scatter_grad.set_prim_instance_name(instance_name)
  171. if self.op != ReduceOp.SUM:
  172. raise RuntimeError("The hostreducescatter bprop only support ReduceOp.SUM until now.")
  173. def bprop(x, out, dout):
  174. dx = host_reduce_scatter_grad(dout)
  175. return (dx,)
  176. return bprop
  177. @bprop_getters.register(_AlltoAll)
  178. def get_bprop_all_to_all(self):
  179. """Generate bprop for AlltoAll."""
  180. all_to_all_grad = _AlltoAll(self.split_count, self.concat_dim, self.split_dim, self.group)
  181. if self.instance_name:
  182. instance_name = "grad" + self.instance_name
  183. all_to_all_grad.set_prim_instance_name(instance_name)
  184. def bprop(x, out, dout):
  185. dx = all_to_all_grad(dout)
  186. return (dx,)
  187. return bprop
  188. @bprop_getters.register(_MirrorOperator)
  189. def get_bprop_mirror_operator(self):
  190. """
  191. Backpropagator for _MirrorOperator, do allreduce or allgather for the devices in group(only for one group),
  192. allgather for sparse feature.
  193. """
  194. group = self.group
  195. dev_num = self.dev_num
  196. mean_flag = self.mean_flag
  197. all_reduce = AllReduce(group=group)
  198. all_gather = AllGather(group=group)
  199. mul = P.Mul()
  200. cast = P.Cast()
  201. fusion = self.get_attr_dict()["fusion"]
  202. all_reduce.add_prim_attr("fusion", fusion)
  203. if hasattr(self, 'parameter'):
  204. parameter = self.parameter
  205. all_reduce.add_prim_attr("parameter", parameter)
  206. if self.instance_name:
  207. instance_name = "grad_mirror" + self.instance_name
  208. all_reduce.set_prim_instance_name(instance_name)
  209. def bprop(x, out, dout):
  210. if mean_flag:
  211. if F.issubclass_(F.typeof(dout), mstype.tensor):
  212. dx = all_reduce(dout)
  213. float_one = F.scalar_cast(1.0, F.dtype(dx))
  214. num = F.scalar_cast(dev_num, F.dtype(dx))
  215. dx = mul(dx, cast(F.scalar_to_array(float_one/num), F.dtype(dx)))
  216. else:
  217. indices = all_gather(dout.indices)
  218. grad = all_gather(dout.values)
  219. float_one = F.scalar_cast(1.0, F.dtype(grad))
  220. num = F.scalar_cast(dev_num, F.dtype(grad))
  221. grad = mul(grad, cast(F.scalar_to_array(float_one/num), F.dtype(grad)))
  222. dx = RowTensor(indices, grad, dout.dense_shape)
  223. else:
  224. if F.issubclass_(F.typeof(dout), mstype.tensor):
  225. dx = all_reduce(dout)
  226. else:
  227. indices = all_gather(dout.indices)
  228. grad = all_gather(dout.values)
  229. dx = RowTensor(indices, grad, dout.dense_shape)
  230. return (dx,)
  231. return bprop
  232. @bprop_getters.register(_MirrorMiniStepOperator)
  233. def get_bprop_mirror_mini_step_operator(self):
  234. """
  235. Backpropagator for _MirrorMiniStepOperator, do allreduce or allgather for the devices in the group,
  236. allgather for sparse feature.
  237. """
  238. group = self.group
  239. dev_num = self.dev_num
  240. mean_flag = self.mean_flag
  241. grad_accumulation_step = self.grad_accumulation_step
  242. all_reduce = AllReduce(group=group)
  243. all_gather = AllGather(group=group)
  244. mul = P.Mul()
  245. cast = P.Cast()
  246. equal = P.Equal()
  247. reshape = P.Reshape()
  248. fusion = 1
  249. if hasattr(self, 'fusion'):
  250. fusion = self.fusion
  251. all_reduce.add_prim_attr("fusion", fusion)
  252. if hasattr(self, 'parameter'):
  253. parameter = self.parameter
  254. all_reduce.add_prim_attr("parameter", parameter)
  255. if self.instance_name:
  256. instance_name = "grad_mirror" + self.instance_name
  257. all_reduce.set_prim_instance_name(instance_name)
  258. def bprop(x, y, z, out, dout):
  259. do_mirror = equal(y, grad_accumulation_step)
  260. do_mirror = reshape(do_mirror, (()))
  261. if mean_flag:
  262. if F.issubclass_(F.typeof(dout), mstype.tensor):
  263. if do_mirror:
  264. tmp = z + dout
  265. real_grad = all_reduce(tmp)
  266. dx = real_grad - z
  267. else:
  268. dx = dout
  269. float_one = F.scalar_cast(1.0, F.dtype(dx))
  270. num = F.scalar_cast(dev_num, F.dtype(dx))
  271. dx = mul(dx, cast(F.scalar_to_array(float_one/num), F.dtype(dx)))
  272. else:
  273. if do_mirror:
  274. indices = all_gather(dout.indices)
  275. grad = all_gather(dout.values)
  276. else:
  277. indices = dout.indices
  278. grad = dout.values
  279. float_one = F.scalar_cast(1.0, F.dtype(grad))
  280. num = F.scalar_cast(dev_num, F.dtype(grad))
  281. grad = mul(grad, cast(F.scalar_to_array(float_one/num), F.dtype(grad)))
  282. dx = RowTensor(indices, grad, dout.dense_shape)
  283. else:
  284. if F.issubclass_(F.typeof(dout), mstype.tensor):
  285. if do_mirror:
  286. tmp = z + dout
  287. real_grad = all_reduce(tmp)
  288. dx = real_grad - z
  289. else:
  290. dx = dout
  291. else:
  292. if do_mirror:
  293. indices = all_gather(dout.indices)
  294. grad = all_gather(dout.values)
  295. else:
  296. indices = dout.indices
  297. grad = dout.values
  298. dx = RowTensor(indices, grad, dout.dense_shape)
  299. return (dx, zeros_like(y), zeros_like(z))
  300. return bprop
  301. @bprop_getters.register(_VirtualDiv)
  302. def get_bprop_virtual_div_operator(self):
  303. """Backpropagator for _VirtualDiv, do Div for the divisor."""
  304. divisor = self.divisor
  305. op = P.RealDiv()
  306. cast = P.Cast()
  307. dtype = P.DType()
  308. def bprop(x, out, dout):
  309. if F.issubclass_(F.typeof(dout), mstype.tensor):
  310. if F.issubclass_(F.dtype(dout), mstype.bool_) or F.issubclass_(F.dtype(dout), mstype.int32) \
  311. or F.issubclass_(F.dtype(dout), mstype.int16):
  312. return (dout,)
  313. dx = op(dout, cast(F.scalar_to_array(divisor), dtype(dout)))
  314. return (dx,)
  315. if F.issubclass_(F.typeof(dout), mstype.tuple_):
  316. dx = ()
  317. input_nums = F.tuple_len(dout)
  318. for i in range(input_nums):
  319. ele_grad = op(dout[i], cast(F.scalar_to_array(divisor), dtype(dout[i])))
  320. dx = dx + (ele_grad,)
  321. return (dx,)
  322. dx = []
  323. input_nums = F.list_len(dout)
  324. for i in range(input_nums):
  325. ele_grad = op(dout[i], cast(F.scalar_to_array(divisor), dtype(dout[i])))
  326. dx.append(ele_grad)
  327. return (dx,)
  328. return bprop
  329. @bprop_getters.register(_GetTensorSlice)
  330. def get_bprop_get_tensor_slice_operator(self):
  331. """Backpropagator for _GetTensorSlice"""
  332. def bprop(x, dev_mat, tensor_map, out, dout):
  333. return (zeros_like(x),)
  334. return bprop