| @@ -753,6 +753,7 @@ class ReduceSumCost : public OperatorCost { | |||
| }; | |||
| using ReduceMethodCost = ReduceSumCost; | |||
| using ReduceProdCost = ReduceSumCost; | |||
| using SquareSumAllCost = ReduceSumCost; | |||
| class ReduceMeanCost : public ReduceSumCost { | |||
| public: | |||
| @@ -784,6 +785,7 @@ class ArgMaxWithValueCost : public ReduceSumCost { | |||
| void CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) override; | |||
| }; | |||
| using ArgMinWithValueCost = ArgMaxWithValueCost; | |||
| using ArgmaxCost = ArgMaxWithValueCost; | |||
| class GetNextCost : public OperatorCost { | |||
| public: | |||
| @@ -919,6 +921,7 @@ class UnsortedSegmentSumCost : public OperatorCost { | |||
| // Taking account of input | |||
| void CalculateInputsInMemory(const std::map<size_t, bool> &prev_output_in_mem) override; | |||
| }; | |||
| using UnsortedSegmentProdCost = UnsortedSegmentSumCost; | |||
| class UnsortedSegmentMinCost : public OperatorCost { | |||
| public: | |||
| @@ -217,6 +217,10 @@ REGISTER(CropAndResizeInfo); | |||
| REGISTER(ROIAlignInfo); | |||
| REGISTER(ReduceProdInfo); | |||
| REGISTER(ReduceAllInfo); | |||
| REGISTER(ArgmaxInfo); | |||
| REGISTER(ArgminInfo); | |||
| REGISTER(UnsortedSegmentProdInfo); | |||
| REGISTER(SquareSumAllInfo); | |||
| } // namespace parallel | |||
| } // namespace mindspore | |||
| @@ -111,6 +111,7 @@ constexpr char FUNCTIONAL_OP_PATH[] = "mindspore.ops.functional"; | |||
| constexpr char GET_OP_FUNCTION_PATH[] = "mindspore.parallel._utils"; | |||
| constexpr char GET_OP_FUNCTION[] = "_get_python_op"; | |||
| constexpr char KEEP_DIMS[] = "keep_dims"; | |||
| constexpr char OUTPUT_TYPE[] = "output_type"; | |||
| constexpr char CROSS_BATCH[] = "cross_batch"; | |||
| constexpr char STEP_PARALLEL_BEGIN[] = "step_parallel_begin"; | |||
| constexpr char STEP_PARALLEL_END[] = "step_parallel_end"; | |||
| @@ -335,6 +336,8 @@ constexpr char REDUCE_ALL[] = "ReduceAll"; | |||
| constexpr char REDUCE_ANY[] = "ReduceAny"; | |||
| constexpr char ARGMAXWITHVALUE[] = "ArgMaxWithValue"; | |||
| constexpr char ARGMINWITHVALUE[] = "ArgMinWithValue"; | |||
| constexpr char ARGMAX[] = "Argmax"; | |||
| constexpr char ARGMIN[] = "Argmin"; | |||
| constexpr char CONV2D[] = "Conv2D"; | |||
| constexpr char CONV2D_BACK_PROP_INPUT[] = "Conv2DBackpropInput"; | |||
| constexpr char CONV2D_TRANSPOSE[] = "Conv2DTranspose"; | |||
| @@ -431,6 +434,7 @@ constexpr char UNSORTEF_SEGMENT_PRODD[] = "UnsortedSegmentProdD"; | |||
| constexpr char UNSORTED_SEGMENT_SUM[] = "UnsortedSegmentSum"; | |||
| constexpr char UNSORTED_SEGMENT_MIN[] = "UnsortedSegmentMin"; | |||
| constexpr char UNSORTED_SEGMENT_MAX[] = "UnsortedSegmentMax"; | |||
| constexpr char UNSORTED_SEGMENT_PROD[] = "UnsortedSegmentProd"; | |||
| constexpr char DEPTHWISE_CONV2D_NATIVE[] = "DepthwiseConv2dNative"; | |||
| constexpr char DEPTHWISE_CONV2D[] = "DepthwiseConv2D"; | |||
| constexpr char DROPOUT[] = "Dropout"; | |||
| @@ -449,6 +453,7 @@ constexpr char RANDOM_CHOICE_WITH_MASK[] = "RandomChoiceWithMask"; | |||
| constexpr char CROP_AND_RESIZE[] = "CropAndResize"; | |||
| constexpr char MASKED_FILL[] = "MaskedFill"; | |||
| constexpr char ROI_ALIGN[] = "ROIAlign"; | |||
| constexpr char SQUARE_SUM_ALL[] = "SquareSumAll"; | |||
| // pipeline | |||
| constexpr size_t PIPELINE_FUSTION_OFFSET = 100; | |||
| @@ -482,6 +487,7 @@ constexpr char MAKE_RECORD[] = "make_record"; | |||
| constexpr char LIST_GETITEM[] = "list_getitem"; | |||
| constexpr char ARRAY_GETITEM[] = "array_getitem"; | |||
| constexpr char TUPLE_SETITEM[] = "tuple_setitem"; | |||
| constexpr char TUPLE_GETITEM[] = "tuple_getitem"; | |||
| constexpr char LIST_SETITEM[] = "list_setitem"; | |||
| constexpr char ARRAY_SETITEM[] = "array_setitem"; | |||
| constexpr char DICT_GETITEM[] = "dict_getitem"; | |||
| @@ -26,6 +26,7 @@ | |||
| #include "frontend/parallel/device_matrix.h" | |||
| #include "frontend/parallel/tensor_layout/tensor_redistribution.h" | |||
| #include "utils/log_adapter.h" | |||
| #include "frontend/parallel/graph_util/generate_graph.h" | |||
| namespace mindspore { | |||
| namespace parallel { | |||
| @@ -578,15 +579,321 @@ Status ArgMaxWithValueInfo::InferAsLossDivisor() { | |||
| return SUCCESS; | |||
| } | |||
| std::vector<StrategyPtr> ArgMaxWithValueInfo::GenerateOpStrategies(int64_t stage_id) { | |||
| Shape input0_split(inputs_shape_[0].size(), 1); | |||
| Shapes splittable_inputs = {input0_split}; | |||
| std::vector<StrategyPtr> sp_vector; | |||
| if (GenerateStrategiesForIndependentInputs(stage_id, inputs_shape_, splittable_inputs, &sp_vector) != SUCCESS) { | |||
| MS_LOG(EXCEPTION) << name_ << ": GenerateStrategiesForIndependentInputs failed."; | |||
| std::vector<int64_t> ArgmaxInfo::reduce_dim() { | |||
| // get axis from attribution | |||
| std::vector<int64_t> dim_list; | |||
| auto iter = attrs_.find(AXIS); | |||
| if (iter == attrs_.end()) { | |||
| MS_LOG(EXCEPTION) << name_ << ": Don't have attribution axis."; | |||
| } | |||
| return sp_vector; | |||
| MS_ASSERT(inputs_shape_.size() == 1); | |||
| auto input_dim = inputs_shape_.at(0).size(); | |||
| MS_EXCEPTION_IF_NULL(iter->second); | |||
| if (iter->second->isa<ValueTuple>()) { | |||
| auto attr_axis = GetValue<std::vector<int64_t>>(iter->second); | |||
| if (attr_axis.empty()) { | |||
| for (size_t i = 0; i < input_dim; ++i) { | |||
| dim_list.push_back(SizeToLong(i)); | |||
| } | |||
| } else { | |||
| for (auto &axis : attr_axis) { | |||
| axis < 0 ? dim_list.push_back(axis + SizeToLong(input_dim)) : dim_list.push_back(axis); | |||
| } | |||
| } | |||
| } else if (iter->second->isa<Int64Imm>()) { | |||
| int64_t axis = GetValue<int64_t>(iter->second); | |||
| axis < 0 ? dim_list.push_back(axis + SizeToLong(input_dim)) : dim_list.push_back(axis); | |||
| } else { | |||
| MS_LOG(EXCEPTION) << "Axis type is invalid."; | |||
| } | |||
| return dim_list; | |||
| } | |||
| Status ArgmaxInfo::GetAttrs() { | |||
| // set the keep_dims False as default | |||
| keepdims_ = false; | |||
| // get attr output_type and cross_batch | |||
| auto output_type_iter = attrs_.find(OUTPUT_TYPE); | |||
| if (output_type_iter != attrs_.end()) { | |||
| MS_EXCEPTION_IF_NULL(output_type_iter->second); | |||
| } | |||
| auto cross_batch_iter = attrs_.find(CROSS_BATCH); | |||
| if (cross_batch_iter != attrs_.end()) { | |||
| MS_EXCEPTION_IF_NULL(cross_batch_iter->second); | |||
| if (!cross_batch_iter->second->isa<BoolImm>()) { | |||
| MS_LOG(ERROR) << name_ << ": cross_batch is not a bool."; | |||
| return FAILED; | |||
| } | |||
| cross_batch_ = cross_batch_iter->second->cast<BoolImmPtr>()->value(); | |||
| } | |||
| auto reducemethodcost = std::dynamic_pointer_cast<ReduceMethodCost>(operator_cost()); | |||
| if (reducemethodcost == nullptr) { | |||
| MS_LOG(ERROR) << "Cost cast to ReduceMethodCostPtr failed!"; | |||
| return FAILED; | |||
| } | |||
| reducemethodcost->set_cross_batch(cross_batch_); | |||
| return SUCCESS; | |||
| } | |||
| Status ArgmaxInfo::CheckStrategy(const StrategyPtr &strategy) { | |||
| if (ReduceMethod::CheckStrategy(strategy) != SUCCESS) { | |||
| MS_LOG(ERROR) << name_ << ": CheckStrategy for parent class ReduceMethod failed"; | |||
| return FAILED; | |||
| } | |||
| std::vector<int64_t> dim_list = reduce_dim(); | |||
| MS_ASSERT(dim_list.size() == 1); | |||
| Strategys stra = strategy->GetInputDim(); | |||
| MS_ASSERT(stra.size() == 1); | |||
| Shape input_strategy = stra.at(0); | |||
| MS_ASSERT(dim_list.at(0) < input_strategy.size()); | |||
| if (input_strategy.at(LongToSize(dim_list.at(0))) != 1) { | |||
| MS_LOG(WARNING) << name_ | |||
| << " CheckStrategy for Argmax/Argmin, the strategy corresponding to axis is not one, real strategy " | |||
| "is " | |||
| << input_strategy.at(LongToSize(dim_list.at(0))) | |||
| << ", the output index may be not compatible with the stand alone Primitive"; | |||
| } | |||
| return SUCCESS; | |||
| } | |||
| Status ArgmaxInfo::InferMirrorOps() { | |||
| if (OperatorInfo::InferMirrorOps() != SUCCESS) { | |||
| MS_LOG(ERROR) << name_ << ": InferMirrorOps for parent class OperatorInfo failed"; | |||
| return FAILED; | |||
| } | |||
| return SUCCESS; | |||
| } | |||
| std::vector<int64_t> SquareSumAllInfo::reduce_dim() { | |||
| std::vector<int64_t> dim_list; | |||
| auto input_dim = inputs_shape_.at(0).size(); | |||
| // reduce all dim | |||
| for (size_t i = 0; i < input_dim; ++i) { | |||
| dim_list.push_back(SizeToLong(i)); | |||
| } | |||
| return dim_list; | |||
| } | |||
| Status SquareSumAllInfo::GetAttrs() { | |||
| // set the keep_dims False as default | |||
| keepdims_ = false; | |||
| // get attr cross_batch | |||
| auto cross_batch_iter = attrs_.find(CROSS_BATCH); | |||
| if (cross_batch_iter != attrs_.end()) { | |||
| MS_EXCEPTION_IF_NULL(cross_batch_iter->second); | |||
| if (!cross_batch_iter->second->isa<BoolImm>()) { | |||
| MS_LOG(ERROR) << name_ << ": cross_batch is not a bool."; | |||
| return FAILED; | |||
| } | |||
| cross_batch_ = cross_batch_iter->second->cast<BoolImmPtr>()->value(); | |||
| } | |||
| auto reducemethodcost = std::dynamic_pointer_cast<ReduceMethodCost>(operator_cost()); | |||
| if (reducemethodcost == nullptr) { | |||
| MS_LOG(ERROR) << "Cost cast to ReduceMethodCostPtr failed!"; | |||
| return FAILED; | |||
| } | |||
| reducemethodcost->set_cross_batch(cross_batch_); | |||
| return SUCCESS; | |||
| } | |||
| Status SquareSumAllInfo::CheckStrategy(const StrategyPtr &strategy) { | |||
| if (CheckStrategyValue(strategy, inputs_shape_) != SUCCESS) { | |||
| MS_LOG(ERROR) << name_ << " : Invalid strategy."; | |||
| return FAILED; | |||
| } | |||
| Strategys stra = strategy->GetInputDim(); | |||
| Dimensions sub_a_strategy = stra.at(0); | |||
| Dimensions sub_b_strategy = stra.at(1); | |||
| Shape input_a_shape = inputs_shape_.at(0); | |||
| Shape input_b_shape = inputs_shape_.at(1); | |||
| MS_ASSERT(input_a_shape.size() == input_b_shape.size()); | |||
| MS_ASSERT(sub_a_strategy.size() == sub_b_strategy.size()); | |||
| for (size_t i = 0; i < input_a_shape.size(); ++i) { | |||
| if ((sub_a_strategy[i] != sub_b_strategy[i]) && (input_a_shape[i] != input_b_shape[i])) { | |||
| MS_LOG(ERROR) << name_ << " : Invalid strategy."; | |||
| return FAILED; | |||
| } | |||
| } | |||
| return SUCCESS; | |||
| } | |||
| Status SquareSumAllInfo::InferDevMatrixShape() { | |||
| Strategys strategy = strategy_->GetInputDim(); | |||
| Dimensions sub_a_strategy = strategy.at(0); | |||
| Shape dev_shape; | |||
| for (size_t i = 0; i < sub_a_strategy.size(); ++i) { | |||
| dev_shape.push_back(sub_a_strategy[i]); | |||
| } | |||
| dev_matrix_shape_ = dev_shape; | |||
| return SUCCESS; | |||
| } | |||
| Status SquareSumAllInfo::InferTensorMap() { | |||
| if (ReduceMethod::InferTensorMap() != SUCCESS) { | |||
| MS_LOG(ERROR) << name_ << ": InferTensorMap for parent class ReduceMethod failed"; | |||
| return FAILED; | |||
| } | |||
| MS_ASSERT(outputs_tensor_map_.size() == 1); | |||
| inputs_tensor_map_.push_back(inputs_tensor_map_[0]); | |||
| outputs_tensor_map_.push_back(outputs_tensor_map_[0]); | |||
| return SUCCESS; | |||
| } | |||
| Status SquareSumAllInfo::InferTensorInfo() { | |||
| // infer tensor shape | |||
| Shape input_shape = inputs_shape_.at(0); | |||
| Shape output_shape = outputs_shape_.at(0); | |||
| // infer slice shape | |||
| Shapes inputs_slice_shape, outputs_slice_shape; | |||
| Strategys inputs_strategy = strategy_->GetInputDim(); | |||
| Dimensions output_strategy = InferOutputStrategy(); | |||
| Strategys outputs_strategy = {output_strategy, output_strategy}; | |||
| if (InferSliceShape(inputs_strategy, outputs_strategy, &inputs_slice_shape, &outputs_slice_shape) != SUCCESS) { | |||
| return FAILED; | |||
| } | |||
| Shape input_slice_shape = inputs_slice_shape.at(0); | |||
| Shape output_slice_shape = outputs_slice_shape.at(0); | |||
| TensorLayout input_tensor_layout, output_tensor_layout; | |||
| if ((input_tensor_layout.InitFromVector(dev_matrix_shape_, inputs_tensor_map_[0], input_shape) != SUCCESS) || | |||
| (output_tensor_layout.InitFromVector(dev_matrix_shape_, outputs_tensor_map_[0], output_shape) != SUCCESS)) { | |||
| return FAILED; | |||
| } | |||
| std::vector<int64_t> dim_list = reduce_dim(); | |||
| TensorInfo input_tensor_info(input_tensor_layout, input_shape, input_slice_shape); | |||
| TensorInfo output_tensor_info(output_tensor_layout, output_shape, output_slice_shape); | |||
| input_tensor_info.set_reduce_dim(dim_list); | |||
| inputs_tensor_info_.push_back(input_tensor_info); | |||
| inputs_tensor_info_.push_back(input_tensor_info); | |||
| outputs_tensor_info_.push_back(output_tensor_info); | |||
| outputs_tensor_info_.push_back(output_tensor_info); | |||
| return SUCCESS; | |||
| } | |||
| Status SquareSumAllInfo::InferGroup() { | |||
| Dimensions stra = strategy_->GetInputDim().at(0); | |||
| forward_op_.clear(); | |||
| std::vector<int64_t> dim_list = reduce_dim(); | |||
| size_t size = stra.size(); | |||
| // judge if the reduce dim is partitioned. | |||
| Shape group_creat_map; | |||
| // if repeated calculation and the repeated_calc_num_ insert to the first dimension of dev matrix, | |||
| // it need to handle the first dimension of map. | |||
| if ((dev_matrix_shape_.size() > size) && !repeated_num_in_dev_matrix_right_) { | |||
| group_creat_map.push_back(SizeToInt(dev_matrix_shape_.size() - size_t(1))); | |||
| } | |||
| for (size_t index = 0; index < size; ++index) { | |||
| auto pos = | |||
| std::find_if(dim_list.begin(), dim_list.end(), [index](const int64_t &dim) { return SizeToLong(index) == dim; }); | |||
| if (pos != dim_list.end() && stra[index] != 1) { | |||
| continue; | |||
| } | |||
| group_creat_map.push_back(SizeToLong(size) - SizeToLong(index) - 1); | |||
| } | |||
| // if repeated calculation and the repeated_calc_num_ insert to the last dimension of dev matrix, | |||
| // it need to handle the group_creat_map and insert the 0 to the last dimension of the group_creat_map. | |||
| if (repeated_num_in_dev_matrix_right_ && (repeated_calc_num_ > 1)) { | |||
| for (auto &ele : group_creat_map) { | |||
| if (ele == MAP_NONE) { | |||
| continue; | |||
| } | |||
| ele += 1; | |||
| } | |||
| group_creat_map.push_back(0); | |||
| } | |||
| if (CreateGroupByTensorMap(group_creat_map, &group_) != SUCCESS) { | |||
| ReportError(name_ + ": Create group failed."); | |||
| return FAILED; | |||
| } | |||
| return SUCCESS; | |||
| } | |||
| Status SquareSumAllInfo::ComputeReplaceGraph(const CNodePtr &cnode) { | |||
| GenerateGraph gen_g = GenerateGraph(attrs_); | |||
| if (gen_g.Init(cnode) != SUCCESS) { | |||
| MS_LOG(ERROR) << name_ << ": GenerateGraph Init failed"; | |||
| } | |||
| if (InferGroup() != SUCCESS) { | |||
| MS_LOG(ERROR) << name_ << ": Infer Group failed"; | |||
| } | |||
| MS_LOG(INFO) << name_ << ": The rank is " << g_device_manager->rank_index_in_stage(); | |||
| auto square_sum_all = | |||
| gen_g.PushBack({gen_g.NewOpInst(SQUARE_SUM_ALL), gen_g.virtual_input_node(), gen_g.virtual_input_node()}); | |||
| auto get_item0 = gen_g.PushBack({gen_g.NewOpInst(TUPLE_GETITEM), square_sum_all, CreatInt64Imm(0)}); | |||
| auto get_item1 = gen_g.PushBack({gen_g.NewOpInst(TUPLE_GETITEM), square_sum_all, CreatInt64Imm(1)}); | |||
| Attr attr_op = std::make_pair(OP, MakeValue(REDUCE_OP_SUM)); | |||
| Attr attr_group = std::make_pair(GROUP, MakeValue(group_[0].name())); | |||
| OperatorAttrs attrs = {attr_op, attr_group}; | |||
| auto allreduce_op0 = gen_g.PushBack({gen_g.NewOpInst(ALL_REDUCE, attrs), get_item0}); | |||
| auto allreduce_op1 = gen_g.PushBack({gen_g.NewOpInst(ALL_REDUCE, attrs), get_item1}); | |||
| auto make_list = gen_g.PushBack({gen_g.NewOpInst(MAKE_LIST), allreduce_op0, allreduce_op1}); | |||
| std::vector<std::pair<AnfNodePtr, int64_t>> input_nodes = {std::make_pair(square_sum_all, 1), | |||
| std::make_pair(square_sum_all, 2)}; | |||
| replace_graph_ = std::make_shared<std::pair<std::vector<std::pair<AnfNodePtr, int64_t>>, AnfNodePtr>>( | |||
| std::make_pair(input_nodes, make_list)); | |||
| return SUCCESS; | |||
| } | |||
| ReplaceGraphPtr SquareSumAllInfo::replace_graph(const CNodePtr &cnode) { | |||
| if (ComputeReplaceGraph(cnode) != SUCCESS) { | |||
| MS_LOG(EXCEPTION) << name_ << ": ComputeReplaceGraph failed."; | |||
| } | |||
| return replace_graph_; | |||
| } | |||
| Status SquareSumAllInfo::InferMirrorOps() { | |||
| if (OperatorInfo::InferMirrorOps() != SUCCESS) { | |||
| MS_LOG(ERROR) << name_ << ": InferMirrorOps for parent class OperatorInfo failed"; | |||
| return FAILED; | |||
| } | |||
| return SUCCESS; | |||
| } | |||
| Status SquareSumAllInfo::InferAsLossDivisor() { | |||
| if (outputs_tensor_map_.empty()) { | |||
| MS_LOG(ERROR) << name_ << ": The outputs tensor map is empty."; | |||
| return FAILED; | |||
| } | |||
| MS_LOG(INFO) << name_ << " has two outputs, use output[0] to infer"; | |||
| if (outputs_tensor_map_[0].empty()) { | |||
| as_loss_divisor_ = stage_device_size_; | |||
| MS_LOG(INFO) << name_ << ": The output is a scalar, use the dev size" << as_loss_divisor_ << " as loss divisor."; | |||
| return SUCCESS; | |||
| } | |||
| as_loss_divisor_ = ComputeRepeatDeviceNumByTensorMap(dev_matrix_shape_, outputs_tensor_map_[0]); | |||
| std::string dev_matrix_shape_str = ShapeToString(dev_matrix_shape_); | |||
| std::string output_tensor_map_str = ShapeToString(outputs_tensor_map_[0]); | |||
| MS_LOG(INFO) << name_ << ": the dev matrix shape, the output tensor map, and loss divisor is " << dev_matrix_shape_str | |||
| << ", " << output_tensor_map_str << ", " << as_loss_divisor_; | |||
| return SUCCESS; | |||
| } | |||
| } // namespace parallel | |||
| } // namespace mindspore | |||
| @@ -76,8 +76,6 @@ class ArgMaxWithValueInfo : public ReduceMethod { | |||
| ~ArgMaxWithValueInfo() override = default; | |||
| std::vector<StrategyPtr> GenerateOpStrategies(int64_t stage_id) override; | |||
| protected: | |||
| std::vector<int64_t> reduce_dim() override; | |||
| Status CheckStrategy(const StrategyPtr &strategy) override; | |||
| @@ -167,6 +165,63 @@ class ReduceAllInfo : public ReduceAnyInfo { | |||
| ~ReduceAllInfo() override = default; | |||
| }; | |||
| class ArgmaxInfo : public ReduceMethod { | |||
| public: | |||
| ArgmaxInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, | |||
| const PrimitiveAttrs &attrs) | |||
| : ReduceMethod(name, inputs_shape, outputs_shape, attrs, std::make_shared<ArgmaxCost>()) { | |||
| reduce_method_ = REDUCE_OP_MAX; | |||
| } | |||
| ~ArgmaxInfo() override = default; | |||
| protected: | |||
| std::vector<int64_t> reduce_dim() override; | |||
| Status GetAttrs() override; | |||
| Status CheckStrategy(const StrategyPtr &strategy) override; | |||
| Status InferMirrorOps() override; | |||
| }; | |||
| class ArgminInfo : public ArgmaxInfo { | |||
| public: | |||
| ArgminInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, | |||
| const PrimitiveAttrs &attrs) | |||
| : ArgmaxInfo(name, inputs_shape, outputs_shape, attrs) { | |||
| reduce_method_ = REDUCE_OP_MIN; | |||
| } | |||
| ~ArgminInfo() override = default; | |||
| }; | |||
| class SquareSumAllInfo : public ReduceMethod { | |||
| public: | |||
| SquareSumAllInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, | |||
| const PrimitiveAttrs &attrs) | |||
| : ReduceMethod(name, inputs_shape, outputs_shape, attrs, std::make_shared<SquareSumAllCost>()) { | |||
| reduce_method_ = REDUCE_OP_SUM; | |||
| } | |||
| ~SquareSumAllInfo() override = default; | |||
| ReplaceGraphPtr replace_graph(const CNodePtr &cnode) override; | |||
| protected: | |||
| std::vector<int64_t> reduce_dim() override; | |||
| Status GetAttrs() override; | |||
| Status CheckStrategy(const StrategyPtr &strategy) override; | |||
| Status InferDevMatrixShape() override; | |||
| Status InferTensorMap() override; | |||
| Status InferTensorInfo() override; | |||
| Status InferForwardCommunication() override { return SUCCESS; } | |||
| Status InferMirrorOps() override; | |||
| Status InferAsLossDivisor() override; | |||
| private: | |||
| Status InferGroup(); | |||
| Status ComputeReplaceGraph(const CNodePtr &cnode); | |||
| std::vector<Group> group_; | |||
| }; | |||
| } // namespace parallel | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_REDUCE_SUM_INFO_H_ | |||
| @@ -205,7 +205,7 @@ Status UnsortedSegmentOpInfo::InferForwardCommunication() { | |||
| return SUCCESS; | |||
| } | |||
| Operator op = CreateAllReduceOp(REDUCE_OP_SUM, group_list[0].name()); | |||
| Operator op = CreateAllReduceOp(reduce_method_, group_list[0].name()); | |||
| forward_op_.push_back(op); | |||
| MS_LOG(INFO) << name_ << " : The group name of forward communication is " << group_list[0].name(); | |||
| return SUCCESS; | |||
| @@ -50,6 +50,7 @@ class UnsortedSegmentOpInfo : public OperatorInfo { | |||
| std::shared_ptr<Strategys> GenerateBatchStrategies() override; | |||
| protected: | |||
| std::string reduce_method_; | |||
| Status CheckStrategy(const StrategyPtr &strategy) override; | |||
| Status InferForwardCommunication() override; | |||
| Status InferMirrorOps() override; | |||
| @@ -65,15 +66,29 @@ class UnsortedSegmentSumInfo : public UnsortedSegmentOpInfo { | |||
| public: | |||
| UnsortedSegmentSumInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, | |||
| const PrimitiveAttrs &attrs) | |||
| : UnsortedSegmentOpInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared<UnsortedSegmentSumCost>()) {} | |||
| : UnsortedSegmentOpInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared<UnsortedSegmentSumCost>()) { | |||
| reduce_method_ = REDUCE_OP_SUM; | |||
| } | |||
| ~UnsortedSegmentSumInfo() override = default; | |||
| }; | |||
| class UnsortedSegmentProdInfo : public UnsortedSegmentOpInfo { | |||
| public: | |||
| UnsortedSegmentProdInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, | |||
| const PrimitiveAttrs &attrs) | |||
| : UnsortedSegmentOpInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared<UnsortedSegmentProdCost>()) { | |||
| reduce_method_ = REDUCE_OP_PROD; | |||
| } | |||
| ~UnsortedSegmentProdInfo() override = default; | |||
| }; | |||
| class UnsortedSegmentMinInfo : public UnsortedSegmentOpInfo { | |||
| public: | |||
| UnsortedSegmentMinInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, | |||
| const PrimitiveAttrs &attrs) | |||
| : UnsortedSegmentOpInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared<UnsortedSegmentMinCost>()) {} | |||
| : UnsortedSegmentOpInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared<UnsortedSegmentMinCost>()) { | |||
| reduce_method_ = REDUCE_OP_MIN; | |||
| } | |||
| ~UnsortedSegmentMinInfo() override = default; | |||
| ReplaceGraphPtr replace_graph(const CNodePtr &cnode) override; | |||
| @@ -86,7 +101,9 @@ class UnsortedSegmentMaxInfo : public UnsortedSegmentOpInfo { | |||
| public: | |||
| UnsortedSegmentMaxInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, | |||
| const PrimitiveAttrs &attrs) | |||
| : UnsortedSegmentOpInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared<UnsortedSegmentMaxCost>()) {} | |||
| : UnsortedSegmentOpInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared<UnsortedSegmentMaxCost>()) { | |||
| reduce_method_ = REDUCE_OP_MAX; | |||
| } | |||
| ~UnsortedSegmentMaxInfo() override = default; | |||
| ReplaceGraphPtr replace_graph(const CNodePtr &cnode) override; | |||
| @@ -174,7 +174,8 @@ bool IsSplittableOperator(const std::string &op_name) { | |||
| UNSORTED_SEGMENT_MIN, REPEAT_ELEMENTS, TENSOR_DOT, RANGE, UNIFORM_CANDIDATE_SAMPLER, SLICE, SELECT, GATHERD, | |||
| UNSORTED_SEGMENT_MAX, GATHER_ND, TOPK, SCATTER_UPDATE, VIRTUAL_OUTPUT, CONV2D_BACK_PROP_INPUT, CONV2D_TRANSPOSE, | |||
| MATMUL_DDS, DSD_MATMUL, UNIFORMREAL, RESIZE_BILINEAR, RESIZE_NEAREST_NEIGHBOR, CUMSUM, FAST_GELU, IOU, | |||
| BOUNDING_BOX_ENCODE, RANDOM_CHOICE_WITH_MASK, CROP_AND_RESIZE, ROI_ALIGN, REDUCE_PROD, REDUCE_ANY, REDUCE_ALL}; | |||
| BOUNDING_BOX_ENCODE, RANDOM_CHOICE_WITH_MASK, CROP_AND_RESIZE, ROI_ALIGN, REDUCE_PROD, REDUCE_ANY, REDUCE_ALL, | |||
| ARGMAX, ARGMIN, UNSORTED_SEGMENT_PROD, SQUARE_SUM_ALL}; | |||
| // clang-format on | |||
| auto iter = splittable_op.find(op_name); | |||
| @@ -415,6 +415,29 @@ class ArgMinWithValueNet(nn.Cell): | |||
| out = self.mul2(out, b) | |||
| return out | |||
| class ArgMaxNet(nn.Cell): | |||
| def __init__(self, strategy1, strategy2): | |||
| super(ArgMaxNet, self).__init__() | |||
| self.mul1 = P.Mul().shard(strategy1) | |||
| self.arg_max = P.Argmax(axis=-1).shard(strategy2) | |||
| def construct(self, x, y): | |||
| out = self.mul1(x, y) | |||
| out = self.arg_max(out) | |||
| return out | |||
| class ArgMinNet(nn.Cell): | |||
| def __init__(self, strategy1, strategy2): | |||
| super(ArgMinNet, self).__init__() | |||
| self.mul1 = P.Mul().shard(strategy1) | |||
| self.arg_min = P.Argmin(axis=-1).shard(strategy2) | |||
| def construct(self, x, y): | |||
| out = self.mul1(x, y) | |||
| out = self.arg_min(out) | |||
| return out | |||
| def gen_inputs_and_compile_net(net): | |||
| x = Tensor(np.ones([128, 64, 64]), dtype=ms.float32) | |||
| @@ -514,6 +537,90 @@ def test_arg_min_with_value_mul_auto(): | |||
| gen_inputs_and_compile_net(net) | |||
| def test_arg_max_semi_axis_parallel(): | |||
| """ | |||
| Feature: test Argmax semi parallel strategy | |||
| Description: partition the reduced axes | |||
| Expectation: compile success | |||
| """ | |||
| context.set_auto_parallel_context(device_num=8, global_rank=0) | |||
| strategy1 = ((1, 4, 2), (1, 4, 2)) | |||
| strategy2 = ((4, 1, 2),) | |||
| net = GradWrapNoBias(NetWithLossNoBias(ArgMaxNet(strategy1, strategy2))) | |||
| context.set_auto_parallel_context(parallel_mode="semi_auto_parallel") | |||
| gen_inputs_and_compile_net_no_bias(net) | |||
| def test_arg_max_mul_semi(): | |||
| """ | |||
| Feature: test Argmax model parallel strategy | |||
| Description: partition the non-reduced axes | |||
| Expectation: compile success | |||
| """ | |||
| context.set_auto_parallel_context(device_num=8, global_rank=0) | |||
| strategy1 = ((1, 4, 2), (1, 4, 2)) | |||
| strategy2 = ((4, 2, 1),) | |||
| net = GradWrapNoBias(NetWithLossNoBias(ArgMaxNet(strategy1, strategy2))) | |||
| context.set_auto_parallel_context(parallel_mode="semi_auto_parallel") | |||
| gen_inputs_and_compile_net_no_bias(net) | |||
| def test_arg_max_mul_auto(): | |||
| """ | |||
| Feature: test Argmax auto parallel strategy | |||
| Description: don't set the strategy | |||
| Expectation: compile success | |||
| """ | |||
| context.set_auto_parallel_context(device_num=8, global_rank=0) | |||
| strategy1 = None | |||
| strategy2 = None | |||
| net = GradWrapNoBias(NetWithLossNoBias(ArgMaxNet(strategy1, strategy2))) | |||
| context.set_auto_parallel_context(parallel_mode="auto_parallel") | |||
| gen_inputs_and_compile_net_no_bias(net) | |||
| def test_arg_min_semi_axis_parallel(): | |||
| """ | |||
| Feature: test Argmin semi parallel strategy | |||
| Description: partition the reduced axes | |||
| Expectation: compile success | |||
| """ | |||
| context.set_auto_parallel_context(device_num=8, global_rank=0) | |||
| strategy1 = ((1, 4, 2), (1, 4, 2)) | |||
| strategy2 = ((4, 1, 2),) | |||
| net = GradWrapNoBias(NetWithLossNoBias(ArgMinNet(strategy1, strategy2))) | |||
| context.set_auto_parallel_context(parallel_mode="semi_auto_parallel") | |||
| gen_inputs_and_compile_net_no_bias(net) | |||
| def test_arg_min_mul_semi(): | |||
| """ | |||
| Feature: test Argmin model parallel strategy | |||
| Description: partition the non-reduced axes | |||
| Expectation: compile success | |||
| """ | |||
| context.set_auto_parallel_context(device_num=8, global_rank=0) | |||
| strategy1 = ((1, 4, 2), (1, 4, 2)) | |||
| strategy2 = ((4, 2, 1),) | |||
| net = GradWrapNoBias(NetWithLossNoBias(ArgMinNet(strategy1, strategy2))) | |||
| context.set_auto_parallel_context(parallel_mode="semi_auto_parallel") | |||
| gen_inputs_and_compile_net_no_bias(net) | |||
| def test_arg_min_mul_auto(): | |||
| """ | |||
| Feature: test Argmin auto parallel strategy | |||
| Description: don't set the strategy | |||
| Expectation: compile success | |||
| """ | |||
| context.set_auto_parallel_context(device_num=8, global_rank=0) | |||
| strategy1 = None | |||
| strategy2 = None | |||
| net = GradWrapNoBias(NetWithLossNoBias(ArgMinNet(strategy1, strategy2))) | |||
| context.set_auto_parallel_context(parallel_mode="auto_parallel") | |||
| gen_inputs_and_compile_net_no_bias(net) | |||
| class ArgMinWithValueNet2(nn.Cell): | |||
| def __init__(self, strategy1, strategy2, strategy3): | |||
| super(ArgMinWithValueNet2, self).__init__() | |||
| @@ -915,3 +1022,59 @@ def test_prod_mul_auto(): | |||
| net = GradWrapNoBias(NetWithLossNoBias(Net(strategy1, strategy2))) | |||
| context.set_auto_parallel_context(parallel_mode="auto_parallel") | |||
| gen_inputs_and_compile_net_no_bias(net) | |||
| def test_square_sum_all_mul(): | |||
| """ | |||
| Feature: test SquareSumAll model parallel strategy | |||
| Description: partition the reduced axes | |||
| Expectation: compile success | |||
| """ | |||
| class Net(nn.Cell): | |||
| def __init__(self, strategy1, strategy2): | |||
| super(Net, self).__init__() | |||
| self.mul1 = P.Mul().shard(strategy1) | |||
| self.square_sum_all = P.SquareSumAll().shard(strategy2) | |||
| def construct(self, x, y): | |||
| out = self.mul1(x, y) | |||
| out = self.square_sum_all(out, out) | |||
| return out | |||
| context.set_auto_parallel_context(device_num=8, global_rank=0) | |||
| strategy1 = ((1, 1, 8), (1, 1, 8)) | |||
| strategy2 = ((2, 4, 1), (2, 4, 1)) | |||
| net = Net(strategy1, strategy2) | |||
| context.set_auto_parallel_context(parallel_mode="semi_auto_parallel") | |||
| x = Tensor(np.ones([128, 32, 64]), dtype=ms.float32) | |||
| y = Tensor(np.ones([128, 32, 64]), dtype=ms.float32) | |||
| compile_net_no_bias(net, x, y) | |||
| def test_square_sum_all_mul2(): | |||
| """ | |||
| Feature: test SquareSumAll model parallel strategy | |||
| Description: partition the reduced axes | |||
| Expectation: compile success | |||
| """ | |||
| class Net(nn.Cell): | |||
| def __init__(self, stra_mul, stra_prod): | |||
| super(Net, self).__init__() | |||
| self.mul = P.Mul().shard(stra_mul) | |||
| self.square_sum_all = P.SquareSumAll().shard(stra_prod) | |||
| def construct(self, x, y): | |||
| out = self.mul(x, y) | |||
| out = self.square_sum_all(out, out) | |||
| return out | |||
| context.set_auto_parallel_context(device_num=8, global_rank=0) | |||
| strategy1 = ((1, 1, 8), (1, 1, 8)) | |||
| strategy2 = ((8, 1, 1), (8, 1, 1)) | |||
| net = Net(strategy1, strategy2) | |||
| context.set_auto_parallel_context(parallel_mode="semi_auto_parallel") | |||
| x = Tensor(np.ones([128, 32, 64]), dtype=ms.float32) | |||
| y = Tensor(np.ones([128, 32, 64]), dtype=ms.float32) | |||
| compile_net_no_bias(net, x, y) | |||
| @@ -0,0 +1,195 @@ | |||
| # Copyright 2022 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| import numpy as np | |||
| import pytest | |||
| import mindspore as ms | |||
| import mindspore.nn as nn | |||
| from mindspore import Tensor | |||
| from mindspore import context | |||
| from mindspore.common.api import _cell_graph_executor | |||
| from mindspore.ops import composite as C | |||
| from mindspore.ops import operations as P | |||
| from tests.ut.python.ops.test_math_ops import VirtualLoss | |||
| context.set_context(mode=context.GRAPH_MODE) | |||
| grad_all = C.GradOperation(get_all=True) | |||
| class Net(nn.Cell): | |||
| def __init__(self, strategy1, strategy2, num_segments): | |||
| super(Net, self).__init__() | |||
| self.merge_op = P.UnsortedSegmentProd().shard((strategy1, strategy2)) | |||
| self.num_segments = num_segments | |||
| def construct(self, vectors, segment_ids): | |||
| predict = self.merge_op(vectors, segment_ids, self.num_segments) | |||
| return predict | |||
| class GradWrap(nn.Cell): | |||
| def __init__(self, network): | |||
| super(GradWrap, self).__init__() | |||
| self.network = network | |||
| def construct(self, x, y): | |||
| return grad_all(self.network)(x, y) | |||
| class NetWithLoss(nn.Cell): | |||
| def __init__(self, network): | |||
| super(NetWithLoss, self).__init__() | |||
| self.network = network | |||
| self.loss = VirtualLoss() | |||
| def construct(self, x, y): | |||
| predict = self.network(x, y) | |||
| return self.loss(predict) | |||
| def compile_graph(x, y, segments, strategy1, strategy2, auto=False): | |||
| if auto: | |||
| context.set_auto_parallel_context(parallel_mode="auto_parallel") | |||
| else: | |||
| context.set_auto_parallel_context(parallel_mode="semi_auto_parallel") | |||
| net = GradWrap(NetWithLoss(Net(strategy1, strategy2, segments))) | |||
| net.set_auto_parallel() | |||
| net.set_train() | |||
| _cell_graph_executor.compile(net, x, y) | |||
| def test_unsortedsegmentprod_model_parallel_slice_1d(): | |||
| """ | |||
| Feature: distribute operator unsorted_segment_prod in auto parallel. | |||
| Description: unsorted_segment_prod net with model parallel strategy, slice 1d. | |||
| Expectation: compile done without error. | |||
| """ | |||
| context.set_auto_parallel_context(device_num=8, global_rank=0) | |||
| x = Tensor(np.ones(8), ms.float32) | |||
| y = Tensor(np.ones(8), ms.int32) | |||
| num_segments = 16 | |||
| strategy1 = (8,) | |||
| strategy2 = (8,) | |||
| compile_graph(x, y, num_segments, strategy1, strategy2) | |||
| def test_unsortedsegmentprod_model_parallel_no_slice_1d(): | |||
| """ | |||
| Feature: distribute operator unsorted_segment_prod in auto parallel. | |||
| Description: unsorted_segment_prod net with no slice strategy in semi auto parallel, slice 1d. | |||
| Expectation: compile done without error. | |||
| """ | |||
| context.set_auto_parallel_context(device_num=8, global_rank=0) | |||
| x = Tensor(np.ones(8), ms.float32) | |||
| y = Tensor(np.ones(8), ms.int32) | |||
| num_segments = 16 | |||
| strategy1 = (1,) | |||
| strategy2 = (1,) | |||
| compile_graph(x, y, num_segments, strategy1, strategy2) | |||
| def test_unsortedsegmentprod_model_parallel_index_slice_2d(): | |||
| """ | |||
| Feature: distribute operator unsorted_segment_prod in auto parallel. | |||
| Description: unsorted_segment_prod net with model parallel strategy, slice 2d. | |||
| Expectation: compile done without error. | |||
| """ | |||
| context.set_auto_parallel_context(device_num=4, global_rank=0) | |||
| x = Tensor(np.ones((4, 8)), ms.float32) | |||
| y = Tensor(np.arange(4), ms.int32) | |||
| num_segments = 4 | |||
| strategy1 = (4, 1) | |||
| strategy2 = (4,) | |||
| compile_graph(x, y, num_segments, strategy1, strategy2) | |||
| def test_unsortedsegmentprod_model_parallel_index_slice_3d(): | |||
| """ | |||
| Feature: distribute operator unsorted_segment_prod in auto parallel. | |||
| Description: unsorted_segment_prod net with model parallel strategy, slice 3d. | |||
| Expectation: compile done without error. | |||
| """ | |||
| context.set_auto_parallel_context(device_num=4, global_rank=0) | |||
| x = Tensor(np.ones((4, 4, 8)), ms.float32) | |||
| y = Tensor(np.ones((4, 4)), ms.int32) | |||
| num_segments = 16 | |||
| strategy1 = (2, 2, 1) | |||
| strategy2 = (2, 2) | |||
| with pytest.raises(ValueError): | |||
| compile_graph(x, y, num_segments, strategy1, strategy2) | |||
| def test_unsortedsegmentprod_model_parallel_vector_slice_2d(): | |||
| """ | |||
| Feature: distribute operator unsorted_segment_prod in auto parallel. | |||
| Description: unsorted_segment_prod net with model parallel strategy, slice 2d. | |||
| Expectation: compile done without error. | |||
| """ | |||
| context.set_auto_parallel_context(device_num=4, global_rank=0) | |||
| x = Tensor(np.ones((4, 8)), ms.float32) | |||
| y = Tensor(np.ones(4), ms.int32) | |||
| num_segments = 4 | |||
| strategy1 = (1, 4) | |||
| strategy2 = (1,) | |||
| compile_graph(x, y, num_segments, strategy1, strategy2) | |||
| def test_unsortedsegmentprod_model_parallel_vector_slice_3d(): | |||
| """ | |||
| Feature: distribute operator unsorted_segment_prod in auto parallel. | |||
| Description: unsorted_segment_prod net with model parallel, slice 3d. | |||
| Expectation: compile done without error. | |||
| """ | |||
| context.set_auto_parallel_context(device_num=4, global_rank=0) | |||
| x = Tensor(np.ones((4, 8, 8)), ms.float32) | |||
| y = Tensor(np.ones(4), ms.int32) | |||
| num_segments = 4 | |||
| strategy1 = (1, 2, 2) | |||
| strategy2 = (1,) | |||
| compile_graph(x, y, num_segments, strategy1, strategy2) | |||
| def test_unsortedsegmentprod_model_parallel_index_vector_slice_2d(): | |||
| """ | |||
| Feature: distribute operator unsorted_segment_prod in auto parallel. | |||
| Description: unsorted_segment_prod net with strategy in semi auto parallel, slice 2d. | |||
| Expectation: compile done without error. | |||
| """ | |||
| context.set_auto_parallel_context(device_num=4, global_rank=0) | |||
| x = Tensor(np.ones((4, 8)), ms.float32) | |||
| y = Tensor(np.ones(4), ms.int32) | |||
| num_segments = 4 | |||
| strategy1 = (2, 2) | |||
| strategy2 = (2,) | |||
| compile_graph(x, y, num_segments, strategy1, strategy2) | |||
| def test_unsortedsegmentprod_model_parallel_index_vector_slice_3d(): | |||
| """ | |||
| Feature: distribute operator unsorted_segment_prod in auto parallel. | |||
| Description: unsorted_segment_prod net with strategy in semi auto parallel, slice 3d. | |||
| Expectation: compile done without error. | |||
| """ | |||
| context.set_auto_parallel_context(device_num=4, global_rank=0) | |||
| x = Tensor(np.ones((4, 4, 8)), ms.float32) | |||
| y = Tensor(np.ones((4, 4)), ms.int32) | |||
| num_segments = 16 | |||
| strategy1 = (2, 1, 2) | |||
| strategy2 = (2, 1) | |||
| with pytest.raises(ValueError): | |||
| compile_graph(x, y, num_segments, strategy1, strategy2) | |||