From: @xiaoda_zh Reviewed-by: @kisnwang,@stsuteng Signed-off-by: @stsutengtags/v1.1.0
| @@ -18,10 +18,11 @@ | |||||
| #include <string> | #include <string> | ||||
| #include "ir/anf.h" | |||||
| #include "ir/param_info.h" | #include "ir/param_info.h" | ||||
| #include "ir/meta_tensor.h" | #include "ir/meta_tensor.h" | ||||
| #include "pipeline/jit/parse/python_adapter.h" | #include "pipeline/jit/parse/python_adapter.h" | ||||
| #include "frontend/parallel/ops_info/ops_utils.h" | |||||
| #include "frontend/parallel/step_parallel.h" | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace parallel { | namespace parallel { | ||||
| @@ -45,5 +46,331 @@ bool ParameterRequireGrad(const AnfNodePtr &node_ptr) { | |||||
| } | } | ||||
| return param_value->requires_grad(); | return param_value->requires_grad(); | ||||
| } | } | ||||
| // Given the node, return whether each input is a parameter or a output of a operator. | |||||
| // The returned boolean vector should be the same order of the inputs, thus its implementation | |||||
| // is closely consistent with ExtractShape() in step_parallel.cc | |||||
| std::vector<bool> ExtractInputParameterByNode(const CNodePtr &node) { | |||||
| std::vector<bool> is_parameter; | |||||
| std::vector<AnfNodePtr> node_inputs{node->inputs()}; | |||||
| // input is a ValueList or ValueTuple, then all inputs are not parameter. | |||||
| if ((node_inputs.size() == 2) && | |||||
| (IsValueNode<ValueList>(node_inputs[1]) || IsValueNode<ValueTuple>(node_inputs[1]))) { | |||||
| std::vector<ValuePtr> inputs_seq; | |||||
| if (IsValueNode<ValueList>(node_inputs[1])) { | |||||
| inputs_seq = node_inputs[1]->cast<ValueNodePtr>()->value()->cast<ValueListPtr>()->value(); | |||||
| } else { | |||||
| inputs_seq = node_inputs[1]->cast<ValueNodePtr>()->value()->cast<ValueTuplePtr>()->value(); | |||||
| } | |||||
| return std::vector<bool>(inputs_seq.size(), false); | |||||
| } | |||||
| if ((node_inputs.size() == 2) && | |||||
| (AnfNodeIsPrimitive(node_inputs[1], MAKE_TUPLE) || AnfNodeIsPrimitive(node_inputs[1], MAKE_LIST))) { | |||||
| node_inputs = node_inputs[1]->cast<CNodePtr>()->inputs(); | |||||
| } | |||||
| for (size_t i = 1; i < node_inputs.size(); ++i) { | |||||
| auto input = node_inputs[i]; | |||||
| if (input->isa<Parameter>()) { | |||||
| auto input_parameter = input->cast<ParameterPtr>(); | |||||
| is_parameter.push_back(ParameterRequireGrad(input_parameter)); | |||||
| } else if (input->isa<CNode>() || IsValueNode<tensor::Tensor>(input) || IsValueNode<RefKey>(input)) { | |||||
| is_parameter.push_back(false); | |||||
| } | |||||
| } | |||||
| return is_parameter; | |||||
| } | |||||
| // Given the type, return the number of bytes to represent this type | |||||
| size_t GetLengthOfDataType(const TypePtr &type) { | |||||
| switch (type->type_id()) { | |||||
| case kNumberTypeBool: | |||||
| return sizeof(bool); | |||||
| case kNumberTypeInt8: | |||||
| return sizeof(int8_t); | |||||
| case kNumberTypeInt16: | |||||
| return sizeof(int16_t); | |||||
| case kNumberTypeInt32: | |||||
| return sizeof(int32_t); | |||||
| case kNumberTypeInt64: | |||||
| return sizeof(int64_t); | |||||
| case kNumberTypeUInt8: | |||||
| return sizeof(uint8_t); | |||||
| case kNumberTypeUInt16: | |||||
| return sizeof(uint16_t); | |||||
| case kNumberTypeUInt32: | |||||
| return sizeof(uint32_t); | |||||
| case kNumberTypeUInt64: | |||||
| return sizeof(uint64_t); | |||||
| case kNumberTypeFloat16: | |||||
| return sizeof(float) / 2; | |||||
| case kNumberTypeFloat32: | |||||
| return sizeof(float); | |||||
| case kNumberTypeFloat64: | |||||
| return sizeof(double); | |||||
| case kNumberTypeInt: | |||||
| return sizeof(int64_t); | |||||
| case kNumberTypeUInt: | |||||
| return sizeof(unsigned int64_t); | |||||
| case kNumberTypeFloat: | |||||
| return sizeof(float); | |||||
| default: | |||||
| MS_LOG(EXCEPTION) << "Unexpected type " << type->type_name(); | |||||
| } | |||||
| } | |||||
| size_t GetInputsTypeLen(const AnfNodePtr &input) { | |||||
| MS_EXCEPTION_IF_NULL(input); | |||||
| if (!input->isa<CNode>() && !input->isa<Parameter>() && !IsValueNode<tensor::Tensor>(input)) { | |||||
| MS_LOG(EXCEPTION) << "The input node is not a cnode or parameter or tensor"; | |||||
| } | |||||
| size_t input_type_len = 0; | |||||
| auto type = input->Type(); | |||||
| MS_EXCEPTION_IF_NULL(type); | |||||
| if (type->isa<mindspore::TensorType>()) { | |||||
| auto input_element_type = type->cast<mindspore::TensorTypePtr>()->element(); | |||||
| input_type_len = GetLengthOfDataType(input_element_type); | |||||
| } else { | |||||
| MS_LOG(EXCEPTION) << "Unknown type: " << type->type_name(); | |||||
| } | |||||
| return input_type_len; | |||||
| } | |||||
| std::vector<size_t> ExtractInputTypeLengthByNode(const CNodePtr &node) { | |||||
| MS_EXCEPTION_IF_NULL(node); | |||||
| std::vector<size_t> inputs_type_len; | |||||
| std::vector<AnfNodePtr> node_inputs{node->inputs()}; | |||||
| if ((node_inputs.size() == 2) && | |||||
| (IsValueNode<ValueList>(node_inputs[1]) || IsValueNode<ValueTuple>(node_inputs[1]))) { | |||||
| std::vector<ValuePtr> inputs_seq; | |||||
| if (IsValueNode<ValueList>(node_inputs[1])) { | |||||
| inputs_seq = node_inputs[1]->cast<ValueNodePtr>()->value()->cast<ValueListPtr>()->value(); | |||||
| } else { | |||||
| inputs_seq = node_inputs[1]->cast<ValueNodePtr>()->value()->cast<ValueTuplePtr>()->value(); | |||||
| } | |||||
| for (auto &ele : inputs_seq) { | |||||
| auto tensor = ele->cast<tensor::TensorPtr>(); | |||||
| MS_EXCEPTION_IF_NULL(tensor); | |||||
| inputs_type_len.push_back(GetLengthOfDataType(tensor->Dtype())); | |||||
| } | |||||
| return inputs_type_len; | |||||
| } | |||||
| if ((node_inputs.size() == 2) && | |||||
| (AnfNodeIsPrimitive(node_inputs[1], MAKE_TUPLE) || AnfNodeIsPrimitive(node_inputs[1], MAKE_LIST))) { | |||||
| node_inputs = node_inputs[1]->cast<CNodePtr>()->inputs(); | |||||
| } | |||||
| // extract input element length | |||||
| for (auto &input : node_inputs) { | |||||
| if (IsValueNode<RefKey>(input)) { | |||||
| auto func_graph = node->func_graph(); | |||||
| MS_EXCEPTION_IF_NULL(func_graph); | |||||
| std::vector<AnfNodePtr> parameters = FindParameterByRefKeyNode(input, func_graph); | |||||
| if (parameters.size() != 1) { | |||||
| MS_LOG(EXCEPTION) << "Find parameter by ref key node failed"; | |||||
| } | |||||
| inputs_type_len.push_back(GetInputsTypeLen(parameters[0])); | |||||
| } else if (input->isa<CNode>() || input->isa<Parameter>() || IsValueNode<tensor::Tensor>(input)) { | |||||
| // extract input shape from parameter and apply node | |||||
| inputs_type_len.push_back(GetInputsTypeLen(input)); | |||||
| } | |||||
| } | |||||
| return inputs_type_len; | |||||
| } | |||||
| std::vector<TypePtr> ExtractOutputTypeByNode(const CNodePtr &node) { | |||||
| MS_EXCEPTION_IF_NULL(node); | |||||
| std::vector<TypePtr> outputs_type; | |||||
| // extract output element type | |||||
| auto primary_output_type = node->Type(); | |||||
| MS_EXCEPTION_IF_NULL(primary_output_type); | |||||
| if (primary_output_type->isa<mindspore::Tuple>()) { | |||||
| // in this case, the output is a tuple | |||||
| auto tuple_output_type = primary_output_type->cast<mindspore::TuplePtr>(); | |||||
| auto elements = tuple_output_type->elements(); | |||||
| for (auto &ele : elements) { | |||||
| if (ele->isa<mindspore::TensorType>()) { | |||||
| auto ele_element_type = ele->cast<mindspore::TensorTypePtr>()->element(); | |||||
| outputs_type.push_back(ele_element_type); | |||||
| } else { | |||||
| MS_LOG(EXCEPTION) << "Unknown type: " << primary_output_type->type_name(); | |||||
| } | |||||
| } | |||||
| } else { | |||||
| // in this case, the output is a single tensor | |||||
| if (primary_output_type->isa<mindspore::TensorType>()) { | |||||
| auto element_type = primary_output_type->cast<mindspore::TensorTypePtr>()->element(); | |||||
| outputs_type.push_back(element_type); | |||||
| } else { | |||||
| MS_LOG(EXCEPTION) << "Unknown type: " << primary_output_type->type_name(); | |||||
| } | |||||
| } | |||||
| return outputs_type; | |||||
| } | |||||
| std::vector<AnfNodePtr> FindParameterByRefKeyNode(const AnfNodePtr &node, const FuncGraphPtr &func_graph) { | |||||
| MS_EXCEPTION_IF_NULL(node); | |||||
| MS_EXCEPTION_IF_NULL(func_graph); | |||||
| std::vector<AnfNodePtr> parameters; | |||||
| if (!IsValueNode<RefKey>(node)) { | |||||
| MS_LOG(ERROR) << "The node is not a ref key"; | |||||
| return parameters; | |||||
| } | |||||
| auto ref_key = GetValueNode<RefKeyPtr>(node); | |||||
| MS_EXCEPTION_IF_NULL(ref_key); | |||||
| auto name = ref_key->tag(); | |||||
| auto manager = func_graph->manager(); | |||||
| MS_EXCEPTION_IF_NULL(manager); | |||||
| auto roots = manager->roots(); | |||||
| if (roots.size() != 1) { | |||||
| MS_LOG(ERROR) << "The size of roots ( " << roots.size() << " ) is not 1"; | |||||
| return parameters; | |||||
| } | |||||
| FuncGraphPtr root_g = roots.back(); | |||||
| MS_EXCEPTION_IF_NULL(root_g); | |||||
| for (auto ¶m_node : root_g->parameters()) { | |||||
| auto param = param_node->cast<ParameterPtr>(); | |||||
| if (param && (name == param->name())) { | |||||
| parameters.push_back(param_node); | |||||
| MS_LOG(INFO) << "The name of ref key is: " << name; | |||||
| return parameters; | |||||
| } | |||||
| } | |||||
| MS_LOG(ERROR) << "The name of ref key is: " << name << ", but have not found the parameter"; | |||||
| return parameters; | |||||
| } | |||||
| bool AnfNodeIsPrimitive(const AnfNodePtr &anf_node, const std::string &prim_name) { | |||||
| MS_EXCEPTION_IF_NULL(anf_node); | |||||
| auto cnode = anf_node->cast<CNodePtr>(); | |||||
| if ((cnode == nullptr) || !IsValueNode<Primitive>(cnode->input(0))) { | |||||
| return false; | |||||
| } | |||||
| auto value_node = cnode->input(0)->cast<ValueNodePtr>(); | |||||
| auto prim = GetValueNode<PrimitivePtr>(value_node); | |||||
| MS_EXCEPTION_IF_NULL(prim); | |||||
| if (prim->name() == prim_name) { | |||||
| return true; | |||||
| } | |||||
| return false; | |||||
| } | |||||
| bool FindReshape(const CNodePtr &cnode, std::unordered_set<std::string> *op_cache) { | |||||
| if ((cnode == nullptr) || !IsValueNode<Primitive>(cnode->input(0))) { | |||||
| return false; | |||||
| } | |||||
| if (!IsParallelCareNode(cnode) || !cnode->has_user_data<OperatorInfo>()) { | |||||
| return false; | |||||
| } | |||||
| ValueNodePtr prim_anf_node = cnode->input(0)->cast<ValueNodePtr>(); | |||||
| PrimitivePtr prim = GetValueNode<PrimitivePtr>(prim_anf_node); | |||||
| MS_EXCEPTION_IF_NULL(prim); | |||||
| if (prim->name() == RESHAPE) { | |||||
| auto operator_info = cnode->user_data<OperatorInfo>(); | |||||
| std::string op_info_name = operator_info->name(); | |||||
| if (op_cache->find(op_info_name) != op_cache->end()) { | |||||
| return false; | |||||
| } | |||||
| op_cache->insert(op_info_name); | |||||
| return true; | |||||
| } | |||||
| return false; | |||||
| } | |||||
| // Find previous node of Reshape, then obtain its strategy_cost_ vector to get its layout vector. | |||||
| bool FindReshapePreNodeStraCosts(const AnfNodePtr &node, OperatorInfoPtr *pre_operator_info, int64_t *out_index) { | |||||
| // if previous node is a parameter, handle it in the outsize. | |||||
| if (node->isa<Parameter>()) { | |||||
| return false; | |||||
| } | |||||
| if (!node->isa<CNode>()) { | |||||
| return false; | |||||
| } | |||||
| CNodePtr cnode = node->cast<CNodePtr>(); | |||||
| if (!IsValueNode<Primitive>(cnode->input(0))) { | |||||
| return false; | |||||
| } | |||||
| auto node_op_info = cnode->user_data<OperatorInfo>(); | |||||
| if (IsParallelCareNode(cnode) && (node_op_info != nullptr)) { | |||||
| *pre_operator_info = node_op_info; | |||||
| *out_index = 0; | |||||
| return true; | |||||
| } | |||||
| ValueNodePtr prim_anf_node = cnode->input(0)->cast<ValueNodePtr>(); | |||||
| PrimitivePtr prim = prim_anf_node->value()->cast<PrimitivePtr>(); | |||||
| if (prim->name() == TUPLE_GETITEM) { | |||||
| *out_index = GetTupleGetItemIndex(cnode); | |||||
| // find tuple_get_item's previous node | |||||
| auto pre_node = cnode->input(1); | |||||
| if (!pre_node->isa<CNode>()) { | |||||
| MS_LOG(EXCEPTION) << "tuple get item's second input is not a cnode"; | |||||
| } | |||||
| CNodePtr pre_cnode = pre_node->cast<CNodePtr>(); | |||||
| auto pre_op_info = pre_cnode->user_data<OperatorInfo>(); | |||||
| if (IsParallelCareNode(pre_cnode) && (pre_op_info != nullptr)) { | |||||
| *pre_operator_info = pre_op_info; | |||||
| return true; | |||||
| } | |||||
| return false; | |||||
| } | |||||
| for (size_t index = 0; index < cnode->inputs().size(); ++index) { | |||||
| if (prim->name() == DEPEND && index != 1) { | |||||
| continue; | |||||
| } | |||||
| if (!FindReshapePreNodeStraCosts(cnode->inputs()[index], pre_operator_info, out_index)) { | |||||
| continue; | |||||
| } | |||||
| return true; | |||||
| } | |||||
| MS_LOG(WARNING) | |||||
| << "FindReshapePreNodeStraCosts failed, if reshape is not the first primitive, there must be some error"; | |||||
| return false; | |||||
| } | |||||
| // Find next node of Reshape, then obtain its strategy_cost_ vector to get its layout vector. | |||||
| // if reshape's output connect to several primitive, return the first layout found | |||||
| bool FindReshapeNextNodeStraCosts(const CNodePtr &cnode, OperatorInfoPtr *next_operator_info, int64_t *in_index) { | |||||
| MS_EXCEPTION_IF_NULL(cnode); | |||||
| MS_EXCEPTION_IF_NULL(cnode->func_graph()); | |||||
| FuncGraphManagerPtr manager = cnode->func_graph()->manager(); | |||||
| MS_EXCEPTION_IF_NULL(manager); | |||||
| AnfNodeIndexSet node_set = manager->node_users()[cnode]; | |||||
| for (auto &node_pair : node_set) { | |||||
| CNodePtr use_apply = node_pair.first->cast<CNodePtr>(); | |||||
| if (use_apply == nullptr || !IsValueNode<Primitive>(use_apply->input(0))) { | |||||
| continue; | |||||
| } | |||||
| ValueNodePtr prim_anf_node = use_apply->input(0)->cast<ValueNodePtr>(); | |||||
| MS_EXCEPTION_IF_NULL(prim_anf_node); | |||||
| PrimitivePtr node_prim = prim_anf_node->value()->cast<PrimitivePtr>(); | |||||
| MS_EXCEPTION_IF_NULL(node_prim); | |||||
| MS_LOG(INFO) << "FindNextLayout prim " << node_prim->name(); | |||||
| if (node_prim->name() == DEPEND && node_pair.second != 1) { | |||||
| continue; | |||||
| } | |||||
| auto op_info = use_apply->user_data<OperatorInfo>(); | |||||
| if (IsParallelCareNode(use_apply) && (op_info != nullptr)) { | |||||
| MS_LOG(INFO) << "FindReshapeNextNodeStraCosts success prim " << node_prim->name(); | |||||
| *next_operator_info = op_info; | |||||
| *in_index = node_pair.second - 1; | |||||
| return true; | |||||
| } | |||||
| MS_LOG(DEBUG) << "FindReshapeNextNodeStraCosts failed prim " << node_prim->name() << " " | |||||
| << IsParallelCareNode(use_apply) << " " << (op_info != nullptr); | |||||
| if (FindReshapeNextNodeStraCosts(use_apply, next_operator_info, in_index)) { | |||||
| return true; | |||||
| } | |||||
| } | |||||
| return false; | |||||
| } | |||||
| } // namespace parallel | } // namespace parallel | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -18,13 +18,37 @@ | |||||
| #define MINDSPORE_CCSRC_FRONTEND_PARALLEL_GRAPH_UTIL_NODE_INFO_H_ | #define MINDSPORE_CCSRC_FRONTEND_PARALLEL_GRAPH_UTIL_NODE_INFO_H_ | ||||
| #include <string> | #include <string> | ||||
| #include <vector> | |||||
| #include <memory> | |||||
| #include <unordered_set> | |||||
| #include "base/base.h" | #include "base/base.h" | ||||
| #include "ir/anf.h" | |||||
| #include "frontend/parallel/ops_info/operator_info.h" | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace parallel { | namespace parallel { | ||||
| using OperatorInfoPtr = std::shared_ptr<mindspore::parallel::OperatorInfo>; | |||||
| std::string ParameterName(const AnfNodePtr &node_ptr); | std::string ParameterName(const AnfNodePtr &node_ptr); | ||||
| bool ParameterRequireGrad(const AnfNodePtr &node_ptr); | bool ParameterRequireGrad(const AnfNodePtr &node_ptr); | ||||
| size_t GetLengthOfDataType(const TypePtr &type); | |||||
| std::vector<bool> ExtractInputParameterByNode(const CNodePtr &node); | |||||
| std::vector<size_t> ExtractInputTypeLengthByNode(const CNodePtr &node); | |||||
| std::vector<TypePtr> ExtractOutputTypeByNode(const CNodePtr &node); | |||||
| std::vector<AnfNodePtr> FindParameterByRefKeyNode(const AnfNodePtr &node, const FuncGraphPtr &func_graph); | |||||
| bool AnfNodeIsPrimitive(const AnfNodePtr &anf_node, const std::string &prim_name); | |||||
| bool FindReshape(const CNodePtr &cnode, std::unordered_set<std::string> *op_cache); | |||||
| bool FindReshapePreNodeStraCosts(const AnfNodePtr &node, OperatorInfoPtr *pre_operator_info, int64_t *out_index); | |||||
| bool FindReshapeNextNodeStraCosts(const CNodePtr &cnode, OperatorInfoPtr *next_operator_info, int64_t *in_index); | |||||
| } // namespace parallel | } // namespace parallel | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -63,7 +63,7 @@ bool StepAutoParallel(const FuncGraphPtr &root, const opt::OptimizerPtr &) { | |||||
| // check whether strategy_search_mode is valid | // check whether strategy_search_mode is valid | ||||
| std::string strategy_search_mode = ParallelContext::GetInstance()->strategy_search_mode(); | std::string strategy_search_mode = ParallelContext::GetInstance()->strategy_search_mode(); | ||||
| if ((strategy_search_mode != DYNAMIC_PROGRAMMING) && (strategy_search_mode != RECURSIVE_PROGRAMMING)) { | if ((strategy_search_mode != DYNAMIC_PROGRAMMING) && (strategy_search_mode != RECURSIVE_PROGRAMMING)) { | ||||
| // Setting searching mode: dynanic programming as default. | |||||
| // Setting searching mode: dynamic programming as default. | |||||
| strategy_search_mode = DYNAMIC_PROGRAMMING; | strategy_search_mode = DYNAMIC_PROGRAMMING; | ||||
| MS_LOG(INFO) << "Non-idicated strategy searching mode, using DP searching mode as default"; | MS_LOG(INFO) << "Non-idicated strategy searching mode, using DP searching mode as default"; | ||||
| } | } | ||||
| @@ -112,170 +112,6 @@ bool StepAutoParallel(const FuncGraphPtr &root, const opt::OptimizerPtr &) { | |||||
| return changes; | return changes; | ||||
| } | } | ||||
| // Given the node, return whether each input is a parameter or a output of a operator. | |||||
| // The returned boolean vector should be the same order of the inputs, thus its implementation | |||||
| // is closely consistent with ExtractShape() in step_parallel.cc | |||||
| std::vector<bool> ExtractInputParameterByNode(const CNodePtr &node) { | |||||
| std::vector<bool> is_parameter; | |||||
| std::vector<AnfNodePtr> node_inputs{node->inputs()}; | |||||
| // input is a ValueList or ValueTuple, then all inputs are not parameter. | |||||
| if ((node_inputs.size() == 2) && | |||||
| (IsValueNode<ValueList>(node_inputs[1]) || IsValueNode<ValueTuple>(node_inputs[1]))) { | |||||
| std::vector<ValuePtr> inputs_seq; | |||||
| if (IsValueNode<ValueList>(node_inputs[1])) { | |||||
| inputs_seq = node_inputs[1]->cast<ValueNodePtr>()->value()->cast<ValueListPtr>()->value(); | |||||
| } else { | |||||
| inputs_seq = node_inputs[1]->cast<ValueNodePtr>()->value()->cast<ValueTuplePtr>()->value(); | |||||
| } | |||||
| return std::vector<bool>(inputs_seq.size(), false); | |||||
| } | |||||
| if ((node_inputs.size() == 2) && | |||||
| (AnfNodeIsPrimitive(node_inputs[1], MAKE_TUPLE) || AnfNodeIsPrimitive(node_inputs[1], MAKE_LIST))) { | |||||
| node_inputs = node_inputs[1]->cast<CNodePtr>()->inputs(); | |||||
| } | |||||
| for (size_t i = 1; i < node_inputs.size(); ++i) { | |||||
| auto input = node_inputs[i]; | |||||
| if (input->isa<Parameter>()) { | |||||
| auto input_parameter = input->cast<ParameterPtr>(); | |||||
| is_parameter.push_back(ParameterRequireGrad(input_parameter)); | |||||
| } else if (input->isa<CNode>() || IsValueNode<tensor::Tensor>(input) || IsValueNode<RefKey>(input)) { | |||||
| is_parameter.push_back(false); | |||||
| } | |||||
| } | |||||
| return is_parameter; | |||||
| } | |||||
| // Given the type, return the number of bytes to represent this type | |||||
| size_t GetLengthOfDataType(const TypePtr &type) { | |||||
| switch (type->type_id()) { | |||||
| case kNumberTypeBool: | |||||
| return sizeof(bool); | |||||
| case kNumberTypeInt8: | |||||
| return sizeof(int8_t); | |||||
| case kNumberTypeInt16: | |||||
| return sizeof(int16_t); | |||||
| case kNumberTypeInt32: | |||||
| return sizeof(int32_t); | |||||
| case kNumberTypeInt64: | |||||
| return sizeof(int64_t); | |||||
| case kNumberTypeUInt8: | |||||
| return sizeof(uint8_t); | |||||
| case kNumberTypeUInt16: | |||||
| return sizeof(uint16_t); | |||||
| case kNumberTypeUInt32: | |||||
| return sizeof(uint32_t); | |||||
| case kNumberTypeUInt64: | |||||
| return sizeof(uint64_t); | |||||
| case kNumberTypeFloat16: | |||||
| return sizeof(float) / 2; | |||||
| case kNumberTypeFloat32: | |||||
| return sizeof(float); | |||||
| case kNumberTypeFloat64: | |||||
| return sizeof(double); | |||||
| case kNumberTypeInt: | |||||
| return sizeof(int64_t); | |||||
| case kNumberTypeUInt: | |||||
| return sizeof(unsigned int64_t); | |||||
| case kNumberTypeFloat: | |||||
| return sizeof(float); | |||||
| default: | |||||
| MS_LOG(EXCEPTION) << "Unexpected type " << type->type_name(); | |||||
| } | |||||
| } | |||||
| size_t GetInputsTypeLen(const AnfNodePtr &input) { | |||||
| MS_EXCEPTION_IF_NULL(input); | |||||
| if (!input->isa<CNode>() && !input->isa<Parameter>() && !IsValueNode<tensor::Tensor>(input)) { | |||||
| MS_LOG(EXCEPTION) << "The input node is not a cnode or parameter or tensor"; | |||||
| } | |||||
| size_t input_type_len = 0; | |||||
| auto type = input->Type(); | |||||
| MS_EXCEPTION_IF_NULL(type); | |||||
| if (type->isa<mindspore::TensorType>()) { | |||||
| auto input_element_type = type->cast<mindspore::TensorTypePtr>()->element(); | |||||
| input_type_len = GetLengthOfDataType(input_element_type); | |||||
| } else { | |||||
| MS_LOG(EXCEPTION) << "Unknown type: " << type->type_name(); | |||||
| } | |||||
| return input_type_len; | |||||
| } | |||||
| std::vector<size_t> ExtractInputTypeLengthByNode(const CNodePtr &node) { | |||||
| MS_EXCEPTION_IF_NULL(node); | |||||
| std::vector<size_t> inputs_type_len; | |||||
| std::vector<AnfNodePtr> node_inputs{node->inputs()}; | |||||
| if ((node_inputs.size() == 2) && | |||||
| (IsValueNode<ValueList>(node_inputs[1]) || IsValueNode<ValueTuple>(node_inputs[1]))) { | |||||
| std::vector<ValuePtr> inputs_seq; | |||||
| if (IsValueNode<ValueList>(node_inputs[1])) { | |||||
| inputs_seq = node_inputs[1]->cast<ValueNodePtr>()->value()->cast<ValueListPtr>()->value(); | |||||
| } else { | |||||
| inputs_seq = node_inputs[1]->cast<ValueNodePtr>()->value()->cast<ValueTuplePtr>()->value(); | |||||
| } | |||||
| for (auto &ele : inputs_seq) { | |||||
| auto tensor = ele->cast<tensor::TensorPtr>(); | |||||
| MS_EXCEPTION_IF_NULL(tensor); | |||||
| inputs_type_len.push_back(GetLengthOfDataType(tensor->Dtype())); | |||||
| } | |||||
| return inputs_type_len; | |||||
| } | |||||
| if ((node_inputs.size() == 2) && | |||||
| (AnfNodeIsPrimitive(node_inputs[1], MAKE_TUPLE) || AnfNodeIsPrimitive(node_inputs[1], MAKE_LIST))) { | |||||
| node_inputs = node_inputs[1]->cast<CNodePtr>()->inputs(); | |||||
| } | |||||
| // extract input element length | |||||
| for (auto &input : node_inputs) { | |||||
| if (IsValueNode<RefKey>(input)) { | |||||
| auto func_graph = node->func_graph(); | |||||
| MS_EXCEPTION_IF_NULL(func_graph); | |||||
| std::vector<AnfNodePtr> parameters = FindParameterByRefKeyNode(input, func_graph); | |||||
| if (parameters.size() != 1) { | |||||
| MS_LOG(EXCEPTION) << "Find parameter by ref key node failed"; | |||||
| } | |||||
| inputs_type_len.push_back(GetInputsTypeLen(parameters[0])); | |||||
| } else if (input->isa<CNode>() || input->isa<Parameter>() || IsValueNode<tensor::Tensor>(input)) { | |||||
| // extract input shape from parameter and apply node | |||||
| inputs_type_len.push_back(GetInputsTypeLen(input)); | |||||
| } | |||||
| } | |||||
| return inputs_type_len; | |||||
| } | |||||
| std::vector<TypePtr> ExtractOutputTypeByNode(const CNodePtr &node) { | |||||
| MS_EXCEPTION_IF_NULL(node); | |||||
| std::vector<TypePtr> outputs_type; | |||||
| // extract output element type | |||||
| auto primary_output_type = node->Type(); | |||||
| MS_EXCEPTION_IF_NULL(primary_output_type); | |||||
| if (primary_output_type->isa<mindspore::Tuple>()) { | |||||
| // in this case, the output is a tuple | |||||
| auto tuple_output_type = primary_output_type->cast<mindspore::TuplePtr>(); | |||||
| auto elements = tuple_output_type->elements(); | |||||
| for (auto &ele : elements) { | |||||
| if (ele->isa<mindspore::TensorType>()) { | |||||
| auto ele_element_type = ele->cast<mindspore::TensorTypePtr>()->element(); | |||||
| outputs_type.push_back(ele_element_type); | |||||
| } else { | |||||
| MS_LOG(EXCEPTION) << "Unknown type: " << primary_output_type->type_name(); | |||||
| } | |||||
| } | |||||
| } else { | |||||
| // in this case, the output is a single tensor | |||||
| if (primary_output_type->isa<mindspore::TensorType>()) { | |||||
| auto element_type = primary_output_type->cast<mindspore::TensorTypePtr>()->element(); | |||||
| outputs_type.push_back(element_type); | |||||
| } else { | |||||
| MS_LOG(EXCEPTION) << "Unknown type: " << primary_output_type->type_name(); | |||||
| } | |||||
| } | |||||
| return outputs_type; | |||||
| } | |||||
| bool IsElementWiseOperator(const std::string &op_name) { | bool IsElementWiseOperator(const std::string &op_name) { | ||||
| // clang-format off | // clang-format off | ||||
| static const std::set<std::string> elementwise_op = {ACTIVATION, GELU, TANH, | static const std::set<std::string> elementwise_op = {ACTIVATION, GELU, TANH, | ||||
| @@ -381,6 +217,11 @@ bool IsOperatorsInTwoSeparateLoops(const CNodePtr &a_cnode, const CNodePtr &b_cn | |||||
| return true; | return true; | ||||
| } | } | ||||
| void InitCostGraph() { | |||||
| entire_costgraph = std::make_shared<CostGraph>(); | |||||
| entire_costgraph->SetDeviceMemoryAndCostParameter(); | |||||
| } | |||||
| OperatorInfoPtr CreateTheOperatorInfo(const PrimitivePtr &prim, const CNodePtr &cnode, StrategyMap *stra_map) { | OperatorInfoPtr CreateTheOperatorInfo(const PrimitivePtr &prim, const CNodePtr &cnode, StrategyMap *stra_map) { | ||||
| MS_EXCEPTION_IF_NULL(prim); | MS_EXCEPTION_IF_NULL(prim); | ||||
| MS_EXCEPTION_IF_NULL(cnode); | MS_EXCEPTION_IF_NULL(cnode); | ||||
| @@ -491,8 +332,6 @@ OperatorInfoPtr CreateTheOperatorInfo(const PrimitivePtr &prim, const CNodePtr & | |||||
| // Using CNode's UniqueIds to construct nodes | // Using CNode's UniqueIds to construct nodes | ||||
| Status ConstructCostGraphNodesByUniqueId(const std::vector<AnfNodePtr> &all_nodes, const FuncGraphPtr &) { | Status ConstructCostGraphNodesByUniqueId(const std::vector<AnfNodePtr> &all_nodes, const FuncGraphPtr &) { | ||||
| MS_LOG(INFO) << "Constructing nodes for cost graph begins."; | MS_LOG(INFO) << "Constructing nodes for cost graph begins."; | ||||
| entire_costgraph = std::make_shared<CostGraph>(); | |||||
| entire_costgraph->SetDeviceMemoryAndCostParameter(); | |||||
| // The map from CNode's UniqueId to its operatorInfo | // The map from CNode's UniqueId to its operatorInfo | ||||
| std::map<std::string, OperatorInfoPtr> from_cnode_to_info; | std::map<std::string, OperatorInfoPtr> from_cnode_to_info; | ||||
| // The operator_infos in a loop | // The operator_infos in a loop | ||||
| @@ -506,7 +345,7 @@ Status ConstructCostGraphNodesByUniqueId(const std::vector<AnfNodePtr> &all_node | |||||
| MS_LOG(EXCEPTION) << "Load strategy checkpoint failed"; | MS_LOG(EXCEPTION) << "Load strategy checkpoint failed"; | ||||
| } | } | ||||
| } | } | ||||
| // Step 1 | |||||
| for (auto &node : all_nodes) { | for (auto &node : all_nodes) { | ||||
| // NOTE: we only care about splittable Primitive operators | // NOTE: we only care about splittable Primitive operators | ||||
| auto cnode = node->cast<CNodePtr>(); | auto cnode = node->cast<CNodePtr>(); | ||||
| @@ -588,8 +427,6 @@ Status ConstructCostGraphNodesByUniqueId(const std::vector<AnfNodePtr> &all_node | |||||
| // Using CNode's UniqueIdThroughCopys to construct nodes | // Using CNode's UniqueIdThroughCopys to construct nodes | ||||
| Status ConstructCostGraphNodesByUniqueIdTC(const std::vector<AnfNodePtr> &all_nodes, const FuncGraphPtr &) { | Status ConstructCostGraphNodesByUniqueIdTC(const std::vector<AnfNodePtr> &all_nodes, const FuncGraphPtr &) { | ||||
| MS_LOG(INFO) << "Constructing nodes for cost graph begins."; | MS_LOG(INFO) << "Constructing nodes for cost graph begins."; | ||||
| entire_costgraph = std::make_shared<CostGraph>(); | |||||
| entire_costgraph->SetDeviceMemoryAndCostParameter(); | |||||
| // The map from CNode's UniqueIdThroughCopy to its operatorInfo | // The map from CNode's UniqueIdThroughCopy to its operatorInfo | ||||
| std::map<std::string, OperatorInfoPtr> from_cnode_to_info; | std::map<std::string, OperatorInfoPtr> from_cnode_to_info; | ||||
| // The operator_infos in a loop | // The operator_infos in a loop | ||||
| @@ -937,115 +774,6 @@ void AugmentCostGraph(const std::vector<AnfNodePtr> &all_nodes) { | |||||
| } | } | ||||
| } | } | ||||
| bool FindReshape(const CNodePtr &cnode, std::unordered_set<std::string> *op_cache) { | |||||
| if ((cnode == nullptr) || !IsValueNode<Primitive>(cnode->input(0))) { | |||||
| return false; | |||||
| } | |||||
| if (!IsParallelCareNode(cnode) || !cnode->has_user_data<OperatorInfo>()) { | |||||
| return false; | |||||
| } | |||||
| ValueNodePtr prim_anf_node = cnode->input(0)->cast<ValueNodePtr>(); | |||||
| PrimitivePtr prim = GetValueNode<PrimitivePtr>(prim_anf_node); | |||||
| MS_EXCEPTION_IF_NULL(prim); | |||||
| if (prim->name() == RESHAPE) { | |||||
| auto operator_info = cnode->user_data<OperatorInfo>(); | |||||
| std::string op_info_name = operator_info->name(); | |||||
| if (op_cache->find(op_info_name) != op_cache->end()) { | |||||
| return false; | |||||
| } | |||||
| op_cache->insert(op_info_name); | |||||
| return true; | |||||
| } | |||||
| return false; | |||||
| } | |||||
| // find previous node, then obtain its strategy_cost_ vector to get its layout vector. | |||||
| bool FindPreNodeStraCosts(const AnfNodePtr &node, OperatorInfoPtr *pre_operator_info, int64_t *out_index) { | |||||
| // if previous node is a parameter, handle it in the outsize. | |||||
| if (node->isa<Parameter>()) { | |||||
| return false; | |||||
| } | |||||
| if (!node->isa<CNode>()) { | |||||
| return false; | |||||
| } | |||||
| CNodePtr cnode = node->cast<CNodePtr>(); | |||||
| if (!IsValueNode<Primitive>(cnode->input(0))) { | |||||
| return false; | |||||
| } | |||||
| auto node_op_info = cnode->user_data<OperatorInfo>(); | |||||
| if (IsParallelCareNode(cnode) && (node_op_info != nullptr)) { | |||||
| *pre_operator_info = node_op_info; | |||||
| *out_index = 0; | |||||
| return true; | |||||
| } | |||||
| ValueNodePtr prim_anf_node = cnode->input(0)->cast<ValueNodePtr>(); | |||||
| PrimitivePtr prim = prim_anf_node->value()->cast<PrimitivePtr>(); | |||||
| if (prim->name() == TUPLE_GETITEM) { | |||||
| *out_index = GetTupleGetItemIndex(cnode); | |||||
| // find tuple_get_item's previous node | |||||
| auto pre_node = cnode->input(1); | |||||
| if (!pre_node->isa<CNode>()) { | |||||
| MS_LOG(EXCEPTION) << "tuple get item's second input is not a cnode"; | |||||
| } | |||||
| CNodePtr pre_cnode = pre_node->cast<CNodePtr>(); | |||||
| auto pre_op_info = pre_cnode->user_data<OperatorInfo>(); | |||||
| if (IsParallelCareNode(pre_cnode) && (pre_op_info != nullptr)) { | |||||
| *pre_operator_info = pre_op_info; | |||||
| return true; | |||||
| } | |||||
| return false; | |||||
| } | |||||
| for (size_t index = 0; index < cnode->inputs().size(); ++index) { | |||||
| if (prim->name() == DEPEND && index != 1) { | |||||
| continue; | |||||
| } | |||||
| if (!FindPreNodeStraCosts(cnode->inputs()[index], pre_operator_info, out_index)) { | |||||
| continue; | |||||
| } | |||||
| return true; | |||||
| } | |||||
| MS_LOG(WARNING) << "FindPreNodeStraCosts failed, if reshape is not the first primitive, there must be some error"; | |||||
| return false; | |||||
| } | |||||
| // find next node, then obtain its strategy_cost_ vector to get its layout vector. | |||||
| // if reshape's output connect to several primitive, return the first layout found | |||||
| bool FindNextNodeStraCosts(const CNodePtr &cnode, OperatorInfoPtr *next_operator_info, int64_t *in_index) { | |||||
| MS_EXCEPTION_IF_NULL(cnode); | |||||
| MS_EXCEPTION_IF_NULL(cnode->func_graph()); | |||||
| FuncGraphManagerPtr manager = cnode->func_graph()->manager(); | |||||
| MS_EXCEPTION_IF_NULL(manager); | |||||
| AnfNodeIndexSet node_set = manager->node_users()[cnode]; | |||||
| for (auto &node_pair : node_set) { | |||||
| CNodePtr use_apply = node_pair.first->cast<CNodePtr>(); | |||||
| if (use_apply == nullptr || !IsValueNode<Primitive>(use_apply->input(0))) { | |||||
| continue; | |||||
| } | |||||
| ValueNodePtr prim_anf_node = use_apply->input(0)->cast<ValueNodePtr>(); | |||||
| MS_EXCEPTION_IF_NULL(prim_anf_node); | |||||
| PrimitivePtr node_prim = prim_anf_node->value()->cast<PrimitivePtr>(); | |||||
| MS_EXCEPTION_IF_NULL(node_prim); | |||||
| MS_LOG(INFO) << "FindNextLayout prim " << node_prim->name(); | |||||
| if (node_prim->name() == DEPEND && node_pair.second != 1) { | |||||
| continue; | |||||
| } | |||||
| auto op_info = use_apply->user_data<OperatorInfo>(); | |||||
| if (IsParallelCareNode(use_apply) && (op_info != nullptr)) { | |||||
| MS_LOG(INFO) << "FindNextNodeStraCosts success prim " << node_prim->name(); | |||||
| *next_operator_info = op_info; | |||||
| *in_index = node_pair.second - 1; | |||||
| return true; | |||||
| } | |||||
| MS_LOG(DEBUG) << "FindNextNodeStraCosts failed prim " << node_prim->name() << " " << IsParallelCareNode(use_apply) | |||||
| << " " << (op_info != nullptr); | |||||
| if (FindNextNodeStraCosts(use_apply, next_operator_info, in_index)) { | |||||
| return true; | |||||
| } | |||||
| } | |||||
| return false; | |||||
| } | |||||
| void ReshapeCostCompute(const std::vector<AnfNodePtr> &all_nodes) { | void ReshapeCostCompute(const std::vector<AnfNodePtr> &all_nodes) { | ||||
| std::unordered_set<std::string> op_cache; | std::unordered_set<std::string> op_cache; | ||||
| for (auto node : all_nodes) { | for (auto node : all_nodes) { | ||||
| @@ -1066,8 +794,8 @@ void ReshapeCostCompute(const std::vector<AnfNodePtr> &all_nodes) { | |||||
| pre_operator_info = reshape_info; | pre_operator_info = reshape_info; | ||||
| pre_stra_costs = reshape_info->strategy_cost(); | pre_stra_costs = reshape_info->strategy_cost(); | ||||
| } else { | } else { | ||||
| if (!FindPreNodeStraCosts(pre_node, &pre_operator_info, &out_index)) { | |||||
| MS_LOG(EXCEPTION) << "FindPreNodeStraCosts for reshape failed"; | |||||
| if (!FindReshapePreNodeStraCosts(pre_node, &pre_operator_info, &out_index)) { | |||||
| MS_LOG(EXCEPTION) << "FindReshapePreNodeStraCosts for reshape failed"; | |||||
| } | } | ||||
| pre_stra_costs = pre_operator_info->strategy_cost(); | pre_stra_costs = pre_operator_info->strategy_cost(); | ||||
| } | } | ||||
| @@ -1075,9 +803,9 @@ void ReshapeCostCompute(const std::vector<AnfNodePtr> &all_nodes) { | |||||
| int64_t in_index = 0; | int64_t in_index = 0; | ||||
| OperatorInfoPtr next_operator_info; | OperatorInfoPtr next_operator_info; | ||||
| std::vector<std::shared_ptr<StrategyWithCost>> next_stra_costs; | std::vector<std::shared_ptr<StrategyWithCost>> next_stra_costs; | ||||
| bool find_next_node = FindNextNodeStraCosts(cnode, &next_operator_info, &in_index); | |||||
| bool find_next_node = FindReshapeNextNodeStraCosts(cnode, &next_operator_info, &in_index); | |||||
| if (!find_next_node) { | if (!find_next_node) { | ||||
| MS_LOG(INFO) << "FindNextNodeStraCosts for reshape failed"; | |||||
| MS_LOG(INFO) << "FindReshapeNextNodeStraCosts for reshape failed"; | |||||
| } | } | ||||
| // set input_layout and output_layout for reshape. | // set input_layout and output_layout for reshape. | ||||
| // init reshape and set cost for each input_layout and output_layout. | // init reshape and set cost for each input_layout and output_layout. | ||||
| @@ -1122,6 +850,7 @@ Status ParallelStrategySearch(const std::vector<AnfNodePtr> &all_nodes, const Fu | |||||
| // | // | ||||
| // OUTPUT: the determined strategy for each operator. | // OUTPUT: the determined strategy for each operator. | ||||
| InitCostGraph(); | |||||
| // Step 1 | // Step 1 | ||||
| if (CostModelContext::GetInstance()->is_multi_subgraphs()) { | if (CostModelContext::GetInstance()->is_multi_subgraphs()) { | ||||
| if (ConstructCostGraphNodesByUniqueIdTC(all_nodes, root) == SUCCESS) { | if (ConstructCostGraphNodesByUniqueIdTC(all_nodes, root) == SUCCESS) { | ||||
| @@ -28,20 +28,14 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace parallel { | namespace parallel { | ||||
| bool IsSplittableOperator(const std::string &); | |||||
| bool IsAutoParallelCareNode(const CNodePtr &); | |||||
| // main step of Auto-parallel | // main step of Auto-parallel | ||||
| bool StepAutoParallel(const FuncGraphPtr &func_graph, const opt::OptimizerPtr &optimizer); | bool StepAutoParallel(const FuncGraphPtr &func_graph, const opt::OptimizerPtr &optimizer); | ||||
| size_t GetLengthOfDataType(const TypePtr &type); | |||||
| std::vector<bool> ExtractInputParameterByNode(const CNodePtr &node); | |||||
| bool IsSplittableOperator(const std::string &); | |||||
| std::vector<size_t> ExtractInputTypeLengthByNode(const CNodePtr &node); | |||||
| bool IsAutoParallelCareNode(const CNodePtr &); | |||||
| std::vector<TypePtr> ExtractOutputTypeByNode(const CNodePtr &node); | |||||
| void InitCostGraph(); | |||||
| Status ConstructCostGraphNodesByUniqueId(const std::vector<AnfNodePtr> &all_nodes, const FuncGraphPtr &root); | Status ConstructCostGraphNodesByUniqueId(const std::vector<AnfNodePtr> &all_nodes, const FuncGraphPtr &root); | ||||
| @@ -292,22 +292,6 @@ TensorLayout GetTensorInLayout(const CNodePtr &middle_node, const PrimitivePtr & | |||||
| return tensorinfo_in.tensor_layout(); | return tensorinfo_in.tensor_layout(); | ||||
| } | } | ||||
| bool AnfNodeIsPrimitive(const AnfNodePtr &anf_node, const std::string &prim_name) { | |||||
| MS_EXCEPTION_IF_NULL(anf_node); | |||||
| auto cnode = anf_node->cast<CNodePtr>(); | |||||
| if ((cnode == nullptr) || !IsValueNode<Primitive>(cnode->input(0))) { | |||||
| return false; | |||||
| } | |||||
| auto value_node = cnode->input(0)->cast<ValueNodePtr>(); | |||||
| auto prim = GetValueNode<PrimitivePtr>(value_node); | |||||
| MS_EXCEPTION_IF_NULL(prim); | |||||
| if (prim->name() == prim_name) { | |||||
| return true; | |||||
| } | |||||
| return false; | |||||
| } | |||||
| std::string GetPrimName(const CNodePtr &node) { | std::string GetPrimName(const CNodePtr &node) { | ||||
| MS_EXCEPTION_IF_NULL(node); | MS_EXCEPTION_IF_NULL(node); | ||||
| if (!IsValueNode<Primitive>(node->input(0))) { | if (!IsValueNode<Primitive>(node->input(0))) { | ||||
| @@ -1219,42 +1203,6 @@ Shapes GetNodeShape(const AnfNodePtr &node) { | |||||
| return shapes; | return shapes; | ||||
| } | } | ||||
| std::vector<AnfNodePtr> FindParameterByRefKeyNode(const AnfNodePtr &node, const FuncGraphPtr &func_graph) { | |||||
| MS_EXCEPTION_IF_NULL(node); | |||||
| MS_EXCEPTION_IF_NULL(func_graph); | |||||
| std::vector<AnfNodePtr> parameters; | |||||
| if (!IsValueNode<RefKey>(node)) { | |||||
| MS_LOG(ERROR) << "The node is not a ref key"; | |||||
| return parameters; | |||||
| } | |||||
| auto ref_key = GetValueNode<RefKeyPtr>(node); | |||||
| MS_EXCEPTION_IF_NULL(ref_key); | |||||
| auto name = ref_key->tag(); | |||||
| auto manager = func_graph->manager(); | |||||
| MS_EXCEPTION_IF_NULL(manager); | |||||
| auto roots = manager->roots(); | |||||
| if (roots.size() != 1) { | |||||
| MS_LOG(ERROR) << "The size of roots ( " << roots.size() << " ) is not 1"; | |||||
| return parameters; | |||||
| } | |||||
| FuncGraphPtr root_g = roots.back(); | |||||
| MS_EXCEPTION_IF_NULL(root_g); | |||||
| for (auto ¶m_node : root_g->parameters()) { | |||||
| auto param = param_node->cast<ParameterPtr>(); | |||||
| if (param && (name == param->name())) { | |||||
| parameters.push_back(param_node); | |||||
| MS_LOG(INFO) << "The name of ref key is: " << name; | |||||
| return parameters; | |||||
| } | |||||
| } | |||||
| MS_LOG(ERROR) << "The name of ref key is: " << name << ", but have not found the parameter"; | |||||
| return parameters; | |||||
| } | |||||
| Shapes GetRefKeyNodeShape(const AnfNodePtr &node, const FuncGraphPtr &func_graph) { | Shapes GetRefKeyNodeShape(const AnfNodePtr &node, const FuncGraphPtr &func_graph) { | ||||
| MS_EXCEPTION_IF_NULL(node); | MS_EXCEPTION_IF_NULL(node); | ||||
| MS_EXCEPTION_IF_NULL(func_graph); | MS_EXCEPTION_IF_NULL(func_graph); | ||||
| @@ -100,8 +100,6 @@ StrategyPtr ExtractStrategy(std::unordered_map<std::string, ValuePtr> attrs); | |||||
| Shapes GetNodeShape(const AnfNodePtr &node); | Shapes GetNodeShape(const AnfNodePtr &node); | ||||
| std::vector<AnfNodePtr> FindParameterByRefKeyNode(const AnfNodePtr &node, const FuncGraphPtr &func_graph); | |||||
| // Extract shape from anfnode | // Extract shape from anfnode | ||||
| std::vector<Shapes> ExtractShape(const CNodePtr &node); | std::vector<Shapes> ExtractShape(const CNodePtr &node); | ||||
| @@ -154,8 +152,6 @@ std::set<FuncGraphPtr> ForwardGraph(const FuncGraphPtr &root); | |||||
| std::vector<std::string> ExtractInputsTensorName(const CNodePtr &node); | std::vector<std::string> ExtractInputsTensorName(const CNodePtr &node); | ||||
| bool AnfNodeIsPrimitive(const AnfNodePtr &anf_node, const std::string &prim_name); | |||||
| using RefKeyPair = std::pair<AnfNodePtr, std::vector<AnfNodePtr>>; | using RefKeyPair = std::pair<AnfNodePtr, std::vector<AnfNodePtr>>; | ||||
| using ParameterUsersInfo = std::pair<std::string, std::pair<AnfNodePtr, AnfNodeIndexSet>>; | using ParameterUsersInfo = std::pair<std::string, std::pair<AnfNodePtr, AnfNodeIndexSet>>; | ||||