From: @xiaoda_zh Reviewed-by: @kisnwang,@stsuteng Signed-off-by: @stsutengtags/v1.1.0
| @@ -18,10 +18,11 @@ | |||
| #include <string> | |||
| #include "ir/anf.h" | |||
| #include "ir/param_info.h" | |||
| #include "ir/meta_tensor.h" | |||
| #include "pipeline/jit/parse/python_adapter.h" | |||
| #include "frontend/parallel/ops_info/ops_utils.h" | |||
| #include "frontend/parallel/step_parallel.h" | |||
| namespace mindspore { | |||
| namespace parallel { | |||
| @@ -45,5 +46,331 @@ bool ParameterRequireGrad(const AnfNodePtr &node_ptr) { | |||
| } | |||
| return param_value->requires_grad(); | |||
| } | |||
| // Given the node, return whether each input is a parameter or a output of a operator. | |||
| // The returned boolean vector should be the same order of the inputs, thus its implementation | |||
| // is closely consistent with ExtractShape() in step_parallel.cc | |||
| std::vector<bool> ExtractInputParameterByNode(const CNodePtr &node) { | |||
| std::vector<bool> is_parameter; | |||
| std::vector<AnfNodePtr> node_inputs{node->inputs()}; | |||
| // input is a ValueList or ValueTuple, then all inputs are not parameter. | |||
| if ((node_inputs.size() == 2) && | |||
| (IsValueNode<ValueList>(node_inputs[1]) || IsValueNode<ValueTuple>(node_inputs[1]))) { | |||
| std::vector<ValuePtr> inputs_seq; | |||
| if (IsValueNode<ValueList>(node_inputs[1])) { | |||
| inputs_seq = node_inputs[1]->cast<ValueNodePtr>()->value()->cast<ValueListPtr>()->value(); | |||
| } else { | |||
| inputs_seq = node_inputs[1]->cast<ValueNodePtr>()->value()->cast<ValueTuplePtr>()->value(); | |||
| } | |||
| return std::vector<bool>(inputs_seq.size(), false); | |||
| } | |||
| if ((node_inputs.size() == 2) && | |||
| (AnfNodeIsPrimitive(node_inputs[1], MAKE_TUPLE) || AnfNodeIsPrimitive(node_inputs[1], MAKE_LIST))) { | |||
| node_inputs = node_inputs[1]->cast<CNodePtr>()->inputs(); | |||
| } | |||
| for (size_t i = 1; i < node_inputs.size(); ++i) { | |||
| auto input = node_inputs[i]; | |||
| if (input->isa<Parameter>()) { | |||
| auto input_parameter = input->cast<ParameterPtr>(); | |||
| is_parameter.push_back(ParameterRequireGrad(input_parameter)); | |||
| } else if (input->isa<CNode>() || IsValueNode<tensor::Tensor>(input) || IsValueNode<RefKey>(input)) { | |||
| is_parameter.push_back(false); | |||
| } | |||
| } | |||
| return is_parameter; | |||
| } | |||
| // Given the type, return the number of bytes to represent this type | |||
| size_t GetLengthOfDataType(const TypePtr &type) { | |||
| switch (type->type_id()) { | |||
| case kNumberTypeBool: | |||
| return sizeof(bool); | |||
| case kNumberTypeInt8: | |||
| return sizeof(int8_t); | |||
| case kNumberTypeInt16: | |||
| return sizeof(int16_t); | |||
| case kNumberTypeInt32: | |||
| return sizeof(int32_t); | |||
| case kNumberTypeInt64: | |||
| return sizeof(int64_t); | |||
| case kNumberTypeUInt8: | |||
| return sizeof(uint8_t); | |||
| case kNumberTypeUInt16: | |||
| return sizeof(uint16_t); | |||
| case kNumberTypeUInt32: | |||
| return sizeof(uint32_t); | |||
| case kNumberTypeUInt64: | |||
| return sizeof(uint64_t); | |||
| case kNumberTypeFloat16: | |||
| return sizeof(float) / 2; | |||
| case kNumberTypeFloat32: | |||
| return sizeof(float); | |||
| case kNumberTypeFloat64: | |||
| return sizeof(double); | |||
| case kNumberTypeInt: | |||
| return sizeof(int64_t); | |||
| case kNumberTypeUInt: | |||
| return sizeof(unsigned int64_t); | |||
| case kNumberTypeFloat: | |||
| return sizeof(float); | |||
| default: | |||
| MS_LOG(EXCEPTION) << "Unexpected type " << type->type_name(); | |||
| } | |||
| } | |||
| size_t GetInputsTypeLen(const AnfNodePtr &input) { | |||
| MS_EXCEPTION_IF_NULL(input); | |||
| if (!input->isa<CNode>() && !input->isa<Parameter>() && !IsValueNode<tensor::Tensor>(input)) { | |||
| MS_LOG(EXCEPTION) << "The input node is not a cnode or parameter or tensor"; | |||
| } | |||
| size_t input_type_len = 0; | |||
| auto type = input->Type(); | |||
| MS_EXCEPTION_IF_NULL(type); | |||
| if (type->isa<mindspore::TensorType>()) { | |||
| auto input_element_type = type->cast<mindspore::TensorTypePtr>()->element(); | |||
| input_type_len = GetLengthOfDataType(input_element_type); | |||
| } else { | |||
| MS_LOG(EXCEPTION) << "Unknown type: " << type->type_name(); | |||
| } | |||
| return input_type_len; | |||
| } | |||
| std::vector<size_t> ExtractInputTypeLengthByNode(const CNodePtr &node) { | |||
| MS_EXCEPTION_IF_NULL(node); | |||
| std::vector<size_t> inputs_type_len; | |||
| std::vector<AnfNodePtr> node_inputs{node->inputs()}; | |||
| if ((node_inputs.size() == 2) && | |||
| (IsValueNode<ValueList>(node_inputs[1]) || IsValueNode<ValueTuple>(node_inputs[1]))) { | |||
| std::vector<ValuePtr> inputs_seq; | |||
| if (IsValueNode<ValueList>(node_inputs[1])) { | |||
| inputs_seq = node_inputs[1]->cast<ValueNodePtr>()->value()->cast<ValueListPtr>()->value(); | |||
| } else { | |||
| inputs_seq = node_inputs[1]->cast<ValueNodePtr>()->value()->cast<ValueTuplePtr>()->value(); | |||
| } | |||
| for (auto &ele : inputs_seq) { | |||
| auto tensor = ele->cast<tensor::TensorPtr>(); | |||
| MS_EXCEPTION_IF_NULL(tensor); | |||
| inputs_type_len.push_back(GetLengthOfDataType(tensor->Dtype())); | |||
| } | |||
| return inputs_type_len; | |||
| } | |||
| if ((node_inputs.size() == 2) && | |||
| (AnfNodeIsPrimitive(node_inputs[1], MAKE_TUPLE) || AnfNodeIsPrimitive(node_inputs[1], MAKE_LIST))) { | |||
| node_inputs = node_inputs[1]->cast<CNodePtr>()->inputs(); | |||
| } | |||
| // extract input element length | |||
| for (auto &input : node_inputs) { | |||
| if (IsValueNode<RefKey>(input)) { | |||
| auto func_graph = node->func_graph(); | |||
| MS_EXCEPTION_IF_NULL(func_graph); | |||
| std::vector<AnfNodePtr> parameters = FindParameterByRefKeyNode(input, func_graph); | |||
| if (parameters.size() != 1) { | |||
| MS_LOG(EXCEPTION) << "Find parameter by ref key node failed"; | |||
| } | |||
| inputs_type_len.push_back(GetInputsTypeLen(parameters[0])); | |||
| } else if (input->isa<CNode>() || input->isa<Parameter>() || IsValueNode<tensor::Tensor>(input)) { | |||
| // extract input shape from parameter and apply node | |||
| inputs_type_len.push_back(GetInputsTypeLen(input)); | |||
| } | |||
| } | |||
| return inputs_type_len; | |||
| } | |||
| std::vector<TypePtr> ExtractOutputTypeByNode(const CNodePtr &node) { | |||
| MS_EXCEPTION_IF_NULL(node); | |||
| std::vector<TypePtr> outputs_type; | |||
| // extract output element type | |||
| auto primary_output_type = node->Type(); | |||
| MS_EXCEPTION_IF_NULL(primary_output_type); | |||
| if (primary_output_type->isa<mindspore::Tuple>()) { | |||
| // in this case, the output is a tuple | |||
| auto tuple_output_type = primary_output_type->cast<mindspore::TuplePtr>(); | |||
| auto elements = tuple_output_type->elements(); | |||
| for (auto &ele : elements) { | |||
| if (ele->isa<mindspore::TensorType>()) { | |||
| auto ele_element_type = ele->cast<mindspore::TensorTypePtr>()->element(); | |||
| outputs_type.push_back(ele_element_type); | |||
| } else { | |||
| MS_LOG(EXCEPTION) << "Unknown type: " << primary_output_type->type_name(); | |||
| } | |||
| } | |||
| } else { | |||
| // in this case, the output is a single tensor | |||
| if (primary_output_type->isa<mindspore::TensorType>()) { | |||
| auto element_type = primary_output_type->cast<mindspore::TensorTypePtr>()->element(); | |||
| outputs_type.push_back(element_type); | |||
| } else { | |||
| MS_LOG(EXCEPTION) << "Unknown type: " << primary_output_type->type_name(); | |||
| } | |||
| } | |||
| return outputs_type; | |||
| } | |||
| std::vector<AnfNodePtr> FindParameterByRefKeyNode(const AnfNodePtr &node, const FuncGraphPtr &func_graph) { | |||
| MS_EXCEPTION_IF_NULL(node); | |||
| MS_EXCEPTION_IF_NULL(func_graph); | |||
| std::vector<AnfNodePtr> parameters; | |||
| if (!IsValueNode<RefKey>(node)) { | |||
| MS_LOG(ERROR) << "The node is not a ref key"; | |||
| return parameters; | |||
| } | |||
| auto ref_key = GetValueNode<RefKeyPtr>(node); | |||
| MS_EXCEPTION_IF_NULL(ref_key); | |||
| auto name = ref_key->tag(); | |||
| auto manager = func_graph->manager(); | |||
| MS_EXCEPTION_IF_NULL(manager); | |||
| auto roots = manager->roots(); | |||
| if (roots.size() != 1) { | |||
| MS_LOG(ERROR) << "The size of roots ( " << roots.size() << " ) is not 1"; | |||
| return parameters; | |||
| } | |||
| FuncGraphPtr root_g = roots.back(); | |||
| MS_EXCEPTION_IF_NULL(root_g); | |||
| for (auto ¶m_node : root_g->parameters()) { | |||
| auto param = param_node->cast<ParameterPtr>(); | |||
| if (param && (name == param->name())) { | |||
| parameters.push_back(param_node); | |||
| MS_LOG(INFO) << "The name of ref key is: " << name; | |||
| return parameters; | |||
| } | |||
| } | |||
| MS_LOG(ERROR) << "The name of ref key is: " << name << ", but have not found the parameter"; | |||
| return parameters; | |||
| } | |||
| bool AnfNodeIsPrimitive(const AnfNodePtr &anf_node, const std::string &prim_name) { | |||
| MS_EXCEPTION_IF_NULL(anf_node); | |||
| auto cnode = anf_node->cast<CNodePtr>(); | |||
| if ((cnode == nullptr) || !IsValueNode<Primitive>(cnode->input(0))) { | |||
| return false; | |||
| } | |||
| auto value_node = cnode->input(0)->cast<ValueNodePtr>(); | |||
| auto prim = GetValueNode<PrimitivePtr>(value_node); | |||
| MS_EXCEPTION_IF_NULL(prim); | |||
| if (prim->name() == prim_name) { | |||
| return true; | |||
| } | |||
| return false; | |||
| } | |||
| bool FindReshape(const CNodePtr &cnode, std::unordered_set<std::string> *op_cache) { | |||
| if ((cnode == nullptr) || !IsValueNode<Primitive>(cnode->input(0))) { | |||
| return false; | |||
| } | |||
| if (!IsParallelCareNode(cnode) || !cnode->has_user_data<OperatorInfo>()) { | |||
| return false; | |||
| } | |||
| ValueNodePtr prim_anf_node = cnode->input(0)->cast<ValueNodePtr>(); | |||
| PrimitivePtr prim = GetValueNode<PrimitivePtr>(prim_anf_node); | |||
| MS_EXCEPTION_IF_NULL(prim); | |||
| if (prim->name() == RESHAPE) { | |||
| auto operator_info = cnode->user_data<OperatorInfo>(); | |||
| std::string op_info_name = operator_info->name(); | |||
| if (op_cache->find(op_info_name) != op_cache->end()) { | |||
| return false; | |||
| } | |||
| op_cache->insert(op_info_name); | |||
| return true; | |||
| } | |||
| return false; | |||
| } | |||
| // Find previous node of Reshape, then obtain its strategy_cost_ vector to get its layout vector. | |||
| bool FindReshapePreNodeStraCosts(const AnfNodePtr &node, OperatorInfoPtr *pre_operator_info, int64_t *out_index) { | |||
| // if previous node is a parameter, handle it in the outsize. | |||
| if (node->isa<Parameter>()) { | |||
| return false; | |||
| } | |||
| if (!node->isa<CNode>()) { | |||
| return false; | |||
| } | |||
| CNodePtr cnode = node->cast<CNodePtr>(); | |||
| if (!IsValueNode<Primitive>(cnode->input(0))) { | |||
| return false; | |||
| } | |||
| auto node_op_info = cnode->user_data<OperatorInfo>(); | |||
| if (IsParallelCareNode(cnode) && (node_op_info != nullptr)) { | |||
| *pre_operator_info = node_op_info; | |||
| *out_index = 0; | |||
| return true; | |||
| } | |||
| ValueNodePtr prim_anf_node = cnode->input(0)->cast<ValueNodePtr>(); | |||
| PrimitivePtr prim = prim_anf_node->value()->cast<PrimitivePtr>(); | |||
| if (prim->name() == TUPLE_GETITEM) { | |||
| *out_index = GetTupleGetItemIndex(cnode); | |||
| // find tuple_get_item's previous node | |||
| auto pre_node = cnode->input(1); | |||
| if (!pre_node->isa<CNode>()) { | |||
| MS_LOG(EXCEPTION) << "tuple get item's second input is not a cnode"; | |||
| } | |||
| CNodePtr pre_cnode = pre_node->cast<CNodePtr>(); | |||
| auto pre_op_info = pre_cnode->user_data<OperatorInfo>(); | |||
| if (IsParallelCareNode(pre_cnode) && (pre_op_info != nullptr)) { | |||
| *pre_operator_info = pre_op_info; | |||
| return true; | |||
| } | |||
| return false; | |||
| } | |||
| for (size_t index = 0; index < cnode->inputs().size(); ++index) { | |||
| if (prim->name() == DEPEND && index != 1) { | |||
| continue; | |||
| } | |||
| if (!FindReshapePreNodeStraCosts(cnode->inputs()[index], pre_operator_info, out_index)) { | |||
| continue; | |||
| } | |||
| return true; | |||
| } | |||
| MS_LOG(WARNING) | |||
| << "FindReshapePreNodeStraCosts failed, if reshape is not the first primitive, there must be some error"; | |||
| return false; | |||
| } | |||
| // Find next node of Reshape, then obtain its strategy_cost_ vector to get its layout vector. | |||
| // if reshape's output connect to several primitive, return the first layout found | |||
| bool FindReshapeNextNodeStraCosts(const CNodePtr &cnode, OperatorInfoPtr *next_operator_info, int64_t *in_index) { | |||
| MS_EXCEPTION_IF_NULL(cnode); | |||
| MS_EXCEPTION_IF_NULL(cnode->func_graph()); | |||
| FuncGraphManagerPtr manager = cnode->func_graph()->manager(); | |||
| MS_EXCEPTION_IF_NULL(manager); | |||
| AnfNodeIndexSet node_set = manager->node_users()[cnode]; | |||
| for (auto &node_pair : node_set) { | |||
| CNodePtr use_apply = node_pair.first->cast<CNodePtr>(); | |||
| if (use_apply == nullptr || !IsValueNode<Primitive>(use_apply->input(0))) { | |||
| continue; | |||
| } | |||
| ValueNodePtr prim_anf_node = use_apply->input(0)->cast<ValueNodePtr>(); | |||
| MS_EXCEPTION_IF_NULL(prim_anf_node); | |||
| PrimitivePtr node_prim = prim_anf_node->value()->cast<PrimitivePtr>(); | |||
| MS_EXCEPTION_IF_NULL(node_prim); | |||
| MS_LOG(INFO) << "FindNextLayout prim " << node_prim->name(); | |||
| if (node_prim->name() == DEPEND && node_pair.second != 1) { | |||
| continue; | |||
| } | |||
| auto op_info = use_apply->user_data<OperatorInfo>(); | |||
| if (IsParallelCareNode(use_apply) && (op_info != nullptr)) { | |||
| MS_LOG(INFO) << "FindReshapeNextNodeStraCosts success prim " << node_prim->name(); | |||
| *next_operator_info = op_info; | |||
| *in_index = node_pair.second - 1; | |||
| return true; | |||
| } | |||
| MS_LOG(DEBUG) << "FindReshapeNextNodeStraCosts failed prim " << node_prim->name() << " " | |||
| << IsParallelCareNode(use_apply) << " " << (op_info != nullptr); | |||
| if (FindReshapeNextNodeStraCosts(use_apply, next_operator_info, in_index)) { | |||
| return true; | |||
| } | |||
| } | |||
| return false; | |||
| } | |||
| } // namespace parallel | |||
| } // namespace mindspore | |||
| @@ -18,13 +18,37 @@ | |||
| #define MINDSPORE_CCSRC_FRONTEND_PARALLEL_GRAPH_UTIL_NODE_INFO_H_ | |||
| #include <string> | |||
| #include <vector> | |||
| #include <memory> | |||
| #include <unordered_set> | |||
| #include "base/base.h" | |||
| #include "ir/anf.h" | |||
| #include "frontend/parallel/ops_info/operator_info.h" | |||
| namespace mindspore { | |||
| namespace parallel { | |||
| using OperatorInfoPtr = std::shared_ptr<mindspore::parallel::OperatorInfo>; | |||
| std::string ParameterName(const AnfNodePtr &node_ptr); | |||
| bool ParameterRequireGrad(const AnfNodePtr &node_ptr); | |||
| size_t GetLengthOfDataType(const TypePtr &type); | |||
| std::vector<bool> ExtractInputParameterByNode(const CNodePtr &node); | |||
| std::vector<size_t> ExtractInputTypeLengthByNode(const CNodePtr &node); | |||
| std::vector<TypePtr> ExtractOutputTypeByNode(const CNodePtr &node); | |||
| std::vector<AnfNodePtr> FindParameterByRefKeyNode(const AnfNodePtr &node, const FuncGraphPtr &func_graph); | |||
| bool AnfNodeIsPrimitive(const AnfNodePtr &anf_node, const std::string &prim_name); | |||
| bool FindReshape(const CNodePtr &cnode, std::unordered_set<std::string> *op_cache); | |||
| bool FindReshapePreNodeStraCosts(const AnfNodePtr &node, OperatorInfoPtr *pre_operator_info, int64_t *out_index); | |||
| bool FindReshapeNextNodeStraCosts(const CNodePtr &cnode, OperatorInfoPtr *next_operator_info, int64_t *in_index); | |||
| } // namespace parallel | |||
| } // namespace mindspore | |||
| @@ -63,7 +63,7 @@ bool StepAutoParallel(const FuncGraphPtr &root, const opt::OptimizerPtr &) { | |||
| // check whether strategy_search_mode is valid | |||
| std::string strategy_search_mode = ParallelContext::GetInstance()->strategy_search_mode(); | |||
| if ((strategy_search_mode != DYNAMIC_PROGRAMMING) && (strategy_search_mode != RECURSIVE_PROGRAMMING)) { | |||
| // Setting searching mode: dynanic programming as default. | |||
| // Setting searching mode: dynamic programming as default. | |||
| strategy_search_mode = DYNAMIC_PROGRAMMING; | |||
| MS_LOG(INFO) << "Non-idicated strategy searching mode, using DP searching mode as default"; | |||
| } | |||
| @@ -112,170 +112,6 @@ bool StepAutoParallel(const FuncGraphPtr &root, const opt::OptimizerPtr &) { | |||
| return changes; | |||
| } | |||
| // Given the node, return whether each input is a parameter or a output of a operator. | |||
| // The returned boolean vector should be the same order of the inputs, thus its implementation | |||
| // is closely consistent with ExtractShape() in step_parallel.cc | |||
| std::vector<bool> ExtractInputParameterByNode(const CNodePtr &node) { | |||
| std::vector<bool> is_parameter; | |||
| std::vector<AnfNodePtr> node_inputs{node->inputs()}; | |||
| // input is a ValueList or ValueTuple, then all inputs are not parameter. | |||
| if ((node_inputs.size() == 2) && | |||
| (IsValueNode<ValueList>(node_inputs[1]) || IsValueNode<ValueTuple>(node_inputs[1]))) { | |||
| std::vector<ValuePtr> inputs_seq; | |||
| if (IsValueNode<ValueList>(node_inputs[1])) { | |||
| inputs_seq = node_inputs[1]->cast<ValueNodePtr>()->value()->cast<ValueListPtr>()->value(); | |||
| } else { | |||
| inputs_seq = node_inputs[1]->cast<ValueNodePtr>()->value()->cast<ValueTuplePtr>()->value(); | |||
| } | |||
| return std::vector<bool>(inputs_seq.size(), false); | |||
| } | |||
| if ((node_inputs.size() == 2) && | |||
| (AnfNodeIsPrimitive(node_inputs[1], MAKE_TUPLE) || AnfNodeIsPrimitive(node_inputs[1], MAKE_LIST))) { | |||
| node_inputs = node_inputs[1]->cast<CNodePtr>()->inputs(); | |||
| } | |||
| for (size_t i = 1; i < node_inputs.size(); ++i) { | |||
| auto input = node_inputs[i]; | |||
| if (input->isa<Parameter>()) { | |||
| auto input_parameter = input->cast<ParameterPtr>(); | |||
| is_parameter.push_back(ParameterRequireGrad(input_parameter)); | |||
| } else if (input->isa<CNode>() || IsValueNode<tensor::Tensor>(input) || IsValueNode<RefKey>(input)) { | |||
| is_parameter.push_back(false); | |||
| } | |||
| } | |||
| return is_parameter; | |||
| } | |||
| // Given the type, return the number of bytes to represent this type | |||
| size_t GetLengthOfDataType(const TypePtr &type) { | |||
| switch (type->type_id()) { | |||
| case kNumberTypeBool: | |||
| return sizeof(bool); | |||
| case kNumberTypeInt8: | |||
| return sizeof(int8_t); | |||
| case kNumberTypeInt16: | |||
| return sizeof(int16_t); | |||
| case kNumberTypeInt32: | |||
| return sizeof(int32_t); | |||
| case kNumberTypeInt64: | |||
| return sizeof(int64_t); | |||
| case kNumberTypeUInt8: | |||
| return sizeof(uint8_t); | |||
| case kNumberTypeUInt16: | |||
| return sizeof(uint16_t); | |||
| case kNumberTypeUInt32: | |||
| return sizeof(uint32_t); | |||
| case kNumberTypeUInt64: | |||
| return sizeof(uint64_t); | |||
| case kNumberTypeFloat16: | |||
| return sizeof(float) / 2; | |||
| case kNumberTypeFloat32: | |||
| return sizeof(float); | |||
| case kNumberTypeFloat64: | |||
| return sizeof(double); | |||
| case kNumberTypeInt: | |||
| return sizeof(int64_t); | |||
| case kNumberTypeUInt: | |||
| return sizeof(unsigned int64_t); | |||
| case kNumberTypeFloat: | |||
| return sizeof(float); | |||
| default: | |||
| MS_LOG(EXCEPTION) << "Unexpected type " << type->type_name(); | |||
| } | |||
| } | |||
| size_t GetInputsTypeLen(const AnfNodePtr &input) { | |||
| MS_EXCEPTION_IF_NULL(input); | |||
| if (!input->isa<CNode>() && !input->isa<Parameter>() && !IsValueNode<tensor::Tensor>(input)) { | |||
| MS_LOG(EXCEPTION) << "The input node is not a cnode or parameter or tensor"; | |||
| } | |||
| size_t input_type_len = 0; | |||
| auto type = input->Type(); | |||
| MS_EXCEPTION_IF_NULL(type); | |||
| if (type->isa<mindspore::TensorType>()) { | |||
| auto input_element_type = type->cast<mindspore::TensorTypePtr>()->element(); | |||
| input_type_len = GetLengthOfDataType(input_element_type); | |||
| } else { | |||
| MS_LOG(EXCEPTION) << "Unknown type: " << type->type_name(); | |||
| } | |||
| return input_type_len; | |||
| } | |||
| std::vector<size_t> ExtractInputTypeLengthByNode(const CNodePtr &node) { | |||
| MS_EXCEPTION_IF_NULL(node); | |||
| std::vector<size_t> inputs_type_len; | |||
| std::vector<AnfNodePtr> node_inputs{node->inputs()}; | |||
| if ((node_inputs.size() == 2) && | |||
| (IsValueNode<ValueList>(node_inputs[1]) || IsValueNode<ValueTuple>(node_inputs[1]))) { | |||
| std::vector<ValuePtr> inputs_seq; | |||
| if (IsValueNode<ValueList>(node_inputs[1])) { | |||
| inputs_seq = node_inputs[1]->cast<ValueNodePtr>()->value()->cast<ValueListPtr>()->value(); | |||
| } else { | |||
| inputs_seq = node_inputs[1]->cast<ValueNodePtr>()->value()->cast<ValueTuplePtr>()->value(); | |||
| } | |||
| for (auto &ele : inputs_seq) { | |||
| auto tensor = ele->cast<tensor::TensorPtr>(); | |||
| MS_EXCEPTION_IF_NULL(tensor); | |||
| inputs_type_len.push_back(GetLengthOfDataType(tensor->Dtype())); | |||
| } | |||
| return inputs_type_len; | |||
| } | |||
| if ((node_inputs.size() == 2) && | |||
| (AnfNodeIsPrimitive(node_inputs[1], MAKE_TUPLE) || AnfNodeIsPrimitive(node_inputs[1], MAKE_LIST))) { | |||
| node_inputs = node_inputs[1]->cast<CNodePtr>()->inputs(); | |||
| } | |||
| // extract input element length | |||
| for (auto &input : node_inputs) { | |||
| if (IsValueNode<RefKey>(input)) { | |||
| auto func_graph = node->func_graph(); | |||
| MS_EXCEPTION_IF_NULL(func_graph); | |||
| std::vector<AnfNodePtr> parameters = FindParameterByRefKeyNode(input, func_graph); | |||
| if (parameters.size() != 1) { | |||
| MS_LOG(EXCEPTION) << "Find parameter by ref key node failed"; | |||
| } | |||
| inputs_type_len.push_back(GetInputsTypeLen(parameters[0])); | |||
| } else if (input->isa<CNode>() || input->isa<Parameter>() || IsValueNode<tensor::Tensor>(input)) { | |||
| // extract input shape from parameter and apply node | |||
| inputs_type_len.push_back(GetInputsTypeLen(input)); | |||
| } | |||
| } | |||
| return inputs_type_len; | |||
| } | |||
| std::vector<TypePtr> ExtractOutputTypeByNode(const CNodePtr &node) { | |||
| MS_EXCEPTION_IF_NULL(node); | |||
| std::vector<TypePtr> outputs_type; | |||
| // extract output element type | |||
| auto primary_output_type = node->Type(); | |||
| MS_EXCEPTION_IF_NULL(primary_output_type); | |||
| if (primary_output_type->isa<mindspore::Tuple>()) { | |||
| // in this case, the output is a tuple | |||
| auto tuple_output_type = primary_output_type->cast<mindspore::TuplePtr>(); | |||
| auto elements = tuple_output_type->elements(); | |||
| for (auto &ele : elements) { | |||
| if (ele->isa<mindspore::TensorType>()) { | |||
| auto ele_element_type = ele->cast<mindspore::TensorTypePtr>()->element(); | |||
| outputs_type.push_back(ele_element_type); | |||
| } else { | |||
| MS_LOG(EXCEPTION) << "Unknown type: " << primary_output_type->type_name(); | |||
| } | |||
| } | |||
| } else { | |||
| // in this case, the output is a single tensor | |||
| if (primary_output_type->isa<mindspore::TensorType>()) { | |||
| auto element_type = primary_output_type->cast<mindspore::TensorTypePtr>()->element(); | |||
| outputs_type.push_back(element_type); | |||
| } else { | |||
| MS_LOG(EXCEPTION) << "Unknown type: " << primary_output_type->type_name(); | |||
| } | |||
| } | |||
| return outputs_type; | |||
| } | |||
| bool IsElementWiseOperator(const std::string &op_name) { | |||
| // clang-format off | |||
| static const std::set<std::string> elementwise_op = {ACTIVATION, GELU, TANH, | |||
| @@ -381,6 +217,11 @@ bool IsOperatorsInTwoSeparateLoops(const CNodePtr &a_cnode, const CNodePtr &b_cn | |||
| return true; | |||
| } | |||
| void InitCostGraph() { | |||
| entire_costgraph = std::make_shared<CostGraph>(); | |||
| entire_costgraph->SetDeviceMemoryAndCostParameter(); | |||
| } | |||
| OperatorInfoPtr CreateTheOperatorInfo(const PrimitivePtr &prim, const CNodePtr &cnode, StrategyMap *stra_map) { | |||
| MS_EXCEPTION_IF_NULL(prim); | |||
| MS_EXCEPTION_IF_NULL(cnode); | |||
| @@ -491,8 +332,6 @@ OperatorInfoPtr CreateTheOperatorInfo(const PrimitivePtr &prim, const CNodePtr & | |||
| // Using CNode's UniqueIds to construct nodes | |||
| Status ConstructCostGraphNodesByUniqueId(const std::vector<AnfNodePtr> &all_nodes, const FuncGraphPtr &) { | |||
| MS_LOG(INFO) << "Constructing nodes for cost graph begins."; | |||
| entire_costgraph = std::make_shared<CostGraph>(); | |||
| entire_costgraph->SetDeviceMemoryAndCostParameter(); | |||
| // The map from CNode's UniqueId to its operatorInfo | |||
| std::map<std::string, OperatorInfoPtr> from_cnode_to_info; | |||
| // The operator_infos in a loop | |||
| @@ -506,7 +345,7 @@ Status ConstructCostGraphNodesByUniqueId(const std::vector<AnfNodePtr> &all_node | |||
| MS_LOG(EXCEPTION) << "Load strategy checkpoint failed"; | |||
| } | |||
| } | |||
| // Step 1 | |||
| for (auto &node : all_nodes) { | |||
| // NOTE: we only care about splittable Primitive operators | |||
| auto cnode = node->cast<CNodePtr>(); | |||
| @@ -588,8 +427,6 @@ Status ConstructCostGraphNodesByUniqueId(const std::vector<AnfNodePtr> &all_node | |||
| // Using CNode's UniqueIdThroughCopys to construct nodes | |||
| Status ConstructCostGraphNodesByUniqueIdTC(const std::vector<AnfNodePtr> &all_nodes, const FuncGraphPtr &) { | |||
| MS_LOG(INFO) << "Constructing nodes for cost graph begins."; | |||
| entire_costgraph = std::make_shared<CostGraph>(); | |||
| entire_costgraph->SetDeviceMemoryAndCostParameter(); | |||
| // The map from CNode's UniqueIdThroughCopy to its operatorInfo | |||
| std::map<std::string, OperatorInfoPtr> from_cnode_to_info; | |||
| // The operator_infos in a loop | |||
| @@ -937,115 +774,6 @@ void AugmentCostGraph(const std::vector<AnfNodePtr> &all_nodes) { | |||
| } | |||
| } | |||
| bool FindReshape(const CNodePtr &cnode, std::unordered_set<std::string> *op_cache) { | |||
| if ((cnode == nullptr) || !IsValueNode<Primitive>(cnode->input(0))) { | |||
| return false; | |||
| } | |||
| if (!IsParallelCareNode(cnode) || !cnode->has_user_data<OperatorInfo>()) { | |||
| return false; | |||
| } | |||
| ValueNodePtr prim_anf_node = cnode->input(0)->cast<ValueNodePtr>(); | |||
| PrimitivePtr prim = GetValueNode<PrimitivePtr>(prim_anf_node); | |||
| MS_EXCEPTION_IF_NULL(prim); | |||
| if (prim->name() == RESHAPE) { | |||
| auto operator_info = cnode->user_data<OperatorInfo>(); | |||
| std::string op_info_name = operator_info->name(); | |||
| if (op_cache->find(op_info_name) != op_cache->end()) { | |||
| return false; | |||
| } | |||
| op_cache->insert(op_info_name); | |||
| return true; | |||
| } | |||
| return false; | |||
| } | |||
| // find previous node, then obtain its strategy_cost_ vector to get its layout vector. | |||
| bool FindPreNodeStraCosts(const AnfNodePtr &node, OperatorInfoPtr *pre_operator_info, int64_t *out_index) { | |||
| // if previous node is a parameter, handle it in the outsize. | |||
| if (node->isa<Parameter>()) { | |||
| return false; | |||
| } | |||
| if (!node->isa<CNode>()) { | |||
| return false; | |||
| } | |||
| CNodePtr cnode = node->cast<CNodePtr>(); | |||
| if (!IsValueNode<Primitive>(cnode->input(0))) { | |||
| return false; | |||
| } | |||
| auto node_op_info = cnode->user_data<OperatorInfo>(); | |||
| if (IsParallelCareNode(cnode) && (node_op_info != nullptr)) { | |||
| *pre_operator_info = node_op_info; | |||
| *out_index = 0; | |||
| return true; | |||
| } | |||
| ValueNodePtr prim_anf_node = cnode->input(0)->cast<ValueNodePtr>(); | |||
| PrimitivePtr prim = prim_anf_node->value()->cast<PrimitivePtr>(); | |||
| if (prim->name() == TUPLE_GETITEM) { | |||
| *out_index = GetTupleGetItemIndex(cnode); | |||
| // find tuple_get_item's previous node | |||
| auto pre_node = cnode->input(1); | |||
| if (!pre_node->isa<CNode>()) { | |||
| MS_LOG(EXCEPTION) << "tuple get item's second input is not a cnode"; | |||
| } | |||
| CNodePtr pre_cnode = pre_node->cast<CNodePtr>(); | |||
| auto pre_op_info = pre_cnode->user_data<OperatorInfo>(); | |||
| if (IsParallelCareNode(pre_cnode) && (pre_op_info != nullptr)) { | |||
| *pre_operator_info = pre_op_info; | |||
| return true; | |||
| } | |||
| return false; | |||
| } | |||
| for (size_t index = 0; index < cnode->inputs().size(); ++index) { | |||
| if (prim->name() == DEPEND && index != 1) { | |||
| continue; | |||
| } | |||
| if (!FindPreNodeStraCosts(cnode->inputs()[index], pre_operator_info, out_index)) { | |||
| continue; | |||
| } | |||
| return true; | |||
| } | |||
| MS_LOG(WARNING) << "FindPreNodeStraCosts failed, if reshape is not the first primitive, there must be some error"; | |||
| return false; | |||
| } | |||
| // find next node, then obtain its strategy_cost_ vector to get its layout vector. | |||
| // if reshape's output connect to several primitive, return the first layout found | |||
| bool FindNextNodeStraCosts(const CNodePtr &cnode, OperatorInfoPtr *next_operator_info, int64_t *in_index) { | |||
| MS_EXCEPTION_IF_NULL(cnode); | |||
| MS_EXCEPTION_IF_NULL(cnode->func_graph()); | |||
| FuncGraphManagerPtr manager = cnode->func_graph()->manager(); | |||
| MS_EXCEPTION_IF_NULL(manager); | |||
| AnfNodeIndexSet node_set = manager->node_users()[cnode]; | |||
| for (auto &node_pair : node_set) { | |||
| CNodePtr use_apply = node_pair.first->cast<CNodePtr>(); | |||
| if (use_apply == nullptr || !IsValueNode<Primitive>(use_apply->input(0))) { | |||
| continue; | |||
| } | |||
| ValueNodePtr prim_anf_node = use_apply->input(0)->cast<ValueNodePtr>(); | |||
| MS_EXCEPTION_IF_NULL(prim_anf_node); | |||
| PrimitivePtr node_prim = prim_anf_node->value()->cast<PrimitivePtr>(); | |||
| MS_EXCEPTION_IF_NULL(node_prim); | |||
| MS_LOG(INFO) << "FindNextLayout prim " << node_prim->name(); | |||
| if (node_prim->name() == DEPEND && node_pair.second != 1) { | |||
| continue; | |||
| } | |||
| auto op_info = use_apply->user_data<OperatorInfo>(); | |||
| if (IsParallelCareNode(use_apply) && (op_info != nullptr)) { | |||
| MS_LOG(INFO) << "FindNextNodeStraCosts success prim " << node_prim->name(); | |||
| *next_operator_info = op_info; | |||
| *in_index = node_pair.second - 1; | |||
| return true; | |||
| } | |||
| MS_LOG(DEBUG) << "FindNextNodeStraCosts failed prim " << node_prim->name() << " " << IsParallelCareNode(use_apply) | |||
| << " " << (op_info != nullptr); | |||
| if (FindNextNodeStraCosts(use_apply, next_operator_info, in_index)) { | |||
| return true; | |||
| } | |||
| } | |||
| return false; | |||
| } | |||
| void ReshapeCostCompute(const std::vector<AnfNodePtr> &all_nodes) { | |||
| std::unordered_set<std::string> op_cache; | |||
| for (auto node : all_nodes) { | |||
| @@ -1066,8 +794,8 @@ void ReshapeCostCompute(const std::vector<AnfNodePtr> &all_nodes) { | |||
| pre_operator_info = reshape_info; | |||
| pre_stra_costs = reshape_info->strategy_cost(); | |||
| } else { | |||
| if (!FindPreNodeStraCosts(pre_node, &pre_operator_info, &out_index)) { | |||
| MS_LOG(EXCEPTION) << "FindPreNodeStraCosts for reshape failed"; | |||
| if (!FindReshapePreNodeStraCosts(pre_node, &pre_operator_info, &out_index)) { | |||
| MS_LOG(EXCEPTION) << "FindReshapePreNodeStraCosts for reshape failed"; | |||
| } | |||
| pre_stra_costs = pre_operator_info->strategy_cost(); | |||
| } | |||
| @@ -1075,9 +803,9 @@ void ReshapeCostCompute(const std::vector<AnfNodePtr> &all_nodes) { | |||
| int64_t in_index = 0; | |||
| OperatorInfoPtr next_operator_info; | |||
| std::vector<std::shared_ptr<StrategyWithCost>> next_stra_costs; | |||
| bool find_next_node = FindNextNodeStraCosts(cnode, &next_operator_info, &in_index); | |||
| bool find_next_node = FindReshapeNextNodeStraCosts(cnode, &next_operator_info, &in_index); | |||
| if (!find_next_node) { | |||
| MS_LOG(INFO) << "FindNextNodeStraCosts for reshape failed"; | |||
| MS_LOG(INFO) << "FindReshapeNextNodeStraCosts for reshape failed"; | |||
| } | |||
| // set input_layout and output_layout for reshape. | |||
| // init reshape and set cost for each input_layout and output_layout. | |||
| @@ -1122,6 +850,7 @@ Status ParallelStrategySearch(const std::vector<AnfNodePtr> &all_nodes, const Fu | |||
| // | |||
| // OUTPUT: the determined strategy for each operator. | |||
| InitCostGraph(); | |||
| // Step 1 | |||
| if (CostModelContext::GetInstance()->is_multi_subgraphs()) { | |||
| if (ConstructCostGraphNodesByUniqueIdTC(all_nodes, root) == SUCCESS) { | |||
| @@ -28,20 +28,14 @@ | |||
| namespace mindspore { | |||
| namespace parallel { | |||
| bool IsSplittableOperator(const std::string &); | |||
| bool IsAutoParallelCareNode(const CNodePtr &); | |||
| // main step of Auto-parallel | |||
| bool StepAutoParallel(const FuncGraphPtr &func_graph, const opt::OptimizerPtr &optimizer); | |||
| size_t GetLengthOfDataType(const TypePtr &type); | |||
| std::vector<bool> ExtractInputParameterByNode(const CNodePtr &node); | |||
| bool IsSplittableOperator(const std::string &); | |||
| std::vector<size_t> ExtractInputTypeLengthByNode(const CNodePtr &node); | |||
| bool IsAutoParallelCareNode(const CNodePtr &); | |||
| std::vector<TypePtr> ExtractOutputTypeByNode(const CNodePtr &node); | |||
| void InitCostGraph(); | |||
| Status ConstructCostGraphNodesByUniqueId(const std::vector<AnfNodePtr> &all_nodes, const FuncGraphPtr &root); | |||
| @@ -292,22 +292,6 @@ TensorLayout GetTensorInLayout(const CNodePtr &middle_node, const PrimitivePtr & | |||
| return tensorinfo_in.tensor_layout(); | |||
| } | |||
| bool AnfNodeIsPrimitive(const AnfNodePtr &anf_node, const std::string &prim_name) { | |||
| MS_EXCEPTION_IF_NULL(anf_node); | |||
| auto cnode = anf_node->cast<CNodePtr>(); | |||
| if ((cnode == nullptr) || !IsValueNode<Primitive>(cnode->input(0))) { | |||
| return false; | |||
| } | |||
| auto value_node = cnode->input(0)->cast<ValueNodePtr>(); | |||
| auto prim = GetValueNode<PrimitivePtr>(value_node); | |||
| MS_EXCEPTION_IF_NULL(prim); | |||
| if (prim->name() == prim_name) { | |||
| return true; | |||
| } | |||
| return false; | |||
| } | |||
| std::string GetPrimName(const CNodePtr &node) { | |||
| MS_EXCEPTION_IF_NULL(node); | |||
| if (!IsValueNode<Primitive>(node->input(0))) { | |||
| @@ -1219,42 +1203,6 @@ Shapes GetNodeShape(const AnfNodePtr &node) { | |||
| return shapes; | |||
| } | |||
| std::vector<AnfNodePtr> FindParameterByRefKeyNode(const AnfNodePtr &node, const FuncGraphPtr &func_graph) { | |||
| MS_EXCEPTION_IF_NULL(node); | |||
| MS_EXCEPTION_IF_NULL(func_graph); | |||
| std::vector<AnfNodePtr> parameters; | |||
| if (!IsValueNode<RefKey>(node)) { | |||
| MS_LOG(ERROR) << "The node is not a ref key"; | |||
| return parameters; | |||
| } | |||
| auto ref_key = GetValueNode<RefKeyPtr>(node); | |||
| MS_EXCEPTION_IF_NULL(ref_key); | |||
| auto name = ref_key->tag(); | |||
| auto manager = func_graph->manager(); | |||
| MS_EXCEPTION_IF_NULL(manager); | |||
| auto roots = manager->roots(); | |||
| if (roots.size() != 1) { | |||
| MS_LOG(ERROR) << "The size of roots ( " << roots.size() << " ) is not 1"; | |||
| return parameters; | |||
| } | |||
| FuncGraphPtr root_g = roots.back(); | |||
| MS_EXCEPTION_IF_NULL(root_g); | |||
| for (auto ¶m_node : root_g->parameters()) { | |||
| auto param = param_node->cast<ParameterPtr>(); | |||
| if (param && (name == param->name())) { | |||
| parameters.push_back(param_node); | |||
| MS_LOG(INFO) << "The name of ref key is: " << name; | |||
| return parameters; | |||
| } | |||
| } | |||
| MS_LOG(ERROR) << "The name of ref key is: " << name << ", but have not found the parameter"; | |||
| return parameters; | |||
| } | |||
| Shapes GetRefKeyNodeShape(const AnfNodePtr &node, const FuncGraphPtr &func_graph) { | |||
| MS_EXCEPTION_IF_NULL(node); | |||
| MS_EXCEPTION_IF_NULL(func_graph); | |||
| @@ -100,8 +100,6 @@ StrategyPtr ExtractStrategy(std::unordered_map<std::string, ValuePtr> attrs); | |||
| Shapes GetNodeShape(const AnfNodePtr &node); | |||
| std::vector<AnfNodePtr> FindParameterByRefKeyNode(const AnfNodePtr &node, const FuncGraphPtr &func_graph); | |||
| // Extract shape from anfnode | |||
| std::vector<Shapes> ExtractShape(const CNodePtr &node); | |||
| @@ -154,8 +152,6 @@ std::set<FuncGraphPtr> ForwardGraph(const FuncGraphPtr &root); | |||
| std::vector<std::string> ExtractInputsTensorName(const CNodePtr &node); | |||
| bool AnfNodeIsPrimitive(const AnfNodePtr &anf_node, const std::string &prim_name); | |||
| using RefKeyPair = std::pair<AnfNodePtr, std::vector<AnfNodePtr>>; | |||
| using ParameterUsersInfo = std::pair<std::string, std::pair<AnfNodePtr, AnfNodeIndexSet>>; | |||