diff --git a/mindspore/ccsrc/frontend/parallel/auto_parallel/operator_costmodel.h b/mindspore/ccsrc/frontend/parallel/auto_parallel/operator_costmodel.h index d1f676fa58..3fef91f290 100644 --- a/mindspore/ccsrc/frontend/parallel/auto_parallel/operator_costmodel.h +++ b/mindspore/ccsrc/frontend/parallel/auto_parallel/operator_costmodel.h @@ -185,6 +185,7 @@ using ZerosLikeCost = CastCost; using OnesLikeCost = CastCost; using RangeCost = CastCost; using SplitCost = CastCost; +using ScatterUpdateCost = CastCost; class SqrtCost : public CastCost { public: diff --git a/mindspore/ccsrc/frontend/parallel/dynamic_creator.h b/mindspore/ccsrc/frontend/parallel/dynamic_creator.h index f8a3b6e335..0e35a3f918 100644 --- a/mindspore/ccsrc/frontend/parallel/dynamic_creator.h +++ b/mindspore/ccsrc/frontend/parallel/dynamic_creator.h @@ -193,6 +193,7 @@ REGISTER(SplitInfo); REGISTER(UniqueInfo); REGISTER(GatherNdInfo); REGISTER(TopKInfo); +REGISTER(ScatterUpdateInfo); } // namespace parallel } // namespace mindspore diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/ops_info_head_files.h b/mindspore/ccsrc/frontend/parallel/ops_info/ops_info_head_files.h index 183614a19b..8e1e4d67dc 100644 --- a/mindspore/ccsrc/frontend/parallel/ops_info/ops_info_head_files.h +++ b/mindspore/ccsrc/frontend/parallel/ops_info/ops_info_head_files.h @@ -52,5 +52,6 @@ #include "frontend/parallel/ops_info/reluv2_info.h" #include "frontend/parallel/ops_info/gathernd_info.h" #include "frontend/parallel/ops_info/topk_info.h" +#include "frontend/parallel/ops_info/scatter_update_info.h" #endif // MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_HEAD_FILES_H_ diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/ops_utils.h b/mindspore/ccsrc/frontend/parallel/ops_info/ops_utils.h index 0abf529b37..322c2b719a 100644 --- a/mindspore/ccsrc/frontend/parallel/ops_info/ops_utils.h +++ b/mindspore/ccsrc/frontend/parallel/ops_info/ops_utils.h @@ -326,6 +326,7 @@ constexpr char DROPOUT[] = "Dropout"; constexpr char KStridedSlice[] = "StridedSlice"; constexpr char UNIQUE[] = "Unique"; constexpr char GATHERND[] = "GatherNd"; +constexpr char SCATTER_UPDATE[] = "ScatterUpdate"; // Parallel don't care constexpr char STRING_EQUAL[] = "string_equal"; diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/scatter_update_info.cc b/mindspore/ccsrc/frontend/parallel/ops_info/scatter_update_info.cc new file mode 100644 index 0000000000..0554b81598 --- /dev/null +++ b/mindspore/ccsrc/frontend/parallel/ops_info/scatter_update_info.cc @@ -0,0 +1,210 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "frontend/parallel/ops_info/scatter_update_info.h" + +#include +#include +#include +#include +#include + +#include "frontend/parallel/device_matrix.h" +#include "frontend/parallel/strategy.h" +#include "frontend/parallel/tensor_layout/tensor_redistribution.h" +#include "pipeline/jit/resource.h" + +namespace mindspore { +namespace parallel { +// The indices can not be split. +// The strategy of input and the strategy of updates must be equal. +// The first dimension of input or updates can not be split. +Status ScatterUpdateInfo::CheckStrategy(const StrategyPtr &strategy) { + MS_EXCEPTION_IF_NULL(strategy); + if (CheckStrategyValue(strategy, inputs_shape_) != SUCCESS) { + MS_LOG(ERROR) << name_ << ": Invalid strategy"; + return FAILED; + } + + std::vector stra = strategy->GetInputDim(); + if (stra.size() != 3) { + MS_LOG(ERROR) << name_ << ": The size of strategy must be 3"; + return FAILED; + } + + if (stra[0] != stra[2]) { + MS_LOG(ERROR) << name_ << ": The strategy[0] and strategy[2] must be equal"; + return FAILED; + } + + if (stra[0].empty()) { + MS_LOG(ERROR) << name_ << ": The strategy[0] is empty"; + return FAILED; + } + + if (stra[0][0] != 1) { + MS_LOG(ERROR) << name_ << ": The first dimension of input can not be split"; + return FAILED; + } + + if (!stra[1].empty() && std::accumulate(stra[1].begin(), stra[1].end(), 1, std::multiplies()) != 1) { + MS_LOG(ERROR) << name_ << ": The indices can not be split"; + return FAILED; + } + + return SUCCESS; +} + +Status ScatterUpdateInfo::InferDevMatrixShape() { + MS_EXCEPTION_IF_NULL(strategy_); + std::vector stra = strategy_->GetInputDim(); + if (stra.empty()) { + MS_LOG(ERROR) << name_ << "The strategy is empty"; + return FAILED; + } + + dev_matrix_shape_ = stra[0]; + return SUCCESS; +} + +Status ScatterUpdateInfo::InferTensorMap() { + TensorMap input_tensor_map; + TensorMap indices_tensor_map(inputs_shape_[1].size(), MAP_NONE); + if (inputs_shape_.size() != 3) { + MS_LOG(ERROR) << name_ << "The size of inputs shape must be 3"; + return FAILED; + } + + // cannot use dev_matrix_shape_ replace inputs_shape_[0], because it may not be fully split in all devices. + int64_t size = SizeToLong(inputs_shape_[0].size()); + for (int64_t i = 0; i < size; ++i) { + input_tensor_map.push_back(size - i - 1); + } + + inputs_tensor_map_.push_back(input_tensor_map); // input + inputs_tensor_map_.push_back(indices_tensor_map); // indices + inputs_tensor_map_.push_back(input_tensor_map); // updates + + outputs_tensor_map_.push_back(input_tensor_map); + return SUCCESS; +} + +Status ScatterUpdateInfo::InferTensorInfo() { + if (inputs_shape_.empty() || outputs_shape_.empty() || inputs_tensor_map_.empty() || outputs_tensor_map_.empty()) { + MS_LOG(ERROR) << name_ << ": Invalid args"; + return FAILED; + } + + TensorLayout input_layout, output_layout; + for (size_t i = 0; i < inputs_shape_.size(); ++i) { + // infer tensor layout + if (input_layout.InitFromVector(dev_matrix_shape_, inputs_tensor_map_[i], inputs_shape_[i]) != SUCCESS) { + MS_LOG(ERROR) << name_ << ": Infer input tensor layout failed."; + return FAILED; + } + TensorInfo input_tensor_info(input_layout); + inputs_tensor_info_.push_back(input_tensor_info); + } + + if (output_layout.InitFromVector(dev_matrix_shape_, outputs_tensor_map_[0], outputs_shape_[0]) != SUCCESS) { + MS_LOG(ERROR) << name_ << ": Infer output tensor layout failed."; + return FAILED; + } + TensorInfo output_tensor_info(output_layout); + outputs_tensor_info_.push_back(output_tensor_info); + return SUCCESS; +} + +void ScatterUpdateInfo::ReComputeBatchSplitFlagList() { + for (size_t i = 0; i < inputs_shape_.size(); i++) { + split_flag_list_[i] = false; // the first dimension can not be split + } +} + +Status ScatterUpdateInfo::SetCostUnderStrategy(const StrategyPtr &strategy) { + return SetCostUnderStrategyBase(strategy); +} + +Status ScatterUpdateInfo::GenerateStrategies(int64_t stage_id) { + if (InferAttrs() != SUCCESS) { + MS_LOG(ERROR) << name_ << ": Infer attrs failed"; + return FAILED; + } + if (inputs_shape_.empty()) { + MS_LOG(ERROR) << name_ << ": The inputs shape is empty"; + return FAILED; + } + + // to generate the first input's strategy + Shape input_split(inputs_shape_[0].size(), 1); + input_split[0] = 0; + Shapes splittable_input = {input_split}; + Shapes tmp_inputs_shape = {inputs_shape_[0]}; + + std::vector sp_vector; + if (GenerateStrategiesForIndependentInputs(stage_id, tmp_inputs_shape, splittable_input, &sp_vector) != SUCCESS) { + MS_LOG(ERROR) << name_ << ": Generate strategies failed"; + return FAILED; + } + + // the others strategies are equal to the first input's strategy + for (auto &sp : sp_vector) { + if ((sp == nullptr) || sp->GetInputDim().empty()) { + MS_LOG(ERROR) << name_ << ": The strategy is null or empty"; + return FAILED; + } + Strategys tmp_strategy; + Dimensions first_input_strategy = sp->GetInputDim()[0]; + Dimensions indices_strategy(inputs_shape_[1].size(), 1); + tmp_strategy.push_back(first_input_strategy); // input + tmp_strategy.push_back(indices_strategy); // indices + tmp_strategy.push_back(first_input_strategy); // updates + + sp->ResetInputs(tmp_strategy); + } + + size_t success = 0; + for (auto &sp : sp_vector) { + PrintStrategy(sp); + if (SetCostUnderStrategy(sp) == SUCCESS) { + success++; + MS_LOG(INFO) << name_ << ": Successfully generated " << success << " strategy."; + PrintStrategy(sp); + } + } + return SUCCESS; +} + +Status ScatterUpdateInfo::Init(const StrategyPtr &strategy) { + if (InitWithAutoRepeatCalc(strategy) != SUCCESS) { + MS_LOG(ERROR) << name_ << ": Init failed."; + return FAILED; + } + MS_LOG(INFO) << name_ << ": Init success."; + return SUCCESS; +} + +Status ScatterUpdateInfo::InitForCostModel(const StrategyPtr &strategy) { + if (InitForCostModelWithAutoRepeatCalc(strategy) != SUCCESS) { + MS_LOG(ERROR) << name_ << ": Init for cost model failed."; + return FAILED; + } + + MS_LOG(INFO) << name_ << ": Init for cost model success."; + return SUCCESS; +} +} // namespace parallel +} // namespace mindspore diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/scatter_update_info.h b/mindspore/ccsrc/frontend/parallel/ops_info/scatter_update_info.h new file mode 100644 index 0000000000..3d14dabc3d --- /dev/null +++ b/mindspore/ccsrc/frontend/parallel/ops_info/scatter_update_info.h @@ -0,0 +1,59 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_SCATTER_UPDATE_INFO_H_ +#define MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_SCATTER_UPDATE_INFO_H_ + +#include +#include +#include +#include + +#include "ir/value.h" +#include "frontend/parallel/auto_parallel/operator_costmodel.h" +#include "frontend/parallel/ops_info/operator_info.h" +#include "frontend/parallel/strategy.h" + +namespace mindspore { +namespace parallel { +class ScatterUpdateInfo : public OperatorInfo { + public: + ScatterUpdateInfo(const std::string &operator_name, const Shapes &inputs_shape, const Shapes &outputs_shape, + const PrimitiveAttrs &attrs) + : OperatorInfo(operator_name, inputs_shape, outputs_shape, attrs, std::make_shared()) {} + ~ScatterUpdateInfo() override = default; + + Status Init(const StrategyPtr &strategy) override; + Status InitForCostModel(const StrategyPtr &strategy) override; + Status GenerateStrategies(int64_t) override; + Status SetCostUnderStrategy(const StrategyPtr &) override; + void ReComputeBatchSplitFlagList() override; + + protected: + Status GetAttrs() override { return SUCCESS; } + Status CheckStrategy(const StrategyPtr &strategy) override; + Status InferMirrorOps() override { return SUCCESS; } + Status InferForwardCommunication() override { return SUCCESS; } + Status InferTensorInfo() override; + Status InferDevMatrixShape() override; + Status InferTensorMap() override; +}; + +using ScatterUpdateInfoPtr = std::shared_ptr; +} // namespace parallel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_SCATTER_UPDATE_INFO_H_ diff --git a/mindspore/ccsrc/frontend/parallel/step_auto_parallel.cc b/mindspore/ccsrc/frontend/parallel/step_auto_parallel.cc index 08bac77f9b..74f4923c50 100644 --- a/mindspore/ccsrc/frontend/parallel/step_auto_parallel.cc +++ b/mindspore/ccsrc/frontend/parallel/step_auto_parallel.cc @@ -163,7 +163,7 @@ bool IsSplittableOperator(const std::string &op_name) { BESSELI0E, BESSELI1E, FLOORMOD, ASSIGN, ASSIGN_ADD, ATAN2, DIVNONAN, LOGICALAND, LOGICALOR, ELU, RELU6, RELUV2, SOFTPLUS, SOFTSIGN, GREATEREQUAL, LESSEQUAL, LESS, APPROXIMATEEQUAL, MOD, UNIQUE, UNSORTED_SEGMENT_SUM, UNSORTED_SEGMENT_MIN, REPEAT_ELEMENTS, TENSOR_DOT, RANGE, UNIFORM_CANDIDATE_SAMPLER, SLICE, - UNSORTED_SEGMENT_MAX, GATHER_ND, TOPK}; + UNSORTED_SEGMENT_MAX, GATHER_ND, TOPK, SCATTER_UPDATE}; // clang-format on auto iter = splittable_op.find(op_name); diff --git a/tests/ut/python/parallel/test_scatter_update.py b/tests/ut/python/parallel/test_scatter_update.py new file mode 100644 index 0000000000..61579f9da2 --- /dev/null +++ b/tests/ut/python/parallel/test_scatter_update.py @@ -0,0 +1,50 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" test scatter update """ +import numpy as np +import mindspore.nn as nn +from mindspore import Tensor, Model, Parameter +from mindspore.ops import operations as P +from mindspore import context + + +class Net(nn.Cell): + """Net definition""" + def __init__(self): + super(Net, self).__init__() + self.inputs = Parameter(Tensor(np.ones([32, 128]).astype(np.float32)), "input") + self.indices = Tensor(np.ones([4]).astype(np.int32)) + self.updates = Tensor(np.ones([4, 128]).astype(np.float32)) + self.scatter_update = P.ScatterUpdate().shard(((1, 8), (1,), (1, 8))) + self.add = P.TensorAdd().shard(((8, 1), (8, 1))) + self.relu = P.ReLU() + + def construct(self, x): + out = self.scatter_update(self.inputs, self.indices, self.updates) + out = self.add(x, out) + out = self.relu(out) + return out + + +def test_distribute_predict(): + context.set_context(mode=context.GRAPH_MODE, save_graphs=True) + context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, full_batch=True) + inputs = Tensor(np.ones([32, 128]).astype(np.float32)) + net = Net() + model = Model(net) + predict_map = model.infer_predict_layout(inputs) + output = model.predict(inputs) + context.reset_auto_parallel_context() + return predict_map, output