| @@ -44,6 +44,7 @@ namespace parallel { | |||||
| #define DEFAULT_TENSOR_SLICE_ALIGNMENT_SIZE 16 | #define DEFAULT_TENSOR_SLICE_ALIGNMENT_SIZE 16 | ||||
| #define DEFAULT_FULLY_USE_DEVICES true | #define DEFAULT_FULLY_USE_DEVICES true | ||||
| #define DEFAULT_ELEMENTWISE_OP_STRA_FOLLOW false | #define DEFAULT_ELEMENTWISE_OP_STRA_FOLLOW false | ||||
| #define DEFAULT_IS_MULTI_SUBGRAPHS false | |||||
| class CostGraph; | class CostGraph; | ||||
| using CostGraphPtr = std::shared_ptr<CostGraph>; | using CostGraphPtr = std::shared_ptr<CostGraph>; | ||||
| @@ -46,6 +46,7 @@ void CostModelContext::ResetCostModel() { | |||||
| costmodel_communi_threshold_ = DEFAULT_COST_MODEL_COMMUNI_THRESHOLD; | costmodel_communi_threshold_ = DEFAULT_COST_MODEL_COMMUNI_THRESHOLD; | ||||
| costmodel_communi_const_ = DEFAULT_COST_MODEL_COMMUNI_CONST; | costmodel_communi_const_ = DEFAULT_COST_MODEL_COMMUNI_CONST; | ||||
| costmodel_communi_bias_ = DEFAULT_COST_MODEL_COMMUNI_BIAS; | costmodel_communi_bias_ = DEFAULT_COST_MODEL_COMMUNI_BIAS; | ||||
| is_multi_subgraphs_ = DEFAULT_IS_MULTI_SUBGRAPHS; | |||||
| costmodel_allreduce_fusion_algorithm_ = DEFAULT_COST_MODEL_ALLREDUCE_FUSION_ALGORITHM; | costmodel_allreduce_fusion_algorithm_ = DEFAULT_COST_MODEL_ALLREDUCE_FUSION_ALGORITHM; | ||||
| costmodel_allreduce_fusion_times_ = DEFAULT_COST_MODEL_ALLREDUCE_FUSION_TIMES; | costmodel_allreduce_fusion_times_ = DEFAULT_COST_MODEL_ALLREDUCE_FUSION_TIMES; | ||||
| costmodel_allreduce_fusion_tail_percent_ = DEFAULT_COST_MODEL_ALLREDUCE_FUSION_TAIL_PERCENT; | costmodel_allreduce_fusion_tail_percent_ = DEFAULT_COST_MODEL_ALLREDUCE_FUSION_TAIL_PERCENT; | ||||
| @@ -84,6 +85,7 @@ void CostModelContext::set_costmodel_communi_const(double cm_communi_const) { | |||||
| void CostModelContext::set_costmodel_communi_bias(double cm_communi_bias) { costmodel_communi_bias_ = cm_communi_bias; } | void CostModelContext::set_costmodel_communi_bias(double cm_communi_bias) { costmodel_communi_bias_ = cm_communi_bias; } | ||||
| void CostModelContext::set_multi_subgraphs(bool multi_graphs) { is_multi_subgraphs_ = multi_graphs; } | |||||
| void CostModelContext::set_costmodel_allreduce_fusion_algorithm(int32_t algorithm) { | void CostModelContext::set_costmodel_allreduce_fusion_algorithm(int32_t algorithm) { | ||||
| costmodel_allreduce_fusion_algorithm_ = algorithm; | costmodel_allreduce_fusion_algorithm_ = algorithm; | ||||
| } | } | ||||
| @@ -67,6 +67,9 @@ class CostModelContext { | |||||
| void set_costmodel_communi_bias(double); | void set_costmodel_communi_bias(double); | ||||
| double costmodel_communi_bias() const { return costmodel_communi_bias_; } | double costmodel_communi_bias() const { return costmodel_communi_bias_; } | ||||
| void set_multi_subgraphs(bool); | |||||
| bool is_multi_subgraphs() const { return is_multi_subgraphs_; } | |||||
| void set_costmodel_allreduce_fusion_algorithm(int32_t); | void set_costmodel_allreduce_fusion_algorithm(int32_t); | ||||
| int32_t costmodel_allreduce_fusion_algorithm() const { return costmodel_allreduce_fusion_algorithm_; } | int32_t costmodel_allreduce_fusion_algorithm() const { return costmodel_allreduce_fusion_algorithm_; } | ||||
| @@ -138,6 +141,8 @@ class CostModelContext { | |||||
| // COST_MODEL_COMMUNI_BIAS | // COST_MODEL_COMMUNI_BIAS | ||||
| double costmodel_communi_bias_; | double costmodel_communi_bias_; | ||||
| bool is_multi_subgraphs_; | |||||
| int32_t costmodel_allreduce_fusion_algorithm_; | int32_t costmodel_allreduce_fusion_algorithm_; | ||||
| int32_t costmodel_allreduce_fusion_times_; | int32_t costmodel_allreduce_fusion_times_; | ||||
| @@ -426,13 +426,13 @@ OperatorInfoPtr CreateTheOperatorInfo(const PrimitivePtr &prim, const CNodePtr & | |||||
| return operator_info; | return operator_info; | ||||
| } | } | ||||
| Status ConstructCostGraphNodes(const std::vector<AnfNodePtr> &all_nodes, const FuncGraphPtr &) { | |||||
| // Using CNode's UniqueIds to construct nodes | |||||
| Status ConstructCostGraphNodesByUniqueId(const std::vector<AnfNodePtr> &all_nodes, const FuncGraphPtr &) { | |||||
| MS_LOG(INFO) << "Constructing nodes for cost graph begins."; | MS_LOG(INFO) << "Constructing nodes for cost graph begins."; | ||||
| entire_costgraph = std::make_shared<CostGraph>(); | entire_costgraph = std::make_shared<CostGraph>(); | ||||
| entire_costgraph->SetDeviceMemoryAndCostParameter(); | entire_costgraph->SetDeviceMemoryAndCostParameter(); | ||||
| bool new_operator = true, first_operator = true; | |||||
| std::string first_operator_cnode; | |||||
| size_t current_op_index = 0; | |||||
| // The map from CNode's UniqueId to its operatorInfo | |||||
| std::map<std::string, OperatorInfoPtr> from_cnode_to_info; | |||||
| // Step 1 | // Step 1 | ||||
| for (auto &node : all_nodes) { | for (auto &node : all_nodes) { | ||||
| @@ -449,12 +449,8 @@ Status ConstructCostGraphNodes(const std::vector<AnfNodePtr> &all_nodes, const F | |||||
| PrimitivePtr prim = GetValueNode<PrimitivePtr>(prim_anf_node); | PrimitivePtr prim = GetValueNode<PrimitivePtr>(prim_anf_node); | ||||
| MS_EXCEPTION_IF_NULL(prim); | MS_EXCEPTION_IF_NULL(prim); | ||||
| // When visiting the second subgraph, use the corresponding operatorInfo which already created | |||||
| bool modify_new_operator = (new_operator) && (!first_operator) && (cnode->UniqueId() == first_operator_cnode); | |||||
| if (modify_new_operator) { | |||||
| new_operator = false; | |||||
| } | |||||
| if (new_operator) { | |||||
| auto search_cnode = from_cnode_to_info.find(cnode->UniqueId()); | |||||
| if (search_cnode == from_cnode_to_info.end()) { | |||||
| auto operator_info = CreateTheOperatorInfo(prim, cnode); | auto operator_info = CreateTheOperatorInfo(prim, cnode); | ||||
| if (operator_info == nullptr) { | if (operator_info == nullptr) { | ||||
| return FAILED; | return FAILED; | ||||
| @@ -465,14 +461,67 @@ Status ConstructCostGraphNodes(const std::vector<AnfNodePtr> &all_nodes, const F | |||||
| entire_costgraph->AddOperator(operator_info); | entire_costgraph->AddOperator(operator_info); | ||||
| (void)cnode->set_operator_info(operator_info); | (void)cnode->set_operator_info(operator_info); | ||||
| if (first_operator) { | |||||
| first_operator_cnode = cnode->UniqueId(); | |||||
| first_operator = false; | |||||
| MS_LOG(INFO) << "The CNode with UniqueId: " << cnode->UniqueId() | |||||
| << " and UniqueIdThroughCopy: " << cnode->UniqueIdThroughCopy() | |||||
| << " is set OperatorInfo: " << operator_info->name() << ", Primitive: " << prim->name(); | |||||
| (void)from_cnode_to_info.emplace(std::make_pair(cnode->UniqueIdThroughCopy(), operator_info)); | |||||
| // Needed by rec_parser | |||||
| entire_costgraph->add_inputs_tensor_name(inputs_tensor_name); | |||||
| } else { | |||||
| // Two CNODEs' UniqueIds should not be equal | |||||
| MS_LOG(EXCEPTION) << "The CNode with UniqueId: " << cnode->UniqueId() | |||||
| << " and UniqueIdThroughCopy: " << cnode->UniqueIdThroughCopy() | |||||
| << " is set OperatorInfo: " << search_cnode->second->name() << ", Primitive: " << prim->name(); | |||||
| } | |||||
| } | |||||
| MS_LOG(INFO) << "Constructing nodes for cost graph ends."; | |||||
| return SUCCESS; | |||||
| } | |||||
| // Using CNode's UniqueIdThroughCopys to construct nodes | |||||
| Status ConstructCostGraphNodesByUniqueIdTC(const std::vector<AnfNodePtr> &all_nodes, const FuncGraphPtr &) { | |||||
| MS_LOG(INFO) << "Constructing nodes for cost graph begins."; | |||||
| entire_costgraph = std::make_shared<CostGraph>(); | |||||
| entire_costgraph->SetDeviceMemoryAndCostParameter(); | |||||
| // The map from CNode's UniqueIdThroughCopy to its operatorInfo | |||||
| std::map<std::string, OperatorInfoPtr> from_cnode_to_info; | |||||
| for (auto &node : all_nodes) { | |||||
| // NOTE: we only care about splittable Primitive operators | |||||
| auto cnode = node->cast<CNodePtr>(); | |||||
| bool bool_result = (cnode == nullptr) || (!IsValueNode<Primitive>(cnode->input(0))); | |||||
| if (bool_result) { | |||||
| continue; | |||||
| } | |||||
| ValueNodePtr prim_anf_node = cnode->input(0)->cast<ValueNodePtr>(); | |||||
| if (!IsAutoParallelCareNode(cnode)) { | |||||
| continue; | |||||
| } | |||||
| PrimitivePtr prim = GetValueNode<PrimitivePtr>(prim_anf_node); | |||||
| // Find the operatorInfo if it exists | |||||
| auto search_cnode = from_cnode_to_info.find(cnode->UniqueIdThroughCopy()); | |||||
| if (search_cnode == from_cnode_to_info.end()) { | |||||
| // In this case, the corresponding OperatorInfo is not created, create the new one. | |||||
| auto operator_info = CreateTheOperatorInfo(prim, cnode); | |||||
| if (operator_info == nullptr) { | |||||
| return FAILED; | |||||
| } | } | ||||
| // Needed by rec_parser | // Needed by rec_parser | ||||
| operator_info->set_type(prim->name()); | |||||
| std::vector<std::string> inputs_tensor_name = ExtractInputsTensorName(cnode); | |||||
| entire_costgraph->AddOperator(operator_info); | |||||
| (void)cnode->set_operator_info(operator_info); | |||||
| MS_LOG(INFO) << "The CNode with UniqueId: " << cnode->UniqueId() | |||||
| << " and UniqueIdThroughCopy: " << cnode->UniqueIdThroughCopy() | |||||
| << " is set OperatorInfo: " << operator_info->name() << ", Primitive: " << prim->name(); | |||||
| (void)from_cnode_to_info.emplace(std::make_pair(cnode->UniqueIdThroughCopy(), operator_info)); | |||||
| // Needed by rec_parser | |||||
| entire_costgraph->add_inputs_tensor_name(inputs_tensor_name); | entire_costgraph->add_inputs_tensor_name(inputs_tensor_name); | ||||
| } else { | } else { | ||||
| auto current_op_ptr = entire_costgraph->FindOperatorByIndex(current_op_index); | |||||
| auto current_op_ptr = search_cnode->second; | |||||
| if (current_op_ptr == nullptr) { | if (current_op_ptr == nullptr) { | ||||
| MS_LOG(EXCEPTION) << "Find " << prim->name() << " from CostGraph failed."; | MS_LOG(EXCEPTION) << "Find " << prim->name() << " from CostGraph failed."; | ||||
| } else { | } else { | ||||
| @@ -484,14 +533,12 @@ Status ConstructCostGraphNodes(const std::vector<AnfNodePtr> &all_nodes, const F | |||||
| << " does not match the Prim: " << prim->name(); | << " does not match the Prim: " << prim->name(); | ||||
| } | } | ||||
| (void)cnode->set_operator_info(current_op_ptr); | (void)cnode->set_operator_info(current_op_ptr); | ||||
| current_op_index++; | |||||
| MS_LOG(INFO) << "The CNode with UniqueId: " << cnode->UniqueId() | |||||
| << " and UniqueIdThroughCopy: " << cnode->UniqueIdThroughCopy() | |||||
| << " is set OperatorInfo: " << current_op_ptr->name() << ", Primitive: " << prim->name(); | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| if ((!new_operator) && (current_op_index != entire_costgraph->GetOperators().size())) { | |||||
| MS_LOG(EXCEPTION) << "The second subgraph's operator number: " << current_op_index | |||||
| << " does not match the first ones: " << entire_costgraph->GetOperators().size(); | |||||
| } | |||||
| MS_LOG(INFO) << "Constructing nodes for cost graph ends."; | MS_LOG(INFO) << "Constructing nodes for cost graph ends."; | ||||
| return SUCCESS; | return SUCCESS; | ||||
| @@ -844,11 +891,20 @@ Status ParallelStrategySearch(const std::vector<AnfNodePtr> &all_nodes, const Fu | |||||
| // OUTPUT: the determined strategy for each operator. | // OUTPUT: the determined strategy for each operator. | ||||
| // Step 1 | // Step 1 | ||||
| if (ConstructCostGraphNodes(all_nodes, root) == SUCCESS) { | |||||
| MS_LOG(INFO) << "Constructing nodes for cost graph succeeded. There are " << entire_costgraph->GetOperators().size() | |||||
| << " operators."; | |||||
| if (CostModelContext::GetInstance()->is_multi_subgraphs()) { | |||||
| if (ConstructCostGraphNodesByUniqueIdTC(all_nodes, root) == SUCCESS) { | |||||
| MS_LOG(INFO) << "Constructing nodes for cost graph succeeded. There are " | |||||
| << entire_costgraph->GetOperators().size() << " operators."; | |||||
| } else { | |||||
| MS_LOG(EXCEPTION) << "Constructing nodes for cost graph failed."; | |||||
| } | |||||
| } else { | } else { | ||||
| MS_LOG(EXCEPTION) << "Constructing nodes for cost graph failed."; | |||||
| if (ConstructCostGraphNodesByUniqueId(all_nodes, root) == SUCCESS) { | |||||
| MS_LOG(INFO) << "Constructing nodes for cost graph succeeded. There are " | |||||
| << entire_costgraph->GetOperators().size() << " operators."; | |||||
| } else { | |||||
| MS_LOG(EXCEPTION) << "Constructing nodes for cost graph failed."; | |||||
| } | |||||
| } | } | ||||
| // Step 2 | // Step 2 | ||||
| @@ -916,7 +972,7 @@ std::vector<std::vector<std::string>> RecInputTensorNames(const std::map<std::st | |||||
| } | } | ||||
| Status ParallelStrategyRecSearch(const std::vector<AnfNodePtr> &all_nodes, const FuncGraphPtr &root) { | Status ParallelStrategyRecSearch(const std::vector<AnfNodePtr> &all_nodes, const FuncGraphPtr &root) { | ||||
| if (ConstructCostGraphNodes(all_nodes, root) == SUCCESS) { | |||||
| if (ConstructCostGraphNodesByUniqueId(all_nodes, root) == SUCCESS) { | |||||
| MS_LOG(INFO) << "Constructing nodes for cost graph succeeded. There are " << entire_costgraph->GetOperators().size() | MS_LOG(INFO) << "Constructing nodes for cost graph succeeded. There are " << entire_costgraph->GetOperators().size() | ||||
| << " operators."; | << " operators."; | ||||
| } else { | } else { | ||||
| @@ -43,7 +43,9 @@ std::vector<size_t> ExtractInputTypeLengthByNode(const CNodePtr &node); | |||||
| std::vector<TypePtr> ExtractOutputTypeByNode(const CNodePtr &node); | std::vector<TypePtr> ExtractOutputTypeByNode(const CNodePtr &node); | ||||
| Status ConstructCostGraphNodes(const std::vector<AnfNodePtr> &all_nodes, const FuncGraphPtr &root); | |||||
| Status ConstructCostGraphNodesByUniqueId(const std::vector<AnfNodePtr> &all_nodes, const FuncGraphPtr &root); | |||||
| Status ConstructCostGraphNodesByUniqueIdTC(const std::vector<AnfNodePtr> &all_nodes, const FuncGraphPtr &root); | |||||
| void ConstructCostGraphEdges(const std::vector<AnfNodePtr> &all_nodes); | void ConstructCostGraphEdges(const std::vector<AnfNodePtr> &all_nodes); | ||||
| @@ -24,6 +24,7 @@ | |||||
| #include <functional> | #include <functional> | ||||
| #include "ir/func_graph_cloner.h" | #include "ir/func_graph_cloner.h" | ||||
| #include "parallel/costmodel_context.h" | |||||
| #include "pipeline/pass.h" | #include "pipeline/pass.h" | ||||
| #include "pipeline/parse/parse_base.h" | #include "pipeline/parse/parse_base.h" | ||||
| #include "pipeline/parse/data_converter.h" | #include "pipeline/parse/data_converter.h" | ||||
| @@ -341,7 +342,10 @@ static std::vector<ActionItem> CommonPipeline() { | |||||
| // Resolve the python func | // Resolve the python func | ||||
| actions.emplace_back(std::make_pair("symbol_resolve", SymbolResolveAction)); | actions.emplace_back(std::make_pair("symbol_resolve", SymbolResolveAction)); | ||||
| actions.emplace_back(std::make_pair("combine_like_graphs", CombineLikeGraphs)); | |||||
| auto multi_graphs = parallel::CostModelContext::GetInstance()->is_multi_subgraphs(); | |||||
| if (!multi_graphs) { | |||||
| actions.emplace_back(std::make_pair("combine_like_graphs", CombineLikeGraphs)); | |||||
| } | |||||
| actions.emplace_back(std::make_pair("inference_opt_prepare", InferenceOptPrepareAction)); | actions.emplace_back(std::make_pair("inference_opt_prepare", InferenceOptPrepareAction)); | ||||
| // Evaluate type and shape, and specialize | // Evaluate type and shape, and specialize | ||||
| actions.emplace_back(std::make_pair("abstract_specialize", AbstractSpecializeAction)); | actions.emplace_back(std::make_pair("abstract_specialize", AbstractSpecializeAction)); | ||||
| @@ -222,6 +222,8 @@ PYBIND11_MODULE(_c_expression, m) { | |||||
| "Set the parameter cost_model_communi_bias of the DP algorithm.") | "Set the parameter cost_model_communi_bias of the DP algorithm.") | ||||
| .def("get_costmodel_communi_bias", &CostModelContext::costmodel_communi_bias, | .def("get_costmodel_communi_bias", &CostModelContext::costmodel_communi_bias, | ||||
| "Get the parameter cost_model_communi_bias of the DP algorithm.") | "Get the parameter cost_model_communi_bias of the DP algorithm.") | ||||
| .def("set_multi_subgraphs", &CostModelContext::set_multi_subgraphs, "Set the parameter is_multi_subgraphs.") | |||||
| .def("get_multi_subgraphs", &CostModelContext::is_multi_subgraphs, "Get the parameter is_multi_subgraphs.") | |||||
| .def("set_costmodel_allreduce_fusion_algorithm", &CostModelContext::set_costmodel_allreduce_fusion_algorithm, | .def("set_costmodel_allreduce_fusion_algorithm", &CostModelContext::set_costmodel_allreduce_fusion_algorithm, | ||||
| "Set the parameter gradient AllReduce fusion algorithm.") | "Set the parameter gradient AllReduce fusion algorithm.") | ||||
| .def("get_costmodel_allreduce_fusion_algorithm", &CostModelContext::costmodel_allreduce_fusion_algorithm, | .def("get_costmodel_allreduce_fusion_algorithm", &CostModelContext::costmodel_allreduce_fusion_algorithm, | ||||
| @@ -214,6 +214,31 @@ class _CostModelContext: | |||||
| raise ValueError("Context handle is none in context!!!") | raise ValueError("Context handle is none in context!!!") | ||||
| return self._context_handle.get_costmodel_communi_bias() | return self._context_handle.get_costmodel_communi_bias() | ||||
| def set_multi_subgraphs(self, multi_subgraph): | |||||
| """ | |||||
| Set the flag of ANF graph containing multiple subgraphs. | |||||
| Args: | |||||
| multi_subgraph (bool): A parameter used in marking the multi-subgraphs flag. | |||||
| Raises: | |||||
| ValueError: If context handle is none. | |||||
| """ | |||||
| if self._context_handle is None: | |||||
| raise ValueError("Context handle is none in context!!!") | |||||
| self._context_handle.set_multi_subgraphs(multi_subgraph) | |||||
| def get_multi_subgraphs(self): | |||||
| """ | |||||
| Get the flag of ANF graph containing multiple subgraphs. | |||||
| Raises: | |||||
| ValueError: If context handle is none. | |||||
| """ | |||||
| if self._context_handle is None: | |||||
| raise ValueError("Context handle is none in context!!!") | |||||
| return self._context_handle.get_multi_subgraphs() | |||||
| def set_costmodel_allreduce_fusion_algorithm(self, algorithm): | def set_costmodel_allreduce_fusion_algorithm(self, algorithm): | ||||
| """ | """ | ||||
| Set costmodel allreduce fusion algorithm. | Set costmodel allreduce fusion algorithm. | ||||
| @@ -427,6 +452,7 @@ set_cost_model_context_func_map = { | |||||
| "costmodel_communi_threshold": cost_model_context().set_costmodel_communi_threshold, | "costmodel_communi_threshold": cost_model_context().set_costmodel_communi_threshold, | ||||
| "costmodel_communi_const": cost_model_context().set_costmodel_communi_const, | "costmodel_communi_const": cost_model_context().set_costmodel_communi_const, | ||||
| "costmodel_communi_bias": cost_model_context().set_costmodel_communi_bias, | "costmodel_communi_bias": cost_model_context().set_costmodel_communi_bias, | ||||
| "multi_subgraphs": cost_model_context().set_multi_subgraphs, | |||||
| "costmodel_allreduce_fusion_algorithm": cost_model_context().set_costmodel_allreduce_fusion_algorithm, | "costmodel_allreduce_fusion_algorithm": cost_model_context().set_costmodel_allreduce_fusion_algorithm, | ||||
| "costmodel_allreduce_fusion_times": cost_model_context().set_costmodel_allreduce_fusion_times, | "costmodel_allreduce_fusion_times": cost_model_context().set_costmodel_allreduce_fusion_times, | ||||
| "costmodel_allreduce_fusion_tail_percent": cost_model_context().set_costmodel_allreduce_fusion_tail_percent, | "costmodel_allreduce_fusion_tail_percent": cost_model_context().set_costmodel_allreduce_fusion_tail_percent, | ||||
| @@ -447,6 +473,7 @@ get_cost_model_context_func_map = { | |||||
| "costmodel_communi_threshold": cost_model_context().get_costmodel_communi_threshold, | "costmodel_communi_threshold": cost_model_context().get_costmodel_communi_threshold, | ||||
| "costmodel_communi_const": cost_model_context().get_costmodel_communi_const, | "costmodel_communi_const": cost_model_context().get_costmodel_communi_const, | ||||
| "costmodel_communi_bias": cost_model_context().get_costmodel_communi_bias, | "costmodel_communi_bias": cost_model_context().get_costmodel_communi_bias, | ||||
| "multi_subgraphs": cost_model_context().get_multi_subgraphs(), | |||||
| "costmodel_allreduce_fusion_algorithm": cost_model_context().get_costmodel_allreduce_fusion_algorithm, | "costmodel_allreduce_fusion_algorithm": cost_model_context().get_costmodel_allreduce_fusion_algorithm, | ||||
| "costmodel_allreduce_fusion_times": cost_model_context().get_costmodel_allreduce_fusion_times, | "costmodel_allreduce_fusion_times": cost_model_context().get_costmodel_allreduce_fusion_times, | ||||
| "costmodel_allreduce_fusion_tail_percent": cost_model_context().get_costmodel_allreduce_fusion_tail_percent, | "costmodel_allreduce_fusion_tail_percent": cost_model_context().get_costmodel_allreduce_fusion_tail_percent, | ||||
| @@ -461,6 +488,7 @@ get_cost_model_context_func_map = { | |||||
| @args_type_check(device_memory_capacity=float, costmodel_alpha=float, costmodel_beta=float, costmodel_gamma=float, | @args_type_check(device_memory_capacity=float, costmodel_alpha=float, costmodel_beta=float, costmodel_gamma=float, | ||||
| costmodel_communi_threshold=float, costmodel_communi_const=float, costmodel_communi_bias=float, | costmodel_communi_threshold=float, costmodel_communi_const=float, costmodel_communi_bias=float, | ||||
| multi_subgraphs=bool, | |||||
| costmodel_allreduce_fusion_algorithm=int, costmodel_allreduce_fusion_times=int, | costmodel_allreduce_fusion_algorithm=int, costmodel_allreduce_fusion_times=int, | ||||
| costmodel_allreduce_fusion_tail_percent=float, costmodel_allreduce_fusion_tail_time=float, | costmodel_allreduce_fusion_tail_percent=float, costmodel_allreduce_fusion_tail_time=float, | ||||
| costmodel_allreduce_fusion_allreduce_inherent_time=float, | costmodel_allreduce_fusion_allreduce_inherent_time=float, | ||||
| @@ -481,6 +509,7 @@ def set_cost_model_context(**kwargs): | |||||
| costmodel_communi_threshold (float): A parameter used in adjusting communication calculation for practice. | costmodel_communi_threshold (float): A parameter used in adjusting communication calculation for practice. | ||||
| costmodel_communi_const (float): A parameter used in adjusting communication calculation for practice. | costmodel_communi_const (float): A parameter used in adjusting communication calculation for practice. | ||||
| costmodel_communi_bias (float): A parameter used in adjusting communication calculation for practice. | costmodel_communi_bias (float): A parameter used in adjusting communication calculation for practice. | ||||
| multi_subgraphs (bool): A parameter used in marking the flag of ANF graph containing multiple subgraphs. | |||||
| costmodel_allreduce_fusion_algorithm (int): The allreduce fusion algorithm. | costmodel_allreduce_fusion_algorithm (int): The allreduce fusion algorithm. | ||||
| 0: bypass allreduce fusion; | 0: bypass allreduce fusion; | ||||
| 1: only use backward computation time to group allreduce; | 1: only use backward computation time to group allreduce; | ||||
| @@ -0,0 +1,101 @@ | |||||
| import numpy as np | |||||
| from mindspore import context | |||||
| import mindspore as ms | |||||
| import mindspore.nn as nn | |||||
| from mindspore.nn.optim import Adam, FTRL | |||||
| from mindspore.ops import operations as P | |||||
| from mindspore.ops import functional as F | |||||
| from mindspore import Tensor, Parameter, ParameterTuple | |||||
| from mindspore.ops import composite as C | |||||
| from mindspore.parallel import _cost_model_context as cost_model_context | |||||
| from mindspore.common.api import _executor | |||||
| from mindspore.parallel import set_algo_parameters, get_algo_parameters, reset_algo_parameters | |||||
| from mindspore.parallel._utils import _reset_op_id as reset_op_id | |||||
| class Net(nn.Cell): | |||||
| def __init__(self): | |||||
| super(Net, self).__init__() | |||||
| self.mul = P.Mul() | |||||
| self.relu = P.ReLU() | |||||
| self.wd = Parameter(Tensor(np.ones([8, 8, 8, 8]).astype(np.float32)), name="wide") | |||||
| self.wt = Parameter(Tensor(np.ones([8, 8, 8, 8]).astype(np.float32)), name="l") | |||||
| def construct(self, x): | |||||
| out = self.mul(x, self.wd) | |||||
| out = self.mul(out, self.wt) | |||||
| out = self.relu(out) | |||||
| return out | |||||
| class NetWithLoss(nn.Cell): | |||||
| def __init__(self, network): | |||||
| super(NetWithLoss, self).__init__() | |||||
| self.sum = P.ReduceSum() | |||||
| self.mean = P.ReduceMean() | |||||
| self.net = network | |||||
| def construct(self, x): | |||||
| predict = self.net(x) | |||||
| loss1 = self.sum(predict, -1) | |||||
| loss2 = self.mean(predict, -1) | |||||
| return loss1, loss2 | |||||
| class IthOutputCell(nn.Cell): | |||||
| def __init__(self, network, output_index): | |||||
| super(IthOutputCell, self).__init__() | |||||
| self.network = network | |||||
| self.output_index = output_index | |||||
| def construct(self, x): | |||||
| predict = self.network(x)[self.output_index] | |||||
| return predict | |||||
| class TrainStepWarp(nn.Cell): | |||||
| def __init__(self, network, sens=1000.0): | |||||
| super(TrainStepWarp, self).__init__() | |||||
| self.network = network | |||||
| self.network.set_train() | |||||
| self.trainable_params = network.trainable_params() | |||||
| weights_w = [] | |||||
| weights_d = [] | |||||
| for params in self.trainable_params: | |||||
| weights_w.append(params) | |||||
| weights_d.append(params) | |||||
| self.weights_w = ParameterTuple(weights_w) | |||||
| self.weights_d = ParameterTuple(weights_d) | |||||
| self.optimizer_w = FTRL(learning_rate=1e-2, params=self.weights_w, l1=1e-8, | |||||
| l2=1e-8, initial_accum=1.0) | |||||
| self.optimizer_d = Adam(self.weights_d, learning_rate=3.5e-4, eps=1e-8, | |||||
| loss_scale=sens) | |||||
| self.hyper_map = C.HyperMap() | |||||
| self.grad_w = C.GradOperation('grad_w', get_by_list=True, sens_param=True) | |||||
| self.grad_d = C.GradOperation('grad_d', get_by_list=True, sens_param=True) | |||||
| self.sens = sens | |||||
| self.loss_net_w = IthOutputCell(network, output_index=0) | |||||
| self.loss_net_d = IthOutputCell(network, output_index=1) | |||||
| def construct(self, x): | |||||
| weights_w = self.weights_w | |||||
| weights_d = self.weights_d | |||||
| loss_w, loss_d = self.network(x) | |||||
| sens_w = P.Fill()(P.DType()(loss_w), P.Shape()(loss_w), self.sens) | |||||
| sens_d = P.Fill()(P.DType()(loss_d), P.Shape()(loss_d), self.sens) | |||||
| grads_w = self.grad_w(self.loss_net_w, weights_w)(x, sens_w) | |||||
| grads_d = self.grad_d(self.loss_net_d, weights_d)(x, sens_d) | |||||
| return F.depend(loss_w, self.optimizer_w(grads_w)), F.depend(loss_d, self.optimizer_d(grads_d)) | |||||
| def test_double_subgraphs(): | |||||
| cost_model_context.set_cost_model_context(multi_subgraphs=True) | |||||
| context.set_context(save_graphs=True) | |||||
| context.set_auto_parallel_context(device_num=8, global_rank=0) | |||||
| net = TrainStepWarp(NetWithLoss(Net())) | |||||
| context.set_auto_parallel_context(parallel_mode="auto_parallel") | |||||
| x = Tensor(np.ones([8, 8, 8, 8]), dtype=ms.float32) | |||||
| reset_op_id() | |||||
| _executor.compile(net, x, phase='train') | |||||
| strategies = _executor._get_strategy(net) | |||||
| expected_strategies = {'Default/network-NetWithLoss/ReduceMean-op0': [[8, 1, 1, 1]], | |||||
| 'Default/network-NetWithLoss/net-Net/ReLU-op1': [[8, 1, 1, 1]], | |||||
| 'Default/network-NetWithLoss/net-Net/Mul-op2': [[8, 1, 1, 1], [8, 1, 1, 1]], | |||||
| 'Default/network-NetWithLoss/net-Net/Mul-op3': [[8, 1, 1, 1], [8, 1, 1, 1]], | |||||
| 'Default/network-NetWithLoss/ReduceSum-op4': [[8, 1, 1, 1]]} | |||||
| assert strategies == expected_strategies | |||||
| @@ -0,0 +1,70 @@ | |||||
| import numpy as np | |||||
| from mindspore import context | |||||
| import mindspore as ms | |||||
| import mindspore.nn as nn | |||||
| from mindspore.ops import operations as P | |||||
| from mindspore import Tensor | |||||
| from mindspore.common.api import _executor | |||||
| from tests.ut.python.ops.test_math_ops import VirtualLoss | |||||
| from mindspore.parallel import set_algo_parameters | |||||
| from mindspore.parallel._utils import _reset_op_id as reset_op_id | |||||
| import re | |||||
| class NetWithLoss(nn.Cell): | |||||
| def __init__(self, network): | |||||
| super(NetWithLoss, self).__init__() | |||||
| self.loss = VirtualLoss() | |||||
| self.network = network | |||||
| def construct(self, x): | |||||
| predict = self.network(x) | |||||
| return self.loss(predict) | |||||
| class Blockcell(nn.Cell): | |||||
| def __init__(self): | |||||
| super(Blockcell, self).__init__() | |||||
| self.bn = nn.BatchNorm2d(64, momentum=0.9) | |||||
| def construct(self, x): | |||||
| out = self.bn(x) | |||||
| return out | |||||
| def getBlock(): | |||||
| return Blockcell() | |||||
| def test_two_bn(): | |||||
| class Net(nn.Cell): | |||||
| def __init__(self): | |||||
| super().__init__() | |||||
| self.block1 = getBlock() | |||||
| self.block2 = getBlock() | |||||
| self.relu = P.ReLU() | |||||
| self.add = P.TensorAdd() | |||||
| self.bias = Tensor(np.ones([64, 64]), dtype=ms.float32) | |||||
| def construct(self, x): | |||||
| out = self.block1(x) | |||||
| out = self.relu(out) | |||||
| out = self.add(out, self.bias) | |||||
| out = self.block2(out) | |||||
| return out | |||||
| net = NetWithLoss(Net()) | |||||
| x = Tensor(np.ones([64, 64]), dtype=ms.float32) | |||||
| context.set_context(save_graphs=True) | |||||
| context.set_auto_parallel_context(device_num=8, global_rank=0) | |||||
| context.set_auto_parallel_context(parallel_mode="auto_parallel") | |||||
| set_algo_parameters(elementwise_op_strategy_follow=True) | |||||
| reset_op_id() | |||||
| _executor.compile(net, x, phase='train') | |||||
| strategies = _executor._get_strategy(net) | |||||
| assert len(strategies) == 4 | |||||
| for (k, v) in strategies.items(): | |||||
| if re.search('BatchNorm-op', k) is not None: | |||||
| assert v == [[8, 1], [1], [1], [1], [1]] | |||||
| elif re.search('TensorAdd-op', k) is not None: | |||||
| assert v == [[8, 1], [8, 1]] | |||||
| elif re.search('ReLU-op', k) is not None: | |||||
| assert v == [[8, 1]] | |||||