From 075f680a420bb86173d077125876e710cf361c3e Mon Sep 17 00:00:00 2001 From: yangzhenzhang Date: Tue, 27 Apr 2021 19:31:19 +0800 Subject: [PATCH] modify scatter update op --- .../parallel/ops_info/scatter_update_info.cc | 44 ++++++++++++++----- .../parallel/ops_info/scatter_update_info.h | 2 +- .../ut/python/parallel/test_scatter_update.py | 28 +++++++++--- 3 files changed, 55 insertions(+), 19 deletions(-) diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/scatter_update_info.cc b/mindspore/ccsrc/frontend/parallel/ops_info/scatter_update_info.cc index 0554b81598..e2046fbb85 100644 --- a/mindspore/ccsrc/frontend/parallel/ops_info/scatter_update_info.cc +++ b/mindspore/ccsrc/frontend/parallel/ops_info/scatter_update_info.cc @@ -29,9 +29,14 @@ namespace mindspore { namespace parallel { +// The first dimension of input can not be split. // The indices can not be split. -// The strategy of input and the strategy of updates must be equal. -// The first dimension of input or updates can not be split. +// The first n dimensions(n is indices' dimension size) of updates can not be split. +// The shape of input: [A, B, ..., M], the strategy of input: (1, b, ..., m) +// The shape of indices: [N, O, ..., Z], the strategy of indices: (1, 1, ..., 1) +// The shape of updates: [N, O, ..., Z, B, ..., M], the strategy of updates: (1, 1, ..., 1, b, ..., m) +// The shape of output: [A, B, ..., M], the strategy of output: (1, b, ..., m) +// The dev matrix: (1, b, ..., m) Status ScatterUpdateInfo::CheckStrategy(const StrategyPtr &strategy) { MS_EXCEPTION_IF_NULL(strategy); if (CheckStrategyValue(strategy, inputs_shape_) != SUCCESS) { @@ -45,11 +50,6 @@ Status ScatterUpdateInfo::CheckStrategy(const StrategyPtr &strategy) { return FAILED; } - if (stra[0] != stra[2]) { - MS_LOG(ERROR) << name_ << ": The strategy[0] and strategy[2] must be equal"; - return FAILED; - } - if (stra[0].empty()) { MS_LOG(ERROR) << name_ << ": The strategy[0] is empty"; return FAILED; @@ -65,6 +65,16 @@ Status ScatterUpdateInfo::CheckStrategy(const StrategyPtr &strategy) { return FAILED; } + if (stra[2].empty()) { + MS_LOG(ERROR) << name_ << ": The strategy[2] is empty"; + return FAILED; + } + + if (std::accumulate(stra[2].begin(), stra[2].begin() + stra[1].size(), 1, std::multiplies()) != 1) { + MS_LOG(ERROR) << name_ << ": The indices can not be split"; + return FAILED; + } + return SUCCESS; } @@ -81,22 +91,28 @@ Status ScatterUpdateInfo::InferDevMatrixShape() { } Status ScatterUpdateInfo::InferTensorMap() { - TensorMap input_tensor_map; - TensorMap indices_tensor_map(inputs_shape_[1].size(), MAP_NONE); if (inputs_shape_.size() != 3) { MS_LOG(ERROR) << name_ << "The size of inputs shape must be 3"; return FAILED; } + TensorMap input_tensor_map, updates_tensor_map; + TensorMap indices_tensor_map(inputs_shape_[1].size(), MAP_NONE); + // cannot use dev_matrix_shape_ replace inputs_shape_[0], because it may not be fully split in all devices. int64_t size = SizeToLong(inputs_shape_[0].size()); for (int64_t i = 0; i < size; ++i) { input_tensor_map.push_back(size - i - 1); } + // updates_tensor_map = indices_tensor_map + input_tensor_map[1:] + updates_tensor_map = indices_tensor_map; + for (size_t i = 1; i < input_tensor_map.size(); ++i) { + updates_tensor_map.push_back(input_tensor_map[i]); + } inputs_tensor_map_.push_back(input_tensor_map); // input inputs_tensor_map_.push_back(indices_tensor_map); // indices - inputs_tensor_map_.push_back(input_tensor_map); // updates + inputs_tensor_map_.push_back(updates_tensor_map); // updates outputs_tensor_map_.push_back(input_tensor_map); return SUCCESS; @@ -169,9 +185,15 @@ Status ScatterUpdateInfo::GenerateStrategies(int64_t stage_id) { Strategys tmp_strategy; Dimensions first_input_strategy = sp->GetInputDim()[0]; Dimensions indices_strategy(inputs_shape_[1].size(), 1); + // updates_strategy = indices_strategy + input_strategy[1:] + Dimensions updates_strategy = indices_strategy; + for (size_t i = 1; i < first_input_strategy.size(); ++i) { + updates_strategy.push_back(first_input_strategy[i]); + } + tmp_strategy.push_back(first_input_strategy); // input tmp_strategy.push_back(indices_strategy); // indices - tmp_strategy.push_back(first_input_strategy); // updates + tmp_strategy.push_back(updates_strategy); // updates sp->ResetInputs(tmp_strategy); } diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/scatter_update_info.h b/mindspore/ccsrc/frontend/parallel/ops_info/scatter_update_info.h index 3d14dabc3d..7c53a364d8 100644 --- a/mindspore/ccsrc/frontend/parallel/ops_info/scatter_update_info.h +++ b/mindspore/ccsrc/frontend/parallel/ops_info/scatter_update_info.h @@ -45,7 +45,7 @@ class ScatterUpdateInfo : public OperatorInfo { protected: Status GetAttrs() override { return SUCCESS; } Status CheckStrategy(const StrategyPtr &strategy) override; - Status InferMirrorOps() override { return SUCCESS; } + Status InferMirrorOps() override { return SUCCESS; } // the scatter_update only use in eval/predict Status InferForwardCommunication() override { return SUCCESS; } Status InferTensorInfo() override; Status InferDevMatrixShape() override; diff --git a/tests/ut/python/parallel/test_scatter_update.py b/tests/ut/python/parallel/test_scatter_update.py index 61579f9da2..7f201026d6 100644 --- a/tests/ut/python/parallel/test_scatter_update.py +++ b/tests/ut/python/parallel/test_scatter_update.py @@ -22,13 +22,13 @@ from mindspore import context class Net(nn.Cell): """Net definition""" - def __init__(self): + def __init__(self, strategy1=None, strategy2=None): super(Net, self).__init__() - self.inputs = Parameter(Tensor(np.ones([32, 128]).astype(np.float32)), "input") - self.indices = Tensor(np.ones([4]).astype(np.int32)) - self.updates = Tensor(np.ones([4, 128]).astype(np.float32)) - self.scatter_update = P.ScatterUpdate().shard(((1, 8), (1,), (1, 8))) - self.add = P.TensorAdd().shard(((8, 1), (8, 1))) + self.inputs = Parameter(Tensor(np.ones([32, 64, 128]).astype(np.float32)), "input") + self.indices = Tensor(np.ones([4, 8]).astype(np.int32)) + self.updates = Tensor(np.ones([4, 8, 64, 128]).astype(np.float32)) + self.scatter_update = P.ScatterUpdate().shard(strategy1) + self.add = P.TensorAdd().shard(strategy2) self.relu = P.ReLU() def construct(self, x): @@ -41,7 +41,21 @@ class Net(nn.Cell): def test_distribute_predict(): context.set_context(mode=context.GRAPH_MODE, save_graphs=True) context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, full_batch=True) - inputs = Tensor(np.ones([32, 128]).astype(np.float32)) + inputs = Tensor(np.ones([32, 64, 128]).astype(np.float32)) + strategy1 = ((1, 2, 4), (1, 1), (1, 1, 2, 4)) + strategy2 = ((1, 2, 4), (1, 2, 4)) + net = Net(strategy1, strategy2) + model = Model(net) + predict_map = model.infer_predict_layout(inputs) + output = model.predict(inputs) + context.reset_auto_parallel_context() + return predict_map, output + + +def test_distribute_predict_auto_parallel(): + context.set_context(mode=context.GRAPH_MODE, save_graphs=True) + context.set_auto_parallel_context(parallel_mode="auto_parallel", device_num=8, full_batch=True) + inputs = Tensor(np.ones([32, 64, 128]).astype(np.float32)) net = Net() model = Model(net) predict_map = model.infer_predict_layout(inputs)