Merge pull request !1023 from lichen/add_gatherv2_distributed_optags/v0.3.0-alpha
| @@ -787,5 +787,90 @@ double LayerNormCost::GetForwardComputationCost(const std::vector<TensorInfo> &i | |||
| } | |||
| return result; | |||
| } | |||
| double GatherV2PCost::GetForwardCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs, | |||
| int32_t stage_id) const { | |||
| double result = 0.0; | |||
| if (outputs_type_lengths_.size() != outputs.size()) { | |||
| MS_LOG(EXCEPTION) << "Invalid inputs type size " << inputs_type_lengths_.size() << " for gatherv2 cost"; | |||
| } | |||
| // don't split axis | |||
| if (strategy_.at(IntToSize(axis_)) == 1) { | |||
| return result; | |||
| } | |||
| // split axis | |||
| auto param_shape = inputs[0].slice_shape(); | |||
| auto index_shape = inputs[1].slice_shape(); | |||
| Shape reducescatter_shape = index_shape; | |||
| if (param_shape.size() == 2) { | |||
| reducescatter_shape.push_back(param_shape.at(1 - axis_)); | |||
| } | |||
| result += ListProduct(reducescatter_shape) * static_cast<double>(outputs_type_lengths_[0]); | |||
| return result; | |||
| } | |||
| double GatherV2PCost::GetBackwardCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs, | |||
| int32_t stage_id) const { | |||
| double result = 0.0; | |||
| CheckGlobalDeviceManager(); | |||
| MS_EXCEPTION_IF_NULL(g_device_manager); | |||
| auto total_device_num = g_device_manager->GetDeviceListByStageId(stage_id).size(); | |||
| for (size_t j = 0; j < inputs.size(); ++j) { | |||
| if (!is_parameter_[j]) { | |||
| continue; | |||
| } | |||
| TensorInfo input_a_tensor_info = inputs[j]; | |||
| Shape input_a_shape = input_a_tensor_info.shape(); | |||
| Shape input_a_slice_shape = input_a_tensor_info.slice_shape(); | |||
| int32_t used_device_num = 1; | |||
| for (size_t i = 0; i < input_a_shape.size(); ++i) { | |||
| used_device_num *= input_a_shape[i] / input_a_slice_shape[i]; | |||
| } | |||
| if (total_device_num != IntToSize(used_device_num)) { | |||
| result += ListProduct(input_a_slice_shape) * static_cast<double>(inputs_type_lengths_[0]); | |||
| } | |||
| } | |||
| return result; | |||
| } | |||
| double GatherV2PCost::GetForwardComputationCost(const std::vector<TensorInfo> &inputs, | |||
| const std::vector<TensorInfo> &outputs, int32_t stage_id) const { | |||
| double result = 0.0; | |||
| Shape input0_slice_shape = inputs[0].slice_shape(); | |||
| Shape input1_slice_shape = inputs[1].slice_shape(); | |||
| if (inputs_type_lengths_.size() != inputs.size()) { | |||
| MS_LOG(EXCEPTION) << "Invalid inputs type size " << inputs_type_lengths_.size() << " for gatherv2 cost"; | |||
| } | |||
| // don't split axis | |||
| if (strategy_.at(IntToSize(axis_)) == 1) { | |||
| result += ListProduct(input0_slice_shape) * static_cast<double>(inputs_type_lengths_[0]) + | |||
| ListProduct(input1_slice_shape) * static_cast<double>(inputs_type_lengths_[1]); | |||
| } else { | |||
| // split axis | |||
| result += ListProduct(input0_slice_shape) * static_cast<double>(inputs_type_lengths_[0]) * GATHERV2_COST_WEIGHT0 + | |||
| ListProduct(input1_slice_shape) * static_cast<double>(inputs_type_lengths_[1]) * GATHERV2_COST_WEIGHT1; | |||
| } | |||
| return result; | |||
| } | |||
| double GatherV2PCost::GetBackwardComputationCost(const std::vector<TensorInfo> &inputs, | |||
| const std::vector<TensorInfo> &outputs, int32_t) const { | |||
| double result = 0.0; | |||
| Shape input1_slice_shape = inputs[1].slice_shape(); | |||
| Shape output0_slice_shape = outputs[0].slice_shape(); | |||
| // don't split axis | |||
| if (strategy_.at(IntToSize(axis_)) == 1) { | |||
| result += ListProduct(output0_slice_shape) * static_cast<double>(inputs_type_lengths_[0]); | |||
| } else { | |||
| // split axis | |||
| result += ListProduct(output0_slice_shape) * static_cast<double>(inputs_type_lengths_[0]) * GATHERV2_COST_WEIGHT2 + | |||
| ListProduct(input1_slice_shape) * static_cast<double>(inputs_type_lengths_[1]) * GATHERV2_COST_WEIGHT3; | |||
| } | |||
| return result; | |||
| } | |||
| } // namespace parallel | |||
| } // namespace mindspore | |||
| @@ -27,6 +27,10 @@ namespace parallel { | |||
| #define MAXIMUM_INPUT_NUMBER 100 | |||
| #define DEFAULT_DATA_TYPE_LENGTH 4 | |||
| #define DROPOUT_COST_RATE 1.125 // the DropoutGenMask need 12.5% memory | |||
| #define GATHERV2_COST_WEIGHT0 3 | |||
| #define GATHERV2_COST_WEIGHT1 7 | |||
| #define GATHERV2_COST_WEIGHT2 2 | |||
| #define GATHERV2_COST_WEIGHT3 6 | |||
| class OperatorCost; | |||
| using OperatorCostPtr = std::shared_ptr<OperatorCost>; | |||
| @@ -609,6 +613,38 @@ class GatherV2Cost : public OperatorCost { | |||
| }; | |||
| using GatherV2CostPtr = std::shared_ptr<GatherV2Cost>; | |||
| class GatherV2PCost : public OperatorCost { | |||
| public: | |||
| explicit GatherV2PCost(bool is_inputs_related) : OperatorCost(is_inputs_related) {} | |||
| GatherV2PCost() : OperatorCost(true) {} | |||
| ~GatherV2PCost() override = default; | |||
| double GetCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs, | |||
| int32_t stage_id) const override { | |||
| return GetForwardCommCost(inputs, outputs, stage_id) + GetBackwardCommCost(inputs, outputs, stage_id); | |||
| } | |||
| double GetForwardCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs, | |||
| int32_t stage_id) const override; | |||
| double GetBackwardCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs, | |||
| int32_t stage_id) const override; | |||
| double GetComputationCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs, | |||
| int32_t stage_id) const override { | |||
| return GetForwardComputationCost(inputs, outputs, stage_id) + GetBackwardComputationCost(inputs, outputs, stage_id); | |||
| } | |||
| double GetForwardComputationCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs, | |||
| int32_t stage_id) const override; | |||
| double GetBackwardComputationCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs, | |||
| int32_t) const override; | |||
| void set_axis(int32_t axis) { axis_ = axis; } | |||
| void set_strategy(const Shape &strategy) { strategy_ = strategy; } | |||
| protected: | |||
| int32_t axis_; | |||
| Shape strategy_; | |||
| }; | |||
| using GatherV2PCostPtr = std::shared_ptr<GatherV2PCost>; | |||
| } // namespace parallel | |||
| } // namespace mindspore | |||
| #endif // PARALLEL_AUTO_PARALLEL_OPERATOR_COSTMODEL_H_ | |||
| @@ -129,6 +129,7 @@ REGISTER(ExpandDimsInfo); | |||
| REGISTER(SqueezeInfo); | |||
| REGISTER(SigmoidCrossEntropyWithLogitsInfo); | |||
| REGISTER(SquareInfo); | |||
| REGISTER(GatherV2PInfo); | |||
| } // namespace parallel | |||
| } // namespace mindspore | |||
| @@ -41,6 +41,7 @@ ValuePtr CreatOpInstance(const OperatorAttrs &attrs, const OperatorName &op_name | |||
| AnfNodePtr CreatTypeInt(int32_t value); | |||
| AnfNodePtr CreatInt32Imm(int32_t value); | |||
| AnfNodePtr CreateInt32Tensor(int32_t value); | |||
| AnfNodePtr ValuePtrToAnfNodePtr(const ValuePtr &value_ptr); | |||
| std::string HashInstanceName(const std::string &name); | |||
| class GenerateGraph { | |||
| @@ -0,0 +1,339 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "parallel/ops_info/gather_v2_p_info.h" | |||
| #include <vector> | |||
| #include <numeric> | |||
| #include <functional> | |||
| #include <utility> | |||
| #include "parallel/device_matrix.h" | |||
| #include "parallel/graph_util/generate_graph.h" | |||
| namespace mindspore { | |||
| namespace parallel { | |||
| Status GatherV2PInfo::GetAttrs() { | |||
| // get axis, the third input is the axis, is a ValueNode | |||
| if (input_value_.at(2) == nullptr) { | |||
| MS_LOG(ERROR) << name_ << ": the third input value is nullptr, is not a ValueNode!"; | |||
| return FAILED; | |||
| } | |||
| auto axis = GetValue<int>(input_value_.at(2)); | |||
| // if axis is negative then convert it to positive | |||
| auto params_shape = inputs_shape_.at(0); | |||
| if (params_shape.size() == 0) { | |||
| MS_LOG(ERROR) << name_ << ": params can not be a scalar!"; | |||
| return FAILED; | |||
| } | |||
| if (axis < 0) { | |||
| axis += SizeToInt(inputs_shape_[0].size()); | |||
| } | |||
| axis_ = axis; | |||
| return SUCCESS; | |||
| } | |||
| Status GatherV2PInfo::CheckStrategy(const StrategyPtr &strategy) { | |||
| if (CheckStrategyValue(strategy, {inputs_shape_.at(0)}, is_auto_parallel_) != SUCCESS) { | |||
| if (is_auto_parallel_) { | |||
| MS_LOG(DEBUG) << name_ << ": Invalid strategy."; | |||
| } else { | |||
| MS_LOG(ERROR) << name_ << ": Invalid strategy."; | |||
| } | |||
| return FAILED; | |||
| } | |||
| // only support 1-dim and 2-dim param | |||
| if (inputs_shape_.at(0).size() != 1 && inputs_shape_.at(0).size() != 2) { | |||
| MS_LOG(ERROR) << name_ << ": Don't support param dim " << inputs_shape_.at(0).size(); | |||
| return FAILED; | |||
| } | |||
| // don't support scalar index | |||
| if (inputs_shape_.at(1).size() == 0) { | |||
| MS_LOG(ERROR) << name_ << ": Don't support scalar index."; | |||
| return FAILED; | |||
| } | |||
| // axis=0, index_shape(0)%param_strategy(0) must be 0 | |||
| Shape index_shape = inputs_shape_.at(1); | |||
| auto param_strategy = strategy->GetInputDim().at(0); | |||
| if ((axis_ == 0) && (index_shape.at(0) % param_strategy.at(0) != 0)) { | |||
| MS_LOG(ERROR) << name_ << ": index_shape(0) can't be divided by param_strategy(0)."; | |||
| return FAILED; | |||
| } | |||
| // axis != 0, param_shape(0)%(param_strategy(0)*param_strategy(axis)) must be 0 | |||
| Shape param_shape = inputs_shape_.at(0); | |||
| if (axis_ != 0 && param_shape.at(0) % (param_strategy.at(0) * param_strategy.at(IntToSize(axis_))) != 0) { | |||
| MS_LOG(ERROR) << name_ << ": index_shape(0) can't be divided by (param_strategy(0)*param_strategy(axis))."; | |||
| return FAILED; | |||
| } | |||
| // Don't support repeated calc | |||
| auto params_strategy = strategy->GetInputDim().at(0); | |||
| CheckGlobalDeviceManager(); | |||
| size_t dev_num = g_device_manager->GetDeviceListByStageId(0).size(); | |||
| auto product = std::accumulate(params_strategy.begin(), params_strategy.end(), 1, std::multiplies<int>()); | |||
| if (dev_num != IntToSize(product)) { | |||
| MS_LOG(ERROR) << name_ << ": Invalid strategy. Don't support repeated calc."; | |||
| return FAILED; | |||
| } | |||
| return SUCCESS; | |||
| } | |||
| Status GatherV2PInfo::InferDevMatrixShape() { | |||
| dev_matrix_shape_.clear(); | |||
| out_dev_matrix_shape_.clear(); | |||
| // infer input dev_matrix_shape | |||
| auto params_strategy = strategy_->GetInputDim().at(0); | |||
| dev_matrix_shape_ = params_strategy; | |||
| // infer out dev_matrix_shape | |||
| // axis!=0, split axis | |||
| if (axis_ != 0 && params_strategy.at(IntToSize(axis_)) != 1) { | |||
| out_dev_matrix_shape_.push_back(params_strategy.at(0) * params_strategy.at(IntToSize(axis_))); | |||
| for (size_t i = 1; i < params_strategy.size(); ++i) { | |||
| if (i == IntToSize(axis_)) { | |||
| out_dev_matrix_shape_.push_back(1); | |||
| } else { | |||
| out_dev_matrix_shape_.push_back(params_strategy.at(i)); | |||
| } | |||
| } | |||
| } else { | |||
| out_dev_matrix_shape_ = params_strategy; | |||
| } | |||
| return SUCCESS; | |||
| } | |||
| Status GatherV2PInfo::InferTensorMap() { | |||
| // infer input tensor map | |||
| size_t param_size = inputs_shape_.at(0).size(); | |||
| size_t index_size = inputs_shape_.at(1).size(); | |||
| std::vector<int32_t> tensor_map_index(index_size, -1); | |||
| std::vector<int32_t> tensor_map_params; | |||
| for (size_t i = 0; i < param_size; ++i) { | |||
| tensor_map_params.push_back(SizeToInt(param_size - i - 1)); | |||
| } | |||
| // infer output tensor map | |||
| std::vector<int32_t> tensor_map_out; | |||
| if (axis_ == 0) { | |||
| tensor_map_out.push_back(SizeToInt(param_size - 1)); | |||
| tensor_map_out.insert(tensor_map_out.end(), index_size - 1, -1); | |||
| for (size_t i = 1; i < param_size; ++i) { | |||
| tensor_map_out.push_back(SizeToInt(param_size - i - 1)); | |||
| } | |||
| } else { | |||
| for (size_t i = 0; i < param_size; ++i) { | |||
| if (i == IntToSize(axis_)) { | |||
| tensor_map_out.insert(tensor_map_out.end(), index_size, -1); | |||
| } else { | |||
| tensor_map_out.push_back(SizeToInt(param_size - i - 1)); | |||
| } | |||
| } | |||
| } | |||
| inputs_tensor_map_.emplace_back(std::move(tensor_map_params)); | |||
| inputs_tensor_map_.emplace_back(std::move(tensor_map_index)); | |||
| outputs_tensor_map_.emplace_back(std::move(tensor_map_out)); | |||
| return SUCCESS; | |||
| } | |||
| Status GatherV2PInfo::InferTensorInfo() { | |||
| // infer tensor shape | |||
| Shape input_shape = inputs_shape_.at(0); | |||
| Shape input_index_shape = inputs_shape_.at(1); | |||
| Shape output_shape = outputs_shape_.at(0); | |||
| // infer tensor layout | |||
| TensorLayout input_tensor_layout, input_index_layout, output_tensor_layout; | |||
| if ((input_tensor_layout.InitFromVector(dev_matrix_shape_, inputs_tensor_map_.at(0), input_shape) != SUCCESS) || | |||
| (input_index_layout.InitFromVector(dev_matrix_shape_, inputs_tensor_map_.at(1), input_index_shape) != SUCCESS) || | |||
| (output_tensor_layout.InitFromVector(out_dev_matrix_shape_, outputs_tensor_map_.at(0), output_shape) != | |||
| SUCCESS)) { | |||
| return FAILED; | |||
| } | |||
| // infer tensor info | |||
| TensorInfo input_tensor_info(input_tensor_layout); | |||
| TensorInfo input_index_info(input_index_layout); | |||
| TensorInfo output_tensor_info(output_tensor_layout); | |||
| inputs_tensor_info_.push_back(input_tensor_info); | |||
| inputs_tensor_info_.push_back(input_index_info); | |||
| outputs_tensor_info_.push_back(output_tensor_info); | |||
| return SUCCESS; | |||
| } | |||
| Status GatherV2PInfo::InferBias() { | |||
| CheckGlobalDeviceManager(); | |||
| int32_t rank = g_device_manager->global_rank(); | |||
| auto input_shape = inputs_shape_.at(0); | |||
| auto params_strategy = strategy_->GetInputDim().at(0); | |||
| // params_size=1, axis=0 | |||
| if ((input_shape.size() == 1) && (axis_ == 0)) { | |||
| slice_size_ = input_shape.at(0) / params_strategy.at(0); | |||
| bias_ = rank * slice_size_; | |||
| return SUCCESS; | |||
| } | |||
| // params_size=2, axis=0 | |||
| if ((input_shape.size() == 2) && (axis_ == 0)) { | |||
| slice_size_ = input_shape.at(0) / params_strategy.at(0); | |||
| bias_ = rank / params_strategy.at(1) * slice_size_; | |||
| return SUCCESS; | |||
| } | |||
| // params_size=2, axis=1 | |||
| if ((input_shape.size() == 2) && (axis_ == 1)) { | |||
| slice_size_ = input_shape.at(1) / params_strategy.at(1); | |||
| bias_ = rank % params_strategy.at(1) * slice_size_; | |||
| return SUCCESS; | |||
| } | |||
| MS_LOG(ERROR) << name_ << ": Don't support params_size:" << input_shape.size() << " axis:" << axis_; | |||
| return FAILED; | |||
| } | |||
| Status GatherV2PInfo::InferGroup() { | |||
| std::vector<Group> group_list; | |||
| if (CreateGroupByDim(IntToSize(axis_), &group_list) != SUCCESS) { | |||
| MS_LOG(ERROR) << name_ << ": Create group failed."; | |||
| return FAILED; | |||
| } | |||
| group_ = group_list.at(0); | |||
| return SUCCESS; | |||
| } | |||
| Status GatherV2PInfo::ComputeReplaceGraph(const CNodePtr &cnode) { | |||
| GenerateGraph gen_g = GenerateGraph(); | |||
| if (gen_g.Init(cnode) != SUCCESS) { | |||
| MS_LOG(ERROR) << "GenerateGraph Init failed"; | |||
| return FAILED; | |||
| } | |||
| if (InferBias() != SUCCESS) { | |||
| MS_LOG(ERROR) << name_ << ": Infer Bias failed."; | |||
| return FAILED; | |||
| } | |||
| auto sub = gen_g.PushBack({gen_g.NewOpInst(SUB), gen_g.virtual_input_node(), CreateInt32Tensor(bias_)}); | |||
| auto relu = gen_g.PushBack({gen_g.NewOpInst(RELU), sub}); | |||
| auto minimum = gen_g.PushBack({gen_g.NewOpInst(MINIMUM), relu, CreateInt32Tensor(slice_size_ - 1)}); | |||
| auto equal = gen_g.PushBack({gen_g.NewOpInst(EQUAL), gen_g.virtual_input_node(), minimum}); | |||
| auto gather_v2 = | |||
| gen_g.PushBack({gen_g.NewOpInst(GATHERV2), gen_g.virtual_input_node(), minimum, CreatInt32Imm(axis_)}); | |||
| auto dtype = gen_g.PushBack({gen_g.NewOpInst(DTYPE), gather_v2}); | |||
| auto cast = gen_g.PushBack({gen_g.NewOpInst(CAST), equal, dtype}); | |||
| auto expand_dims = gen_g.PushBack({gen_g.NewOpInst(EXPAND_DIMS), cast, CreatInt32Imm(axis_ - 1)}); | |||
| auto mul = gen_g.PushBack({gen_g.NewOpInst(MUL), gather_v2, expand_dims}); | |||
| // don't need expandim,if param_size = 1, | |||
| if (inputs_shape_.at(0).size() == 1) { | |||
| mul = gen_g.PushBack({gen_g.NewOpInst(MUL), gather_v2, cast}); | |||
| } | |||
| if (InferGroup() != SUCCESS) { | |||
| MS_LOG(ERROR) << name_ << ": Infer Group failed."; | |||
| return FAILED; | |||
| } | |||
| Attr attr_op = std::make_pair(OP, MakeValue(REDUCE_OP_SUM)); | |||
| Attr attr_group = std::make_pair(GROUP, MakeValue(group_.name())); | |||
| OperatorAttrs attrs = {attr_op, attr_group}; | |||
| auto reduce_scatter = gen_g.PushBack({gen_g.NewOpInst(REDUCE_SCATTER, attrs), mul}); | |||
| std::vector<std::pair<AnfNodePtr, int>> input_nodes = {std::make_pair(sub, 2), std::make_pair(gather_v2, 1), | |||
| std::make_pair(equal, 2)}; | |||
| replace_graph_ = std::make_shared<std::pair<std::vector<std::pair<AnfNodePtr, int>>, AnfNodePtr>>( | |||
| std::make_pair(input_nodes, reduce_scatter)); | |||
| return SUCCESS; | |||
| } | |||
| Status GatherV2PInfo::Init(const StrategyPtr &strategy) { | |||
| auto param_strategy = strategy->GetInputDim().at(0); | |||
| if (InitWithAutoRepeatCalc(strategy) != SUCCESS) { | |||
| MS_LOG(ERROR) << name_ << ": Init failed."; | |||
| return FAILED; | |||
| } | |||
| if (param_strategy.at(IntToSize(axis_)) != 1 && ComputeReplaceGraph(cnode_) != SUCCESS) { | |||
| MS_LOG(ERROR) << name_ << ": ComputeReplaceGraph failed."; | |||
| return FAILED; | |||
| } | |||
| MS_LOG(INFO) << name_ << ": Init success."; | |||
| return SUCCESS; | |||
| } | |||
| Status GatherV2PInfo::InitForCostModel(const StrategyPtr &strategy) { | |||
| if (InitForCostModelWithAutoRepeatCalc(strategy) != SUCCESS) { | |||
| if (is_auto_parallel_) { | |||
| MS_LOG(DEBUG) << name_ << ": Init for cost model failed."; | |||
| } else { | |||
| MS_LOG(ERROR) << name_ << ": Init for cost model failed."; | |||
| } | |||
| return FAILED; | |||
| } | |||
| auto param_strategy = strategy_->GetInputDim().at(0); | |||
| // cost model set axis and strategy | |||
| auto gatherv2_2cost = std::dynamic_pointer_cast<GatherV2PCost>(operator_cost()); | |||
| gatherv2_2cost->set_axis(axis_); | |||
| gatherv2_2cost->set_strategy(param_strategy); | |||
| MS_LOG(INFO) << name_ << ": Init for cost model success."; | |||
| return SUCCESS; | |||
| } | |||
| Status GatherV2PInfo::SetCostUnderStrategy(const StrategyPtr &strategy) { | |||
| if (SetCostUnderStrategyBase(strategy) != SUCCESS) { | |||
| if (is_auto_parallel_) { | |||
| MS_LOG(DEBUG) << name_ << ": Set cost under strategy failed."; | |||
| } else { | |||
| MS_LOG(ERROR) << name_ << ": Set cost under strategy failed."; | |||
| } | |||
| return FAILED; | |||
| } | |||
| return SUCCESS; | |||
| } | |||
| Status GatherV2PInfo::GenerateStrategies(int32_t stage_id) { | |||
| is_auto_parallel_ = true; | |||
| Shape input0_split(inputs_shape_[0].size(), 1); | |||
| Shapes splittable_inputs = {input0_split}; | |||
| std::vector<StrategyPtr> sp_vector; | |||
| if (GenerateStrategiesForIndependentInputs(stage_id, {inputs_shape_.at(0)}, splittable_inputs, &sp_vector) != | |||
| SUCCESS) { | |||
| MS_LOG(ERROR) << name_ << " : Generate strategies for independent inputs() failed."; | |||
| return FAILED; | |||
| } | |||
| size_t success = 0; | |||
| for (auto &sp : sp_vector) { | |||
| if (SetCostUnderStrategy(sp) == SUCCESS) { | |||
| success++; | |||
| MS_LOG(INFO) << name_ << " : Successfully generated " << success << " strategy"; | |||
| PrintStrategy(sp); | |||
| } | |||
| } | |||
| return SUCCESS; | |||
| } | |||
| std::shared_ptr<std::vector<std::vector<int32_t>>> GatherV2PInfo::GenerateBatchStrategies() { | |||
| CheckGlobalDeviceManager(); | |||
| size_t dev_num = g_device_manager->GetDeviceListByStageId(0).size(); | |||
| Dimensions strategy; | |||
| strategy.push_back(SizeToInt(dev_num)); | |||
| for (size_t i = 1; i < inputs_shape_[0].size(); i++) { | |||
| strategy.push_back(1); | |||
| } | |||
| std::vector<Dimensions> strategy_v = {strategy}; | |||
| return std::make_shared<std::vector<std::vector<int32_t>>>(strategy_v); | |||
| } | |||
| } // namespace parallel | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,67 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_PARALLEL_OPS_INFO_GATHER_V2_P_INFO_H_ | |||
| #define MINDSPORE_CCSRC_PARALLEL_OPS_INFO_GATHER_V2_P_INFO_H_ | |||
| #include <memory> | |||
| #include <string> | |||
| #include <unordered_map> | |||
| #include <vector> | |||
| #include "ir/value.h" | |||
| #include "parallel/auto_parallel/operator_costmodel.h" | |||
| #include "parallel/ops_info/operator_info.h" | |||
| #include "parallel/strategy.h" | |||
| namespace mindspore { | |||
| namespace parallel { | |||
| class GatherV2PInfo : public OperatorInfo { | |||
| public: | |||
| GatherV2PInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, | |||
| const PrimitiveAttrs &attrs) | |||
| : OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared<GatherV2PCost>()) {} | |||
| ~GatherV2PInfo() override = default; | |||
| Status Init(const StrategyPtr &strategy) override; | |||
| Status InitForCostModel(const StrategyPtr &strategy) override; | |||
| Status GenerateStrategies(int32_t stage_id) override; | |||
| Status SetCostUnderStrategy(const StrategyPtr &strategy) override; | |||
| std::shared_ptr<std::vector<std::vector<int32_t>>> GenerateBatchStrategies() override; | |||
| protected: | |||
| Status CheckStrategy(const StrategyPtr &strategy) override; | |||
| Status InferMirrorOps() override { return SUCCESS; } | |||
| Status InferForwardCommunication() override { return SUCCESS; } | |||
| Status InferTensorInfo() override; | |||
| Status InferDevMatrixShape() override; | |||
| Status InferTensorMap() override; | |||
| Status GetAttrs() override; | |||
| private: | |||
| Status ComputeReplaceGraph(const CNodePtr &cnode); | |||
| Status InferBias(); | |||
| Status InferGroup(); | |||
| int32_t axis_; | |||
| int32_t bias_; | |||
| int32_t slice_size_; | |||
| Shape out_dev_matrix_shape_; | |||
| Group group_; | |||
| }; | |||
| } // namespace parallel | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_PARALLEL_OPS_INFO_GATHER_V2_P_INFO_H_ | |||
| @@ -215,9 +215,9 @@ Status OneHotInfo::ComputeReplaceGraph(const CNodePtr &cnode) { | |||
| OperatorAttrs attrs_onehot = {attr_onehot_axis}; | |||
| auto onehot = gen_g.PushBack({gen_g.NewOpInst(ONEHOT, attrs_onehot), sub2, CreatInt32Imm(classes_each_device_), | |||
| cnode->input(3), cnode->input(4)}); | |||
| std::vector<AnfNodePtr> input_nodes = {floor_div, sub1}; | |||
| replace_graph_ = | |||
| std::make_shared<std::pair<std::vector<AnfNodePtr>, AnfNodePtr>>(std::make_pair(input_nodes, onehot)); | |||
| std::vector<std::pair<AnfNodePtr, int>> input_nodes = {std::make_pair(floor_div, 1), std::make_pair(sub1, 1)}; | |||
| replace_graph_ = std::make_shared<std::pair<std::vector<std::pair<AnfNodePtr, int>>, AnfNodePtr>>( | |||
| std::make_pair(input_nodes, onehot)); | |||
| return SUCCESS; | |||
| } | |||
| @@ -48,7 +48,7 @@ using TensorLayouts = std::vector<TensorLayout>; | |||
| using different_type = std::vector<int32_t>::difference_type; | |||
| using PrimitiveAttrs = std::unordered_map<std::string, ValuePtr>; | |||
| using Strategys = std::vector<Dimensions>; | |||
| using ReplaceGraphPtr = std::shared_ptr<std::pair<std::vector<AnfNodePtr>, AnfNodePtr>>; | |||
| using ReplaceGraphPtr = std::shared_ptr<std::pair<std::vector<std::pair<AnfNodePtr, int>>, AnfNodePtr>>; | |||
| class Edge; | |||
| @@ -36,5 +36,6 @@ | |||
| #include "parallel/ops_info/reshape_info.h" | |||
| #include "parallel/ops_info/transpose_info.h" | |||
| #include "parallel/ops_info/virtual_dataset_info.h" | |||
| #include "parallel/ops_info/gather_v2_p_info.h" | |||
| #endif // MINDSPORE_CCSRC_PARALLEL_OPS_INFO_HEAD_FILES_H_ | |||
| @@ -114,7 +114,7 @@ constexpr char BE_CLONED_INDEX[] = "be_cloned_index"; | |||
| constexpr char GROUP_RANKS[] = "group_ranks"; | |||
| constexpr char IS_IN_FORWARD[] = "is_in_forward"; | |||
| constexpr char DEFAULT_INPUT[] = "default_input"; | |||
| constexpr char DTYPE[] = "dtype"; | |||
| constexpr char DTYPE[] = "DType"; | |||
| constexpr char DEV_NUM[] = "dev_num"; | |||
| constexpr char MEAN_FLAG[] = "mean_flag"; | |||
| constexpr char TYPES[] = "types"; | |||
| @@ -124,6 +124,7 @@ constexpr char SHARED_NAME[] = "shared_name"; | |||
| constexpr char MIRROR_OP[] = "mirror_op"; | |||
| constexpr char FORWARD_OP[] = "forward_op"; | |||
| constexpr char REDISTRIBUTION_OP[] = "redistribution_op"; | |||
| constexpr char DARA_PARALLEL[] = "data_parallel"; | |||
| // Operator | |||
| constexpr char VIRTUAL_DIV[] = "_VirtualDiv"; | |||
| @@ -605,8 +605,7 @@ bool IsSomePrimitive(const CNodePtr &cnode, const std::string &name) { | |||
| return (prim->name() == name); | |||
| } | |||
| void StepReplaceGraph(const std::shared_ptr<std::pair<std::vector<AnfNodePtr>, AnfNodePtr>> &replace_graph, | |||
| const CNodePtr &node) { | |||
| void StepReplaceGraph(const ReplaceGraphPtr &replace_graph, const CNodePtr &node) { | |||
| MS_EXCEPTION_IF_NULL(replace_graph); | |||
| MS_EXCEPTION_IF_NULL(node); | |||
| MS_EXCEPTION_IF_NULL(replace_graph->second); | |||
| @@ -616,20 +615,10 @@ void StepReplaceGraph(const std::shared_ptr<std::pair<std::vector<AnfNodePtr>, A | |||
| if (manager == nullptr) { | |||
| MS_LOG(EXCEPTION) << "Failure:AddNode error since manager is nullptr"; | |||
| } | |||
| if (!IsSomePrimitive(node, ONEHOT)) { | |||
| MS_LOG(EXCEPTION) << "Failure:Only OneHot Primitive will enter StepReplaceGraph!"; | |||
| } | |||
| if (node->inputs().size() != 5) { | |||
| MS_LOG(EXCEPTION) << "Failure:There is 5 inputs for the CNode corresponding to OneHot Primitive!"; | |||
| } | |||
| auto pre_node = node->input(1); | |||
| if (replace_graph->first.size() != 2) { | |||
| MS_LOG(EXCEPTION) << "Failure:replace_graph->first.size() must be 2 for OneHot Primitive!"; | |||
| } | |||
| for (auto &replace_input : replace_graph->first) { | |||
| MS_EXCEPTION_IF_NULL(replace_input); | |||
| manager->SetEdge(replace_input, 1, pre_node); | |||
| CNodePtr replace_input_cnode = replace_input->cast<CNodePtr>(); | |||
| auto pre_node = node->input(IntToSize(replace_input.second)); | |||
| manager->SetEdge(replace_input.first, 1, pre_node); | |||
| auto replace_input_cnode = replace_input.first->cast<CNodePtr>(); | |||
| MS_EXCEPTION_IF_NULL(replace_input_cnode); | |||
| (void)replace_input_cnode->set_operator_info(node->operator_info()); | |||
| replace_input_cnode->set_in_forward_flag(true); // mark this new cnode is forward node | |||
| @@ -943,6 +932,20 @@ OperatorInfoPtr OperatorInstanceByName(const std::string &name, const PrimitiveA | |||
| MS_LOG(EXCEPTION) << "Length of name is zero!"; | |||
| } | |||
| std::string distribute_opname = GetDisOpName(name); | |||
| if (name == GATHERV2) { | |||
| distribute_opname = name + "PInfo"; | |||
| auto data_parallel_iter = attrs.find(DATA_PARALLEL); | |||
| if (data_parallel_iter != attrs.end()) { | |||
| MS_EXCEPTION_IF_NULL(data_parallel_iter->second); | |||
| if (!data_parallel_iter->second->isa<BoolImm>()) { | |||
| MS_LOG(EXCEPTION) << ": data_parallel flag's type is not a bool."; | |||
| } | |||
| bool data_parallel = data_parallel_iter->second->cast<BoolImmPtr>()->value(); | |||
| if (data_parallel) { | |||
| distribute_opname = name + "Info"; | |||
| } | |||
| } | |||
| } | |||
| OperatorInfoPtr operator_ = | |||
| (OperatorInfoPtr)DynCreator::Instance().Creat(distribute_opname, shape_list[0], shape_list[1], attrs, TOTAL_OPS); | |||
| if (operator_ == nullptr) { | |||
| @@ -0,0 +1,173 @@ | |||
| # Copyright 2019 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| import numpy as np | |||
| import mindspore as ms | |||
| from mindspore import Tensor | |||
| from mindspore import context | |||
| import mindspore.nn as nn | |||
| from mindspore.ops import operations as P | |||
| from mindspore.ops import composite as C | |||
| from mindspore.common import dtype as mstype | |||
| from mindspore.common.api import _executor | |||
| from tests.ut.python.ops.test_math_ops import VirtualLoss | |||
| class NetWithLoss(nn.Cell): | |||
| def __init__(self, network): | |||
| super(NetWithLoss, self).__init__() | |||
| self.loss = VirtualLoss() | |||
| self.network = network | |||
| def construct(self, x, y): | |||
| predict = self.network(x, y) | |||
| return self.loss(predict) | |||
| class GradWrap(nn.Cell): | |||
| def __init__(self, network): | |||
| super(GradWrap, self).__init__() | |||
| self.network = network | |||
| def construct(self, x, y): | |||
| return C.grad_all(self.network)(x, y) | |||
| class Net(nn.Cell): | |||
| def __init__(self, axis=0, strategy1=None, strategy2=None, shape=[64, 64]): | |||
| super().__init__() | |||
| self.gatherv2 = P.GatherV2().set_strategy(strategy1) | |||
| self.mul = P.Mul().set_strategy(strategy2) | |||
| self.index = Tensor(np.ones(shape), dtype=ms.int32) | |||
| self.axis = axis | |||
| def construct(self, x, y): | |||
| out = self.gatherv2(x, self.index, self.axis) | |||
| out = self.mul(out, y) | |||
| return out | |||
| def test_gatherv2_semi_auto0(): | |||
| context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel") | |||
| strategy1 = ((1, 8), ) | |||
| strategy2 = ((4, 2, 1), (4, 2, 1)) | |||
| net = GradWrap(NetWithLoss(Net(0, strategy1, strategy2))) | |||
| net.set_auto_parallel() | |||
| x = Tensor(np.ones([64, 32]), dtype=ms.float32) | |||
| y = Tensor(np.ones([64, 64, 32]), dtype=ms.float32) | |||
| _executor.compile(net, x, y) | |||
| def test_gatherv2_semi_auto1(): | |||
| context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel") | |||
| strategy1 = ((8, 1), ) | |||
| strategy2 = ((4, 2, 1), (4, 2, 1)) | |||
| net = GradWrap(NetWithLoss(Net(0, strategy1, strategy2))) | |||
| net.set_auto_parallel() | |||
| x = Tensor(np.ones([64, 32]), dtype=ms.float32) | |||
| y = Tensor(np.ones([64, 64, 32]), dtype=ms.float32) | |||
| _executor.compile(net, x, y) | |||
| def test_gatherv2_semi_auto2(): | |||
| context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel") | |||
| strategy1 = ((2, 4), ) | |||
| strategy2 = ((4, 2, 1), (4, 2, 1)) | |||
| net = GradWrap(NetWithLoss(Net(0, strategy1, strategy2))) | |||
| net.set_auto_parallel() | |||
| x = Tensor(np.ones([64, 32]), dtype=ms.float32) | |||
| y = Tensor(np.ones([64, 64, 32]), dtype=ms.float32) | |||
| _executor.compile(net, x, y) | |||
| def test_gatherv2_semi_auto3(): | |||
| context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel") | |||
| strategy1 = ((1, 8), ) | |||
| strategy2 = ((4, 2, 1), (4, 2, 1)) | |||
| net = GradWrap(NetWithLoss(Net(1, strategy1, strategy2))) | |||
| net.set_auto_parallel() | |||
| x = Tensor(np.ones([64, 32]), dtype=ms.float32) | |||
| y = Tensor(np.ones([64, 64, 64]), dtype=ms.float32) | |||
| _executor.compile(net, x, y) | |||
| def test_gatherv2_semi_auto4(): | |||
| context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel") | |||
| strategy1 = ((8, 1), ) | |||
| strategy2 = ((4, 2, 1), (4, 2, 1)) | |||
| net = GradWrap(NetWithLoss(Net(1, strategy1, strategy2))) | |||
| net.set_auto_parallel() | |||
| x = Tensor(np.ones([64, 32]), dtype=ms.float32) | |||
| y = Tensor(np.ones([64, 64, 64]), dtype=ms.float32) | |||
| _executor.compile(net, x, y) | |||
| def test_gatherv2_semi_auto5(): | |||
| context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel") | |||
| strategy1 = ((2, 4), ) | |||
| strategy2 = ((4, 2, 1), (4, 2, 1)) | |||
| net = GradWrap(NetWithLoss(Net(1, strategy1, strategy2))) | |||
| net.set_auto_parallel() | |||
| x = Tensor(np.ones([64, 32]), dtype=ms.float32) | |||
| y = Tensor(np.ones([64, 64, 64]), dtype=ms.float32) | |||
| _executor.compile(net, x, y) | |||
| def test_gatherv2_semi_auto6(): | |||
| context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel") | |||
| strategy2 = ((4, 2, 1), (4, 2, 1)) | |||
| net = GradWrap(NetWithLoss(Net(0, None, strategy2))) | |||
| net.set_auto_parallel() | |||
| x = Tensor(np.ones([64, 32]), dtype=ms.float32) | |||
| y = Tensor(np.ones([64, 64, 32]), dtype=ms.float32) | |||
| _executor.compile(net, x, y) | |||
| def test_gatherv2_semi_auto7(): | |||
| context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel") | |||
| strategy2 = ((4, 2, 1), (4, 2, 1)) | |||
| net = GradWrap(NetWithLoss(Net(1, None, strategy2))) | |||
| net.set_auto_parallel() | |||
| x = Tensor(np.ones([64, 32]), dtype=ms.float32) | |||
| y = Tensor(np.ones([64, 64, 64]), dtype=ms.float32) | |||
| _executor.compile(net, x, y) | |||
| def test_gatherv2_semi_auto8(): | |||
| context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel") | |||
| strategy1 = ((8, ), ) | |||
| strategy2 = ((4, 2), (4, 2)) | |||
| net = GradWrap(NetWithLoss(Net(0, strategy1, strategy2))) | |||
| net.set_auto_parallel() | |||
| x = Tensor(np.ones([64]), dtype=ms.float32) | |||
| y = Tensor(np.ones([64, 64]), dtype=ms.float32) | |||
| _executor.compile(net, x, y) | |||
| def test_gatherv2_auto0(): | |||
| context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="auto_parallel") | |||
| net = GradWrap(NetWithLoss(Net(0))) | |||
| net.set_auto_parallel() | |||
| x = Tensor(np.ones([64, 32]), dtype=ms.float32) | |||
| y = Tensor(np.ones([64, 64, 32]), dtype=ms.float32) | |||
| _executor.compile(net, x, y) | |||
| def test_gatherv2_auto1(): | |||
| context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="auto_parallel") | |||
| net = GradWrap(NetWithLoss(Net(1))) | |||
| net.set_auto_parallel() | |||
| x = Tensor(np.ones([64, 32]), dtype=ms.float32) | |||
| y = Tensor(np.ones([64, 64, 64]), dtype=ms.float32) | |||
| _executor.compile(net, x, y) | |||
| @@ -74,7 +74,7 @@ class GatherV2(_Loss): | |||
| emb2_list = np.reshape(emb_list[1::2], (int(index_size/2), 16)) | |||
| self.emb1_param = Tensor(emb1_list, dtype=mstype.int32) | |||
| self.emb2_param = Tensor(emb2_list, dtype=mstype.int32) | |||
| self.gatherv2 = P.GatherV2().set_strategy(strategy) | |||
| self.gatherv2 = P.GatherV2().set_strategy(strategy).add_prim_attr("data_parallel", True) | |||
| def construct(self, nembeddings): | |||
| emb1 = self.gatherv2(nembeddings, self.emb1_param, 0) | |||