Browse Source

add the reshape part of the embeddinglookup backward operator

tags/v0.5.0-beta
Xiaoda Zhang 6 years ago
parent
commit
1cfb52bc0e
4 changed files with 39 additions and 7 deletions
  1. +1
    -1
      mindspore/ccsrc/parallel/ops_info/ops_utils.h
  2. +32
    -0
      mindspore/ops/operations/_grad_ops.py
  3. +1
    -1
      mindspore/ops/operations/array_ops.py
  4. +5
    -5
      tests/ut/python/parallel/test_gather_v2.py

+ 1
- 1
mindspore/ccsrc/parallel/ops_info/ops_utils.h View File

@@ -76,7 +76,7 @@ constexpr char DEPEND[] = "depend";
constexpr char BATCH_PARALLEL[] = "BatchParallel";

constexpr char ACTIVATION_TYPE[] = "activation_type";
constexpr char TARGET[] = "target";
constexpr char TARGET[] = "primitive_target";
constexpr char CPU[] = "CPU";
constexpr char TRANSPOSE_A[] = "transpose_a";
constexpr char TRANSPOSE_B[] = "transpose_b";


+ 32
- 0
mindspore/ops/operations/_grad_ops.py View File

@@ -21,6 +21,7 @@ from ..primitive import Primitive, PrimitiveWithInfer, prim_attr_register
from ..._checkparam import Validator as validator, Rel
from .._utils import get_concat_offset
from ...common import dtype as mstype
from .. import functional as F


class AbsGrad(PrimitiveWithInfer):
@@ -1121,6 +1122,37 @@ class MirrorPadGrad(PrimitiveWithInfer):
'value': None}


class EmbeddingLookupCommGrad(PrimitiveWithInfer):
"""
Perform the gradient for the communication part of EmbeddingLookup operator.

This works ONLY when 'reduce_scatter_flag' is True in 'EmbeddingLookup'. Roughly speaking,
this primitive is implemented by StridedSlice --> HostAllGather --> Concat. This primitive runs on host.
"""
@prim_attr_register
def __init__(self):
self.init_prim_io_names(inputs=['dy', 'split_num'], outputs=['output'])
self.add_prim_attr('primitive_target', 'CPU')

def __infer__(self, dy, split_num):
"""
This primitive is implemented by three steps:
1) Split the 'dy' along dimension 0 into 'split_num' parts.
2) For each part, perform HostAllGather((0, 1, 2, 3, 4, 5, 6, 7)) on the host.
3) After HostAllGather, there are still 'split_num' parts in each process. Then, perform Concat on them
along dimension 0.

The output shape of this primitive: shape(output)[0] == shape(dy)[0] * 8
"""
dy_shape = tuple(dy['shape'])
split_num_value = split_num['value']
validator.check_value_type("split_num_value", split_num_value, [int], self.name)
dy_shape_all = F.tuple_setitem(dy_shape, 0, dy_shape[0] * 8)
return {'shape': dy_shape_all,
'dtype': dy['dtype'],
'value': None}


class RefToEmbed(Primitive):
r"""
Make a key from Ref.


+ 1
- 1
mindspore/ops/operations/array_ops.py View File

@@ -614,7 +614,7 @@ class EmbeddingLookup(PrimitiveWithInfer):
self.__setattr_flag__ = True
self.init_prim_io_names(inputs=['params', 'indices', 'axis', 'offset', 'reduce_scatter_flag', 'split_num'],
outputs=['output'])
self.add_prim_attr('target', 'CPU')
self.add_prim_attr('primitive_target', 'CPU')

def __infer__(self, params, indices, axis, offset, reduce_scatter_flag=False, split_num=2):
validator.check_subclass("params", params['dtype'], mstype.tensor, self.name)


+ 5
- 5
tests/ut/python/parallel/test_gather_v2.py View File

@@ -45,11 +45,11 @@ class GradWrap(nn.Cell):


class Net(nn.Cell):
def __init__(self, axis=0, strategy1=None, strategy2=None, shape=None):
def __init__(self, axis=0, strategy1=None, strategy2=None, shape=None, target=""):
super().__init__()
if shape is None:
shape = [64, 64]
self.gatherv2 = P.GatherV2().set_strategy(strategy1)
self.gatherv2 = P.GatherV2().set_strategy(strategy1).add_prim_attr("primitive_target", target)
self.mul = P.Mul().set_strategy(strategy2)
self.index = Tensor(np.ones(shape), dtype=ms.int32)
self.axis = axis
@@ -188,7 +188,7 @@ def test_gatherv2_cpu0():
context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel")
strategy1 = ((8, 1), (1, 1))
strategy2 = ((4, 2, 1), (4, 2, 1))
net = NetWithLoss(Net(0, strategy1, strategy2))
net = NetWithLoss(Net(0, strategy1, strategy2, None, "CPU"))
net.set_auto_parallel()

x = Tensor(np.ones([64, 64]), dtype=ms.float32)
@@ -200,7 +200,7 @@ def test_gatherv2_cpu1():
context.set_auto_parallel_context(device_num=16, global_rank=0, parallel_mode="semi_auto_parallel")
strategy1 = ((16, 1), (1, 1))
strategy2 = ((4, 2, 1), (4, 2, 1))
net = NetWithLoss(Net(0, strategy1, strategy2))
net = NetWithLoss(Net(0, strategy1, strategy2, None, "CPU"))
net.set_auto_parallel()

x = Tensor(np.ones([64, 64]), dtype=ms.float32)
@@ -212,7 +212,7 @@ def test_gatherv2_cpu2():
context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel")
strategy1 = ((1, 8), (1, 1))
strategy2 = ((4, 2, 1), (4, 2, 1))
net = NetWithLoss(Net(0, strategy1, strategy2))
net = NetWithLoss(Net(0, strategy1, strategy2, None, "CPU"))
net.set_auto_parallel()

x = Tensor(np.ones([64, 64]), dtype=ms.float32)


Loading…
Cancel
Save