| @@ -54,20 +54,39 @@ Status BatchParallelInfo::CheckStrategy(const StrategyPtr &strategy) { | |||
| Status BatchParallelInfo::InferDevMatrixShape() { | |||
| dev_matrix_shape_.push_back(stage_device_size_); | |||
| if (need_replace_input_ && !inputs_shape_.empty()) { | |||
| replace_shape_ = inputs_shape_[0]; | |||
| if (!replace_shape_.empty()) { | |||
| replace_shape_[0] /= stage_device_size_; | |||
| } | |||
| } | |||
| return SUCCESS; | |||
| } | |||
| Status BatchParallelInfo::InferForwardCommunication() { return SUCCESS; } | |||
| Status BatchParallelInfo::InferTensorMap() { | |||
| if (strategy_->GetInputDim()[0][0] != stage_device_size_) { | |||
| MS_LOG(ERROR) << name_ << " : It is not a valid data parallel strategy."; | |||
| auto strategy = strategy_->GetInputDim(); | |||
| if (strategy.empty()) { | |||
| MS_LOG(INFO) << name_ << ": the strategy is empty"; | |||
| return SUCCESS; | |||
| } | |||
| if (strategy[0].empty()) { | |||
| MS_LOG(INFO) << name_ << ": the first element of strategy is empty"; | |||
| return FAILED; | |||
| } | |||
| if (strategy[0][0] != stage_device_size_) { | |||
| MS_LOG(ERROR) << name_ << ": It is not a valid data parallel strategy."; | |||
| return FAILED; | |||
| } | |||
| for (size_t i = 0; i < inputs_shape_.size(); i++) { | |||
| Shape tensor_map_index; | |||
| for (size_t j = 0; j < inputs_shape_[i].size(); ++j) { | |||
| if (strategy_->GetInputDim()[i][j] == stage_device_size_ && j == 0) { | |||
| if (strategy[i][j] == stage_device_size_ && j == 0) { | |||
| tensor_map_index.push_back(0); | |||
| } else { | |||
| tensor_map_index.push_back(MAP_NONE); | |||
| @@ -89,25 +108,23 @@ Status BatchParallelInfo::InferTensorMap() { | |||
| return SUCCESS; | |||
| } | |||
| Strategys BatchParallelInfo::GetOutputsStrategy() { | |||
| Strategys outputs_strategy; | |||
| Status BatchParallelInfo::GetAttrs() { | |||
| // if the operator's input is a shape(is not a tensor), need to assign the shape value to inputs_shape_ | |||
| if (!inputs_shape_.empty()) { | |||
| return SUCCESS; | |||
| } | |||
| for (size_t i = 0; i < outputs_shape_.size(); ++i) { | |||
| Dimensions strategy; | |||
| for (size_t j = 0; j < outputs_shape_[i].size(); ++j) { | |||
| if (i == 0 && j == 0) { | |||
| strategy.push_back(stage_device_size_); | |||
| } else { | |||
| strategy.push_back(1); | |||
| } | |||
| } | |||
| outputs_strategy.push_back(strategy); | |||
| if (input_value_.empty()) { | |||
| return SUCCESS; | |||
| } | |||
| return outputs_strategy; | |||
| } | |||
| auto shape_ptr = input_value_[0]->cast<ValueTuplePtr>(); | |||
| MS_EXCEPTION_IF_NULL(shape_ptr); | |||
| Status BatchParallelInfo::GetAttrs() { return SUCCESS; } | |||
| inputs_shape_.push_back(GetValue<Shape>(shape_ptr)); | |||
| need_replace_input_ = true; | |||
| return SUCCESS; | |||
| } | |||
| Status BatchParallelInfo::Init(const StrategyPtr &strategy) { | |||
| if (InitWithAutoRepeatCalc(strategy) != SUCCESS) { | |||
| @@ -158,5 +175,30 @@ Status BatchParallelInfo::InferAsLossDivisor() { | |||
| as_loss_divisor_ = 1; | |||
| return SUCCESS; | |||
| } | |||
| void BatchParallelInfo::ReplaceNodeInputOrAttrs() { | |||
| if (!need_replace_input_) { | |||
| return; | |||
| } | |||
| auto cnode = cnode_; | |||
| MS_EXCEPTION_IF_NULL(cnode); | |||
| if (cnode->size() != 2) { | |||
| MS_LOG(EXCEPTION) << name_ << ": The size of tile cnode's inputs must be 2"; | |||
| } | |||
| if (!IsValueNode<ValueTuple>(cnode->input(1))) { | |||
| MS_LOG(EXCEPTION) << name_ << ": The input[1] of tile cnode is not ValueTuple."; | |||
| } | |||
| auto func_graph = cnode->func_graph(); | |||
| MS_EXCEPTION_IF_NULL(func_graph); | |||
| auto manager = func_graph->manager(); | |||
| MS_EXCEPTION_IF_NULL(manager); | |||
| ValuePtr replace_shape = MakeValue(replace_shape_); | |||
| AnfNodePtr val = NewValueNode(replace_shape); | |||
| (void)manager->Replace(cnode->input(1), val); | |||
| } | |||
| } // namespace parallel | |||
| } // namespace mindspore | |||
| @@ -41,6 +41,7 @@ class BatchParallelInfo : public OperatorInfo { | |||
| Status InitForCostModel(const StrategyPtr &strategy) override; | |||
| std::vector<StrategyPtr> GenerateOpStrategies(int64_t stage_id) override; | |||
| Status SetCostUnderStrategy(const StrategyPtr &strategy) override; | |||
| void ReplaceNodeInputOrAttrs() override; | |||
| protected: | |||
| Status CheckStrategy(const StrategyPtr &strategy) override; | |||
| @@ -48,11 +49,12 @@ class BatchParallelInfo : public OperatorInfo { | |||
| Status InferDevMatrixShape() override; | |||
| Status InferTensorMap() override; | |||
| Status GetAttrs() override; | |||
| Strategys GetOutputsStrategy(); | |||
| Status InferAsLossDivisor() override; | |||
| private: | |||
| int64_t dev_num_; | |||
| int64_t dev_num_ = 1; | |||
| bool need_replace_input_ = false; | |||
| Shape replace_shape_; | |||
| }; | |||
| class SparseSoftmaxCrossEntropyWithLogitsInfo : public BatchParallelInfo { | |||
| @@ -17,6 +17,7 @@ | |||
| #include "frontend/parallel/ops_info/conv2d_info.h" | |||
| #include <algorithm> | |||
| #include <functional> | |||
| #include <cmath> | |||
| #include <memory> | |||
| #include <utility> | |||
| @@ -1041,7 +1042,8 @@ Status Conv2DBackpropInputInfo::InferMirrorOps() { | |||
| return SUCCESS; | |||
| } | |||
| void Conv2DBackpropInputInfo::UpdateOutShape(const CNodePtr &cnode) { | |||
| void Conv2DBackpropInputInfo::UpdateOutShape() { | |||
| auto cnode = cnode_; | |||
| MS_EXCEPTION_IF_NULL(cnode); | |||
| if (cnode->size() != 4) { | |||
| MS_LOG(EXCEPTION) << name_ << ": The size of cnode's inputs must be 4, but got " << cnode->size(); | |||
| @@ -1191,5 +1193,7 @@ void Conv2DBackpropInputInfo::InferNewPadList() { | |||
| MS_LOG(INFO) << name_ << ": the new pad list is " << new_pad_list_ << ", the required size of current rank is " | |||
| << current_rank_required_size << ", new pad all is " << pad_all; | |||
| } | |||
| void Conv2DBackpropInputInfo::ReplaceNodeInputOrAttrs() { UpdateOutShape(); } | |||
| } // namespace parallel | |||
| } // namespace mindspore | |||
| @@ -127,7 +127,8 @@ class Conv2DBackpropInputInfo : public Conv2DInfo { | |||
| const PrimitiveAttrs &attrs) | |||
| : Conv2DInfo(name, inputs_shape, outputs_shape, attrs) {} | |||
| ~Conv2DBackpropInputInfo() override = default; | |||
| void UpdateOutShape(const CNodePtr &cnode); | |||
| void UpdateOutShape(); | |||
| void ReplaceNodeInputOrAttrs() override; | |||
| protected: | |||
| Status GetAttrs() override; | |||
| @@ -24,6 +24,8 @@ | |||
| #include "ir/value.h" | |||
| #include "pipeline/jit/resource.h" | |||
| #include "frontend/parallel/auto_parallel/costmodel.h" | |||
| #include "frontend/parallel/graph_util/node_info.h" | |||
| #include "frontend/parallel/step_parallel_utils.h" | |||
| #include "frontend/parallel/device_matrix.h" | |||
| #include "frontend/parallel/strategy.h" | |||
| @@ -213,7 +215,8 @@ void SetGenMaskShape(const CNodePtr &cnode, const Shape &input_slice_shape) { | |||
| // split. Find the DropoutGenMask node in the anf graph according to DropoutDoMask node, and modify the input shape | |||
| // of DropoutGenMask according to the strategy of DropoutDoMask. When the DropoutDoMask performs repeated calculation | |||
| // and both seeds of DropoutGenMask are 0, two new seeds are automatically generated for DropoutGenMask. | |||
| std::vector<Operator> DropoutDoMaskInfo::GetDropoutGenMaskReplaceOp(const CNodePtr &cnode) { | |||
| std::vector<Operator> DropoutDoMaskInfo::GetDropoutGenMaskReplaceOp() { | |||
| auto cnode = cnode_; | |||
| std::vector<Operator> replace_ops; | |||
| MS_EXCEPTION_IF_NULL(cnode); | |||
| PrimitivePtr prim = GetDropoutGenMaskPrim(cnode); | |||
| @@ -262,5 +265,46 @@ std::vector<Operator> DropoutDoMaskInfo::GetDropoutGenMaskReplaceOp(const CNodeP | |||
| replace_ops.push_back(replace_op); | |||
| return replace_ops; | |||
| } | |||
| static void ReplaceOneOp(const Operator &replace_op, const CNodePtr &node) { | |||
| FuncGraphPtr func_graph = node->func_graph(); | |||
| MS_EXCEPTION_IF_NULL(func_graph); | |||
| FuncGraphManagerPtr manager = func_graph->manager(); | |||
| if (manager == nullptr) { | |||
| MS_LOG(EXCEPTION) << "Failure:AddNode error since manager is nullptr"; | |||
| } | |||
| std::string instance_name = CreateInstanceName(node, 0); | |||
| std::vector<AnfNodePtr> replace_input; | |||
| replace_input = ReplaceOpInput(replace_op, instance_name, node); | |||
| if (node->inputs().size() == DROPOUT_DO_MASK_CNODE_INPUT_SIZE) { | |||
| replace_input.push_back(node->input(3)); | |||
| } | |||
| CNodePtr replace_node = func_graph->NewCNode(replace_input); | |||
| MS_EXCEPTION_IF_NULL(replace_node); | |||
| ScopePtr scope = node->scope(); | |||
| MS_EXCEPTION_IF_NULL(scope); | |||
| replace_node->set_scope(scope); | |||
| replace_node->set_in_forward_flag(true); | |||
| replace_input[0]->set_scope(scope); | |||
| PrimitivePtr prim = GetValueNode<PrimitivePtr>(replace_node->input(0)); | |||
| PrimitivePtr origin_prim = GetValueNode<PrimitivePtr>(node->input(0)); | |||
| SetUserAttrs(origin_prim->attrs(), prim); | |||
| (void)manager->Replace(node, replace_node); | |||
| } | |||
| void DropoutDoMaskInfo::ReplaceNodeInputOrAttrs() { | |||
| auto cnode = cnode_; | |||
| MS_EXCEPTION_IF_NULL(cnode); | |||
| std::vector<Operator> replace_op = GetDropoutGenMaskReplaceOp(); | |||
| if (replace_op.empty()) { | |||
| MS_LOG(DEBUG) << name_ << ": No need to replace dropout_gen_mask"; | |||
| return; | |||
| } | |||
| if (cnode->inputs().size() != DROPOUT_DO_MASK_CNODE_INPUT_SIZE) { | |||
| MS_LOG(EXCEPTION) << name_ << ": The size of drop out do mask cnode's input is not " | |||
| << DROPOUT_DO_MASK_CNODE_INPUT_SIZE; | |||
| } | |||
| ReplaceOneOp(replace_op[0], cnode->input(DROPOUT_GEN_MASK_INDEX)->cast<CNodePtr>()); | |||
| } | |||
| } // namespace parallel | |||
| } // namespace mindspore | |||
| @@ -41,7 +41,8 @@ class DropoutDoMaskInfo : public OperatorInfo { | |||
| Status SetCostUnderStrategy(const StrategyPtr &strategy) override; | |||
| Status InitForCostModel(const StrategyPtr &strategy) override; | |||
| std::shared_ptr<Strategys> GenerateBatchStrategies() override; | |||
| std::vector<Operator> GetDropoutGenMaskReplaceOp(const CNodePtr &cnode); | |||
| std::vector<Operator> GetDropoutGenMaskReplaceOp(); | |||
| void ReplaceNodeInputOrAttrs() override; | |||
| protected: | |||
| Status CheckStrategy(const StrategyPtr &strategy) override; | |||
| @@ -1022,6 +1022,9 @@ Status PrepareStrategyBase(int64_t stage_id, size_t dev_num, const Shapes &input | |||
| } | |||
| std::shared_ptr<Strategys> OperatorInfo::GenerateBatchStrategies() { | |||
| if (inputs_shape_.empty() && InferAttrs() != SUCCESS) { | |||
| MS_LOG(EXCEPTION) << name_ << ": Infer attrs failed"; | |||
| } | |||
| ComputeBatchSplitFlagList(); | |||
| return GenerateBatchStrategiesBySplitFlag(inputs_shape_, split_flag_list_); | |||
| } | |||
| @@ -182,6 +182,7 @@ class OperatorInfo { | |||
| int32_t stage_id() const { return stage_id_; } | |||
| Status CreateGroupByTensorMap(const Shape &tensor_map, std::vector<Group> *group); | |||
| Status CreateGroupForOptShard(TensorLayout *const tensor_layout, std::vector<Group> *group); | |||
| virtual void ReplaceNodeInputOrAttrs() {} | |||
| // Key for user data. | |||
| constexpr static char key[] = "OpInfo"; | |||
| @@ -323,6 +323,7 @@ constexpr char GATHERV2[] = "Gather"; | |||
| constexpr char SPARSE_GATHERV2[] = "SparseGatherV2"; | |||
| constexpr char STRIDEDSLICE[] = "StridedSlice"; | |||
| constexpr char SLICE[] = "Slice"; | |||
| constexpr char UNIFORM_REAL[] = "UniformReal"; | |||
| constexpr char BROADCAST[] = "Broadcast"; | |||
| constexpr char BROADCAST_TO[] = "BroadcastTo"; | |||
| constexpr char SQRT[] = "Sqrt"; | |||
| @@ -155,7 +155,8 @@ Status TileInfo::InferMirrorOps() { | |||
| return SUCCESS; | |||
| } | |||
| void TileInfo::UpdateMultiples(const CNodePtr &cnode) { | |||
| void TileInfo::UpdateMultiples() { | |||
| auto cnode = cnode_; | |||
| MS_EXCEPTION_IF_NULL(cnode); | |||
| if (cnode->size() != 3) { | |||
| MS_LOG(EXCEPTION) << name_ << ": The size of tile cnode's inputs must be 3"; | |||
| @@ -175,6 +176,8 @@ void TileInfo::UpdateMultiples(const CNodePtr &cnode) { | |||
| (void)manager->Replace(cnode->input(2), val); | |||
| } | |||
| void TileInfo::ReplaceNodeInputOrAttrs() { UpdateMultiples(); } | |||
| std::shared_ptr<Strategys> TileInfo::GenerateBatchStrategies() { | |||
| if (InferAttrs() != SUCCESS) { | |||
| MS_LOG(EXCEPTION) << name_ << ": Infer attrs failed"; | |||
| @@ -42,7 +42,8 @@ class TileInfo : public OperatorInfo { | |||
| std::vector<StrategyPtr> GenerateOpStrategies(int64_t) override; | |||
| Status SetCostUnderStrategy(const StrategyPtr &) override; | |||
| std::shared_ptr<Strategys> GenerateBatchStrategies() override; | |||
| void UpdateMultiples(const CNodePtr &cnode); | |||
| void UpdateMultiples(); | |||
| void ReplaceNodeInputOrAttrs() override; | |||
| protected: | |||
| Status GetAttrs() override; | |||
| @@ -59,33 +59,11 @@ namespace mindspore { | |||
| namespace parallel { | |||
| static const std::set<std::string> COMMUNICATION_OPS = {ALL_REDUCE, ALL_GATHER, ALL_TO_ALL, REDUCE_SCATTER}; | |||
| static const std::set<std::string> INVALID_LOSS_OPS = {GET_NEXT, VIRTUALLOSS, LOAD, UPDATESTATE}; | |||
| static const std::set<std::string> NO_INPUT_TENSOR_OPS = {UNIFORM_REAL}; | |||
| // g_RefMap, for CNode B input i is a RefKey[Parameter C], | |||
| // it will be one item in map with key: C, and value: (B, i) | |||
| static std::map<AnfNodePtr, std::pair<AnfNodePtr, int64_t>> g_RefMap; | |||
| void SetCommunicationOpGroupLabel(std::vector<AnfNodePtr> new_node_input) { | |||
| if (new_node_input.empty()) { | |||
| return; | |||
| } | |||
| auto prim_anf_node = new_node_input[0]->cast<ValueNodePtr>(); | |||
| auto prim = GetValueNode<PrimitivePtr>(prim_anf_node); | |||
| MS_EXCEPTION_IF_NULL(prim); | |||
| auto attrs = prim->attrs(); | |||
| auto iter = attrs.find(GROUP); | |||
| if (iter != attrs.end()) { | |||
| auto value = iter->second; | |||
| MS_EXCEPTION_IF_NULL(value); | |||
| if (value->isa<StringImm>()) { | |||
| std::string hash_name = value->cast<StringImmPtr>()->value(); | |||
| MS_EXCEPTION_IF_NULL(g_device_manager); | |||
| std::string rank_list_name = g_device_manager->FindRankListNameByHashName(hash_name); | |||
| (void)prim->AddAttr(GROUP_RANKS, MakeValue(rank_list_name)); | |||
| } | |||
| } | |||
| } | |||
| void SetMiniStepOpDoMirrorLabel(std::vector<AnfNodePtr> new_node_input, bool accu_flag) { | |||
| if (new_node_input.empty()) { | |||
| return; | |||
| @@ -282,17 +260,6 @@ static CNodePtr ReplaceNode(const Operator &op, const AnfNodePtr &pre_node, cons | |||
| return new_node; | |||
| } | |||
| std::string CreateInstanceName(const CNodePtr &node, size_t index) { | |||
| MS_EXCEPTION_IF_NULL(node); | |||
| if (!IsValueNode<Primitive>(node->input(0))) { | |||
| MS_LOG(EXCEPTION) << "CreateInstanceName: " << node->ToString() << " doesn't have primitive"; | |||
| } | |||
| std::string name_base = node->fullname_with_scope(); | |||
| std::string name = name_base + "_" + std::to_string(index); | |||
| std::string instance_name = HashInstanceName(name); | |||
| return instance_name; | |||
| } | |||
| void ForwardCommunication(OperatorVector forward_op, const CNodePtr &node) { | |||
| MS_EXCEPTION_IF_NULL(node); | |||
| // step1:get graph manager distribute_operator | |||
| @@ -729,7 +696,8 @@ void StepSplitTensor(const AnfNodePtr &node, const FuncGraphManagerPtr &manager) | |||
| MS_EXCEPTION_IF_NULL(prim_anf_node); | |||
| PrimitivePtr use_cnode_prim = prim_anf_node->value()->cast<PrimitivePtr>(); | |||
| MS_EXCEPTION_IF_NULL(use_cnode_prim); | |||
| if (use_cnode_prim->name() == DEPEND && node_pair.second != 1) { | |||
| if ((use_cnode_prim->name() == DEPEND && node_pair.second != 1) || | |||
| NO_INPUT_TENSOR_OPS.find(use_cnode_prim->name()) != NO_INPUT_TENSOR_OPS.end()) { | |||
| continue; | |||
| } | |||
| if (IsParallelCareNode(use_cnode)) { | |||
| @@ -742,76 +710,6 @@ void StepSplitTensor(const AnfNodePtr &node, const FuncGraphManagerPtr &manager) | |||
| } | |||
| } | |||
| std::vector<AnfNodePtr> ReplaceOpInput(const Operator &replace_op, const std::string &instance_name, | |||
| const CNodePtr &node) { | |||
| OperatorArgs arg_replace_op = replace_op.second; | |||
| ValuePtr pyop_instance = CreatOpInstance(arg_replace_op.first, replace_op.first, instance_name); | |||
| if (pyop_instance == nullptr) { | |||
| MS_LOG(EXCEPTION) << "Failure: " << replace_op.first << " CreatOpInstance failed"; | |||
| } | |||
| OperatorParams params = arg_replace_op.second; | |||
| if (node->inputs().size() < 2) { | |||
| // GetNext operator dose not has input | |||
| if (node->inputs().size() == 1) { | |||
| return {NewValueNode(pyop_instance)}; | |||
| } | |||
| MS_LOG(EXCEPTION) << "Failure: " << node->ToString() << " size is smaller than 2"; | |||
| } | |||
| std::vector<AnfNodePtr> replace_input = {NewValueNode(pyop_instance), node->input(1)}; | |||
| if (replace_op.first == EMBEDDING_LOOKUP) { | |||
| replace_input = {NewValueNode(pyop_instance), node->input(1), node->input(2)}; | |||
| } | |||
| if (!params.empty()) { | |||
| Param param_first = *(params.begin()); | |||
| int64_t first_position = param_first.second; | |||
| if (first_position == 1) { | |||
| replace_input.pop_back(); | |||
| } | |||
| for (auto ¶m : params) { | |||
| AnfNodePtr val = NewValueNode(param.first.second); | |||
| if (val == nullptr) { | |||
| MS_LOG(EXCEPTION) << "Failure:val is nullptr"; | |||
| } | |||
| int64_t position = param.second; | |||
| (void)replace_input.insert(replace_input.begin() + position, val); | |||
| } | |||
| } else if (replace_op.first == SYNC_BATCH_NORM) { | |||
| for (size_t i = 2; i < node->inputs().size(); ++i) { | |||
| replace_input.push_back(node->input(i)); | |||
| } | |||
| } | |||
| SetCommunicationOpGroupLabel(replace_input); | |||
| return replace_input; | |||
| } | |||
| void ReplaceOneOp(const Operator &replace_op, const CNodePtr &node) { | |||
| FuncGraphPtr func_graph = node->func_graph(); | |||
| MS_EXCEPTION_IF_NULL(func_graph); | |||
| FuncGraphManagerPtr manager = func_graph->manager(); | |||
| if (manager == nullptr) { | |||
| MS_LOG(EXCEPTION) << "Failure:AddNode error since manager is nullptr"; | |||
| } | |||
| std::string instance_name = CreateInstanceName(node, 0); | |||
| std::vector<AnfNodePtr> replace_input; | |||
| replace_input = ReplaceOpInput(replace_op, instance_name, node); | |||
| if (node->inputs().size() == DROPOUT_DO_MASK_CNODE_INPUT_SIZE) { | |||
| replace_input.push_back(node->input(3)); | |||
| } | |||
| CNodePtr replace_node = func_graph->NewCNode(replace_input); | |||
| MS_EXCEPTION_IF_NULL(replace_node); | |||
| ScopePtr scope = node->scope(); | |||
| MS_EXCEPTION_IF_NULL(scope); | |||
| replace_node->set_scope(scope); | |||
| replace_node->set_in_forward_flag(true); | |||
| replace_input[0]->set_scope(scope); | |||
| PrimitivePtr prim = GetValueNode<PrimitivePtr>(replace_node->input(0)); | |||
| PrimitivePtr origin_prim = GetValueNode<PrimitivePtr>(node->input(0)); | |||
| SetUserAttrs(origin_prim->attrs(), prim); | |||
| (void)manager->Replace(node, replace_node); | |||
| } | |||
| void StepReplaceOp(OperatorVector replace_op, const CNodePtr &node) { | |||
| // step1:get graph manager distribute_operator | |||
| OperatorInfoPtr distribute_operator = node->user_data<OperatorInfo>(); | |||
| @@ -2522,64 +2420,6 @@ void StepReplace(const OperatorInfoPtr &distribute_operator, const CNodePtr &cno | |||
| } | |||
| } | |||
| void HandleDropoutNode(const OperatorInfoPtr &distribute_operator, const CNodePtr &cnode) { | |||
| MS_EXCEPTION_IF_NULL(distribute_operator); | |||
| MS_EXCEPTION_IF_NULL(cnode); | |||
| std::string op_name = distribute_operator->name(); | |||
| if (op_name.find(DROPOUT_DO_MASK) == std::string::npos) { | |||
| return; | |||
| } | |||
| DropoutDoMaskInfoPtr dropout_do_mask = std::dynamic_pointer_cast<DropoutDoMaskInfo>(distribute_operator); | |||
| MS_EXCEPTION_IF_NULL(dropout_do_mask); | |||
| std::vector<Operator> replace_op = dropout_do_mask->GetDropoutGenMaskReplaceOp(cnode); | |||
| if (replace_op.empty()) { | |||
| MS_LOG(DEBUG) << "No need to replace dropout_gen_mask"; | |||
| return; | |||
| } | |||
| if (cnode->inputs().size() != DROPOUT_DO_MASK_CNODE_INPUT_SIZE) { | |||
| MS_LOG(EXCEPTION) << "The size of drop out do mask cnode's input is not " << DROPOUT_DO_MASK_CNODE_INPUT_SIZE; | |||
| } | |||
| ReplaceOneOp(replace_op[0], cnode->input(DROPOUT_GEN_MASK_INDEX)->cast<CNodePtr>()); | |||
| } | |||
| void HandleTileNode(const OperatorInfoPtr &distribute_operator, const CNodePtr &cnode) { | |||
| MS_EXCEPTION_IF_NULL(cnode); | |||
| if (cnode->size() < 3 || !IsValueNode<Primitive>(cnode->input(0))) { | |||
| return; | |||
| } | |||
| auto prim = GetValueNode<PrimitivePtr>(cnode->input(0)); | |||
| if (prim->name() != TILE) { | |||
| return; | |||
| } | |||
| TileInfoPtr tile = std::dynamic_pointer_cast<TileInfo>(distribute_operator); | |||
| MS_EXCEPTION_IF_NULL(tile); | |||
| tile->UpdateMultiples(cnode); | |||
| } | |||
| void HandleConv2dTransposeNode(const OperatorInfoPtr &distribute_operator, const CNodePtr &cnode) { | |||
| MS_EXCEPTION_IF_NULL(cnode); | |||
| if (cnode->size() != 4 || !IsValueNode<Primitive>(cnode->input(0))) { | |||
| return; | |||
| } | |||
| auto prim = GetValueNode<PrimitivePtr>(cnode->input(0)); | |||
| if (prim->name() != CONV2D_BACK_PROP_INPUT && prim->name() != CONV2D_TRANSPOSE) { | |||
| return; | |||
| } | |||
| Conv2DBackpropInputInfoPtr op_ptr = std::dynamic_pointer_cast<Conv2DBackpropInputInfo>(distribute_operator); | |||
| MS_EXCEPTION_IF_NULL(op_ptr); | |||
| op_ptr->UpdateOutShape(cnode); | |||
| } | |||
| void HandleSpecialNode(const OperatorInfoPtr &distribute_operator, const CNodePtr &cnode) { | |||
| HandleDropoutNode(distribute_operator, cnode); | |||
| HandleTileNode(distribute_operator, cnode); | |||
| HandleConv2dTransposeNode(distribute_operator, cnode); | |||
| } | |||
| std::set<FuncGraphPtr> FindForwardGraphByRootNodes(const AnfNodeSet &root_all_nodes) { | |||
| // J->CNode->Graph | |||
| std::set<FuncGraphPtr> graph_set; | |||
| @@ -2719,7 +2559,7 @@ void ParallelCommunication(const FuncGraphPtr &root, const std::vector<AnfNodePt | |||
| BackwardCommunication(root, distribute_operator, cnode, sens_loss_pairs); | |||
| } | |||
| HandleSpecialNode(distribute_operator, cnode); | |||
| distribute_operator->ReplaceNodeInputOrAttrs(); | |||
| } else if (IsValueNode<Tensor>(node) || IsValueNode<ValueList>(node) || IsValueNode<ValueTuple>(node)) { | |||
| StepSplitTensor(node, manager); | |||
| } | |||
| @@ -55,7 +55,6 @@ struct CommInfo { | |||
| }; | |||
| std::vector<AnfNodePtr> CreateInput(const Operator &op, const AnfNodePtr &node, const std::string &instance_name); | |||
| std::string CreateInstanceName(const CNodePtr &node, size_t index); | |||
| void ForwardCommunication(OperatorVector forward_op, const CNodePtr &node); | |||
| void InsertRedistribution(const RedistributionOpListPtr &redistribution_oplist_ptr, const CNodePtr &node, | |||
| @@ -79,9 +78,6 @@ bool FindCommunicationOp(const std::vector<AnfNodePtr> &all_nodes); | |||
| void StepRedistribution(const CNodePtr &node, const OperatorInfoPtr &distribute_operator, const CNodePtr &insert_node, | |||
| const TensorRedistribution &tensor_redistribution, const CNodePtr &pre_node); | |||
| std::vector<AnfNodePtr> ReplaceOpInput(const Operator &replace_op, const std::string &instance_name, | |||
| const CNodePtr &node); | |||
| void StepReplaceOp(OperatorVector replace_op, const CNodePtr &node); | |||
| void InsertVirtualDivOp(const VirtualDivOp &virtual_div_op, const CNodePtr &node); | |||
| @@ -92,7 +92,10 @@ Shapes GetValueListShape(const AnfNodePtr &node) { | |||
| } | |||
| for (auto &ele : inputs_seq) { | |||
| auto tensor = ele->cast<tensor::TensorPtr>(); | |||
| MS_EXCEPTION_IF_NULL(tensor); | |||
| if (tensor == nullptr) { | |||
| MS_LOG(WARNING) << "The value node is not a tensor"; | |||
| break; | |||
| } | |||
| auto one_shape = tensor->shape(); | |||
| shapes.push_back(one_shape); | |||
| } | |||
| @@ -145,5 +148,83 @@ Shapes GetNodeShape(const AnfNodePtr &node) { | |||
| } | |||
| return shapes; | |||
| } | |||
| std::string CreateInstanceName(const CNodePtr &node, size_t index) { | |||
| MS_EXCEPTION_IF_NULL(node); | |||
| if (!IsValueNode<Primitive>(node->input(0))) { | |||
| MS_LOG(EXCEPTION) << "CreateInstanceName: " << node->ToString() << " doesn't have primitive"; | |||
| } | |||
| std::string name_base = node->fullname_with_scope(); | |||
| std::string name = name_base + "_" + std::to_string(index); | |||
| std::string instance_name = HashInstanceName(name); | |||
| return instance_name; | |||
| } | |||
| void SetCommunicationOpGroupLabel(std::vector<AnfNodePtr> new_node_input) { | |||
| if (new_node_input.empty()) { | |||
| return; | |||
| } | |||
| auto prim_anf_node = new_node_input[0]->cast<ValueNodePtr>(); | |||
| auto prim = GetValueNode<PrimitivePtr>(prim_anf_node); | |||
| MS_EXCEPTION_IF_NULL(prim); | |||
| auto attrs = prim->attrs(); | |||
| auto iter = attrs.find(GROUP); | |||
| if (iter != attrs.end()) { | |||
| auto value = iter->second; | |||
| MS_EXCEPTION_IF_NULL(value); | |||
| if (value->isa<StringImm>()) { | |||
| std::string hash_name = value->cast<StringImmPtr>()->value(); | |||
| MS_EXCEPTION_IF_NULL(g_device_manager); | |||
| std::string rank_list_name = g_device_manager->FindRankListNameByHashName(hash_name); | |||
| (void)prim->AddAttr(GROUP_RANKS, MakeValue(rank_list_name)); | |||
| } | |||
| } | |||
| } | |||
| std::vector<AnfNodePtr> ReplaceOpInput(const Operator &replace_op, const std::string &instance_name, | |||
| const CNodePtr &node) { | |||
| OperatorArgs arg_replace_op = replace_op.second; | |||
| ValuePtr pyop_instance = CreatOpInstance(arg_replace_op.first, replace_op.first, instance_name); | |||
| if (pyop_instance == nullptr) { | |||
| MS_LOG(EXCEPTION) << "Failure: " << replace_op.first << " CreatOpInstance failed"; | |||
| } | |||
| OperatorParams params = arg_replace_op.second; | |||
| if (node->inputs().size() < 2) { | |||
| // GetNext operator dose not has input | |||
| if (node->inputs().size() == 1) { | |||
| return {NewValueNode(pyop_instance)}; | |||
| } | |||
| MS_LOG(EXCEPTION) << "Failure: " << node->ToString() << " size is smaller than 2"; | |||
| } | |||
| std::vector<AnfNodePtr> replace_input = {NewValueNode(pyop_instance), node->input(1)}; | |||
| if (replace_op.first == EMBEDDING_LOOKUP) { | |||
| replace_input = {NewValueNode(pyop_instance), node->input(1), node->input(2)}; | |||
| } | |||
| if (!params.empty()) { | |||
| Param param_first = *(params.begin()); | |||
| int64_t first_position = param_first.second; | |||
| if (first_position == 1) { | |||
| replace_input.pop_back(); | |||
| } | |||
| for (auto ¶m : params) { | |||
| AnfNodePtr val = NewValueNode(param.first.second); | |||
| if (val == nullptr) { | |||
| MS_LOG(EXCEPTION) << "Failure:val is nullptr"; | |||
| } | |||
| int64_t position = param.second; | |||
| (void)replace_input.insert(replace_input.begin() + position, val); | |||
| } | |||
| } else if (replace_op.first == SYNC_BATCH_NORM) { | |||
| for (size_t i = 2; i < node->inputs().size(); ++i) { | |||
| replace_input.push_back(node->input(i)); | |||
| } | |||
| } | |||
| SetCommunicationOpGroupLabel(replace_input); | |||
| return replace_input; | |||
| } | |||
| } // namespace parallel | |||
| } // namespace mindspore | |||
| @@ -28,6 +28,10 @@ namespace parallel { | |||
| bool IsSomePrimitive(const CNodePtr &cnode, const std::string &name); | |||
| bool IsParallelCareNode(const CNodePtr &cnode); | |||
| Shapes GetNodeShape(const AnfNodePtr &node); | |||
| std::string CreateInstanceName(const CNodePtr &node, size_t index); | |||
| void SetCommunicationOpGroupLabel(std::vector<AnfNodePtr> new_node_input); | |||
| std::vector<AnfNodePtr> ReplaceOpInput(const Operator &replace_op, const std::string &instance_name, | |||
| const CNodePtr &node); | |||
| } // namespace parallel | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,59 @@ | |||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| import numpy as np | |||
| import mindspore as ms | |||
| from mindspore import context, Tensor, Parameter | |||
| from mindspore.common.api import _cell_graph_executor | |||
| from mindspore.nn import Cell, TrainOneStepCell, Momentum | |||
| from mindspore.ops import operations as P | |||
| class Net(Cell): | |||
| def __init__(self, mul_weight, strategy1=None, strategy2=None): | |||
| super().__init__() | |||
| self.mul = P.Mul().shard(strategy1) | |||
| self.neg = P.Neg().shard(strategy2) | |||
| self.mul_weight = Parameter(mul_weight, "w1") | |||
| self.uniform_real = P.UniformReal() | |||
| def construct(self, x, b): | |||
| out = self.mul(x, self.mul_weight) | |||
| out = self.neg(out) | |||
| z = self.uniform_real((128, 64, 32)) | |||
| out = out + z | |||
| return out | |||
| _x = Tensor(np.ones([128, 64, 32]), dtype=ms.float32) | |||
| _w1 = Tensor(np.ones([128, 64, 32]), dtype=ms.float32) | |||
| _b = Tensor(np.ones([128, 64, 32]), dtype=ms.float32) | |||
| def compile_net(net): | |||
| optimizer = Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9) | |||
| train_net = TrainOneStepCell(net, optimizer) | |||
| train_net.set_auto_parallel() | |||
| train_net.set_train() | |||
| _cell_graph_executor.compile(train_net, _x, _b) | |||
| context.reset_auto_parallel_context() | |||
| def test_batch_parallel_replace_shape(): | |||
| context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=16, global_rank=0) | |||
| strategy1 = ((16, 1, 1), (16, 1, 1)) | |||
| strategy2 = ((16, 1, 1),) | |||
| net = Net(_w1, strategy1, strategy2) | |||
| compile_net(net) | |||