| @@ -76,7 +76,7 @@ constexpr char DEPEND[] = "depend"; | |||
| constexpr char BATCH_PARALLEL[] = "BatchParallel"; | |||
| constexpr char ACTIVATION_TYPE[] = "activation_type"; | |||
| constexpr char TARGET[] = "target"; | |||
| constexpr char TARGET[] = "primitive_target"; | |||
| constexpr char CPU[] = "CPU"; | |||
| constexpr char TRANSPOSE_A[] = "transpose_a"; | |||
| constexpr char TRANSPOSE_B[] = "transpose_b"; | |||
| @@ -21,6 +21,7 @@ from ..primitive import Primitive, PrimitiveWithInfer, prim_attr_register | |||
| from ..._checkparam import Validator as validator, Rel | |||
| from .._utils import get_concat_offset | |||
| from ...common import dtype as mstype | |||
| from .. import functional as F | |||
| class AbsGrad(PrimitiveWithInfer): | |||
| @@ -1121,6 +1122,37 @@ class MirrorPadGrad(PrimitiveWithInfer): | |||
| 'value': None} | |||
| class EmbeddingLookupCommGrad(PrimitiveWithInfer): | |||
| """ | |||
| Perform the gradient for the communication part of EmbeddingLookup operator. | |||
| This works ONLY when 'reduce_scatter_flag' is True in 'EmbeddingLookup'. Roughly speaking, | |||
| this primitive is implemented by StridedSlice --> HostAllGather --> Concat. This primitive runs on host. | |||
| """ | |||
| @prim_attr_register | |||
| def __init__(self): | |||
| self.init_prim_io_names(inputs=['dy', 'split_num'], outputs=['output']) | |||
| self.add_prim_attr('primitive_target', 'CPU') | |||
| def __infer__(self, dy, split_num): | |||
| """ | |||
| This primitive is implemented by three steps: | |||
| 1) Split the 'dy' along dimension 0 into 'split_num' parts. | |||
| 2) For each part, perform HostAllGather((0, 1, 2, 3, 4, 5, 6, 7)) on the host. | |||
| 3) After HostAllGather, there are still 'split_num' parts in each process. Then, perform Concat on them | |||
| along dimension 0. | |||
| The output shape of this primitive: shape(output)[0] == shape(dy)[0] * 8 | |||
| """ | |||
| dy_shape = tuple(dy['shape']) | |||
| split_num_value = split_num['value'] | |||
| validator.check_value_type("split_num_value", split_num_value, [int], self.name) | |||
| dy_shape_all = F.tuple_setitem(dy_shape, 0, dy_shape[0] * 8) | |||
| return {'shape': dy_shape_all, | |||
| 'dtype': dy['dtype'], | |||
| 'value': None} | |||
| class RefToEmbed(Primitive): | |||
| r""" | |||
| Make a key from Ref. | |||
| @@ -614,7 +614,7 @@ class EmbeddingLookup(PrimitiveWithInfer): | |||
| self.__setattr_flag__ = True | |||
| self.init_prim_io_names(inputs=['params', 'indices', 'axis', 'offset', 'reduce_scatter_flag', 'split_num'], | |||
| outputs=['output']) | |||
| self.add_prim_attr('target', 'CPU') | |||
| self.add_prim_attr('primitive_target', 'CPU') | |||
| def __infer__(self, params, indices, axis, offset, reduce_scatter_flag=False, split_num=2): | |||
| validator.check_subclass("params", params['dtype'], mstype.tensor, self.name) | |||
| @@ -45,11 +45,11 @@ class GradWrap(nn.Cell): | |||
| class Net(nn.Cell): | |||
| def __init__(self, axis=0, strategy1=None, strategy2=None, shape=None): | |||
| def __init__(self, axis=0, strategy1=None, strategy2=None, shape=None, target=""): | |||
| super().__init__() | |||
| if shape is None: | |||
| shape = [64, 64] | |||
| self.gatherv2 = P.GatherV2().set_strategy(strategy1) | |||
| self.gatherv2 = P.GatherV2().set_strategy(strategy1).add_prim_attr("primitive_target", target) | |||
| self.mul = P.Mul().set_strategy(strategy2) | |||
| self.index = Tensor(np.ones(shape), dtype=ms.int32) | |||
| self.axis = axis | |||
| @@ -188,7 +188,7 @@ def test_gatherv2_cpu0(): | |||
| context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel") | |||
| strategy1 = ((8, 1), (1, 1)) | |||
| strategy2 = ((4, 2, 1), (4, 2, 1)) | |||
| net = NetWithLoss(Net(0, strategy1, strategy2)) | |||
| net = NetWithLoss(Net(0, strategy1, strategy2, None, "CPU")) | |||
| net.set_auto_parallel() | |||
| x = Tensor(np.ones([64, 64]), dtype=ms.float32) | |||
| @@ -200,7 +200,7 @@ def test_gatherv2_cpu1(): | |||
| context.set_auto_parallel_context(device_num=16, global_rank=0, parallel_mode="semi_auto_parallel") | |||
| strategy1 = ((16, 1), (1, 1)) | |||
| strategy2 = ((4, 2, 1), (4, 2, 1)) | |||
| net = NetWithLoss(Net(0, strategy1, strategy2)) | |||
| net = NetWithLoss(Net(0, strategy1, strategy2, None, "CPU")) | |||
| net.set_auto_parallel() | |||
| x = Tensor(np.ones([64, 64]), dtype=ms.float32) | |||
| @@ -212,7 +212,7 @@ def test_gatherv2_cpu2(): | |||
| context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel") | |||
| strategy1 = ((1, 8), (1, 1)) | |||
| strategy2 = ((4, 2, 1), (4, 2, 1)) | |||
| net = NetWithLoss(Net(0, strategy1, strategy2)) | |||
| net = NetWithLoss(Net(0, strategy1, strategy2, None, "CPU")) | |||
| net.set_auto_parallel() | |||
| x = Tensor(np.ones([64, 64]), dtype=ms.float32) | |||