You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

grad_comm_ops.py 8.7 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """Generate bprop for comm ops"""
  16. import mindspore.common.dtype as mstype
  17. from mindspore.ops import functional as F
  18. from .. import operations as P
  19. from ...common.tensor import RowTensor
  20. from ..composite.multitype_ops.zeros_like_impl import zeros_like
  21. from ..operations.comm_ops import (AllGather, _HostAllGather, AllReduce, _AlltoAll, Broadcast,
  22. _GetTensorSlice, _MirrorOperator, ReduceOp,
  23. ReduceScatter, _HostReduceScatter, _VirtualDiv)
  24. from .grad_base import bprop_getters
  25. @bprop_getters.register(AllReduce)
  26. def get_bprop_all_reduce(self):
  27. """Generate bprop for AllReduce, do allreduce or allgather, allgather for sparse feature."""
  28. all_reduce_grad = AllReduce(ReduceOp.SUM, self.group)
  29. all_gather = AllGather(group=self.group)
  30. if self.instance_name:
  31. instance_name = "grad" + self.instance_name
  32. all_reduce_grad.set_prim_instance_name(instance_name)
  33. equal = P.Equal()
  34. cast = P.Cast()
  35. mul = P.Mul()
  36. dtype = P.DType()
  37. if self.op == ReduceOp.PROD:
  38. raise RuntimeError("The bprop of ReduceOp.PROD is not supported yet.")
  39. if self.op == ReduceOp.SUM:
  40. def bprop(x, out, dout):
  41. if F.issubclass_(F.typeof(dout), mstype.tensor):
  42. dx = all_reduce_grad(dout)
  43. else:
  44. indices = all_gather(dout.indices)
  45. grad = all_gather(dout.values)
  46. dx = RowTensor(indices, grad, dout.dense_shape)
  47. return (dx,)
  48. else:
  49. def bprop(x, out, dout):
  50. if F.issubclass_(F.typeof(dout), mstype.tensor):
  51. dx = all_reduce_grad(dout)
  52. z = equal(x, out)
  53. z = cast(z, dtype(dx))
  54. dx = mul(dx, z)
  55. else:
  56. indices = all_gather(dout.indices)
  57. grad = all_gather(dout.values)
  58. z = equal(x, out)
  59. z = cast(z, dtype(grad))
  60. grad = mul(grad, z)
  61. dx = RowTensor(indices, grad, dout.dense_shape)
  62. return (dx,)
  63. return bprop
  64. @bprop_getters.register(Broadcast)
  65. def get_bprop_broad_cast(self):
  66. """Generate bprop for Broadcast."""
  67. def bprop(x, out, dout):
  68. return (dout,)
  69. return bprop
  70. @bprop_getters.register(AllGather)
  71. def get_bprop_all_gather(self):
  72. """Generate bprop for AllGather"""
  73. all_gather_grad = ReduceScatter(ReduceOp.SUM, self.group)
  74. fusion = self.get_attr_dict()["fusion"]
  75. all_gather_grad.add_prim_attr("fusion", fusion)
  76. if self.instance_name:
  77. instance_name = "grad_" + self.instance_name
  78. all_gather_grad.set_prim_instance_name(instance_name)
  79. def bprop(x, out, dout):
  80. dx = all_gather_grad(dout)
  81. return (dx,)
  82. return bprop
  83. @bprop_getters.register(_HostAllGather)
  84. def get_bprop_host_all_gather(self):
  85. """Generate bprop for _HostAllGather"""
  86. host_all_gather_grad = _HostReduceScatter(ReduceOp.SUM, self.group)
  87. if self.instance_name:
  88. instance_name = "grad" + self.instance_name
  89. host_all_gather_grad.set_prim_instance_name(instance_name)
  90. def bprop(x, out, dout):
  91. dx = host_all_gather_grad(dout)
  92. return (dx,)
  93. return bprop
  94. @bprop_getters.register(ReduceScatter)
  95. def get_bprop_reduce_scatter(self):
  96. """Generate bprop for ReduceScatter"""
  97. reduce_scatter_grad = AllGather(self.group)
  98. if self.instance_name:
  99. instance_name = "grad" + self.instance_name
  100. reduce_scatter_grad.set_prim_instance_name(instance_name)
  101. if self.op != ReduceOp.SUM:
  102. raise RuntimeError("The reducescatter bprop only support ReduceOp.SUM until now.")
  103. def bprop(x, out, dout):
  104. dx = reduce_scatter_grad(dout)
  105. return (dx,)
  106. return bprop
  107. @bprop_getters.register(_HostReduceScatter)
  108. def get_bprop_host_reduce_scatter(self):
  109. """Generate bprop for _HostReduceScatter"""
  110. host_reduce_scatter_grad = _HostAllGather(self.group)
  111. if self.instance_name:
  112. instance_name = "grad" + self.instance_name
  113. host_reduce_scatter_grad.set_prim_instance_name(instance_name)
  114. if self.op != ReduceOp.SUM:
  115. raise RuntimeError("The hostreducescatter bprop only support ReduceOp.SUM until now.")
  116. def bprop(x, out, dout):
  117. dx = host_reduce_scatter_grad(dout)
  118. return (dx,)
  119. return bprop
  120. @bprop_getters.register(_AlltoAll)
  121. def get_bprop_all_to_all(self):
  122. """Generate bprop for AlltoAll."""
  123. all_to_all_grad = _AlltoAll(self.split_count, self.concat_dim, self.split_dim, self.group)
  124. if self.instance_name:
  125. instance_name = "grad" + self.instance_name
  126. all_to_all_grad.set_prim_instance_name(instance_name)
  127. def bprop(x, out, dout):
  128. dx = all_to_all_grad(dout)
  129. return (dx,)
  130. return bprop
  131. @bprop_getters.register(_MirrorOperator)
  132. def get_bprop_mirror_operator(self):
  133. """
  134. Backpropagator for _MirrorOperator, do allreduce or allgather for the devices in group(only for one group),
  135. allgather for sparse feature.
  136. """
  137. group = self.group
  138. dev_num = self.dev_num
  139. mean_flag = self.mean_flag
  140. all_reduce = AllReduce(group=group)
  141. all_gather = AllGather(group=group)
  142. mul = P.Mul()
  143. cast = P.Cast()
  144. fusion = 1
  145. if hasattr(self, 'fusion'):
  146. fusion = self.fusion
  147. all_reduce.add_prim_attr("fusion", fusion)
  148. if hasattr(self, 'parameter'):
  149. parameter = self.parameter
  150. all_reduce.add_prim_attr("parameter", parameter)
  151. if self.instance_name:
  152. instance_name = "grad_mirror" + self.instance_name
  153. all_reduce.set_prim_instance_name(instance_name)
  154. def bprop(x, out, dout):
  155. if mean_flag:
  156. if F.issubclass_(F.typeof(dout), mstype.tensor):
  157. dx = all_reduce(dout)
  158. float_one = F.scalar_cast(1.0, F.dtype(dx))
  159. num = F.scalar_cast(dev_num, F.dtype(dx))
  160. dx = mul(dx, cast(F.scalar_to_array(float_one/num), F.dtype(dx)))
  161. else:
  162. indices = all_gather(dout.indices)
  163. grad = all_gather(dout.values)
  164. float_one = F.scalar_cast(1.0, F.dtype(grad))
  165. num = F.scalar_cast(dev_num, F.dtype(grad))
  166. grad = mul(grad, cast(F.scalar_to_array(float_one/num), F.dtype(grad)))
  167. dx = RowTensor(indices, grad, dout.dense_shape)
  168. else:
  169. if F.issubclass_(F.typeof(dout), mstype.tensor):
  170. dx = all_reduce(dout)
  171. else:
  172. indices = all_gather(dout.indices)
  173. grad = all_gather(dout.values)
  174. dx = RowTensor(indices, grad, dout.dense_shape)
  175. return (dx,)
  176. return bprop
  177. @bprop_getters.register(_VirtualDiv)
  178. def get_bprop_virtual_div_operator(self):
  179. """Backpropagator for _VirtualDiv, do Div for the divisor."""
  180. divisor = self.divisor
  181. op = P.RealDiv()
  182. cast = P.Cast()
  183. dtype = P.DType()
  184. def bprop(x, out, dout):
  185. if F.issubclass_(F.typeof(dout), mstype.tensor):
  186. if F.issubclass_(F.dtype(dout), mstype.bool_):
  187. return (dout,)
  188. dx = op(dout, cast(F.scalar_to_array(divisor), dtype(dout)))
  189. return (dx,)
  190. if F.issubclass_(F.typeof(dout), mstype.tuple_):
  191. dx = ()
  192. input_nums = F.tuple_len(dout)
  193. for i in range(input_nums):
  194. ele_grad = op(dout[i], cast(F.scalar_to_array(divisor), dtype(dout[i])))
  195. dx = dx + (ele_grad,)
  196. return (dx,)
  197. dx = []
  198. input_nums = F.list_len(dout)
  199. for i in range(input_nums):
  200. ele_grad = op(dout[i], cast(F.scalar_to_array(divisor), dtype(dout[i])))
  201. dx.append(ele_grad)
  202. return (dx,)
  203. return bprop
  204. @bprop_getters.register(_GetTensorSlice)
  205. def get_bprop_get_tensor_slice_operator(self):
  206. """Backpropagator for _GetTensorSlice"""
  207. def bprop(x, dev_mat, tensor_map, out, dout):
  208. return (zeros_like(x),)
  209. return bprop