| @@ -913,14 +913,13 @@ double GatherV2PCost::GetBackwardCommCost(const std::vector<TensorInfo> &inputs, | |||||
| double UniformCandidateSamplerCost::GetForwardComputationCost(const std::vector<TensorInfo> &inputs, | double UniformCandidateSamplerCost::GetForwardComputationCost(const std::vector<TensorInfo> &inputs, | ||||
| const std::vector<TensorInfo> &outputs, | const std::vector<TensorInfo> &outputs, | ||||
| int64_t stage_id) const { | int64_t stage_id) const { | ||||
| double result = 0.0; | |||||
| Shape input0_slice_shape = inputs[0].slice_shape(); | Shape input0_slice_shape = inputs[0].slice_shape(); | ||||
| if (inputs_type_lengths_.size() != inputs.size()) { | if (inputs_type_lengths_.size() != inputs.size()) { | ||||
| MS_LOG(EXCEPTION) << "Invalid inputs type size " << inputs_type_lengths_.size() | MS_LOG(EXCEPTION) << "Invalid inputs type size " << inputs_type_lengths_.size() | ||||
| << " for UniformCandidateSampler cost"; | << " for UniformCandidateSampler cost"; | ||||
| } | } | ||||
| result = ListProduct(input0_slice_shape) * static_cast<double>(inputs_type_lengths_[0]); | |||||
| double result = ListProduct(input0_slice_shape) * static_cast<double>(inputs_type_lengths_[0]); | |||||
| return result; | return result; | ||||
| } | } | ||||
| @@ -279,6 +279,5 @@ Status SliceInfo::ComputeReplaceGraph(const CNodePtr &cnode) { | |||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| } // namespace parallel | } // namespace parallel | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -30,7 +30,6 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace parallel { | namespace parallel { | ||||
| Status UniformCandidateSamplerInfo::GetUniformSamplerAttrInt64(const std::string &args, int64_t *value) { | Status UniformCandidateSamplerInfo::GetUniformSamplerAttrInt64(const std::string &args, int64_t *value) { | ||||
| auto iter = attrs_.find(args); | auto iter = attrs_.find(args); | ||||
| if (iter == attrs_.end()) { | if (iter == attrs_.end()) { | ||||
| @@ -276,7 +275,6 @@ Status UniformCandidateSamplerInfo::InitForCostModel(const StrategyPtr &strategy | |||||
| ReplaceGraphPtr UniformCandidateSamplerInfo::replace_graph(const CNodePtr &cnode) { | ReplaceGraphPtr UniformCandidateSamplerInfo::replace_graph(const CNodePtr &cnode) { | ||||
| auto input_strategy = strategy_->GetInputDim().at(0); | auto input_strategy = strategy_->GetInputDim().at(0); | ||||
| // Only when the axis-1 is sharded, we need to modify the attribute | // Only when the axis-1 is sharded, we need to modify the attribute | ||||
| if (input_strategy.size() == 2 && input_strategy[1] > 1) { | if (input_strategy.size() == 2 && input_strategy[1] > 1) { | ||||
| if (ComputeReplaceGraph(cnode) != SUCCESS) { | if (ComputeReplaceGraph(cnode) != SUCCESS) { | ||||
| @@ -311,6 +309,5 @@ Status UniformCandidateSamplerInfo::ComputeReplaceGraph(const CNodePtr &cnode) { | |||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| } // namespace parallel | } // namespace parallel | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -331,7 +331,6 @@ Status UnsortedSegmentMinInfo::ComputeReplaceGraph(const CNodePtr &cnode) { | |||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| // The UnsortedSegmentMaxInfo is almost same with UnsortedSegmentMinInfo | // The UnsortedSegmentMaxInfo is almost same with UnsortedSegmentMinInfo | ||||
| // Except the reduceMin op in the ComputeReplaceGraph is replaced with reduceMax op | // Except the reduceMin op in the ComputeReplaceGraph is replaced with reduceMax op | ||||
| ReplaceGraphPtr UnsortedSegmentMaxInfo::replace_graph(const CNodePtr &cnode) { | ReplaceGraphPtr UnsortedSegmentMaxInfo::replace_graph(const CNodePtr &cnode) { | ||||
| @@ -351,9 +350,8 @@ Status UnsortedSegmentMaxInfo::ComputeReplaceGraph(const CNodePtr &cnode) { | |||||
| MS_LOG(ERROR) << "GenerateGraph Init failed"; | MS_LOG(ERROR) << "GenerateGraph Init failed"; | ||||
| return FAILED; | return FAILED; | ||||
| } | } | ||||
| // Get the attributes of the UnsortedSegmentMin | |||||
| // Get the attributes of the UnsortedSegmentMax | |||||
| auto num_segments = GetValue<int64_t>(input_value_.at(2)); | auto num_segments = GetValue<int64_t>(input_value_.at(2)); | ||||
| // Step1: Output branch | |||||
| auto segment_max = gen_g.PushBack({gen_g.NewOpInst(UNSORTED_SEGMENT_MAX), gen_g.virtual_input_node(), | auto segment_max = gen_g.PushBack({gen_g.NewOpInst(UNSORTED_SEGMENT_MAX), gen_g.virtual_input_node(), | ||||
| gen_g.virtual_input_node(), CreatInt64Imm(num_segments)}); | gen_g.virtual_input_node(), CreatInt64Imm(num_segments)}); | ||||
| auto expandim_output = gen_g.PushBack({gen_g.NewOpInst(EXPAND_DIMS), segment_max, CreatInt64Imm(0)}); | auto expandim_output = gen_g.PushBack({gen_g.NewOpInst(EXPAND_DIMS), segment_max, CreatInt64Imm(0)}); | ||||
| @@ -78,7 +78,6 @@ class UnsortedSegmentMinInfo : public UnsortedSegmentOpInfo { | |||||
| protected: | protected: | ||||
| Status ComputeReplaceGraph(const CNodePtr &cnode); | Status ComputeReplaceGraph(const CNodePtr &cnode); | ||||
| }; | }; | ||||
| class UnsortedSegmentMaxInfo : public UnsortedSegmentOpInfo { | class UnsortedSegmentMaxInfo : public UnsortedSegmentOpInfo { | ||||
| public: | public: | ||||
| UnsortedSegmentMaxInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, | UnsortedSegmentMaxInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, | ||||
| @@ -22,7 +22,6 @@ | |||||
| #include <algorithm> | #include <algorithm> | ||||
| #include <memory> | #include <memory> | ||||
| #include "frontend/parallel/pipeline_transformer/pipeline_transformer.h" | #include "frontend/parallel/pipeline_transformer/pipeline_transformer.h" | ||||
| #include "frontend/parallel/graph_util/generate_graph.h" | |||||
| #include "frontend/parallel/auto_parallel/graph_costmodel.h" | #include "frontend/parallel/auto_parallel/graph_costmodel.h" | ||||
| #include "frontend/parallel/ops_info/ops_utils.h" | #include "frontend/parallel/ops_info/ops_utils.h" | ||||
| #include "frontend/parallel/group_manager.h" | #include "frontend/parallel/group_manager.h" | ||||
| @@ -33,8 +32,6 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace parallel { | namespace parallel { | ||||
| static std::unordered_map<AnfNodePtr, std::set<int>> parameter_color_map; | static std::unordered_map<AnfNodePtr, std::set<int>> parameter_color_map; | ||||
| static std::pair<bool, int> IsSharedNode(const AnfNodePtr &node, const AnfNodeIndexSet &node_users); | |||||
| static bool IsSomePrimitive(const CNodePtr &cnode, const std::string &name); | |||||
| static int send_tag = 0; | static int send_tag = 0; | ||||
| static int recv_tag = 0; | static int recv_tag = 0; | ||||
| @@ -236,7 +233,7 @@ void PipelineTransformer::InsertReceive(const FuncGraphPtr &graph, const AnfNode | |||||
| manager_->SetEdge(use_node, index, recv); | manager_->SetEdge(use_node, index, recv); | ||||
| } | } | ||||
| static std::pair<bool, int> IsSharedNode(const AnfNodePtr &node, const AnfNodeIndexSet &node_users) { | |||||
| std::pair<bool, int> PipelineTransformer::IsSharedNode(const AnfNodePtr &node, const AnfNodeIndexSet &node_users) { | |||||
| std::set<int> tag_set; | std::set<int> tag_set; | ||||
| auto node_stage = node->stage(); | auto node_stage = node->stage(); | ||||
| int min_tag = node_stage; | int min_tag = node_stage; | ||||
| @@ -368,7 +365,7 @@ void PipelineTransformer::ElimGraphStage() { | |||||
| } | } | ||||
| } | } | ||||
| static bool IsSomePrimitive(const CNodePtr &cnode, const std::string &name) { | |||||
| bool PipelineTransformer::IsSomePrimitive(const CNodePtr &cnode, const std::string &name) { | |||||
| ValueNodePtr anf_node = cnode->input(0)->cast<ValueNodePtr>(); | ValueNodePtr anf_node = cnode->input(0)->cast<ValueNodePtr>(); | ||||
| MS_EXCEPTION_IF_NULL(anf_node); | MS_EXCEPTION_IF_NULL(anf_node); | ||||
| PrimitivePtr prim = anf_node->value()->cast<PrimitivePtr>(); | PrimitivePtr prim = anf_node->value()->cast<PrimitivePtr>(); | ||||
| @@ -18,9 +18,11 @@ | |||||
| #define MINDSPORE_CCSRC_FRONTEND_PARALLEL_PIPELINE_TRANSFORMER_PIPELINE_TRANSFORMER_H_ | #define MINDSPORE_CCSRC_FRONTEND_PARALLEL_PIPELINE_TRANSFORMER_PIPELINE_TRANSFORMER_H_ | ||||
| #include <utility> | #include <utility> | ||||
| #include <string> | |||||
| #include "ir/value.h" | #include "ir/value.h" | ||||
| #include "ir/graph_utils.h" | #include "ir/graph_utils.h" | ||||
| #include "base/base.h" | #include "base/base.h" | ||||
| #include "frontend/parallel/graph_util/generate_graph.h" | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace parallel { | namespace parallel { | ||||
| @@ -49,6 +51,8 @@ class PipelineTransformer { | |||||
| void ElimParameter(); | void ElimParameter(); | ||||
| private: | private: | ||||
| std::pair<bool, int> IsSharedNode(const AnfNodePtr &node, const AnfNodeIndexSet &node_users); | |||||
| bool IsSomePrimitive(const CNodePtr &cnode, const std::string &name); | |||||
| void DoBroadCast(const FuncGraphPtr &func); | void DoBroadCast(const FuncGraphPtr &func); | ||||
| SendAttr InsertSend(const FuncGraphPtr &graph, const AnfNodePtr ¶meter, const int &user_node_stage, | SendAttr InsertSend(const FuncGraphPtr &graph, const AnfNodePtr ¶meter, const int &user_node_stage, | ||||
| const int &node_stage); | const int &node_stage); | ||||
| @@ -54,11 +54,6 @@ static const std::set<std::string> INVALID_LOSS_OPS = {GET_NEXT, VIRTUALLOSS}; | |||||
| // g_RefMap, for CNode B input i is a RefKey[Parameter C], | // g_RefMap, for CNode B input i is a RefKey[Parameter C], | ||||
| // it will be one item in map with key: C, and value: (B, i) | // it will be one item in map with key: C, and value: (B, i) | ||||
| static std::map<AnfNodePtr, std::pair<AnfNodePtr, int64_t>> g_RefMap; | static std::map<AnfNodePtr, std::pair<AnfNodePtr, int64_t>> g_RefMap; | ||||
| static void HandleNoUsedParameter(const FuncGraphPtr &root); | |||||
| static CNodePtr ReplaceNode(const Operator &op, const AnfNodePtr &pre_node, const FuncGraphPtr &func_graph, | |||||
| const std::string &instance_name); | |||||
| static void ApplyParallelOptOnParam(const FuncGraphPtr &root, const AnfNodePtr ¶meter, | |||||
| const std::string &opt_shard_group); | |||||
| void SetCommunicationOpGroupLabel(std::vector<AnfNodePtr> new_node_input) { | void SetCommunicationOpGroupLabel(std::vector<AnfNodePtr> new_node_input) { | ||||
| if (new_node_input.empty()) { | if (new_node_input.empty()) { | ||||