Merge pull request !3574 from hewei/rename_user_data_functags/v0.7.0-beta
| @@ -267,7 +267,7 @@ void DumpParallelInfo(const CNodePtr &node, const std::shared_ptr<SubGraphIRInfo | |||||
| return; | return; | ||||
| } | } | ||||
| auto operator_info = node->GetUserData<parallel::OperatorInfo>(); | |||||
| auto operator_info = node->user_data<parallel::OperatorInfo>(); | |||||
| if (operator_info == nullptr) { | if (operator_info == nullptr) { | ||||
| return; | return; | ||||
| } | } | ||||
| @@ -437,7 +437,7 @@ static void DrawParallelInfo(Graphviz *const graph_obj, const CNodePtr &node) { | |||||
| if (graph_obj == nullptr || node == nullptr) { | if (graph_obj == nullptr || node == nullptr) { | ||||
| return; | return; | ||||
| } | } | ||||
| auto distributed_operation_info = node->GetUserData<parallel::OperatorInfo>(); | |||||
| auto distributed_operation_info = node->user_data<parallel::OperatorInfo>(); | |||||
| if (distributed_operation_info != nullptr) { | if (distributed_operation_info != nullptr) { | ||||
| auto strategyPtr = distributed_operation_info->strategy(); | auto strategyPtr = distributed_operation_info->strategy(); | ||||
| if (strategyPtr != nullptr) { | if (strategyPtr != nullptr) { | ||||
| @@ -50,7 +50,7 @@ std::unordered_set<CNodePtr> FindCNodesWithPara(const AnfNodePtr ¶, uint32_t | |||||
| if (node_prim->name() == DEPEND && node_pair.second != 1) { | if (node_prim->name() == DEPEND && node_pair.second != 1) { | ||||
| continue; | continue; | ||||
| } | } | ||||
| if (IsParallelCareNode(cnode) && cnode->HasUserData<OperatorInfo>()) { | |||||
| if (IsParallelCareNode(cnode) && cnode->has_user_data<OperatorInfo>()) { | |||||
| (void)cnode_set.emplace(cnode); | (void)cnode_set.emplace(cnode); | ||||
| } else { | } else { | ||||
| auto cnode_set_sub = FindCNodesWithPara(node_pair.first, recursive_times + 1); | auto cnode_set_sub = FindCNodesWithPara(node_pair.first, recursive_times + 1); | ||||
| @@ -98,7 +98,7 @@ CNodeCostMap AllreduceFusion::FindCNode(const AnfNodePtr &from, uint32_t recursi | |||||
| return cnode_dist; | return cnode_dist; | ||||
| } | } | ||||
| auto operator_info = cnode->GetUserData<OperatorInfo>(); | |||||
| auto operator_info = cnode->user_data<OperatorInfo>(); | |||||
| MS_LOG(DEBUG) << "cnode " << cnode->ToString() << " IsParallelCareNode: " << IsParallelCareNode(cnode) | MS_LOG(DEBUG) << "cnode " << cnode->ToString() << " IsParallelCareNode: " << IsParallelCareNode(cnode) | ||||
| << " operator_info: " << (operator_info != nullptr); | << " operator_info: " << (operator_info != nullptr); | ||||
| @@ -83,7 +83,7 @@ Status AllreduceNode::AddPara(const AnfNodePtr &node_ptr) { | |||||
| } | } | ||||
| auto para_ptr = node_ptr->cast<ParameterPtr>(); | auto para_ptr = node_ptr->cast<ParameterPtr>(); | ||||
| MS_EXCEPTION_IF_NULL(para_ptr); | MS_EXCEPTION_IF_NULL(para_ptr); | ||||
| auto layout_ptr = para_ptr->GetUserData<TensorLayout>(); | |||||
| auto layout_ptr = para_ptr->user_data<TensorLayout>(); | |||||
| if (layout_ptr == nullptr) { | if (layout_ptr == nullptr) { | ||||
| MS_LOG(ERROR) << "layout_ptr is nullptr!"; | MS_LOG(ERROR) << "layout_ptr is nullptr!"; | ||||
| return FAILED; | return FAILED; | ||||
| @@ -37,7 +37,7 @@ py::dict GetParameterLayout(const FuncGraphPtr &graph) { | |||||
| for (auto para : graph_params) { | for (auto para : graph_params) { | ||||
| std::string name = std::static_pointer_cast<Parameter>(para)->name(); | std::string name = std::static_pointer_cast<Parameter>(para)->name(); | ||||
| auto tensor_layout = para->GetUserData<parallel::TensorLayout>(); | |||||
| auto tensor_layout = para->user_data<parallel::TensorLayout>(); | |||||
| if (tensor_layout == nullptr) { | if (tensor_layout == nullptr) { | ||||
| MS_LOG(INFO) << "GetParameterLayout nullptr name = " << name; | MS_LOG(INFO) << "GetParameterLayout nullptr name = " << name; | ||||
| } else { | } else { | ||||
| @@ -70,7 +70,7 @@ py::dict GetCNodeStrategy(const FuncGraphPtr &graph) { | |||||
| if (node->isa<CNode>()) { | if (node->isa<CNode>()) { | ||||
| auto cnode = node->cast<CNodePtr>(); | auto cnode = node->cast<CNodePtr>(); | ||||
| MS_EXCEPTION_IF_NULL(cnode); | MS_EXCEPTION_IF_NULL(cnode); | ||||
| auto distributed_operation_info = cnode->GetUserData<OperatorInfo>(); | |||||
| auto distributed_operation_info = cnode->user_data<OperatorInfo>(); | |||||
| if (distributed_operation_info != nullptr) { | if (distributed_operation_info != nullptr) { | ||||
| auto strategyPtr = distributed_operation_info->strategy(); | auto strategyPtr = distributed_operation_info->strategy(); | ||||
| if (strategyPtr != nullptr) { | if (strategyPtr != nullptr) { | ||||
| @@ -435,7 +435,7 @@ Status ConstructCostGraphNodesByUniqueId(const std::vector<AnfNodePtr> &all_node | |||||
| std::vector<std::string> inputs_tensor_name = ExtractInputsTensorName(cnode); | std::vector<std::string> inputs_tensor_name = ExtractInputsTensorName(cnode); | ||||
| entire_costgraph->AddOperator(operator_info); | entire_costgraph->AddOperator(operator_info); | ||||
| cnode->SetUserData<OperatorInfo>(operator_info); | |||||
| cnode->set_user_data<OperatorInfo>(operator_info); | |||||
| MS_LOG(INFO) << "The CNode with UniqueId: " << cnode->UniqueId() | MS_LOG(INFO) << "The CNode with UniqueId: " << cnode->UniqueId() | ||||
| << " and UniqueIdThroughCopy: " << cnode->UniqueIdThroughCopy() | << " and UniqueIdThroughCopy: " << cnode->UniqueIdThroughCopy() | ||||
| << " is set OperatorInfo: " << operator_info->name() << ", Primitive: " << prim->name(); | << " is set OperatorInfo: " << operator_info->name() << ", Primitive: " << prim->name(); | ||||
| @@ -501,7 +501,7 @@ Status ConstructCostGraphNodesByUniqueIdTC(const std::vector<AnfNodePtr> &all_no | |||||
| std::vector<std::string> inputs_tensor_name = ExtractInputsTensorName(cnode); | std::vector<std::string> inputs_tensor_name = ExtractInputsTensorName(cnode); | ||||
| entire_costgraph->AddOperator(operator_info); | entire_costgraph->AddOperator(operator_info); | ||||
| cnode->SetUserData<OperatorInfo>(operator_info); | |||||
| cnode->set_user_data<OperatorInfo>(operator_info); | |||||
| MS_LOG(INFO) << "The CNode with UniqueId: " << cnode->UniqueId() | MS_LOG(INFO) << "The CNode with UniqueId: " << cnode->UniqueId() | ||||
| << " and UniqueIdThroughCopy: " << cnode->UniqueIdThroughCopy() | << " and UniqueIdThroughCopy: " << cnode->UniqueIdThroughCopy() | ||||
| << " is set OperatorInfo: " << operator_info->name() << ", Primitive: " << prim->name(); | << " is set OperatorInfo: " << operator_info->name() << ", Primitive: " << prim->name(); | ||||
| @@ -520,7 +520,7 @@ Status ConstructCostGraphNodesByUniqueIdTC(const std::vector<AnfNodePtr> &all_no | |||||
| MS_LOG(EXCEPTION) << "The OperatorInfo: " << current_op_ptr->name() | MS_LOG(EXCEPTION) << "The OperatorInfo: " << current_op_ptr->name() | ||||
| << " does not match the Prim: " << prim->name(); | << " does not match the Prim: " << prim->name(); | ||||
| } | } | ||||
| cnode->SetUserData<OperatorInfo>(current_op_ptr); | |||||
| cnode->set_user_data<OperatorInfo>(current_op_ptr); | |||||
| MS_LOG(INFO) << "The CNode with UniqueId: " << cnode->UniqueId() | MS_LOG(INFO) << "The CNode with UniqueId: " << cnode->UniqueId() | ||||
| << " and UniqueIdThroughCopy: " << cnode->UniqueIdThroughCopy() | << " and UniqueIdThroughCopy: " << cnode->UniqueIdThroughCopy() | ||||
| << " is set OperatorInfo: " << current_op_ptr->name() << ", Primitive: " << prim->name(); | << " is set OperatorInfo: " << current_op_ptr->name() << ", Primitive: " << prim->name(); | ||||
| @@ -549,7 +549,7 @@ void ConstructCostGraphEdges(const std::vector<AnfNodePtr> &all_nodes) { | |||||
| PrimitivePtr prim = GetValueNode<PrimitivePtr>(prim_anf_node); | PrimitivePtr prim = GetValueNode<PrimitivePtr>(prim_anf_node); | ||||
| size_t edge_count = 0; | size_t edge_count = 0; | ||||
| auto node_op_info = cnode->GetUserData<OperatorInfo>(); | |||||
| auto node_op_info = cnode->user_data<OperatorInfo>(); | |||||
| for (size_t i = 1; i < inputs.size(); ++i) { | for (size_t i = 1; i < inputs.size(); ++i) { | ||||
| auto prev_cnode = inputs[i]->cast<CNodePtr>(); | auto prev_cnode = inputs[i]->cast<CNodePtr>(); | ||||
| @@ -565,7 +565,7 @@ void ConstructCostGraphEdges(const std::vector<AnfNodePtr> &all_nodes) { | |||||
| (IsAutoParallelCareNode(prev_cnode)) || (prev_prim->name() == TUPLE_GETITEM) || (prev_prim->name() == DEPEND); | (IsAutoParallelCareNode(prev_cnode)) || (prev_prim->name() == TUPLE_GETITEM) || (prev_prim->name() == DEPEND); | ||||
| while (bool_result) { | while (bool_result) { | ||||
| if (IsAutoParallelCareNode(prev_cnode)) { | if (IsAutoParallelCareNode(prev_cnode)) { | ||||
| auto prev_op_info = prev_cnode->GetUserData<OperatorInfo>(); | |||||
| auto prev_op_info = prev_cnode->user_data<OperatorInfo>(); | |||||
| std::string edge_name = prev_op_info->name() + OPERATOR_TO_OPERATOR_CONNECTOR + node_op_info->name(); | std::string edge_name = prev_op_info->name() + OPERATOR_TO_OPERATOR_CONNECTOR + node_op_info->name(); | ||||
| // If the edge between these two operators already has been added, then the edge will not be added again. | // If the edge between these two operators already has been added, then the edge will not be added again. | ||||
| if (entire_costgraph->IsEdgeInCostGraph(edge_name, output_index, i - 1)) { | if (entire_costgraph->IsEdgeInCostGraph(edge_name, output_index, i - 1)) { | ||||
| @@ -751,7 +751,7 @@ void AugmentCostGraph(const std::vector<AnfNodePtr> &all_nodes) { | |||||
| auto target_cnode = target.first->cast<CNodePtr>(); | auto target_cnode = target.first->cast<CNodePtr>(); | ||||
| auto input_index = target.second; | auto input_index = target.second; | ||||
| (void)target_without_duplicate.insert(std::to_string(input_index) + | (void)target_without_duplicate.insert(std::to_string(input_index) + | ||||
| target_cnode->GetUserData<OperatorInfo>()->name()); | |||||
| target_cnode->user_data<OperatorInfo>()->name()); | |||||
| } | } | ||||
| if (target_without_duplicate.size() <= 1) { | if (target_without_duplicate.size() <= 1) { | ||||
| continue; | continue; | ||||
| @@ -831,7 +831,7 @@ void AugmentCostGraph(const std::vector<AnfNodePtr> &all_nodes) { | |||||
| auto target_cnode = target.first->cast<CNodePtr>(); | auto target_cnode = target.first->cast<CNodePtr>(); | ||||
| auto prim = GetValueNode<PrimitivePtr>(target_cnode->input(0)); | auto prim = GetValueNode<PrimitivePtr>(target_cnode->input(0)); | ||||
| auto input_index = target.second; | auto input_index = target.second; | ||||
| auto target_op_info = target_cnode->GetUserData<OperatorInfo>(); | |||||
| auto target_op_info = target_cnode->user_data<OperatorInfo>(); | |||||
| std::string edge_name = std::string(IDENTITY_INFO) + OPERATOR_TO_OPERATOR_CONNECTOR + target_op_info->name(); | std::string edge_name = std::string(IDENTITY_INFO) + OPERATOR_TO_OPERATOR_CONNECTOR + target_op_info->name(); | ||||
| // If the edge between these two operators already has been added, then the edge will not be added again. | // If the edge between these two operators already has been added, then the edge will not be added again. | ||||
| @@ -862,7 +862,7 @@ bool FindReshape(const CNodePtr &cnode) { | |||||
| if ((cnode == nullptr) || !IsValueNode<Primitive>(cnode->input(0))) { | if ((cnode == nullptr) || !IsValueNode<Primitive>(cnode->input(0))) { | ||||
| return false; | return false; | ||||
| } | } | ||||
| if (!IsParallelCareNode(cnode) || !cnode->HasUserData<OperatorInfo>()) { | |||||
| if (!IsParallelCareNode(cnode) || !cnode->has_user_data<OperatorInfo>()) { | |||||
| return false; | return false; | ||||
| } | } | ||||
| ValueNodePtr prim_anf_node = cnode->input(0)->cast<ValueNodePtr>(); | ValueNodePtr prim_anf_node = cnode->input(0)->cast<ValueNodePtr>(); | ||||
| @@ -884,7 +884,7 @@ bool FindPreNodeStraCosts(const AnfNodePtr &node, OperatorInfoPtr *pre_operator_ | |||||
| if (!IsValueNode<Primitive>(cnode->input(0))) { | if (!IsValueNode<Primitive>(cnode->input(0))) { | ||||
| return false; | return false; | ||||
| } | } | ||||
| auto node_op_info = cnode->GetUserData<OperatorInfo>(); | |||||
| auto node_op_info = cnode->user_data<OperatorInfo>(); | |||||
| if (IsParallelCareNode(cnode) && (node_op_info != nullptr)) { | if (IsParallelCareNode(cnode) && (node_op_info != nullptr)) { | ||||
| *pre_operator_info = node_op_info; | *pre_operator_info = node_op_info; | ||||
| *out_index = 0; | *out_index = 0; | ||||
| @@ -900,7 +900,7 @@ bool FindPreNodeStraCosts(const AnfNodePtr &node, OperatorInfoPtr *pre_operator_ | |||||
| MS_LOG(EXCEPTION) << "tuple get item's second input is not a cnode"; | MS_LOG(EXCEPTION) << "tuple get item's second input is not a cnode"; | ||||
| } | } | ||||
| CNodePtr pre_cnode = pre_node->cast<CNodePtr>(); | CNodePtr pre_cnode = pre_node->cast<CNodePtr>(); | ||||
| auto pre_op_info = pre_cnode->GetUserData<OperatorInfo>(); | |||||
| auto pre_op_info = pre_cnode->user_data<OperatorInfo>(); | |||||
| if (IsParallelCareNode(pre_cnode) && (pre_op_info != nullptr)) { | if (IsParallelCareNode(pre_cnode) && (pre_op_info != nullptr)) { | ||||
| *pre_operator_info = pre_op_info; | *pre_operator_info = pre_op_info; | ||||
| return true; | return true; | ||||
| @@ -941,7 +941,7 @@ bool FindNextNodeStraCosts(const CNodePtr &cnode, OperatorInfoPtr *next_operator | |||||
| if (node_prim->name() == DEPEND && node_pair.second != 1) { | if (node_prim->name() == DEPEND && node_pair.second != 1) { | ||||
| continue; | continue; | ||||
| } | } | ||||
| auto op_info = use_apply->GetUserData<OperatorInfo>(); | |||||
| auto op_info = use_apply->user_data<OperatorInfo>(); | |||||
| if (IsParallelCareNode(use_apply) && (op_info != nullptr)) { | if (IsParallelCareNode(use_apply) && (op_info != nullptr)) { | ||||
| MS_LOG(INFO) << "FindNextNodeStraCosts success prim " << node_prim->name(); | MS_LOG(INFO) << "FindNextNodeStraCosts success prim " << node_prim->name(); | ||||
| *next_operator_info = op_info; | *next_operator_info = op_info; | ||||
| @@ -970,7 +970,7 @@ void ReshapeCostCompute(const std::vector<AnfNodePtr> &all_nodes) { | |||||
| int32_t out_index = 0; | int32_t out_index = 0; | ||||
| OperatorInfoPtr pre_operator_info; | OperatorInfoPtr pre_operator_info; | ||||
| std::vector<std::shared_ptr<StrategyWithCost>> pre_stra_costs; | std::vector<std::shared_ptr<StrategyWithCost>> pre_stra_costs; | ||||
| auto operator_info = cnode->GetUserData<OperatorInfo>(); | |||||
| auto operator_info = cnode->user_data<OperatorInfo>(); | |||||
| if (pre_node->isa<Parameter>()) { | if (pre_node->isa<Parameter>()) { | ||||
| auto reshape_info = std::dynamic_pointer_cast<ReshapeInfo>(operator_info); | auto reshape_info = std::dynamic_pointer_cast<ReshapeInfo>(operator_info); | ||||
| reshape_info->SetCostForReshapeWithParameter(); | reshape_info->SetCostForReshapeWithParameter(); | ||||
| @@ -272,7 +272,7 @@ OperatorInfoPtr GetDistributeOperator(const CNodePtr &node) { | |||||
| if (!IsParallelCareNode(node)) { | if (!IsParallelCareNode(node)) { | ||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| OperatorInfoPtr distribute_operator = node->GetUserData<OperatorInfo>(); | |||||
| OperatorInfoPtr distribute_operator = node->user_data<OperatorInfo>(); | |||||
| if (distribute_operator == nullptr) { | if (distribute_operator == nullptr) { | ||||
| MS_LOG(EXCEPTION) << "GetDistributeOperator:distribute_operator is nullptr"; | MS_LOG(EXCEPTION) << "GetDistributeOperator:distribute_operator is nullptr"; | ||||
| } | } | ||||
| @@ -409,7 +409,7 @@ bool IsParallelCareNode(const CNodePtr &cnode) { | |||||
| if (prim->name() == GET_NEXT) { | if (prim->name() == GET_NEXT) { | ||||
| return true; | return true; | ||||
| } | } | ||||
| if ((prim->name() == CAST) && !cnode->HasUserData<OperatorInfo>()) { | |||||
| if ((prim->name() == CAST) && !cnode->has_user_data<OperatorInfo>()) { | |||||
| return false; | return false; | ||||
| } | } | ||||
| @@ -446,7 +446,7 @@ void StepRedistribution(const CNodePtr &node, const OperatorInfoPtr &distribute_ | |||||
| if (node_prim->name() == DEPEND && node_pair.second != 1) { | if (node_prim->name() == DEPEND && node_pair.second != 1) { | ||||
| continue; | continue; | ||||
| } | } | ||||
| if (IsParallelCareNode(use_cnode) && use_cnode->HasUserData<OperatorInfo>()) { | |||||
| if (IsParallelCareNode(use_cnode) && use_cnode->has_user_data<OperatorInfo>()) { | |||||
| Redistribution(node_pair, distribute_operator, insert_node_new, node_pair.second, tensor_redistribution, | Redistribution(node_pair, distribute_operator, insert_node_new, node_pair.second, tensor_redistribution, | ||||
| pre_node); | pre_node); | ||||
| } else { | } else { | ||||
| @@ -459,7 +459,7 @@ void StepRedistribution(const CNodePtr &node, const OperatorInfoPtr &distribute_ | |||||
| void SplitTensor(const AnfNodePtr &node, const CNodePtr &next_node, int index) { | void SplitTensor(const AnfNodePtr &node, const CNodePtr &next_node, int index) { | ||||
| MS_EXCEPTION_IF_NULL(node); | MS_EXCEPTION_IF_NULL(node); | ||||
| MS_EXCEPTION_IF_NULL(next_node); | MS_EXCEPTION_IF_NULL(next_node); | ||||
| OperatorInfoPtr op_info = next_node->GetUserData<OperatorInfo>(); | |||||
| OperatorInfoPtr op_info = next_node->user_data<OperatorInfo>(); | |||||
| MS_EXCEPTION_IF_NULL(op_info); | MS_EXCEPTION_IF_NULL(op_info); | ||||
| // If the shape of tensor is [] or [1], no need to split it. | // If the shape of tensor is [] or [1], no need to split it. | ||||
| @@ -584,7 +584,7 @@ void ReplaceOneOp(const Operator &replace_op, const CNodePtr &node) { | |||||
| void StepReplaceOp(OperatorVector replace_op, const CNodePtr &node) { | void StepReplaceOp(OperatorVector replace_op, const CNodePtr &node) { | ||||
| // step1:get graph manager distribute_operator | // step1:get graph manager distribute_operator | ||||
| OperatorInfoPtr distribute_operator = node->GetUserData<OperatorInfo>(); | |||||
| OperatorInfoPtr distribute_operator = node->user_data<OperatorInfo>(); | |||||
| if (distribute_operator == nullptr) { | if (distribute_operator == nullptr) { | ||||
| MS_LOG(EXCEPTION) << "Failure:AddNode error since distribute_operator is nullptr"; | MS_LOG(EXCEPTION) << "Failure:AddNode error since distribute_operator is nullptr"; | ||||
| } | } | ||||
| @@ -622,7 +622,7 @@ void StepReplaceOp(OperatorVector replace_op, const CNodePtr &node) { | |||||
| (void)prim->SetAttrs(attrs); | (void)prim->SetAttrs(attrs); | ||||
| } | } | ||||
| if (index == replace_op.size() - 1) { | if (index == replace_op.size() - 1) { | ||||
| replace_node->SetUserData<OperatorInfo>(node->GetUserData<OperatorInfo>()); | |||||
| replace_node->set_user_data<OperatorInfo>(node->user_data<OperatorInfo>()); | |||||
| } | } | ||||
| replace_node->set_in_forward_flag(true); | replace_node->set_in_forward_flag(true); | ||||
| replace_input[0]->set_scope(scope); | replace_input[0]->set_scope(scope); | ||||
| @@ -702,7 +702,7 @@ LossNodeInfo GetLossNodeInfo(const AnfNodePtr &loss_node) { | |||||
| auto pre_cnode = pre_node->cast<CNodePtr>(); | auto pre_cnode = pre_node->cast<CNodePtr>(); | ||||
| MS_EXCEPTION_IF_NULL(pre_cnode); | MS_EXCEPTION_IF_NULL(pre_cnode); | ||||
| auto pre_prim = GetValueNode<PrimitivePtr>(pre_cnode->input(0)); | auto pre_prim = GetValueNode<PrimitivePtr>(pre_cnode->input(0)); | ||||
| if (pre_prim->name() == CAST && !pre_cnode->HasUserData<OperatorInfo>()) { | |||||
| if (pre_prim->name() == CAST && !pre_cnode->has_user_data<OperatorInfo>()) { | |||||
| pre_node = pre_cnode->input(1); | pre_node = pre_cnode->input(1); | ||||
| } | } | ||||
| @@ -1198,7 +1198,7 @@ std::pair<AnfNodePtr, int> FindParallelCareNode(const AnfNodePtr &node) { | |||||
| if (node_prim->name() == DEPEND && node_pair.second != 1) { | if (node_prim->name() == DEPEND && node_pair.second != 1) { | ||||
| continue; | continue; | ||||
| } | } | ||||
| if (IsParallelCareNode(cnode) && cnode->HasUserData<OperatorInfo>()) { | |||||
| if (IsParallelCareNode(cnode) && cnode->has_user_data<OperatorInfo>()) { | |||||
| return node_pair; | return node_pair; | ||||
| } else if (FindParallelCareNode(node_pair.first).first != nullptr) { | } else if (FindParallelCareNode(node_pair.first).first != nullptr) { | ||||
| return FindParallelCareNode(node_pair.first); | return FindParallelCareNode(node_pair.first); | ||||
| @@ -1248,7 +1248,7 @@ void SetParallelShape(const AnfNodePtr ¶meter, const std::pair<AnfNodePtr, i | |||||
| MS_LOG(DEBUG) << "SetParallelShape " << parameter->ToString() << " shape " << parameter->Shape()->ToString(); | MS_LOG(DEBUG) << "SetParallelShape " << parameter->ToString() << " shape " << parameter->Shape()->ToString(); | ||||
| CNodePtr cnode = res.first->cast<CNodePtr>(); | CNodePtr cnode = res.first->cast<CNodePtr>(); | ||||
| MS_EXCEPTION_IF_NULL(cnode); | MS_EXCEPTION_IF_NULL(cnode); | ||||
| OperatorInfoPtr distribute_operator = cnode->GetUserData<OperatorInfo>(); | |||||
| OperatorInfoPtr distribute_operator = cnode->user_data<OperatorInfo>(); | |||||
| if (distribute_operator == nullptr) { | if (distribute_operator == nullptr) { | ||||
| MS_LOG(EXCEPTION) << "Failure:node " << cnode->ToString() << " 's OperatorInfoPtr is nullptr"; | MS_LOG(EXCEPTION) << "Failure:node " << cnode->ToString() << " 's OperatorInfoPtr is nullptr"; | ||||
| } | } | ||||
| @@ -1271,7 +1271,7 @@ void SetParallelShape(const AnfNodePtr ¶meter, const std::pair<AnfNodePtr, i | |||||
| TensorLayout tensor_layout = tensorinfo_in.tensor_layout(); | TensorLayout tensor_layout = tensorinfo_in.tensor_layout(); | ||||
| ParameterPtr parameter_ptr = parameter->cast<ParameterPtr>(); | ParameterPtr parameter_ptr = parameter->cast<ParameterPtr>(); | ||||
| MS_EXCEPTION_IF_NULL(parameter_ptr); | MS_EXCEPTION_IF_NULL(parameter_ptr); | ||||
| parameter_ptr->SetUserData<TensorLayout>(std::make_shared<TensorLayout>(tensor_layout)); | |||||
| parameter_ptr->set_user_data<TensorLayout>(std::make_shared<TensorLayout>(tensor_layout)); | |||||
| } | } | ||||
| void CoverSliceShape(const FuncGraphPtr &root) { | void CoverSliceShape(const FuncGraphPtr &root) { | ||||
| @@ -1359,7 +1359,7 @@ void SetClonedTensorShapeForOptimizer(const FuncGraphPtr &root) { | |||||
| if (found_be_cloned_parameter) { | if (found_be_cloned_parameter) { | ||||
| // set the shape and tensor layout for cloned parameter | // set the shape and tensor layout for cloned parameter | ||||
| cloned_parameter->SetUserData<TensorLayout>(cloned_from_parameter->GetUserData<TensorLayout>()); | |||||
| cloned_parameter->set_user_data<TensorLayout>(cloned_from_parameter->user_data<TensorLayout>()); | |||||
| MS_EXCEPTION_IF_NULL(cloned_parameter_node->abstract()); | MS_EXCEPTION_IF_NULL(cloned_parameter_node->abstract()); | ||||
| MS_EXCEPTION_IF_NULL(cloned_from_node->abstract()); | MS_EXCEPTION_IF_NULL(cloned_from_node->abstract()); | ||||
| auto cloned_abstract = cloned_parameter_node->abstract()->Clone(); | auto cloned_abstract = cloned_parameter_node->abstract()->Clone(); | ||||
| @@ -1454,7 +1454,7 @@ void ExtractInformation(const std::vector<AnfNodePtr> &all_nodes) { | |||||
| (*operator_).set_outputs_dtype(cnode->Type()); | (*operator_).set_outputs_dtype(cnode->Type()); | ||||
| (*operator_).set_cnode(cnode); | (*operator_).set_cnode(cnode); | ||||
| if (prim->name() == RESHAPE) { | if (prim->name() == RESHAPE) { | ||||
| cnode->SetUserData<OperatorInfo>(operator_); | |||||
| cnode->set_user_data<OperatorInfo>(operator_); | |||||
| continue; | continue; | ||||
| } | } | ||||
| // load strategy checkpoint | // load strategy checkpoint | ||||
| @@ -1489,7 +1489,7 @@ void ExtractInformation(const std::vector<AnfNodePtr> &all_nodes) { | |||||
| if (operator_->Init(strategyPtr) == FAILED) { | if (operator_->Init(strategyPtr) == FAILED) { | ||||
| MS_LOG(EXCEPTION) << "Failure:operator " << prim->name() << " init failed"; | MS_LOG(EXCEPTION) << "Failure:operator " << prim->name() << " init failed"; | ||||
| } | } | ||||
| cnode->SetUserData<OperatorInfo>(operator_); | |||||
| cnode->set_user_data<OperatorInfo>(operator_); | |||||
| } else { | } else { | ||||
| MS_LOG(EXCEPTION) << "ERROR:strategy_ptr is nullptr"; | MS_LOG(EXCEPTION) << "ERROR:strategy_ptr is nullptr"; | ||||
| } | } | ||||
| @@ -1532,13 +1532,13 @@ std::shared_ptr<TensorLayout> FindNextLayout(const CNodePtr &cnode) { | |||||
| if (node_prim->name() == DEPEND && node_pair.second != 1) { | if (node_prim->name() == DEPEND && node_pair.second != 1) { | ||||
| continue; | continue; | ||||
| } | } | ||||
| if (IsParallelCareNode(use_apply) && use_apply->HasUserData<OperatorInfo>()) { | |||||
| if (IsParallelCareNode(use_apply) && use_apply->has_user_data<OperatorInfo>()) { | |||||
| MS_LOG(INFO) << "FindNextLayout success prim " << node_prim->name(); | MS_LOG(INFO) << "FindNextLayout success prim " << node_prim->name(); | ||||
| auto layout = GetInputLayoutFromCNode(node_pair); | auto layout = GetInputLayoutFromCNode(node_pair); | ||||
| return std::make_shared<TensorLayout>(layout); | return std::make_shared<TensorLayout>(layout); | ||||
| } | } | ||||
| MS_LOG(DEBUG) << "FindNextLayout failed prim " << node_prim->name() << " " << IsParallelCareNode(use_apply) | MS_LOG(DEBUG) << "FindNextLayout failed prim " << node_prim->name() << " " << IsParallelCareNode(use_apply) | ||||
| << " " << use_apply->HasUserData<OperatorInfo>(); | |||||
| << " " << use_apply->has_user_data<OperatorInfo>(); | |||||
| auto layout_ptr = FindNextLayout(use_apply); | auto layout_ptr = FindNextLayout(use_apply); | ||||
| if (layout_ptr) { | if (layout_ptr) { | ||||
| @@ -1570,7 +1570,7 @@ std::shared_ptr<TensorLayout> FindPrevParallelCareNodeLayout(const AnfNodePtr &n | |||||
| if (!IsValueNode<Primitive>(cnode->input(0))) { | if (!IsValueNode<Primitive>(cnode->input(0))) { | ||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| if (IsParallelCareNode(cnode) && cnode->HasUserData<OperatorInfo>()) { | |||||
| if (IsParallelCareNode(cnode) && cnode->has_user_data<OperatorInfo>()) { | |||||
| auto layout_ptr = GetOutputLayoutFromCNode(cnode, output_index); | auto layout_ptr = GetOutputLayoutFromCNode(cnode, output_index); | ||||
| if (!layout_ptr) { | if (!layout_ptr) { | ||||
| MS_LOG(EXCEPTION) << "Failure:GetLayoutFromCNode failed"; | MS_LOG(EXCEPTION) << "Failure:GetLayoutFromCNode failed"; | ||||
| @@ -1614,7 +1614,7 @@ std::shared_ptr<TensorLayout> FindPrevLayout(const AnfNodePtr &node) { | |||||
| if (!IsValueNode<Primitive>(cnode->input(0))) { | if (!IsValueNode<Primitive>(cnode->input(0))) { | ||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| if (IsParallelCareNode(cnode) && cnode->HasUserData<OperatorInfo>()) { | |||||
| if (IsParallelCareNode(cnode) && cnode->has_user_data<OperatorInfo>()) { | |||||
| auto layout_ptr = GetOutputLayoutFromCNode(cnode, 0); | auto layout_ptr = GetOutputLayoutFromCNode(cnode, 0); | ||||
| if (!layout_ptr) { | if (!layout_ptr) { | ||||
| MS_LOG(EXCEPTION) << "Failure:GetLayoutFromCNode failed"; | MS_LOG(EXCEPTION) << "Failure:GetLayoutFromCNode failed"; | ||||
| @@ -1654,12 +1654,12 @@ void ReshapeInit(const std::vector<AnfNodePtr> &all_nodes) { | |||||
| continue; | continue; | ||||
| } | } | ||||
| ValueNodePtr prim_anf_node = cnode->input(0)->cast<ValueNodePtr>(); | ValueNodePtr prim_anf_node = cnode->input(0)->cast<ValueNodePtr>(); | ||||
| if (!IsParallelCareNode(cnode) || !cnode->HasUserData<OperatorInfo>()) { | |||||
| if (!IsParallelCareNode(cnode) || !cnode->has_user_data<OperatorInfo>()) { | |||||
| continue; | continue; | ||||
| } | } | ||||
| PrimitivePtr prim = GetValueNode<PrimitivePtr>(prim_anf_node); | PrimitivePtr prim = GetValueNode<PrimitivePtr>(prim_anf_node); | ||||
| MS_EXCEPTION_IF_NULL(prim); | MS_EXCEPTION_IF_NULL(prim); | ||||
| OperatorInfoPtr operator_info = cnode->GetUserData<OperatorInfo>(); | |||||
| OperatorInfoPtr operator_info = cnode->user_data<OperatorInfo>(); | |||||
| if (operator_info == nullptr) { | if (operator_info == nullptr) { | ||||
| MS_LOG(EXCEPTION) << "Failure:Primitive " << prim->ToString() << " OperatorInstance is nullptr"; | MS_LOG(EXCEPTION) << "Failure:Primitive " << prim->ToString() << " OperatorInstance is nullptr"; | ||||
| } | } | ||||
| @@ -1704,7 +1704,7 @@ CNodePtr FindLossCNode(const FuncGraphPtr &func_graph) { | |||||
| auto current_prim = GetValueNode<PrimitivePtr>(pre_cnode->input(0)); | auto current_prim = GetValueNode<PrimitivePtr>(pre_cnode->input(0)); | ||||
| // return -> cast | // return -> cast | ||||
| if (current_prim->name() == CAST && !pre_cnode->HasUserData<OperatorInfo>()) { | |||||
| if (current_prim->name() == CAST && !pre_cnode->has_user_data<OperatorInfo>()) { | |||||
| pre_cnode = pre_cnode->input(1)->cast<CNodePtr>(); | pre_cnode = pre_cnode->input(1)->cast<CNodePtr>(); | ||||
| MS_EXCEPTION_IF_NULL(pre_cnode); | MS_EXCEPTION_IF_NULL(pre_cnode); | ||||
| current_prim = GetValueNode<PrimitivePtr>(pre_cnode->input(0)); | current_prim = GetValueNode<PrimitivePtr>(pre_cnode->input(0)); | ||||
| @@ -1761,7 +1761,7 @@ TensorLayouts GetLossNodeGradOutputLayout(const CNodePtr &loss_cnode) { | |||||
| return ret; | return ret; | ||||
| } | } | ||||
| OperatorInfoPtr operator_info = loss_cnode->GetUserData<OperatorInfo>(); | |||||
| OperatorInfoPtr operator_info = loss_cnode->user_data<OperatorInfo>(); | |||||
| MS_EXCEPTION_IF_NULL(operator_info); | MS_EXCEPTION_IF_NULL(operator_info); | ||||
| TensorInfo loss_grad_tensor_info; | TensorInfo loss_grad_tensor_info; | ||||
| size_t op_output_size = operator_info->outputs_tensor_info().size(); | size_t op_output_size = operator_info->outputs_tensor_info().size(); | ||||
| @@ -1799,7 +1799,7 @@ void SplitSens(const CNodePtr &grad_sens_node, const TensorLayout &loss_grad_lay | |||||
| if (sens_tensor_node->isa<Parameter>()) { | if (sens_tensor_node->isa<Parameter>()) { | ||||
| auto sens_tensor_param = sens_tensor_node->cast<ParameterPtr>(); | auto sens_tensor_param = sens_tensor_node->cast<ParameterPtr>(); | ||||
| MS_LOG(DEBUG) << "loss layout " << loss_grad_layout.ToString(); | MS_LOG(DEBUG) << "loss layout " << loss_grad_layout.ToString(); | ||||
| sens_tensor_param->SetUserData<TensorLayout>(std::make_shared<TensorLayout>(loss_grad_layout)); | |||||
| sens_tensor_param->set_user_data<TensorLayout>(std::make_shared<TensorLayout>(loss_grad_layout)); | |||||
| } | } | ||||
| MS_LOG(INFO) << "The shape of sens is " << ShapeToString(sens_shape) << ", no need to split sens"; | MS_LOG(INFO) << "The shape of sens is " << ShapeToString(sens_shape) << ", no need to split sens"; | ||||
| return; | return; | ||||
| @@ -1824,7 +1824,7 @@ void SplitSens(const CNodePtr &grad_sens_node, const TensorLayout &loss_grad_lay | |||||
| cloned_abstract->set_shape(parallel_shape); | cloned_abstract->set_shape(parallel_shape); | ||||
| sens_tensor_node->set_abstract(cloned_abstract); | sens_tensor_node->set_abstract(cloned_abstract); | ||||
| auto sens_tensor_param = sens_tensor_node->cast<ParameterPtr>(); | auto sens_tensor_param = sens_tensor_node->cast<ParameterPtr>(); | ||||
| sens_tensor_param->SetUserData<TensorLayout>(std::make_shared<TensorLayout>(loss_grad_layout)); | |||||
| sens_tensor_param->set_user_data<TensorLayout>(std::make_shared<TensorLayout>(loss_grad_layout)); | |||||
| return; | return; | ||||
| } | } | ||||
| MS_LOG(EXCEPTION) << "The type of sens node is not Tensor or Parameter, it is unsupported now."; | MS_LOG(EXCEPTION) << "The type of sens node is not Tensor or Parameter, it is unsupported now."; | ||||
| @@ -2131,7 +2131,7 @@ void CheckpointStrategy(const FuncGraphPtr &func_graph) { | |||||
| } | } | ||||
| PrimitivePtr prim = GetValueNode<PrimitivePtr>(cnode->input(0)); | PrimitivePtr prim = GetValueNode<PrimitivePtr>(cnode->input(0)); | ||||
| MS_EXCEPTION_IF_NULL(prim); | MS_EXCEPTION_IF_NULL(prim); | ||||
| OperatorInfoPtr operator_info = cnode->GetUserData<OperatorInfo>(); | |||||
| OperatorInfoPtr operator_info = cnode->user_data<OperatorInfo>(); | |||||
| if (operator_info) { | if (operator_info) { | ||||
| if (operator_info->name().find(RESHAPEINFO) != std::string::npos) { | if (operator_info->name().find(RESHAPEINFO) != std::string::npos) { | ||||
| continue; | continue; | ||||
| @@ -158,29 +158,29 @@ class AnfNode : public Base { | |||||
| size_t seen_{0}; | size_t seen_{0}; | ||||
| template <typename T> | template <typename T> | ||||
| void SetUserData(const std::string &key, const std::shared_ptr<T> &value) { | |||||
| void set_user_data(const std::string &key, const std::shared_ptr<T> &value) { | |||||
| user_data_.set<T>(key, value); | user_data_.set<T>(key, value); | ||||
| } | } | ||||
| template <typename T> | template <typename T> | ||||
| void SetUserData(const std::shared_ptr<T> &value) { | |||||
| void set_user_data(const std::shared_ptr<T> &value) { | |||||
| user_data_.set<T>(T::key, value); | user_data_.set<T>(T::key, value); | ||||
| } | } | ||||
| template <typename T> | template <typename T> | ||||
| std::shared_ptr<T> GetUserData(const std::string &key) const { | |||||
| std::shared_ptr<T> user_data(const std::string &key) const { | |||||
| return user_data_.get<T>(key); | return user_data_.get<T>(key); | ||||
| } | } | ||||
| template <typename T> | template <typename T> | ||||
| std::shared_ptr<T> GetUserData() const { | |||||
| std::shared_ptr<T> user_data() const { | |||||
| return user_data_.get<T>(T::key); | return user_data_.get<T>(T::key); | ||||
| } | } | ||||
| bool HasUserData(const std::string &key) const { return user_data_.has(key); } | |||||
| bool has_user_data(const std::string &key) const { return user_data_.has(key); } | |||||
| template <typename T> | template <typename T> | ||||
| bool HasUserData() const { | |||||
| bool has_user_data() const { | |||||
| return user_data_.has(T::key); | return user_data_.has(T::key); | ||||
| } | } | ||||
| @@ -153,7 +153,7 @@ TEST_F(TestStepAutoParallel, test_create_op_instance) { | |||||
| StrategyPtr strategyPtr; | StrategyPtr strategyPtr; | ||||
| std::shared_ptr<OperatorInfo> matmul_info = NewOperatorInstance(prim, attrs, shape); | std::shared_ptr<OperatorInfo> matmul_info = NewOperatorInstance(prim, attrs, shape); | ||||
| node->SetUserData<OperatorInfo>(matmul_info); | |||||
| node->set_user_data<OperatorInfo>(matmul_info); | |||||
| std::string name_expect = "MatMulInfo00"; | std::string name_expect = "MatMulInfo00"; | ||||
| std::string name_test = matmul_info->name(); | std::string name_test = matmul_info->name(); | ||||
| ASSERT_EQ(name_expect, name_test); | ASSERT_EQ(name_expect, name_test); | ||||
| @@ -522,8 +522,8 @@ TEST_F(TestStepParallel, GetTensorInLayout) { | |||||
| std::vector<Shapes> shape = {inputs_shape, outputs_shape}; | std::vector<Shapes> shape = {inputs_shape, outputs_shape}; | ||||
| OperatorInfoPtr matmul_info = OperatorInstance(prim, attrs, shape); | OperatorInfoPtr matmul_info = OperatorInstance(prim, attrs, shape); | ||||
| matmul_info->Init(strategyPtr); | matmul_info->Init(strategyPtr); | ||||
| node->SetUserData<OperatorInfo>(matmul_info); | |||||
| OperatorInfoPtr distribute_operator_pre = node->GetUserData<OperatorInfo>(); | |||||
| node->set_user_data<OperatorInfo>(matmul_info); | |||||
| OperatorInfoPtr distribute_operator_pre = node->user_data<OperatorInfo>(); | |||||
| TensorLayout tensorlayout_e; | TensorLayout tensorlayout_e; | ||||
| std::vector<int32_t> array = {64, 64}; | std::vector<int32_t> array = {64, 64}; | ||||
| TensorLayout tensorlayout = GetTensorInLayout(node1, prim, distribute_operator_pre); | TensorLayout tensorlayout = GetTensorInLayout(node1, prim, distribute_operator_pre); | ||||