Browse Source

!15790 modify scatter update op

From: @yangzhenzhang
Reviewed-by: @kisnwang,@stsuteng
Signed-off-by: @stsuteng
pull/15790/MERGE
mindspore-ci-bot Gitee 4 years ago
parent
commit
78fcdbc7c9
3 changed files with 55 additions and 19 deletions
  1. +33
    -11
      mindspore/ccsrc/frontend/parallel/ops_info/scatter_update_info.cc
  2. +1
    -1
      mindspore/ccsrc/frontend/parallel/ops_info/scatter_update_info.h
  3. +21
    -7
      tests/ut/python/parallel/test_scatter_update.py

+ 33
- 11
mindspore/ccsrc/frontend/parallel/ops_info/scatter_update_info.cc View File

@@ -29,9 +29,14 @@

namespace mindspore {
namespace parallel {
// The first dimension of input can not be split.
// The indices can not be split.
// The strategy of input and the strategy of updates must be equal.
// The first dimension of input or updates can not be split.
// The first n dimensions(n is indices' dimension size) of updates can not be split.
// The shape of input: [A, B, ..., M], the strategy of input: (1, b, ..., m)
// The shape of indices: [N, O, ..., Z], the strategy of indices: (1, 1, ..., 1)
// The shape of updates: [N, O, ..., Z, B, ..., M], the strategy of updates: (1, 1, ..., 1, b, ..., m)
// The shape of output: [A, B, ..., M], the strategy of output: (1, b, ..., m)
// The dev matrix: (1, b, ..., m)
Status ScatterUpdateInfo::CheckStrategy(const StrategyPtr &strategy) {
MS_EXCEPTION_IF_NULL(strategy);
if (CheckStrategyValue(strategy, inputs_shape_) != SUCCESS) {
@@ -45,11 +50,6 @@ Status ScatterUpdateInfo::CheckStrategy(const StrategyPtr &strategy) {
return FAILED;
}

if (stra[0] != stra[2]) {
MS_LOG(ERROR) << name_ << ": The strategy[0] and strategy[2] must be equal";
return FAILED;
}

if (stra[0].empty()) {
MS_LOG(ERROR) << name_ << ": The strategy[0] is empty";
return FAILED;
@@ -65,6 +65,16 @@ Status ScatterUpdateInfo::CheckStrategy(const StrategyPtr &strategy) {
return FAILED;
}

if (stra[2].empty()) {
MS_LOG(ERROR) << name_ << ": The strategy[2] is empty";
return FAILED;
}

if (std::accumulate(stra[2].begin(), stra[2].begin() + stra[1].size(), 1, std::multiplies<int64_t>()) != 1) {
MS_LOG(ERROR) << name_ << ": The indices can not be split";
return FAILED;
}

return SUCCESS;
}

@@ -81,22 +91,28 @@ Status ScatterUpdateInfo::InferDevMatrixShape() {
}

Status ScatterUpdateInfo::InferTensorMap() {
TensorMap input_tensor_map;
TensorMap indices_tensor_map(inputs_shape_[1].size(), MAP_NONE);
if (inputs_shape_.size() != 3) {
MS_LOG(ERROR) << name_ << "The size of inputs shape must be 3";
return FAILED;
}

TensorMap input_tensor_map, updates_tensor_map;
TensorMap indices_tensor_map(inputs_shape_[1].size(), MAP_NONE);

// cannot use dev_matrix_shape_ replace inputs_shape_[0], because it may not be fully split in all devices.
int64_t size = SizeToLong(inputs_shape_[0].size());
for (int64_t i = 0; i < size; ++i) {
input_tensor_map.push_back(size - i - 1);
}

// updates_tensor_map = indices_tensor_map + input_tensor_map[1:]
updates_tensor_map = indices_tensor_map;
for (size_t i = 1; i < input_tensor_map.size(); ++i) {
updates_tensor_map.push_back(input_tensor_map[i]);
}
inputs_tensor_map_.push_back(input_tensor_map); // input
inputs_tensor_map_.push_back(indices_tensor_map); // indices
inputs_tensor_map_.push_back(input_tensor_map); // updates
inputs_tensor_map_.push_back(updates_tensor_map); // updates

outputs_tensor_map_.push_back(input_tensor_map);
return SUCCESS;
@@ -169,9 +185,15 @@ Status ScatterUpdateInfo::GenerateStrategies(int64_t stage_id) {
Strategys tmp_strategy;
Dimensions first_input_strategy = sp->GetInputDim()[0];
Dimensions indices_strategy(inputs_shape_[1].size(), 1);
// updates_strategy = indices_strategy + input_strategy[1:]
Dimensions updates_strategy = indices_strategy;
for (size_t i = 1; i < first_input_strategy.size(); ++i) {
updates_strategy.push_back(first_input_strategy[i]);
}

tmp_strategy.push_back(first_input_strategy); // input
tmp_strategy.push_back(indices_strategy); // indices
tmp_strategy.push_back(first_input_strategy); // updates
tmp_strategy.push_back(updates_strategy); // updates

sp->ResetInputs(tmp_strategy);
}


+ 1
- 1
mindspore/ccsrc/frontend/parallel/ops_info/scatter_update_info.h View File

@@ -45,7 +45,7 @@ class ScatterUpdateInfo : public OperatorInfo {
protected:
Status GetAttrs() override { return SUCCESS; }
Status CheckStrategy(const StrategyPtr &strategy) override;
Status InferMirrorOps() override { return SUCCESS; }
Status InferMirrorOps() override { return SUCCESS; } // the scatter_update only use in eval/predict
Status InferForwardCommunication() override { return SUCCESS; }
Status InferTensorInfo() override;
Status InferDevMatrixShape() override;


+ 21
- 7
tests/ut/python/parallel/test_scatter_update.py View File

@@ -22,13 +22,13 @@ from mindspore import context

class Net(nn.Cell):
"""Net definition"""
def __init__(self):
def __init__(self, strategy1=None, strategy2=None):
super(Net, self).__init__()
self.inputs = Parameter(Tensor(np.ones([32, 128]).astype(np.float32)), "input")
self.indices = Tensor(np.ones([4]).astype(np.int32))
self.updates = Tensor(np.ones([4, 128]).astype(np.float32))
self.scatter_update = P.ScatterUpdate().shard(((1, 8), (1,), (1, 8)))
self.add = P.TensorAdd().shard(((8, 1), (8, 1)))
self.inputs = Parameter(Tensor(np.ones([32, 64, 128]).astype(np.float32)), "input")
self.indices = Tensor(np.ones([4, 8]).astype(np.int32))
self.updates = Tensor(np.ones([4, 8, 64, 128]).astype(np.float32))
self.scatter_update = P.ScatterUpdate().shard(strategy1)
self.add = P.TensorAdd().shard(strategy2)
self.relu = P.ReLU()

def construct(self, x):
@@ -41,7 +41,21 @@ class Net(nn.Cell):
def test_distribute_predict():
context.set_context(mode=context.GRAPH_MODE, save_graphs=True)
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, full_batch=True)
inputs = Tensor(np.ones([32, 128]).astype(np.float32))
inputs = Tensor(np.ones([32, 64, 128]).astype(np.float32))
strategy1 = ((1, 2, 4), (1, 1), (1, 1, 2, 4))
strategy2 = ((1, 2, 4), (1, 2, 4))
net = Net(strategy1, strategy2)
model = Model(net)
predict_map = model.infer_predict_layout(inputs)
output = model.predict(inputs)
context.reset_auto_parallel_context()
return predict_map, output


def test_distribute_predict_auto_parallel():
context.set_context(mode=context.GRAPH_MODE, save_graphs=True)
context.set_auto_parallel_context(parallel_mode="auto_parallel", device_num=8, full_batch=True)
inputs = Tensor(np.ones([32, 64, 128]).astype(np.float32))
net = Net()
model = Model(net)
predict_map = model.infer_predict_layout(inputs)


Loading…
Cancel
Save