diff --git a/mindspore/ccsrc/frontend/parallel/graph_util/node_info.cc b/mindspore/ccsrc/frontend/parallel/graph_util/node_info.cc index 272260b20d..cc9b598c3e 100644 --- a/mindspore/ccsrc/frontend/parallel/graph_util/node_info.cc +++ b/mindspore/ccsrc/frontend/parallel/graph_util/node_info.cc @@ -18,10 +18,11 @@ #include -#include "ir/anf.h" #include "ir/param_info.h" #include "ir/meta_tensor.h" #include "pipeline/jit/parse/python_adapter.h" +#include "frontend/parallel/ops_info/ops_utils.h" +#include "frontend/parallel/step_parallel.h" namespace mindspore { namespace parallel { @@ -45,5 +46,331 @@ bool ParameterRequireGrad(const AnfNodePtr &node_ptr) { } return param_value->requires_grad(); } + +// Given the node, return whether each input is a parameter or a output of a operator. +// The returned boolean vector should be the same order of the inputs, thus its implementation +// is closely consistent with ExtractShape() in step_parallel.cc +std::vector ExtractInputParameterByNode(const CNodePtr &node) { + std::vector is_parameter; + std::vector node_inputs{node->inputs()}; + // input is a ValueList or ValueTuple, then all inputs are not parameter. + if ((node_inputs.size() == 2) && + (IsValueNode(node_inputs[1]) || IsValueNode(node_inputs[1]))) { + std::vector inputs_seq; + if (IsValueNode(node_inputs[1])) { + inputs_seq = node_inputs[1]->cast()->value()->cast()->value(); + } else { + inputs_seq = node_inputs[1]->cast()->value()->cast()->value(); + } + return std::vector(inputs_seq.size(), false); + } + if ((node_inputs.size() == 2) && + (AnfNodeIsPrimitive(node_inputs[1], MAKE_TUPLE) || AnfNodeIsPrimitive(node_inputs[1], MAKE_LIST))) { + node_inputs = node_inputs[1]->cast()->inputs(); + } + for (size_t i = 1; i < node_inputs.size(); ++i) { + auto input = node_inputs[i]; + + if (input->isa()) { + auto input_parameter = input->cast(); + is_parameter.push_back(ParameterRequireGrad(input_parameter)); + } else if (input->isa() || IsValueNode(input) || IsValueNode(input)) { + is_parameter.push_back(false); + } + } + return is_parameter; +} + +// Given the type, return the number of bytes to represent this type +size_t GetLengthOfDataType(const TypePtr &type) { + switch (type->type_id()) { + case kNumberTypeBool: + return sizeof(bool); + case kNumberTypeInt8: + return sizeof(int8_t); + case kNumberTypeInt16: + return sizeof(int16_t); + case kNumberTypeInt32: + return sizeof(int32_t); + case kNumberTypeInt64: + return sizeof(int64_t); + case kNumberTypeUInt8: + return sizeof(uint8_t); + case kNumberTypeUInt16: + return sizeof(uint16_t); + case kNumberTypeUInt32: + return sizeof(uint32_t); + case kNumberTypeUInt64: + return sizeof(uint64_t); + case kNumberTypeFloat16: + return sizeof(float) / 2; + case kNumberTypeFloat32: + return sizeof(float); + case kNumberTypeFloat64: + return sizeof(double); + case kNumberTypeInt: + return sizeof(int64_t); + case kNumberTypeUInt: + return sizeof(unsigned int64_t); + case kNumberTypeFloat: + return sizeof(float); + default: + MS_LOG(EXCEPTION) << "Unexpected type " << type->type_name(); + } +} + +size_t GetInputsTypeLen(const AnfNodePtr &input) { + MS_EXCEPTION_IF_NULL(input); + if (!input->isa() && !input->isa() && !IsValueNode(input)) { + MS_LOG(EXCEPTION) << "The input node is not a cnode or parameter or tensor"; + } + + size_t input_type_len = 0; + auto type = input->Type(); + MS_EXCEPTION_IF_NULL(type); + if (type->isa()) { + auto input_element_type = type->cast()->element(); + input_type_len = GetLengthOfDataType(input_element_type); + } else { + MS_LOG(EXCEPTION) << "Unknown type: " << type->type_name(); + } + return input_type_len; +} + +std::vector ExtractInputTypeLengthByNode(const CNodePtr &node) { + MS_EXCEPTION_IF_NULL(node); + std::vector inputs_type_len; + std::vector node_inputs{node->inputs()}; + + if ((node_inputs.size() == 2) && + (IsValueNode(node_inputs[1]) || IsValueNode(node_inputs[1]))) { + std::vector inputs_seq; + if (IsValueNode(node_inputs[1])) { + inputs_seq = node_inputs[1]->cast()->value()->cast()->value(); + } else { + inputs_seq = node_inputs[1]->cast()->value()->cast()->value(); + } + for (auto &ele : inputs_seq) { + auto tensor = ele->cast(); + MS_EXCEPTION_IF_NULL(tensor); + inputs_type_len.push_back(GetLengthOfDataType(tensor->Dtype())); + } + return inputs_type_len; + } + + if ((node_inputs.size() == 2) && + (AnfNodeIsPrimitive(node_inputs[1], MAKE_TUPLE) || AnfNodeIsPrimitive(node_inputs[1], MAKE_LIST))) { + node_inputs = node_inputs[1]->cast()->inputs(); + } + + // extract input element length + for (auto &input : node_inputs) { + if (IsValueNode(input)) { + auto func_graph = node->func_graph(); + MS_EXCEPTION_IF_NULL(func_graph); + std::vector parameters = FindParameterByRefKeyNode(input, func_graph); + if (parameters.size() != 1) { + MS_LOG(EXCEPTION) << "Find parameter by ref key node failed"; + } + inputs_type_len.push_back(GetInputsTypeLen(parameters[0])); + } else if (input->isa() || input->isa() || IsValueNode(input)) { + // extract input shape from parameter and apply node + inputs_type_len.push_back(GetInputsTypeLen(input)); + } + } + return inputs_type_len; +} + +std::vector ExtractOutputTypeByNode(const CNodePtr &node) { + MS_EXCEPTION_IF_NULL(node); + std::vector outputs_type; + // extract output element type + auto primary_output_type = node->Type(); + MS_EXCEPTION_IF_NULL(primary_output_type); + if (primary_output_type->isa()) { + // in this case, the output is a tuple + auto tuple_output_type = primary_output_type->cast(); + auto elements = tuple_output_type->elements(); + for (auto &ele : elements) { + if (ele->isa()) { + auto ele_element_type = ele->cast()->element(); + outputs_type.push_back(ele_element_type); + } else { + MS_LOG(EXCEPTION) << "Unknown type: " << primary_output_type->type_name(); + } + } + } else { + // in this case, the output is a single tensor + if (primary_output_type->isa()) { + auto element_type = primary_output_type->cast()->element(); + outputs_type.push_back(element_type); + } else { + MS_LOG(EXCEPTION) << "Unknown type: " << primary_output_type->type_name(); + } + } + return outputs_type; +} + +std::vector FindParameterByRefKeyNode(const AnfNodePtr &node, const FuncGraphPtr &func_graph) { + MS_EXCEPTION_IF_NULL(node); + MS_EXCEPTION_IF_NULL(func_graph); + std::vector parameters; + if (!IsValueNode(node)) { + MS_LOG(ERROR) << "The node is not a ref key"; + return parameters; + } + + auto ref_key = GetValueNode(node); + MS_EXCEPTION_IF_NULL(ref_key); + auto name = ref_key->tag(); + + auto manager = func_graph->manager(); + MS_EXCEPTION_IF_NULL(manager); + auto roots = manager->roots(); + if (roots.size() != 1) { + MS_LOG(ERROR) << "The size of roots ( " << roots.size() << " ) is not 1"; + return parameters; + } + + FuncGraphPtr root_g = roots.back(); + MS_EXCEPTION_IF_NULL(root_g); + for (auto ¶m_node : root_g->parameters()) { + auto param = param_node->cast(); + if (param && (name == param->name())) { + parameters.push_back(param_node); + MS_LOG(INFO) << "The name of ref key is: " << name; + return parameters; + } + } + + MS_LOG(ERROR) << "The name of ref key is: " << name << ", but have not found the parameter"; + return parameters; +} + +bool AnfNodeIsPrimitive(const AnfNodePtr &anf_node, const std::string &prim_name) { + MS_EXCEPTION_IF_NULL(anf_node); + auto cnode = anf_node->cast(); + if ((cnode == nullptr) || !IsValueNode(cnode->input(0))) { + return false; + } + + auto value_node = cnode->input(0)->cast(); + auto prim = GetValueNode(value_node); + MS_EXCEPTION_IF_NULL(prim); + if (prim->name() == prim_name) { + return true; + } + return false; +} + +bool FindReshape(const CNodePtr &cnode, std::unordered_set *op_cache) { + if ((cnode == nullptr) || !IsValueNode(cnode->input(0))) { + return false; + } + if (!IsParallelCareNode(cnode) || !cnode->has_user_data()) { + return false; + } + ValueNodePtr prim_anf_node = cnode->input(0)->cast(); + PrimitivePtr prim = GetValueNode(prim_anf_node); + MS_EXCEPTION_IF_NULL(prim); + if (prim->name() == RESHAPE) { + auto operator_info = cnode->user_data(); + std::string op_info_name = operator_info->name(); + if (op_cache->find(op_info_name) != op_cache->end()) { + return false; + } + op_cache->insert(op_info_name); + return true; + } + return false; +} + +// Find previous node of Reshape, then obtain its strategy_cost_ vector to get its layout vector. +bool FindReshapePreNodeStraCosts(const AnfNodePtr &node, OperatorInfoPtr *pre_operator_info, int64_t *out_index) { + // if previous node is a parameter, handle it in the outsize. + if (node->isa()) { + return false; + } + if (!node->isa()) { + return false; + } + CNodePtr cnode = node->cast(); + if (!IsValueNode(cnode->input(0))) { + return false; + } + auto node_op_info = cnode->user_data(); + if (IsParallelCareNode(cnode) && (node_op_info != nullptr)) { + *pre_operator_info = node_op_info; + *out_index = 0; + return true; + } + ValueNodePtr prim_anf_node = cnode->input(0)->cast(); + PrimitivePtr prim = prim_anf_node->value()->cast(); + if (prim->name() == TUPLE_GETITEM) { + *out_index = GetTupleGetItemIndex(cnode); + // find tuple_get_item's previous node + auto pre_node = cnode->input(1); + if (!pre_node->isa()) { + MS_LOG(EXCEPTION) << "tuple get item's second input is not a cnode"; + } + CNodePtr pre_cnode = pre_node->cast(); + auto pre_op_info = pre_cnode->user_data(); + if (IsParallelCareNode(pre_cnode) && (pre_op_info != nullptr)) { + *pre_operator_info = pre_op_info; + return true; + } + return false; + } + for (size_t index = 0; index < cnode->inputs().size(); ++index) { + if (prim->name() == DEPEND && index != 1) { + continue; + } + if (!FindReshapePreNodeStraCosts(cnode->inputs()[index], pre_operator_info, out_index)) { + continue; + } + return true; + } + MS_LOG(WARNING) + << "FindReshapePreNodeStraCosts failed, if reshape is not the first primitive, there must be some error"; + return false; +} + +// Find next node of Reshape, then obtain its strategy_cost_ vector to get its layout vector. +// if reshape's output connect to several primitive, return the first layout found +bool FindReshapeNextNodeStraCosts(const CNodePtr &cnode, OperatorInfoPtr *next_operator_info, int64_t *in_index) { + MS_EXCEPTION_IF_NULL(cnode); + MS_EXCEPTION_IF_NULL(cnode->func_graph()); + FuncGraphManagerPtr manager = cnode->func_graph()->manager(); + MS_EXCEPTION_IF_NULL(manager); + AnfNodeIndexSet node_set = manager->node_users()[cnode]; + for (auto &node_pair : node_set) { + CNodePtr use_apply = node_pair.first->cast(); + if (use_apply == nullptr || !IsValueNode(use_apply->input(0))) { + continue; + } + ValueNodePtr prim_anf_node = use_apply->input(0)->cast(); + MS_EXCEPTION_IF_NULL(prim_anf_node); + PrimitivePtr node_prim = prim_anf_node->value()->cast(); + MS_EXCEPTION_IF_NULL(node_prim); + MS_LOG(INFO) << "FindNextLayout prim " << node_prim->name(); + if (node_prim->name() == DEPEND && node_pair.second != 1) { + continue; + } + auto op_info = use_apply->user_data(); + if (IsParallelCareNode(use_apply) && (op_info != nullptr)) { + MS_LOG(INFO) << "FindReshapeNextNodeStraCosts success prim " << node_prim->name(); + *next_operator_info = op_info; + *in_index = node_pair.second - 1; + return true; + } + MS_LOG(DEBUG) << "FindReshapeNextNodeStraCosts failed prim " << node_prim->name() << " " + << IsParallelCareNode(use_apply) << " " << (op_info != nullptr); + + if (FindReshapeNextNodeStraCosts(use_apply, next_operator_info, in_index)) { + return true; + } + } + return false; +} } // namespace parallel } // namespace mindspore diff --git a/mindspore/ccsrc/frontend/parallel/graph_util/node_info.h b/mindspore/ccsrc/frontend/parallel/graph_util/node_info.h index ea97747e1d..105893deb3 100644 --- a/mindspore/ccsrc/frontend/parallel/graph_util/node_info.h +++ b/mindspore/ccsrc/frontend/parallel/graph_util/node_info.h @@ -18,13 +18,37 @@ #define MINDSPORE_CCSRC_FRONTEND_PARALLEL_GRAPH_UTIL_NODE_INFO_H_ #include +#include +#include +#include #include "base/base.h" +#include "ir/anf.h" +#include "frontend/parallel/ops_info/operator_info.h" namespace mindspore { namespace parallel { +using OperatorInfoPtr = std::shared_ptr; std::string ParameterName(const AnfNodePtr &node_ptr); bool ParameterRequireGrad(const AnfNodePtr &node_ptr); + +size_t GetLengthOfDataType(const TypePtr &type); + +std::vector ExtractInputParameterByNode(const CNodePtr &node); + +std::vector ExtractInputTypeLengthByNode(const CNodePtr &node); + +std::vector ExtractOutputTypeByNode(const CNodePtr &node); + +std::vector FindParameterByRefKeyNode(const AnfNodePtr &node, const FuncGraphPtr &func_graph); + +bool AnfNodeIsPrimitive(const AnfNodePtr &anf_node, const std::string &prim_name); + +bool FindReshape(const CNodePtr &cnode, std::unordered_set *op_cache); + +bool FindReshapePreNodeStraCosts(const AnfNodePtr &node, OperatorInfoPtr *pre_operator_info, int64_t *out_index); + +bool FindReshapeNextNodeStraCosts(const CNodePtr &cnode, OperatorInfoPtr *next_operator_info, int64_t *in_index); } // namespace parallel } // namespace mindspore diff --git a/mindspore/ccsrc/frontend/parallel/step_auto_parallel.cc b/mindspore/ccsrc/frontend/parallel/step_auto_parallel.cc index c21bdac67f..cfd6f396e1 100644 --- a/mindspore/ccsrc/frontend/parallel/step_auto_parallel.cc +++ b/mindspore/ccsrc/frontend/parallel/step_auto_parallel.cc @@ -63,7 +63,7 @@ bool StepAutoParallel(const FuncGraphPtr &root, const opt::OptimizerPtr &) { // check whether strategy_search_mode is valid std::string strategy_search_mode = ParallelContext::GetInstance()->strategy_search_mode(); if ((strategy_search_mode != DYNAMIC_PROGRAMMING) && (strategy_search_mode != RECURSIVE_PROGRAMMING)) { - // Setting searching mode: dynanic programming as default. + // Setting searching mode: dynamic programming as default. strategy_search_mode = DYNAMIC_PROGRAMMING; MS_LOG(INFO) << "Non-idicated strategy searching mode, using DP searching mode as default"; } @@ -112,170 +112,6 @@ bool StepAutoParallel(const FuncGraphPtr &root, const opt::OptimizerPtr &) { return changes; } -// Given the node, return whether each input is a parameter or a output of a operator. -// The returned boolean vector should be the same order of the inputs, thus its implementation -// is closely consistent with ExtractShape() in step_parallel.cc -std::vector ExtractInputParameterByNode(const CNodePtr &node) { - std::vector is_parameter; - std::vector node_inputs{node->inputs()}; - // input is a ValueList or ValueTuple, then all inputs are not parameter. - if ((node_inputs.size() == 2) && - (IsValueNode(node_inputs[1]) || IsValueNode(node_inputs[1]))) { - std::vector inputs_seq; - if (IsValueNode(node_inputs[1])) { - inputs_seq = node_inputs[1]->cast()->value()->cast()->value(); - } else { - inputs_seq = node_inputs[1]->cast()->value()->cast()->value(); - } - return std::vector(inputs_seq.size(), false); - } - if ((node_inputs.size() == 2) && - (AnfNodeIsPrimitive(node_inputs[1], MAKE_TUPLE) || AnfNodeIsPrimitive(node_inputs[1], MAKE_LIST))) { - node_inputs = node_inputs[1]->cast()->inputs(); - } - for (size_t i = 1; i < node_inputs.size(); ++i) { - auto input = node_inputs[i]; - - if (input->isa()) { - auto input_parameter = input->cast(); - is_parameter.push_back(ParameterRequireGrad(input_parameter)); - } else if (input->isa() || IsValueNode(input) || IsValueNode(input)) { - is_parameter.push_back(false); - } - } - return is_parameter; -} - -// Given the type, return the number of bytes to represent this type -size_t GetLengthOfDataType(const TypePtr &type) { - switch (type->type_id()) { - case kNumberTypeBool: - return sizeof(bool); - case kNumberTypeInt8: - return sizeof(int8_t); - case kNumberTypeInt16: - return sizeof(int16_t); - case kNumberTypeInt32: - return sizeof(int32_t); - case kNumberTypeInt64: - return sizeof(int64_t); - case kNumberTypeUInt8: - return sizeof(uint8_t); - case kNumberTypeUInt16: - return sizeof(uint16_t); - case kNumberTypeUInt32: - return sizeof(uint32_t); - case kNumberTypeUInt64: - return sizeof(uint64_t); - case kNumberTypeFloat16: - return sizeof(float) / 2; - case kNumberTypeFloat32: - return sizeof(float); - case kNumberTypeFloat64: - return sizeof(double); - case kNumberTypeInt: - return sizeof(int64_t); - case kNumberTypeUInt: - return sizeof(unsigned int64_t); - case kNumberTypeFloat: - return sizeof(float); - default: - MS_LOG(EXCEPTION) << "Unexpected type " << type->type_name(); - } -} - -size_t GetInputsTypeLen(const AnfNodePtr &input) { - MS_EXCEPTION_IF_NULL(input); - if (!input->isa() && !input->isa() && !IsValueNode(input)) { - MS_LOG(EXCEPTION) << "The input node is not a cnode or parameter or tensor"; - } - - size_t input_type_len = 0; - auto type = input->Type(); - MS_EXCEPTION_IF_NULL(type); - if (type->isa()) { - auto input_element_type = type->cast()->element(); - input_type_len = GetLengthOfDataType(input_element_type); - } else { - MS_LOG(EXCEPTION) << "Unknown type: " << type->type_name(); - } - return input_type_len; -} - -std::vector ExtractInputTypeLengthByNode(const CNodePtr &node) { - MS_EXCEPTION_IF_NULL(node); - std::vector inputs_type_len; - std::vector node_inputs{node->inputs()}; - - if ((node_inputs.size() == 2) && - (IsValueNode(node_inputs[1]) || IsValueNode(node_inputs[1]))) { - std::vector inputs_seq; - if (IsValueNode(node_inputs[1])) { - inputs_seq = node_inputs[1]->cast()->value()->cast()->value(); - } else { - inputs_seq = node_inputs[1]->cast()->value()->cast()->value(); - } - for (auto &ele : inputs_seq) { - auto tensor = ele->cast(); - MS_EXCEPTION_IF_NULL(tensor); - inputs_type_len.push_back(GetLengthOfDataType(tensor->Dtype())); - } - return inputs_type_len; - } - - if ((node_inputs.size() == 2) && - (AnfNodeIsPrimitive(node_inputs[1], MAKE_TUPLE) || AnfNodeIsPrimitive(node_inputs[1], MAKE_LIST))) { - node_inputs = node_inputs[1]->cast()->inputs(); - } - - // extract input element length - for (auto &input : node_inputs) { - if (IsValueNode(input)) { - auto func_graph = node->func_graph(); - MS_EXCEPTION_IF_NULL(func_graph); - std::vector parameters = FindParameterByRefKeyNode(input, func_graph); - if (parameters.size() != 1) { - MS_LOG(EXCEPTION) << "Find parameter by ref key node failed"; - } - inputs_type_len.push_back(GetInputsTypeLen(parameters[0])); - } else if (input->isa() || input->isa() || IsValueNode(input)) { - // extract input shape from parameter and apply node - inputs_type_len.push_back(GetInputsTypeLen(input)); - } - } - return inputs_type_len; -} - -std::vector ExtractOutputTypeByNode(const CNodePtr &node) { - MS_EXCEPTION_IF_NULL(node); - std::vector outputs_type; - // extract output element type - auto primary_output_type = node->Type(); - MS_EXCEPTION_IF_NULL(primary_output_type); - if (primary_output_type->isa()) { - // in this case, the output is a tuple - auto tuple_output_type = primary_output_type->cast(); - auto elements = tuple_output_type->elements(); - for (auto &ele : elements) { - if (ele->isa()) { - auto ele_element_type = ele->cast()->element(); - outputs_type.push_back(ele_element_type); - } else { - MS_LOG(EXCEPTION) << "Unknown type: " << primary_output_type->type_name(); - } - } - } else { - // in this case, the output is a single tensor - if (primary_output_type->isa()) { - auto element_type = primary_output_type->cast()->element(); - outputs_type.push_back(element_type); - } else { - MS_LOG(EXCEPTION) << "Unknown type: " << primary_output_type->type_name(); - } - } - return outputs_type; -} - bool IsElementWiseOperator(const std::string &op_name) { // clang-format off static const std::set elementwise_op = {ACTIVATION, GELU, TANH, @@ -381,6 +217,11 @@ bool IsOperatorsInTwoSeparateLoops(const CNodePtr &a_cnode, const CNodePtr &b_cn return true; } +void InitCostGraph() { + entire_costgraph = std::make_shared(); + entire_costgraph->SetDeviceMemoryAndCostParameter(); +} + OperatorInfoPtr CreateTheOperatorInfo(const PrimitivePtr &prim, const CNodePtr &cnode, StrategyMap *stra_map) { MS_EXCEPTION_IF_NULL(prim); MS_EXCEPTION_IF_NULL(cnode); @@ -491,8 +332,6 @@ OperatorInfoPtr CreateTheOperatorInfo(const PrimitivePtr &prim, const CNodePtr & // Using CNode's UniqueIds to construct nodes Status ConstructCostGraphNodesByUniqueId(const std::vector &all_nodes, const FuncGraphPtr &) { MS_LOG(INFO) << "Constructing nodes for cost graph begins."; - entire_costgraph = std::make_shared(); - entire_costgraph->SetDeviceMemoryAndCostParameter(); // The map from CNode's UniqueId to its operatorInfo std::map from_cnode_to_info; // The operator_infos in a loop @@ -506,7 +345,7 @@ Status ConstructCostGraphNodesByUniqueId(const std::vector &all_node MS_LOG(EXCEPTION) << "Load strategy checkpoint failed"; } } - // Step 1 + for (auto &node : all_nodes) { // NOTE: we only care about splittable Primitive operators auto cnode = node->cast(); @@ -588,8 +427,6 @@ Status ConstructCostGraphNodesByUniqueId(const std::vector &all_node // Using CNode's UniqueIdThroughCopys to construct nodes Status ConstructCostGraphNodesByUniqueIdTC(const std::vector &all_nodes, const FuncGraphPtr &) { MS_LOG(INFO) << "Constructing nodes for cost graph begins."; - entire_costgraph = std::make_shared(); - entire_costgraph->SetDeviceMemoryAndCostParameter(); // The map from CNode's UniqueIdThroughCopy to its operatorInfo std::map from_cnode_to_info; // The operator_infos in a loop @@ -937,115 +774,6 @@ void AugmentCostGraph(const std::vector &all_nodes) { } } -bool FindReshape(const CNodePtr &cnode, std::unordered_set *op_cache) { - if ((cnode == nullptr) || !IsValueNode(cnode->input(0))) { - return false; - } - if (!IsParallelCareNode(cnode) || !cnode->has_user_data()) { - return false; - } - ValueNodePtr prim_anf_node = cnode->input(0)->cast(); - PrimitivePtr prim = GetValueNode(prim_anf_node); - MS_EXCEPTION_IF_NULL(prim); - if (prim->name() == RESHAPE) { - auto operator_info = cnode->user_data(); - std::string op_info_name = operator_info->name(); - if (op_cache->find(op_info_name) != op_cache->end()) { - return false; - } - op_cache->insert(op_info_name); - return true; - } - return false; -} - -// find previous node, then obtain its strategy_cost_ vector to get its layout vector. -bool FindPreNodeStraCosts(const AnfNodePtr &node, OperatorInfoPtr *pre_operator_info, int64_t *out_index) { - // if previous node is a parameter, handle it in the outsize. - if (node->isa()) { - return false; - } - if (!node->isa()) { - return false; - } - CNodePtr cnode = node->cast(); - if (!IsValueNode(cnode->input(0))) { - return false; - } - auto node_op_info = cnode->user_data(); - if (IsParallelCareNode(cnode) && (node_op_info != nullptr)) { - *pre_operator_info = node_op_info; - *out_index = 0; - return true; - } - ValueNodePtr prim_anf_node = cnode->input(0)->cast(); - PrimitivePtr prim = prim_anf_node->value()->cast(); - if (prim->name() == TUPLE_GETITEM) { - *out_index = GetTupleGetItemIndex(cnode); - // find tuple_get_item's previous node - auto pre_node = cnode->input(1); - if (!pre_node->isa()) { - MS_LOG(EXCEPTION) << "tuple get item's second input is not a cnode"; - } - CNodePtr pre_cnode = pre_node->cast(); - auto pre_op_info = pre_cnode->user_data(); - if (IsParallelCareNode(pre_cnode) && (pre_op_info != nullptr)) { - *pre_operator_info = pre_op_info; - return true; - } - return false; - } - for (size_t index = 0; index < cnode->inputs().size(); ++index) { - if (prim->name() == DEPEND && index != 1) { - continue; - } - if (!FindPreNodeStraCosts(cnode->inputs()[index], pre_operator_info, out_index)) { - continue; - } - return true; - } - MS_LOG(WARNING) << "FindPreNodeStraCosts failed, if reshape is not the first primitive, there must be some error"; - return false; -} - -// find next node, then obtain its strategy_cost_ vector to get its layout vector. -// if reshape's output connect to several primitive, return the first layout found -bool FindNextNodeStraCosts(const CNodePtr &cnode, OperatorInfoPtr *next_operator_info, int64_t *in_index) { - MS_EXCEPTION_IF_NULL(cnode); - MS_EXCEPTION_IF_NULL(cnode->func_graph()); - FuncGraphManagerPtr manager = cnode->func_graph()->manager(); - MS_EXCEPTION_IF_NULL(manager); - AnfNodeIndexSet node_set = manager->node_users()[cnode]; - for (auto &node_pair : node_set) { - CNodePtr use_apply = node_pair.first->cast(); - if (use_apply == nullptr || !IsValueNode(use_apply->input(0))) { - continue; - } - ValueNodePtr prim_anf_node = use_apply->input(0)->cast(); - MS_EXCEPTION_IF_NULL(prim_anf_node); - PrimitivePtr node_prim = prim_anf_node->value()->cast(); - MS_EXCEPTION_IF_NULL(node_prim); - MS_LOG(INFO) << "FindNextLayout prim " << node_prim->name(); - if (node_prim->name() == DEPEND && node_pair.second != 1) { - continue; - } - auto op_info = use_apply->user_data(); - if (IsParallelCareNode(use_apply) && (op_info != nullptr)) { - MS_LOG(INFO) << "FindNextNodeStraCosts success prim " << node_prim->name(); - *next_operator_info = op_info; - *in_index = node_pair.second - 1; - return true; - } - MS_LOG(DEBUG) << "FindNextNodeStraCosts failed prim " << node_prim->name() << " " << IsParallelCareNode(use_apply) - << " " << (op_info != nullptr); - - if (FindNextNodeStraCosts(use_apply, next_operator_info, in_index)) { - return true; - } - } - return false; -} - void ReshapeCostCompute(const std::vector &all_nodes) { std::unordered_set op_cache; for (auto node : all_nodes) { @@ -1066,8 +794,8 @@ void ReshapeCostCompute(const std::vector &all_nodes) { pre_operator_info = reshape_info; pre_stra_costs = reshape_info->strategy_cost(); } else { - if (!FindPreNodeStraCosts(pre_node, &pre_operator_info, &out_index)) { - MS_LOG(EXCEPTION) << "FindPreNodeStraCosts for reshape failed"; + if (!FindReshapePreNodeStraCosts(pre_node, &pre_operator_info, &out_index)) { + MS_LOG(EXCEPTION) << "FindReshapePreNodeStraCosts for reshape failed"; } pre_stra_costs = pre_operator_info->strategy_cost(); } @@ -1075,9 +803,9 @@ void ReshapeCostCompute(const std::vector &all_nodes) { int64_t in_index = 0; OperatorInfoPtr next_operator_info; std::vector> next_stra_costs; - bool find_next_node = FindNextNodeStraCosts(cnode, &next_operator_info, &in_index); + bool find_next_node = FindReshapeNextNodeStraCosts(cnode, &next_operator_info, &in_index); if (!find_next_node) { - MS_LOG(INFO) << "FindNextNodeStraCosts for reshape failed"; + MS_LOG(INFO) << "FindReshapeNextNodeStraCosts for reshape failed"; } // set input_layout and output_layout for reshape. // init reshape and set cost for each input_layout and output_layout. @@ -1122,6 +850,7 @@ Status ParallelStrategySearch(const std::vector &all_nodes, const Fu // // OUTPUT: the determined strategy for each operator. + InitCostGraph(); // Step 1 if (CostModelContext::GetInstance()->is_multi_subgraphs()) { if (ConstructCostGraphNodesByUniqueIdTC(all_nodes, root) == SUCCESS) { diff --git a/mindspore/ccsrc/frontend/parallel/step_auto_parallel.h b/mindspore/ccsrc/frontend/parallel/step_auto_parallel.h index 13d96ce334..17af1b96a8 100644 --- a/mindspore/ccsrc/frontend/parallel/step_auto_parallel.h +++ b/mindspore/ccsrc/frontend/parallel/step_auto_parallel.h @@ -28,20 +28,14 @@ namespace mindspore { namespace parallel { -bool IsSplittableOperator(const std::string &); - -bool IsAutoParallelCareNode(const CNodePtr &); - // main step of Auto-parallel bool StepAutoParallel(const FuncGraphPtr &func_graph, const opt::OptimizerPtr &optimizer); -size_t GetLengthOfDataType(const TypePtr &type); - -std::vector ExtractInputParameterByNode(const CNodePtr &node); +bool IsSplittableOperator(const std::string &); -std::vector ExtractInputTypeLengthByNode(const CNodePtr &node); +bool IsAutoParallelCareNode(const CNodePtr &); -std::vector ExtractOutputTypeByNode(const CNodePtr &node); +void InitCostGraph(); Status ConstructCostGraphNodesByUniqueId(const std::vector &all_nodes, const FuncGraphPtr &root); diff --git a/mindspore/ccsrc/frontend/parallel/step_parallel.cc b/mindspore/ccsrc/frontend/parallel/step_parallel.cc index a26234b7ae..7fe85d60e0 100644 --- a/mindspore/ccsrc/frontend/parallel/step_parallel.cc +++ b/mindspore/ccsrc/frontend/parallel/step_parallel.cc @@ -292,22 +292,6 @@ TensorLayout GetTensorInLayout(const CNodePtr &middle_node, const PrimitivePtr & return tensorinfo_in.tensor_layout(); } -bool AnfNodeIsPrimitive(const AnfNodePtr &anf_node, const std::string &prim_name) { - MS_EXCEPTION_IF_NULL(anf_node); - auto cnode = anf_node->cast(); - if ((cnode == nullptr) || !IsValueNode(cnode->input(0))) { - return false; - } - - auto value_node = cnode->input(0)->cast(); - auto prim = GetValueNode(value_node); - MS_EXCEPTION_IF_NULL(prim); - if (prim->name() == prim_name) { - return true; - } - return false; -} - std::string GetPrimName(const CNodePtr &node) { MS_EXCEPTION_IF_NULL(node); if (!IsValueNode(node->input(0))) { @@ -1219,42 +1203,6 @@ Shapes GetNodeShape(const AnfNodePtr &node) { return shapes; } -std::vector FindParameterByRefKeyNode(const AnfNodePtr &node, const FuncGraphPtr &func_graph) { - MS_EXCEPTION_IF_NULL(node); - MS_EXCEPTION_IF_NULL(func_graph); - std::vector parameters; - if (!IsValueNode(node)) { - MS_LOG(ERROR) << "The node is not a ref key"; - return parameters; - } - - auto ref_key = GetValueNode(node); - MS_EXCEPTION_IF_NULL(ref_key); - auto name = ref_key->tag(); - - auto manager = func_graph->manager(); - MS_EXCEPTION_IF_NULL(manager); - auto roots = manager->roots(); - if (roots.size() != 1) { - MS_LOG(ERROR) << "The size of roots ( " << roots.size() << " ) is not 1"; - return parameters; - } - - FuncGraphPtr root_g = roots.back(); - MS_EXCEPTION_IF_NULL(root_g); - for (auto ¶m_node : root_g->parameters()) { - auto param = param_node->cast(); - if (param && (name == param->name())) { - parameters.push_back(param_node); - MS_LOG(INFO) << "The name of ref key is: " << name; - return parameters; - } - } - - MS_LOG(ERROR) << "The name of ref key is: " << name << ", but have not found the parameter"; - return parameters; -} - Shapes GetRefKeyNodeShape(const AnfNodePtr &node, const FuncGraphPtr &func_graph) { MS_EXCEPTION_IF_NULL(node); MS_EXCEPTION_IF_NULL(func_graph); diff --git a/mindspore/ccsrc/frontend/parallel/step_parallel.h b/mindspore/ccsrc/frontend/parallel/step_parallel.h index ab4ecdf101..a9c6a436d7 100644 --- a/mindspore/ccsrc/frontend/parallel/step_parallel.h +++ b/mindspore/ccsrc/frontend/parallel/step_parallel.h @@ -100,8 +100,6 @@ StrategyPtr ExtractStrategy(std::unordered_map attrs); Shapes GetNodeShape(const AnfNodePtr &node); -std::vector FindParameterByRefKeyNode(const AnfNodePtr &node, const FuncGraphPtr &func_graph); - // Extract shape from anfnode std::vector ExtractShape(const CNodePtr &node); @@ -154,8 +152,6 @@ std::set ForwardGraph(const FuncGraphPtr &root); std::vector ExtractInputsTensorName(const CNodePtr &node); -bool AnfNodeIsPrimitive(const AnfNodePtr &anf_node, const std::string &prim_name); - using RefKeyPair = std::pair>; using ParameterUsersInfo = std::pair>;