You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

grad_comm_ops.py 5.4 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """Generate bprop for comm ops"""
  16. import mindspore.common.dtype as mstype
  17. from mindspore.ops import functional as F
  18. from .. import operations as P
  19. from ..composite.multitype_ops.zeros_like_impl import zeros_like
  20. from ..operations.comm_ops import (AllGather, AllReduce, _AlltoAll, Broadcast,
  21. _GetTensorSlice, _MirrorOperator, ReduceOp,
  22. ReduceScatter, _VirtualDiv)
  23. from .grad_base import bprop_getters
  24. @bprop_getters.register(AllReduce)
  25. def get_bprop_all_reduce(self):
  26. """Generate bprop for AllReduce."""
  27. all_reduce_grad = AllReduce(ReduceOp.SUM, self.group)
  28. if self.instance_name:
  29. instance_name = "grad" + self.instance_name
  30. all_reduce_grad.set_prim_instance_name(instance_name)
  31. equal = P.Equal()
  32. cast = P.Cast()
  33. mul = P.Mul()
  34. dtype = P.DType()
  35. if self.op == ReduceOp.PROD:
  36. raise RuntimeError("The bprop of ReduceOp.PROD is not supported yet.")
  37. if self.op == ReduceOp.SUM:
  38. def bprop(x, out, dout):
  39. dx = all_reduce_grad(dout)
  40. return (dx,)
  41. else:
  42. def bprop(x, out, dout):
  43. dx = all_reduce_grad(dout)
  44. z = equal(x, out)
  45. z = cast(z, dtype(dx))
  46. dx = mul(dx, z)
  47. return (dx,)
  48. return bprop
  49. @bprop_getters.register(Broadcast)
  50. def get_bprop_broad_cast(self):
  51. """Generate bprop for Broadcast."""
  52. def bprop(x, out, dout):
  53. return (dout,)
  54. return bprop
  55. @bprop_getters.register(AllGather)
  56. def get_bprop_all_gather(self):
  57. """Generate bprop for AllGather"""
  58. all_gather_grad = ReduceScatter(ReduceOp.SUM, self.group)
  59. if self.instance_name:
  60. instance_name = "grad" + self.instance_name
  61. all_gather_grad.set_prim_instance_name(instance_name)
  62. def bprop(x, out, dout):
  63. dx = all_gather_grad(dout)
  64. return (dx,)
  65. return bprop
  66. @bprop_getters.register(ReduceScatter)
  67. def get_bprop_reduce_scatter(self):
  68. """Generate bprop for ReduceScatter"""
  69. reduce_scatter_grad = AllGather(self.group)
  70. if self.instance_name:
  71. instance_name = "grad" + self.instance_name
  72. reduce_scatter_grad.set_prim_instance_name(instance_name)
  73. if self.op != ReduceOp.SUM:
  74. raise RuntimeError("The reducescatter bprop only support ReduceOp.SUM until now.")
  75. def bprop(x, out, dout):
  76. dx = reduce_scatter_grad(dout)
  77. return (dx,)
  78. return bprop
  79. @bprop_getters.register(_AlltoAll)
  80. def get_bprop_all_to_all(self):
  81. """Generate bprop for AlltoAll."""
  82. all_to_all_grad = _AlltoAll(self.split_count, self.concat_dim, self.split_dim, self.group)
  83. if self.instance_name:
  84. instance_name = "grad" + self.instance_name
  85. all_to_all_grad.set_prim_instance_name(instance_name)
  86. def bprop(x, out, dout):
  87. dx = all_to_all_grad(dout)
  88. return (dx,)
  89. return bprop
  90. @bprop_getters.register(_MirrorOperator)
  91. def get_bprop_mirror_operator(self):
  92. """Backpropagator for _MirrorOperator, do allreduce for the devices in group(only for one group)."""
  93. group = self.group
  94. dev_num = self.dev_num
  95. mean_flag = self.mean_flag
  96. all_reduce = AllReduce(group=group)
  97. mul = P.Mul()
  98. cast = P.Cast()
  99. fusion = 1
  100. if hasattr(self, 'fusion'):
  101. fusion = self.fusion
  102. all_reduce.add_prim_attr("fusion", fusion)
  103. if hasattr(self, 'parameter'):
  104. parameter = self.parameter
  105. all_reduce.add_prim_attr("parameter", parameter)
  106. if self.instance_name:
  107. instance_name = "grad_mirror" + self.instance_name
  108. all_reduce.set_prim_instance_name(instance_name)
  109. def bprop(x, out, dout):
  110. if mean_flag:
  111. dx = all_reduce(dout)
  112. float_one = F.scalar_cast(1.0, F.dtype(dx))
  113. num = F.scalar_cast(dev_num, F.dtype(dx))
  114. dx = mul(dx, cast(F.scalar_to_array(float_one/num), F.dtype(dx)))
  115. else:
  116. dx = all_reduce(dout)
  117. return (dx,)
  118. return bprop
  119. @bprop_getters.register(_VirtualDiv)
  120. def get_bprop_virtual_div_operator(self):
  121. """Backpropagator for _VirtualDiv, do Div for the divisor."""
  122. divisor = self.divisor
  123. op = P.RealDiv()
  124. cast = P.Cast()
  125. dtype = P.DType()
  126. def bprop(x, out, dout):
  127. if F.issubclass_(F.dtype(dout), mstype.bool_):
  128. return (dout,)
  129. dx = op(dout, cast(F.scalar_to_array(divisor), dtype(dout)))
  130. return (dx,)
  131. return bprop
  132. @bprop_getters.register(_GetTensorSlice)
  133. def get_bprop_get_tensor_slice_operator(self):
  134. """Backpropagator for _GetTensorSlice"""
  135. def bprop(x, dev_mat, tensor_map, out, dout):
  136. return (zeros_like(x),)
  137. return bprop