Merge pull request !221 from chentingting/support_distributd_GatherV2_operatortags/v0.2.0-alpha
| @@ -623,5 +623,34 @@ double DropOutCost::GetForwardComputationCost(const std::vector<TensorInfo>& inp | |||||
| Shape input0_slice_shape = input0.slice_shape(); | Shape input0_slice_shape = input0.slice_shape(); | ||||
| return ListProduct(input0_slice_shape) * static_cast<double>(inputs_type_lengths_[0]) * DROPOUT_COST_RATE; | return ListProduct(input0_slice_shape) * static_cast<double>(inputs_type_lengths_[0]) * DROPOUT_COST_RATE; | ||||
| } | } | ||||
| // return the per device communication cost in the forward phase. | |||||
| double GatherV2Cost::GetForwardCommCost(const std::vector<TensorInfo>&, const std::vector<TensorInfo>&, | |||||
| const int32_t&) const { | |||||
| // GatherV2Cost does not need communication in the forward phase | |||||
| return 0.0; | |||||
| } | |||||
| // return the per device communication cost in the backward phase. | |||||
| double GatherV2Cost::GetBackwardCommCost(const std::vector<TensorInfo>&, const std::vector<TensorInfo>&, | |||||
| const int32_t&) const { | |||||
| // GatherV2Cost does not need communication in the backward phase | |||||
| return 0.0; | |||||
| } | |||||
| double GatherV2Cost::GetForwardComputationCost(const std::vector<TensorInfo>& inputs, const std::vector<TensorInfo>&, | |||||
| const int32_t&) const { | |||||
| // In forward phase, the computation cost = slice(A) + slice(B) | |||||
| Shape input0_slice_shape = inputs[0].slice_shape(); | |||||
| Shape input1_slice_shape = inputs[1].slice_shape(); | |||||
| double result = ListProduct(input0_slice_shape) * static_cast<double>(inputs_type_lengths_[0]) + | |||||
| ListProduct(input1_slice_shape) * static_cast<double>(inputs_type_lengths_[1]); | |||||
| return result; | |||||
| } | |||||
| double GatherV2Cost::GetBackwardComputationCost(const std::vector<TensorInfo>&, const std::vector<TensorInfo>&, | |||||
| const int32_t&) const { | |||||
| return 0.0; | |||||
| } | |||||
| } // namespace parallel | } // namespace parallel | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -81,6 +81,8 @@ class OperatorCost { | |||||
| std::vector<size_t> outputs_type_lengths_; | std::vector<size_t> outputs_type_lengths_; | ||||
| }; | }; | ||||
| using OperatorCostPtr = std::shared_ptr<OperatorCost>; | |||||
| class MatMulCost : public OperatorCost { | class MatMulCost : public OperatorCost { | ||||
| public: | public: | ||||
| MatMulCost() = default; | MatMulCost() = default; | ||||
| @@ -525,6 +527,31 @@ class DropOutCost : public OperatorCost { | |||||
| }; | }; | ||||
| using DropOutCostPtr = std::shared_ptr<DropOutCost>; | using DropOutCostPtr = std::shared_ptr<DropOutCost>; | ||||
| class GatherV2Cost : public OperatorCost { | |||||
| public: | |||||
| GatherV2Cost() = default; | |||||
| ~GatherV2Cost() override = default; | |||||
| double GetCommCost(const std::vector<TensorInfo>& inputs, const std::vector<TensorInfo>& outputs, | |||||
| const int32_t& stage_id) const override { | |||||
| return GetForwardCommCost(inputs, outputs, stage_id) + GetBackwardCommCost(inputs, outputs, stage_id); | |||||
| } | |||||
| double GetForwardCommCost(const std::vector<TensorInfo>& inputs, const std::vector<TensorInfo>& outputs, | |||||
| const int32_t& stage_id) const override; | |||||
| double GetBackwardCommCost(const std::vector<TensorInfo>& inputs, const std::vector<TensorInfo>& outputs, | |||||
| const int32_t& stage_id) const override; | |||||
| double GetComputationCost(const std::vector<TensorInfo>& inputs, const std::vector<TensorInfo>& outputs, | |||||
| const int32_t& stage_id) const override { | |||||
| return GetForwardComputationCost(inputs, outputs, stage_id) + GetBackwardComputationCost(inputs, outputs, stage_id); | |||||
| } | |||||
| double GetForwardComputationCost(const std::vector<TensorInfo>& inputs, const std::vector<TensorInfo>& outputs, | |||||
| const int32_t& stage_id) const override; | |||||
| double GetBackwardComputationCost(const std::vector<TensorInfo>& inputs, const std::vector<TensorInfo>& outputs, | |||||
| const int32_t&) const override; | |||||
| }; | |||||
| using GatherV2CostPtr = std::shared_ptr<GatherV2Cost>; | |||||
| } // namespace parallel | } // namespace parallel | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| #endif // PARALLEL_AUTO_PARALLEL_OPERATOR_COSTMODEL_H_ | #endif // PARALLEL_AUTO_PARALLEL_OPERATOR_COSTMODEL_H_ | ||||
| @@ -228,26 +228,6 @@ void SparseSoftmaxCrossEntropyWithLogitsInfo::ReComputeBatchSplitFlagList() { | |||||
| } | } | ||||
| } | } | ||||
| void GatherV2Info::ReComputeBatchSplitFlagList() { | |||||
| MS_ASSERT(inputs_shape_.size() == 2); | |||||
| MS_ASSERT(input_value_.size() == 3); | |||||
| MS_ASSERT(input_value_[0] == nullptr); | |||||
| // the second input is the index tensor | |||||
| MS_ASSERT(input_value_[1] != nullptr); | |||||
| // the third input is the axis | |||||
| MS_ASSERT(input_value_[2] != nullptr); | |||||
| int axis = GetValue<int>(input_value_[2]); | |||||
| MS_ASSERT(axis < inputs_shape_[0].size() && axis >= 0 - inputs_shape_[0].size()); | |||||
| if (axis < 0) { | |||||
| axis += SizeToInt(inputs_shape_[0].size()); | |||||
| } | |||||
| split_flag_list_[0] = true; | |||||
| // if gather axis is 0, the index's strategy is equal to device number | |||||
| if (axis == 0) { | |||||
| split_flag_list_[1] = true; | |||||
| } | |||||
| } | |||||
| Status BatchParallelInfo::InferAsLossDivisor() { | Status BatchParallelInfo::InferAsLossDivisor() { | ||||
| as_loss_divisor_ = 1; | as_loss_divisor_ = 1; | ||||
| return SUCCESS; | return SUCCESS; | ||||
| @@ -62,15 +62,6 @@ class SparseSoftmaxCrossEntropyWithLogitsInfo : public BatchParallelInfo { | |||||
| ~SparseSoftmaxCrossEntropyWithLogitsInfo() override = default; | ~SparseSoftmaxCrossEntropyWithLogitsInfo() override = default; | ||||
| void ReComputeBatchSplitFlagList() override; | void ReComputeBatchSplitFlagList() override; | ||||
| }; | }; | ||||
| class GatherV2Info : public BatchParallelInfo { | |||||
| public: | |||||
| GatherV2Info(const std::string& name, const Shapes& inputs_shape, const Shapes& outputs_shape, | |||||
| const PrimitiveAttrs& attrs) | |||||
| : BatchParallelInfo(name, inputs_shape, outputs_shape, attrs) {} | |||||
| ~GatherV2Info() override = default; | |||||
| void ReComputeBatchSplitFlagList() override; | |||||
| }; | |||||
| } // namespace parallel | } // namespace parallel | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -0,0 +1,350 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "parallel/ops_info/gather_v2_info.h" | |||||
| #include <memory> | |||||
| #include <utility> | |||||
| #include <vector> | |||||
| #include "ir/meta_tensor.h" | |||||
| #include "ir/value.h" | |||||
| #include "parallel/auto_parallel/costmodel.h" | |||||
| #include "parallel/device_matrix.h" | |||||
| #include "parallel/graph_util/generate_graph.h" | |||||
| #include "parallel/strategy.h" | |||||
| #include "utils/log_adapter.h" | |||||
| namespace mindspore { | |||||
| namespace parallel { | |||||
| Status GatherV2Info::GetAttrs() { | |||||
| if (inputs_shape_.size() != GATHER_V2_INPUTS_SIZE) { | |||||
| MS_LOG(ERROR) << name_ << ": inputs shape size must be 2, but is " << inputs_shape_.size(); | |||||
| return FAILED; | |||||
| } | |||||
| if (outputs_shape_.size() != GATHER_V2_OUTPUTS_SIZE) { | |||||
| MS_LOG(ERROR) << name_ << ": outputs shape size must be 1, but is " << outputs_shape_.size(); | |||||
| return FAILED; | |||||
| } | |||||
| if (input_value_.size() != GATHER_V2_INPUTS_VALUE_SIZE) { | |||||
| MS_LOG(ERROR) << name_ << ": input value size must be 3, but is " << input_value_.size(); | |||||
| return FAILED; | |||||
| } | |||||
| // the second input is the index tensor | |||||
| // the third input is the axis, is a ValueNode | |||||
| if (input_value_.at(2) == nullptr) { | |||||
| MS_LOG(ERROR) << name_ << ": the third input value is nullptr, is not a ValueNode!"; | |||||
| return FAILED; | |||||
| } | |||||
| if (inputs_shape_.at(0).size() == 0) { | |||||
| MS_LOG(ERROR) << name_ << ": input can not be a scalar!"; | |||||
| return FAILED; | |||||
| } | |||||
| int axis = GetValue<int>(input_value_.at(2)); | |||||
| if (axis >= SizeToInt(inputs_shape_.at(0).size()) || axis < 0 - SizeToInt(inputs_shape_.at(0).size())) { | |||||
| MS_LOG(ERROR) << "Axis is " << axis << ", not in [-" << inputs_shape_.at(0).size() << ", " | |||||
| << inputs_shape_.at(0).size() << ")."; | |||||
| } | |||||
| if (axis < 0) { | |||||
| axis += SizeToInt(inputs_shape_[0].size()); | |||||
| } | |||||
| axis_ = axis; | |||||
| index_size_ = inputs_shape_.at(1).size(); | |||||
| return SUCCESS; | |||||
| } | |||||
| Status GatherV2Info::CheckStrategy(const StrategyPtr& strategy) { | |||||
| if (inputs_shape_.size() != GATHER_V2_INPUTS_SIZE) { | |||||
| MS_LOG(ERROR) << name_ << ": inputs shape size must be " << GATHER_V2_INPUTS_SIZE << ", but is " | |||||
| << inputs_shape_.size(); | |||||
| return FAILED; | |||||
| } | |||||
| if (outputs_shape_.size() != GATHER_V2_OUTPUTS_SIZE) { | |||||
| MS_LOG(ERROR) << name_ << ": outputs shape size must be " << GATHER_V2_OUTPUTS_SIZE << ", but is " | |||||
| << outputs_shape_.size(); | |||||
| return FAILED; | |||||
| } | |||||
| // Only strategy of the first input should be set. | |||||
| if (CheckStrategyValue(strategy, {inputs_shape_.at(0)}, is_auto_parallel_) != SUCCESS) { | |||||
| if (is_auto_parallel_) { | |||||
| MS_LOG(DEBUG) << name_ << ": Invalid strategy."; | |||||
| } else { | |||||
| MS_LOG(ERROR) << name_ << ": Invalid strategy."; | |||||
| } | |||||
| return FAILED; | |||||
| } | |||||
| axis_strategy_ = strategy->GetInputDim().at(0).at(axis_); | |||||
| if (index_size_ != 1 && axis_strategy_ != 1) { | |||||
| MS_LOG(ERROR) << name_ | |||||
| << ": Invalid strategy. If the index is a scalar or a more than 1 dimension vector, the strategy " | |||||
| "corresponding to axis must be 1, but is " | |||||
| << axis_strategy_; | |||||
| return FAILED; | |||||
| } | |||||
| if (index_size_ == 1 && axis_strategy_ != 1 && inputs_shape_.at(1).at(0) % axis_strategy_ != 0) { | |||||
| MS_LOG(ERROR) << name_ | |||||
| << ": Invalid strategy. The first dimension of index can not be divided by strategy corresponding to " | |||||
| "axis. The first dimension of index is " | |||||
| << inputs_shape_.at(1).at(0) << " strategy corresponding to axis is " << axis_strategy_; | |||||
| return FAILED; | |||||
| } | |||||
| return SUCCESS; | |||||
| } | |||||
| Status GatherV2Info::InferDevMatrixShape() { | |||||
| std::vector<Dimensions> stra = strategy_->GetInputDim(); | |||||
| dev_matrix_shape_ = stra.at(0); | |||||
| return SUCCESS; | |||||
| } | |||||
| // If index is a scalar, output dimension is input dimension minus 1; | |||||
| // If index is a n dimension tensor, output dimension is input dimension plus (n - 1). | |||||
| // Tensor map dimension is equal to the corresponding input and output dimension. | |||||
| // If index's dimension is more than 1, we insert -1 for the output tensor map. | |||||
| Status GatherV2Info::InferTensorMap() { | |||||
| if (inputs_shape_.size() != GATHER_V2_INPUTS_SIZE) { | |||||
| MS_LOG(ERROR) << name_ << ": inputs shape size must be " << GATHER_V2_INPUTS_SIZE << ", but is " | |||||
| << inputs_shape_.size(); | |||||
| return FAILED; | |||||
| } | |||||
| if (outputs_shape_.size() != GATHER_V2_OUTPUTS_SIZE) { | |||||
| MS_LOG(ERROR) << name_ << ": outputs shape size must be " << GATHER_V2_OUTPUTS_SIZE << ", but is " | |||||
| << outputs_shape_.size(); | |||||
| return FAILED; | |||||
| } | |||||
| std::vector<int32_t> tensor_map_in; | |||||
| std::vector<int32_t> tensor_map_out; | |||||
| size_t size = inputs_shape_.at(0).size(); | |||||
| // such as 4: tensor_map_index [3,2,1,0] | |||||
| for (size_t i = 0; i < size; ++i) { | |||||
| tensor_map_in.push_back(SizeToInt(size - i - 1)); | |||||
| tensor_map_out.push_back(SizeToInt(size - i - 1)); | |||||
| } | |||||
| if (index_size_ == 0) { | |||||
| (void)tensor_map_out.erase(tensor_map_out.begin() + axis_); | |||||
| } else if (index_size_ > 1) { | |||||
| (void)tensor_map_out.insert(tensor_map_out.begin() + axis_, index_size_ - 1, -1); | |||||
| } | |||||
| if (tensor_map_out.size() != outputs_shape_.at(0).size()) { | |||||
| MS_LOG(ERROR) << "Out tensor map size is not equal to output size! Out tensor map size is " << tensor_map_out.size() | |||||
| << " output size is " << outputs_shape_.at(0).size(); | |||||
| return FAILED; | |||||
| } | |||||
| std::vector<int32_t> tensor_map_in_index; | |||||
| if (index_size_ >= 1) { | |||||
| tensor_map_in_index.push_back(SizeToInt(size - axis_ - 1)); | |||||
| } | |||||
| for (size_t i = 1; i < index_size_; ++i) { | |||||
| tensor_map_in_index.push_back(-1); | |||||
| } | |||||
| inputs_tensor_map_.emplace_back(std::move(tensor_map_in)); | |||||
| inputs_tensor_map_.emplace_back(std::move(tensor_map_in_index)); | |||||
| outputs_tensor_map_.emplace_back(std::move(tensor_map_out)); | |||||
| return SUCCESS; | |||||
| } | |||||
| Status GatherV2Info::InferTensorInfo() { | |||||
| if (inputs_shape_.size() != GATHER_V2_INPUTS_SIZE) { | |||||
| MS_LOG(ERROR) << name_ << ": inputs shape size must be " << GATHER_V2_INPUTS_SIZE << ", but is " | |||||
| << inputs_shape_.size(); | |||||
| return FAILED; | |||||
| } | |||||
| if (outputs_shape_.size() != GATHER_V2_OUTPUTS_SIZE) { | |||||
| MS_LOG(ERROR) << name_ << ": outputs shape size must be " << GATHER_V2_OUTPUTS_SIZE << ", but is " | |||||
| << outputs_shape_.size(); | |||||
| return FAILED; | |||||
| } | |||||
| if (inputs_tensor_map_.size() != GATHER_V2_INPUTS_SIZE) { | |||||
| MS_LOG(ERROR) << name_ << ": inputs tensor map size must be " << GATHER_V2_INPUTS_SIZE << ", but is " | |||||
| << inputs_tensor_map_.size(); | |||||
| return FAILED; | |||||
| } | |||||
| if (outputs_tensor_map_.size() != GATHER_V2_OUTPUTS_SIZE) { | |||||
| MS_LOG(ERROR) << name_ << ": outputs tensor map size must be " << GATHER_V2_OUTPUTS_SIZE << ", but is " | |||||
| << outputs_tensor_map_.size(); | |||||
| return FAILED; | |||||
| } | |||||
| // infer tensor shape | |||||
| Shape input_shape = inputs_shape_.at(0); | |||||
| Shape input_index_shape = inputs_shape_.at(1); | |||||
| Shape output_shape = outputs_shape_.at(0); | |||||
| TensorLayout input_tensor_layout, input_index_layout, output_tensor_layout; | |||||
| if ((input_tensor_layout.InitFromVector(dev_matrix_shape_, inputs_tensor_map_.at(0), input_shape) != SUCCESS) || | |||||
| (input_index_layout.InitFromVector(dev_matrix_shape_, inputs_tensor_map_.at(1), input_index_shape) != SUCCESS) || | |||||
| (output_tensor_layout.InitFromVector(dev_matrix_shape_, outputs_tensor_map_.at(0), output_shape) != SUCCESS)) { | |||||
| return FAILED; | |||||
| } | |||||
| TensorInfo input_tensor_info(input_tensor_layout); | |||||
| TensorInfo input_index_info(input_index_layout); | |||||
| TensorInfo output_tensor_info(output_tensor_layout); | |||||
| inputs_tensor_info_.push_back(input_tensor_info); | |||||
| inputs_tensor_info_.push_back(input_index_info); | |||||
| outputs_tensor_info_.push_back(output_tensor_info); | |||||
| return SUCCESS; | |||||
| } | |||||
| OperatorVector CreateSubOp(int32_t sub_value) { | |||||
| OperatorVector ops; | |||||
| OperatorName operator_name = SUB; | |||||
| OperatorAttrs operator_attrs; | |||||
| py::tuple tuple = py::make_tuple(sub_value); | |||||
| mindspore::tensor::TensorPtr tensor_ptr = std::make_shared<mindspore::tensor::Tensor>(tuple, kInt32); | |||||
| ValuePtr op_param_value = MakeValue(tensor_ptr); | |||||
| Attr op1_param = std::make_pair("", op_param_value); | |||||
| OperatorParams operator_param = {std::make_pair(op1_param, 2)}; | |||||
| OperatorArgs operator_args = std::make_pair(operator_attrs, operator_param); | |||||
| Operator op = std::make_pair(operator_name, operator_args); | |||||
| ops.push_back(op); | |||||
| return ops; | |||||
| } | |||||
| Status GatherV2Info::InferTensorSubOps() { | |||||
| sub_ops_.clear(); | |||||
| if ((index_size_ == 0) || (axis_strategy_ == 1)) { | |||||
| return SUCCESS; | |||||
| } | |||||
| int32_t mod_n = 1; | |||||
| for (size_t i = IntToSize(axis_) + 1; i < dev_matrix_shape_.size(); i++) { | |||||
| mod_n *= dev_matrix_shape_.at(i); | |||||
| } | |||||
| if ((axis_ >= SizeToInt(dev_matrix_shape_.size())) || axis_ < 0) { | |||||
| MS_LOG(ERROR) << "Axis is " << axis_ << ", not in [0, " << dev_matrix_shape_.size() << ")."; | |||||
| } | |||||
| int32_t mod_p = mod_n * dev_matrix_shape_.at(axis_); | |||||
| int32_t rank = g_device_manager->global_rank(); | |||||
| int32_t mod_rank = rank % mod_p; | |||||
| mod_rank = static_cast<int32_t>(mod_rank / mod_n); | |||||
| if (inputs_shape_.size() != GATHER_V2_INPUTS_SIZE) { | |||||
| MS_LOG(ERROR) << name_ << ": inputs shape size must be " << GATHER_V2_INPUTS_SIZE << ", but is " | |||||
| << inputs_shape_.size(); | |||||
| return FAILED; | |||||
| } | |||||
| if ((axis_ >= SizeToInt(inputs_shape_.at(0).size())) || axis_ < 0) { | |||||
| MS_LOG(ERROR) << "Axis is " << axis_ << ", not in [0, " << inputs_shape_.at(0).size() << ")."; | |||||
| } | |||||
| int32_t sub_value = static_cast<int32_t>(inputs_shape_.at(0).at(axis_) / dev_matrix_shape_.at(axis_)) * mod_rank; | |||||
| OperatorVector sub_op; | |||||
| sub_ops_.emplace_back(std::move(sub_op)); | |||||
| sub_op = CreateSubOp(sub_value); | |||||
| sub_ops_.emplace_back(std::move(sub_op)); | |||||
| return SUCCESS; | |||||
| } | |||||
| Status GatherV2Info::Init(const StrategyPtr& strategy) { | |||||
| if (InitWithAutoRepeatCalc(strategy) != SUCCESS) { | |||||
| MS_LOG(ERROR) << name_ << ": Init failed."; | |||||
| return FAILED; | |||||
| } | |||||
| Status status = InferTensorSubOps(); | |||||
| if (status != SUCCESS) { | |||||
| MS_LOG(ERROR) << name_ << ": InferTensorSubOps failed."; | |||||
| return status; | |||||
| } | |||||
| MS_LOG(INFO) << name_ << ": Init success."; | |||||
| return SUCCESS; | |||||
| } | |||||
| Status GatherV2Info::InitForCostModel(const StrategyPtr& strategy) { | |||||
| if (InitForCostModelWithAutoRepeatCalc(strategy) != SUCCESS) { | |||||
| if (is_auto_parallel_) { | |||||
| MS_LOG(DEBUG) << name_ << ": Init for cost model failed."; | |||||
| } else { | |||||
| MS_LOG(ERROR) << name_ << ": Init for cost model failed."; | |||||
| } | |||||
| return FAILED; | |||||
| } | |||||
| MS_LOG(INFO) << name_ << ": Init for cost model success."; | |||||
| return SUCCESS; | |||||
| } | |||||
| Status GatherV2Info::GenerateStrategies(int32_t stage_id) { | |||||
| if ((inputs_shape_.size() != GATHER_V2_INPUTS_SIZE) || (outputs_shape_.size() != GATHER_V2_OUTPUTS_SIZE)) { | |||||
| MS_LOG(ERROR) << name_ << " : Inputs shape size(" << inputs_shape_.size() << ") or outputs shape size(" | |||||
| << outputs_shape_.size() << "is wrong."; | |||||
| return FAILED; | |||||
| } | |||||
| is_auto_parallel_ = true; | |||||
| Shape input0_split(inputs_shape_[0].size()); | |||||
| Shapes splittable_inputs = {input0_split}; | |||||
| std::vector<StrategyPtr> sp_vector; | |||||
| if (GenerateStrategiesForIndependentInputs(stage_id, {inputs_shape_.at(0)}, splittable_inputs, &sp_vector) != | |||||
| SUCCESS) { | |||||
| MS_LOG(ERROR) << name_ << " : Generate strategies for independent inputs() failed."; | |||||
| return FAILED; | |||||
| } | |||||
| size_t success = 0; | |||||
| for (auto& sp : sp_vector) { | |||||
| if (SetCostUnderStrategy(sp) == SUCCESS) { | |||||
| success++; | |||||
| MS_LOG(INFO) << name_ << " : Successfully generated " << success << " strategy"; | |||||
| PrintStrategy(sp); | |||||
| } | |||||
| } | |||||
| return SUCCESS; | |||||
| } | |||||
| Status GatherV2Info::SetCostUnderStrategy(const StrategyPtr& strategy) { | |||||
| if (SetCostUnderStrategyBase(strategy) != SUCCESS) { | |||||
| if (is_auto_parallel_) { | |||||
| MS_LOG(DEBUG) << name_ << ": Set cost under strategy failed."; | |||||
| } else { | |||||
| MS_LOG(ERROR) << name_ << ": Set cost under strategy failed."; | |||||
| } | |||||
| return FAILED; | |||||
| } | |||||
| return SUCCESS; | |||||
| } | |||||
| std::shared_ptr<std::vector<std::vector<int32_t>>> GatherV2Info::GenerateBatchStrategies() { | |||||
| if (inputs_shape_.size() != GATHER_V2_INPUTS_SIZE) { | |||||
| MS_LOG(EXCEPTION) << name_ << ": inputs shape size must be " << GATHER_V2_INPUTS_SIZE << ", but is " | |||||
| << inputs_shape_.size(); | |||||
| } | |||||
| CheckGlobalDeviceManager(); | |||||
| size_t dev_num = g_device_manager->GetDeviceListByStageId(0).size(); | |||||
| if (GetAttrs() != SUCCESS) { | |||||
| MS_LOG(EXCEPTION) << "GetAttrs failed!"; | |||||
| } | |||||
| Dimensions strategy; | |||||
| if (index_size_ != 1) { | |||||
| strategy.push_back(1); | |||||
| } else { | |||||
| strategy.push_back(SizeToInt(dev_num)); | |||||
| } | |||||
| for (size_t i = 1; i < inputs_shape_[0].size(); i++) { | |||||
| strategy.push_back(1); | |||||
| } | |||||
| std::vector<Dimensions> strategy_v = {strategy}; | |||||
| return std::make_shared<std::vector<std::vector<int32_t>>>(strategy_v); | |||||
| } | |||||
| } // namespace parallel | |||||
| } // namespace mindspore | |||||
| @@ -0,0 +1,73 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_CCSRC_PARALLEL_OPS_INFO_GATHER_V2_INFO_H_ | |||||
| #define MINDSPORE_CCSRC_PARALLEL_OPS_INFO_GATHER_V2_INFO_H_ | |||||
| #include <memory> | |||||
| #include <string> | |||||
| #include <unordered_map> | |||||
| #include <vector> | |||||
| #include "ir/value.h" | |||||
| #include "parallel/auto_parallel/operator_costmodel.h" | |||||
| #include "parallel/ops_info/operator_info.h" | |||||
| #include "parallel/strategy.h" | |||||
| namespace mindspore { | |||||
| namespace parallel { | |||||
| constexpr size_t GATHER_V2_INPUTS_SIZE = 2; | |||||
| constexpr size_t GATHER_V2_OUTPUTS_SIZE = 1; | |||||
| constexpr size_t GATHER_V2_INPUTS_VALUE_SIZE = 3; | |||||
| // We now supported limited parallel strategies. | |||||
| // If the strategy corresponding to axis is more than 1, index must be evenly distributed across the axis-dimension of | |||||
| // the input. | |||||
| // If Index is a scalar or n-dimension vector(n > 1), the strategy corresponding to axis must be 1. | |||||
| class GatherV2Info : public OperatorInfo { | |||||
| public: | |||||
| GatherV2Info(const std::string& name, const Shapes& inputs_shape, const Shapes& outputs_shape, | |||||
| const PrimitiveAttrs& attrs) | |||||
| : OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared<GatherV2Cost>()), | |||||
| axis_(-1), | |||||
| index_size_(0), | |||||
| axis_strategy_(1) {} | |||||
| ~GatherV2Info() override = default; | |||||
| Status Init(const StrategyPtr& strategy) override; | |||||
| Status InitForCostModel(const StrategyPtr& strategy) override; | |||||
| Status GenerateStrategies(int32_t stage_id) override; | |||||
| Status SetCostUnderStrategy(const StrategyPtr& strategy) override; | |||||
| std::shared_ptr<std::vector<std::vector<int32_t>>> GenerateBatchStrategies() override; | |||||
| protected: | |||||
| Status CheckStrategy(const StrategyPtr& strategy) override; | |||||
| Status InferMirrorOps() override { return SUCCESS; } | |||||
| Status InferForwardCommunication() override { return SUCCESS; } | |||||
| Status InferTensorInfo() override; | |||||
| Status InferDevMatrixShape() override; | |||||
| Status InferTensorMap() override; | |||||
| Status GetAttrs() override; | |||||
| private: | |||||
| Status InferTensorSubOps(); | |||||
| int32_t axis_; | |||||
| size_t index_size_; | |||||
| int32_t axis_strategy_; | |||||
| }; | |||||
| } // namespace parallel | |||||
| } // namespace mindspore | |||||
| #endif // MINDSPORE_CCSRC_PARALLEL_OPS_INFO_GATHER_V2_INFO_H_ | |||||
| @@ -112,6 +112,7 @@ void OperatorInfo::ResetQueueMember() { | |||||
| dev_matrix_shape_.clear(); | dev_matrix_shape_.clear(); | ||||
| forward_op_.clear(); | forward_op_.clear(); | ||||
| mirror_ops_.clear(); | mirror_ops_.clear(); | ||||
| sub_ops_.clear(); | |||||
| replace_op_.clear(); | replace_op_.clear(); | ||||
| replace_op_info_.clear(); | replace_op_info_.clear(); | ||||
| virtual_div_op_.clear(); | virtual_div_op_.clear(); | ||||
| @@ -41,6 +41,7 @@ namespace mindspore { | |||||
| namespace parallel { | namespace parallel { | ||||
| using ForwardOp = OperatorVector; | using ForwardOp = OperatorVector; | ||||
| using MirrorOps = std::vector<OperatorVector>; | using MirrorOps = std::vector<OperatorVector>; | ||||
| using Ops = std::vector<OperatorVector>; | |||||
| using VirtualDivOp = OperatorVector; | using VirtualDivOp = OperatorVector; | ||||
| using TensorMaps = std::vector<std::vector<int32_t>>; | using TensorMaps = std::vector<std::vector<int32_t>>; | ||||
| using TensorLayouts = std::vector<TensorLayout>; | using TensorLayouts = std::vector<TensorLayout>; | ||||
| @@ -99,6 +100,7 @@ class OperatorInfo { | |||||
| OutPutInfoVector replace_op_info() const { return replace_op_info_; } | OutPutInfoVector replace_op_info() const { return replace_op_info_; } | ||||
| virtual ReplaceGraphPtr replace_graph(const CNodePtr&) { return replace_graph_; } | virtual ReplaceGraphPtr replace_graph(const CNodePtr&) { return replace_graph_; } | ||||
| MirrorOps mirror_ops() const { return mirror_ops_; } | MirrorOps mirror_ops() const { return mirror_ops_; } | ||||
| Ops sub_ops() const { return sub_ops_; } | |||||
| VirtualDivOp virtual_div_op() const { return virtual_div_op_; } | VirtualDivOp virtual_div_op() const { return virtual_div_op_; } | ||||
| Shape dev_matrix_shape() const { return dev_matrix_shape_; } | Shape dev_matrix_shape() const { return dev_matrix_shape_; } | ||||
| std::vector<TensorInfo> inputs_tensor_info() const { return inputs_tensor_info_; } | std::vector<TensorInfo> inputs_tensor_info() const { return inputs_tensor_info_; } | ||||
| @@ -190,6 +192,7 @@ class OperatorInfo { | |||||
| TensorMaps inputs_tensor_map_; | TensorMaps inputs_tensor_map_; | ||||
| TensorMaps outputs_tensor_map_; | TensorMaps outputs_tensor_map_; | ||||
| ForwardOp forward_op_; | ForwardOp forward_op_; | ||||
| Ops sub_ops_; | |||||
| ForwardOp replace_op_; | ForwardOp replace_op_; | ||||
| OutPutInfoVector replace_op_info_; | OutPutInfoVector replace_op_info_; | ||||
| ReplaceGraphPtr replace_graph_; | ReplaceGraphPtr replace_graph_; | ||||
| @@ -24,6 +24,7 @@ | |||||
| #include "parallel/ops_info/comparison_function_info.h" | #include "parallel/ops_info/comparison_function_info.h" | ||||
| #include "parallel/ops_info/dropout_do_mask_info.h" | #include "parallel/ops_info/dropout_do_mask_info.h" | ||||
| #include "parallel/ops_info/elementary_function_info.h" | #include "parallel/ops_info/elementary_function_info.h" | ||||
| #include "parallel/ops_info/gather_v2_info.h" | |||||
| #include "parallel/ops_info/get_next_info.h" | #include "parallel/ops_info/get_next_info.h" | ||||
| #include "parallel/ops_info/l2_normalize_info.h" | #include "parallel/ops_info/l2_normalize_info.h" | ||||
| #include "parallel/ops_info/loss_info.h" | #include "parallel/ops_info/loss_info.h" | ||||
| @@ -464,6 +464,14 @@ void SplitTensor(const AnfNodePtr& node, const CNodePtr& next_node, int index) { | |||||
| MS_EXCEPTION_IF_NULL(func_graph); | MS_EXCEPTION_IF_NULL(func_graph); | ||||
| Operator op = CreateGetTensorSliceOp(tensor_layout); | Operator op = CreateGetTensorSliceOp(tensor_layout); | ||||
| InsertGetTensorSliceOp(op, next_node, func_graph, index, SPLIT_TENSOR); | InsertGetTensorSliceOp(op, next_node, func_graph, index, SPLIT_TENSOR); | ||||
| if (!op_info->sub_ops().empty()) { | |||||
| auto sub_ops = op_info->sub_ops(); | |||||
| for (size_t i = 0; i < sub_ops.size(); i++) { | |||||
| if (!sub_ops.at(i).empty()) { | |||||
| InsertGetTensorSliceOp(sub_ops.at(i).at(0), next_node, func_graph, index, SUB); | |||||
| } | |||||
| } | |||||
| } | |||||
| } | } | ||||
| void StepSplitTensor(const AnfNodePtr& node, const FuncGraphManagerPtr& manager) { | void StepSplitTensor(const AnfNodePtr& node, const FuncGraphManagerPtr& manager) { | ||||
| @@ -29,6 +29,8 @@ from mindspore.nn import Dense, Cell | |||||
| from mindspore import context | from mindspore import context | ||||
| context.set_context(mode=context.GRAPH_MODE) | context.set_context(mode=context.GRAPH_MODE) | ||||
| device_number = 32 | |||||
| batch_size_per_device = 128 | |||||
| class Dataset(): | class Dataset(): | ||||
| @@ -57,15 +59,22 @@ class Dataset(): | |||||
| class GatherV2(_Loss): | class GatherV2(_Loss): | ||||
| def __init__(self, batchsize): | |||||
| def __init__(self, index_dim, strategy, index_size=16): | |||||
| super(GatherV2, self).__init__() | super(GatherV2, self).__init__() | ||||
| self.pow = P.Pow() | self.pow = P.Pow() | ||||
| emb_list = list(range(batchsize)) | |||||
| emb1_list = emb_list[0::2] | |||||
| emb2_list = emb_list[1::2] | |||||
| emb1_list = 21 | |||||
| emb2_list = 2 | |||||
| if index_dim == 1: | |||||
| emb_list = list(range(index_size)) | |||||
| emb1_list = emb_list[0::2] | |||||
| emb2_list = emb_list[1::2] | |||||
| if index_dim == 2: | |||||
| emb_list = np.arange(index_size*16) | |||||
| emb1_list = np.reshape(emb_list[0::2], (int(index_size/2), 16)) | |||||
| emb2_list = np.reshape(emb_list[1::2], (int(index_size/2), 16)) | |||||
| self.emb1_param = Tensor(emb1_list, dtype=mstype.int32) | self.emb1_param = Tensor(emb1_list, dtype=mstype.int32) | ||||
| self.emb2_param = Tensor(emb2_list, dtype=mstype.int32) | self.emb2_param = Tensor(emb2_list, dtype=mstype.int32) | ||||
| self.gatherv2 = P.GatherV2() | |||||
| self.gatherv2 = P.GatherV2().set_strategy(strategy) | |||||
| def construct(self, nembeddings): | def construct(self, nembeddings): | ||||
| emb1 = self.gatherv2(nembeddings, self.emb1_param, 0) | emb1 = self.gatherv2(nembeddings, self.emb1_param, 0) | ||||
| @@ -73,10 +82,6 @@ class GatherV2(_Loss): | |||||
| return self.pow((emb1 - emb2), 2.0) | return self.pow((emb1 - emb2), 2.0) | ||||
| def get_loss(batchsize): | |||||
| return GatherV2(batchsize) | |||||
| def fc_with_initialize(input_channels, out_channels): | def fc_with_initialize(input_channels, out_channels): | ||||
| return Dense(input_channels, out_channels) | return Dense(input_channels, out_channels) | ||||
| @@ -114,26 +119,23 @@ class TrainOneStepCell(Cell): | |||||
| return F.depend(loss, self.optimizer(grads)) | return F.depend(loss, self.optimizer(grads)) | ||||
| def test_trains(): | |||||
| def net_trains(gather_v2_strategy, criterion, rank): | |||||
| init() | init() | ||||
| lr = 0.1 | lr = 0.1 | ||||
| momentum = 0.9 | momentum = 0.9 | ||||
| max_epoch = 20 | max_epoch = 20 | ||||
| device_number = 32 | |||||
| batch_size_per_device = 128 | |||||
| input_channels = 256 | input_channels = 256 | ||||
| out_channels = 512 | out_channels = 512 | ||||
| context.set_context(mode=context.GRAPH_MODE, save_graphs=False) | |||||
| context.reset_auto_parallel_context() | context.reset_auto_parallel_context() | ||||
| context.set_auto_parallel_context(parallel_mode=ParallelMode.SEMI_AUTO_PARALLEL, device_num=device_number) | |||||
| context.set_auto_parallel_context(parallel_mode=ParallelMode.SEMI_AUTO_PARALLEL, device_num=device_number, | |||||
| global_rank=rank) | |||||
| predict = Tensor(np.ones([batch_size_per_device, input_channels]), dtype=ms.float32) | predict = Tensor(np.ones([batch_size_per_device, input_channels]), dtype=ms.float32) | ||||
| dataset = Dataset(predict, 4) | dataset = Dataset(predict, 4) | ||||
| network = fc_with_initialize(input_channels, out_channels) | network = fc_with_initialize(input_channels, out_channels) | ||||
| network.set_train() | network.set_train() | ||||
| criterion = get_loss(batch_size_per_device * device_number) | |||||
| train_network = BuildTrainNetwork(network, criterion) | train_network = BuildTrainNetwork(network, criterion) | ||||
| train_network.set_train() | train_network.set_train() | ||||
| opt = Momentum(train_network.trainable_params(), lr, momentum) | opt = Momentum(train_network.trainable_params(), lr, momentum) | ||||
| @@ -143,5 +145,90 @@ def test_trains(): | |||||
| model.train(max_epoch, dataset, dataset_sink_mode=False) | model.train(max_epoch, dataset, dataset_sink_mode=False) | ||||
| context.reset_auto_parallel_context() | context.reset_auto_parallel_context() | ||||
| if __name__ == "__main__": | |||||
| test_trains() | |||||
| def test_auto_batch_parallel(): | |||||
| gather_v2_strategy = None | |||||
| criterion = GatherV2(1, strategy=gather_v2_strategy, index_size=batch_size_per_device * device_number) | |||||
| rank = 2 | |||||
| net_trains(gather_v2_strategy, criterion, rank) | |||||
| def test_2d_index_auto_batch_parallel(): | |||||
| gather_v2_strategy = None | |||||
| criterion = GatherV2(2, strategy=gather_v2_strategy, index_size=batch_size_per_device * device_number) | |||||
| rank = 2 | |||||
| net_trains(gather_v2_strategy, criterion, rank) | |||||
| def test_batch_parallel(): | |||||
| gather_v2_strategy = ((device_number, 1),) | |||||
| criterion = GatherV2(1, strategy=gather_v2_strategy, index_size=batch_size_per_device * device_number) | |||||
| rank = 2 | |||||
| net_trains(gather_v2_strategy, criterion, rank) | |||||
| def test_strategy1(): | |||||
| gather_v2_strategy = ((16, 2),) | |||||
| rank = 2 | |||||
| criterion = GatherV2(1, strategy=gather_v2_strategy, index_size=batch_size_per_device * device_number) | |||||
| net_trains(gather_v2_strategy, criterion, rank) | |||||
| def test_strategy2(): | |||||
| gather_v2_strategy = ((1, device_number),) | |||||
| rank = 2 | |||||
| criterion = GatherV2(1, strategy=gather_v2_strategy, index_size=batch_size_per_device * device_number) | |||||
| net_trains(gather_v2_strategy, criterion, rank) | |||||
| def test_strategy3(): | |||||
| gather_v2_strategy = ((8, 1),) | |||||
| rank = 2 | |||||
| criterion = GatherV2(1, strategy=gather_v2_strategy, index_size=batch_size_per_device * device_number) | |||||
| net_trains(gather_v2_strategy, criterion, rank) | |||||
| class GatherV2Axis1(_Loss): | |||||
| def __init__(self, index_dim, strategy, index_size=16): | |||||
| super(GatherV2Axis1, self).__init__() | |||||
| self.pow = P.Pow() | |||||
| emb1_list = 21 | |||||
| emb2_list = 2 | |||||
| if index_dim == 1: | |||||
| emb_list = list(range(index_size)) | |||||
| emb1_list = emb_list[0::2] | |||||
| emb2_list = emb_list[1::2] | |||||
| if index_dim == 2: | |||||
| emb_list = np.arange(index_size*index_size) | |||||
| emb1_list = np.reshape(emb_list[0::2], (int(index_size/2), index_size)) | |||||
| emb2_list = np.reshape(emb_list[1::2], (int(index_size/2), index_size)) | |||||
| self.emb1_param = Tensor(emb1_list, dtype=mstype.int32) | |||||
| self.emb2_param = Tensor(emb2_list, dtype=mstype.int32) | |||||
| self.gatherv2 = P.GatherV2().set_strategy(strategy) | |||||
| def construct(self, nembeddings): | |||||
| emb1 = self.gatherv2(nembeddings, self.emb1_param, 1) | |||||
| emb2 = self.gatherv2(nembeddings, self.emb2_param, 1) | |||||
| return self.pow((emb1 - emb2), 2.0) | |||||
| def test_axis1_auto_batch_parallel(): | |||||
| gather_v2_strategy = None | |||||
| criterion = GatherV2Axis1(1, strategy=gather_v2_strategy, index_size=512) | |||||
| rank = 2 | |||||
| net_trains(gather_v2_strategy, criterion, rank) | |||||
| def test_axis1_batch_parallel(): | |||||
| gather_v2_strategy = ((device_number, 1),) | |||||
| criterion = GatherV2Axis1(1, strategy=gather_v2_strategy, index_size=512) | |||||
| rank = 2 | |||||
| net_trains(gather_v2_strategy, criterion, rank) | |||||
| def test_axis1_strategy1(): | |||||
| gather_v2_strategy = ((16, 2),) | |||||
| rank = 17 | |||||
| criterion = GatherV2Axis1(1, strategy=gather_v2_strategy, index_size=512) | |||||
| net_trains(gather_v2_strategy, criterion, rank) | |||||