Merge pull request !3574 from hewei/rename_user_data_functags/v0.7.0-beta
| @@ -267,7 +267,7 @@ void DumpParallelInfo(const CNodePtr &node, const std::shared_ptr<SubGraphIRInfo | |||
| return; | |||
| } | |||
| auto operator_info = node->GetUserData<parallel::OperatorInfo>(); | |||
| auto operator_info = node->user_data<parallel::OperatorInfo>(); | |||
| if (operator_info == nullptr) { | |||
| return; | |||
| } | |||
| @@ -437,7 +437,7 @@ static void DrawParallelInfo(Graphviz *const graph_obj, const CNodePtr &node) { | |||
| if (graph_obj == nullptr || node == nullptr) { | |||
| return; | |||
| } | |||
| auto distributed_operation_info = node->GetUserData<parallel::OperatorInfo>(); | |||
| auto distributed_operation_info = node->user_data<parallel::OperatorInfo>(); | |||
| if (distributed_operation_info != nullptr) { | |||
| auto strategyPtr = distributed_operation_info->strategy(); | |||
| if (strategyPtr != nullptr) { | |||
| @@ -50,7 +50,7 @@ std::unordered_set<CNodePtr> FindCNodesWithPara(const AnfNodePtr ¶, uint32_t | |||
| if (node_prim->name() == DEPEND && node_pair.second != 1) { | |||
| continue; | |||
| } | |||
| if (IsParallelCareNode(cnode) && cnode->HasUserData<OperatorInfo>()) { | |||
| if (IsParallelCareNode(cnode) && cnode->has_user_data<OperatorInfo>()) { | |||
| (void)cnode_set.emplace(cnode); | |||
| } else { | |||
| auto cnode_set_sub = FindCNodesWithPara(node_pair.first, recursive_times + 1); | |||
| @@ -98,7 +98,7 @@ CNodeCostMap AllreduceFusion::FindCNode(const AnfNodePtr &from, uint32_t recursi | |||
| return cnode_dist; | |||
| } | |||
| auto operator_info = cnode->GetUserData<OperatorInfo>(); | |||
| auto operator_info = cnode->user_data<OperatorInfo>(); | |||
| MS_LOG(DEBUG) << "cnode " << cnode->ToString() << " IsParallelCareNode: " << IsParallelCareNode(cnode) | |||
| << " operator_info: " << (operator_info != nullptr); | |||
| @@ -83,7 +83,7 @@ Status AllreduceNode::AddPara(const AnfNodePtr &node_ptr) { | |||
| } | |||
| auto para_ptr = node_ptr->cast<ParameterPtr>(); | |||
| MS_EXCEPTION_IF_NULL(para_ptr); | |||
| auto layout_ptr = para_ptr->GetUserData<TensorLayout>(); | |||
| auto layout_ptr = para_ptr->user_data<TensorLayout>(); | |||
| if (layout_ptr == nullptr) { | |||
| MS_LOG(ERROR) << "layout_ptr is nullptr!"; | |||
| return FAILED; | |||
| @@ -37,7 +37,7 @@ py::dict GetParameterLayout(const FuncGraphPtr &graph) { | |||
| for (auto para : graph_params) { | |||
| std::string name = std::static_pointer_cast<Parameter>(para)->name(); | |||
| auto tensor_layout = para->GetUserData<parallel::TensorLayout>(); | |||
| auto tensor_layout = para->user_data<parallel::TensorLayout>(); | |||
| if (tensor_layout == nullptr) { | |||
| MS_LOG(INFO) << "GetParameterLayout nullptr name = " << name; | |||
| } else { | |||
| @@ -70,7 +70,7 @@ py::dict GetCNodeStrategy(const FuncGraphPtr &graph) { | |||
| if (node->isa<CNode>()) { | |||
| auto cnode = node->cast<CNodePtr>(); | |||
| MS_EXCEPTION_IF_NULL(cnode); | |||
| auto distributed_operation_info = cnode->GetUserData<OperatorInfo>(); | |||
| auto distributed_operation_info = cnode->user_data<OperatorInfo>(); | |||
| if (distributed_operation_info != nullptr) { | |||
| auto strategyPtr = distributed_operation_info->strategy(); | |||
| if (strategyPtr != nullptr) { | |||
| @@ -435,7 +435,7 @@ Status ConstructCostGraphNodesByUniqueId(const std::vector<AnfNodePtr> &all_node | |||
| std::vector<std::string> inputs_tensor_name = ExtractInputsTensorName(cnode); | |||
| entire_costgraph->AddOperator(operator_info); | |||
| cnode->SetUserData<OperatorInfo>(operator_info); | |||
| cnode->set_user_data<OperatorInfo>(operator_info); | |||
| MS_LOG(INFO) << "The CNode with UniqueId: " << cnode->UniqueId() | |||
| << " and UniqueIdThroughCopy: " << cnode->UniqueIdThroughCopy() | |||
| << " is set OperatorInfo: " << operator_info->name() << ", Primitive: " << prim->name(); | |||
| @@ -501,7 +501,7 @@ Status ConstructCostGraphNodesByUniqueIdTC(const std::vector<AnfNodePtr> &all_no | |||
| std::vector<std::string> inputs_tensor_name = ExtractInputsTensorName(cnode); | |||
| entire_costgraph->AddOperator(operator_info); | |||
| cnode->SetUserData<OperatorInfo>(operator_info); | |||
| cnode->set_user_data<OperatorInfo>(operator_info); | |||
| MS_LOG(INFO) << "The CNode with UniqueId: " << cnode->UniqueId() | |||
| << " and UniqueIdThroughCopy: " << cnode->UniqueIdThroughCopy() | |||
| << " is set OperatorInfo: " << operator_info->name() << ", Primitive: " << prim->name(); | |||
| @@ -520,7 +520,7 @@ Status ConstructCostGraphNodesByUniqueIdTC(const std::vector<AnfNodePtr> &all_no | |||
| MS_LOG(EXCEPTION) << "The OperatorInfo: " << current_op_ptr->name() | |||
| << " does not match the Prim: " << prim->name(); | |||
| } | |||
| cnode->SetUserData<OperatorInfo>(current_op_ptr); | |||
| cnode->set_user_data<OperatorInfo>(current_op_ptr); | |||
| MS_LOG(INFO) << "The CNode with UniqueId: " << cnode->UniqueId() | |||
| << " and UniqueIdThroughCopy: " << cnode->UniqueIdThroughCopy() | |||
| << " is set OperatorInfo: " << current_op_ptr->name() << ", Primitive: " << prim->name(); | |||
| @@ -549,7 +549,7 @@ void ConstructCostGraphEdges(const std::vector<AnfNodePtr> &all_nodes) { | |||
| PrimitivePtr prim = GetValueNode<PrimitivePtr>(prim_anf_node); | |||
| size_t edge_count = 0; | |||
| auto node_op_info = cnode->GetUserData<OperatorInfo>(); | |||
| auto node_op_info = cnode->user_data<OperatorInfo>(); | |||
| for (size_t i = 1; i < inputs.size(); ++i) { | |||
| auto prev_cnode = inputs[i]->cast<CNodePtr>(); | |||
| @@ -565,7 +565,7 @@ void ConstructCostGraphEdges(const std::vector<AnfNodePtr> &all_nodes) { | |||
| (IsAutoParallelCareNode(prev_cnode)) || (prev_prim->name() == TUPLE_GETITEM) || (prev_prim->name() == DEPEND); | |||
| while (bool_result) { | |||
| if (IsAutoParallelCareNode(prev_cnode)) { | |||
| auto prev_op_info = prev_cnode->GetUserData<OperatorInfo>(); | |||
| auto prev_op_info = prev_cnode->user_data<OperatorInfo>(); | |||
| std::string edge_name = prev_op_info->name() + OPERATOR_TO_OPERATOR_CONNECTOR + node_op_info->name(); | |||
| // If the edge between these two operators already has been added, then the edge will not be added again. | |||
| if (entire_costgraph->IsEdgeInCostGraph(edge_name, output_index, i - 1)) { | |||
| @@ -751,7 +751,7 @@ void AugmentCostGraph(const std::vector<AnfNodePtr> &all_nodes) { | |||
| auto target_cnode = target.first->cast<CNodePtr>(); | |||
| auto input_index = target.second; | |||
| (void)target_without_duplicate.insert(std::to_string(input_index) + | |||
| target_cnode->GetUserData<OperatorInfo>()->name()); | |||
| target_cnode->user_data<OperatorInfo>()->name()); | |||
| } | |||
| if (target_without_duplicate.size() <= 1) { | |||
| continue; | |||
| @@ -831,7 +831,7 @@ void AugmentCostGraph(const std::vector<AnfNodePtr> &all_nodes) { | |||
| auto target_cnode = target.first->cast<CNodePtr>(); | |||
| auto prim = GetValueNode<PrimitivePtr>(target_cnode->input(0)); | |||
| auto input_index = target.second; | |||
| auto target_op_info = target_cnode->GetUserData<OperatorInfo>(); | |||
| auto target_op_info = target_cnode->user_data<OperatorInfo>(); | |||
| std::string edge_name = std::string(IDENTITY_INFO) + OPERATOR_TO_OPERATOR_CONNECTOR + target_op_info->name(); | |||
| // If the edge between these two operators already has been added, then the edge will not be added again. | |||
| @@ -862,7 +862,7 @@ bool FindReshape(const CNodePtr &cnode) { | |||
| if ((cnode == nullptr) || !IsValueNode<Primitive>(cnode->input(0))) { | |||
| return false; | |||
| } | |||
| if (!IsParallelCareNode(cnode) || !cnode->HasUserData<OperatorInfo>()) { | |||
| if (!IsParallelCareNode(cnode) || !cnode->has_user_data<OperatorInfo>()) { | |||
| return false; | |||
| } | |||
| ValueNodePtr prim_anf_node = cnode->input(0)->cast<ValueNodePtr>(); | |||
| @@ -884,7 +884,7 @@ bool FindPreNodeStraCosts(const AnfNodePtr &node, OperatorInfoPtr *pre_operator_ | |||
| if (!IsValueNode<Primitive>(cnode->input(0))) { | |||
| return false; | |||
| } | |||
| auto node_op_info = cnode->GetUserData<OperatorInfo>(); | |||
| auto node_op_info = cnode->user_data<OperatorInfo>(); | |||
| if (IsParallelCareNode(cnode) && (node_op_info != nullptr)) { | |||
| *pre_operator_info = node_op_info; | |||
| *out_index = 0; | |||
| @@ -900,7 +900,7 @@ bool FindPreNodeStraCosts(const AnfNodePtr &node, OperatorInfoPtr *pre_operator_ | |||
| MS_LOG(EXCEPTION) << "tuple get item's second input is not a cnode"; | |||
| } | |||
| CNodePtr pre_cnode = pre_node->cast<CNodePtr>(); | |||
| auto pre_op_info = pre_cnode->GetUserData<OperatorInfo>(); | |||
| auto pre_op_info = pre_cnode->user_data<OperatorInfo>(); | |||
| if (IsParallelCareNode(pre_cnode) && (pre_op_info != nullptr)) { | |||
| *pre_operator_info = pre_op_info; | |||
| return true; | |||
| @@ -941,7 +941,7 @@ bool FindNextNodeStraCosts(const CNodePtr &cnode, OperatorInfoPtr *next_operator | |||
| if (node_prim->name() == DEPEND && node_pair.second != 1) { | |||
| continue; | |||
| } | |||
| auto op_info = use_apply->GetUserData<OperatorInfo>(); | |||
| auto op_info = use_apply->user_data<OperatorInfo>(); | |||
| if (IsParallelCareNode(use_apply) && (op_info != nullptr)) { | |||
| MS_LOG(INFO) << "FindNextNodeStraCosts success prim " << node_prim->name(); | |||
| *next_operator_info = op_info; | |||
| @@ -970,7 +970,7 @@ void ReshapeCostCompute(const std::vector<AnfNodePtr> &all_nodes) { | |||
| int32_t out_index = 0; | |||
| OperatorInfoPtr pre_operator_info; | |||
| std::vector<std::shared_ptr<StrategyWithCost>> pre_stra_costs; | |||
| auto operator_info = cnode->GetUserData<OperatorInfo>(); | |||
| auto operator_info = cnode->user_data<OperatorInfo>(); | |||
| if (pre_node->isa<Parameter>()) { | |||
| auto reshape_info = std::dynamic_pointer_cast<ReshapeInfo>(operator_info); | |||
| reshape_info->SetCostForReshapeWithParameter(); | |||
| @@ -272,7 +272,7 @@ OperatorInfoPtr GetDistributeOperator(const CNodePtr &node) { | |||
| if (!IsParallelCareNode(node)) { | |||
| return nullptr; | |||
| } | |||
| OperatorInfoPtr distribute_operator = node->GetUserData<OperatorInfo>(); | |||
| OperatorInfoPtr distribute_operator = node->user_data<OperatorInfo>(); | |||
| if (distribute_operator == nullptr) { | |||
| MS_LOG(EXCEPTION) << "GetDistributeOperator:distribute_operator is nullptr"; | |||
| } | |||
| @@ -409,7 +409,7 @@ bool IsParallelCareNode(const CNodePtr &cnode) { | |||
| if (prim->name() == GET_NEXT) { | |||
| return true; | |||
| } | |||
| if ((prim->name() == CAST) && !cnode->HasUserData<OperatorInfo>()) { | |||
| if ((prim->name() == CAST) && !cnode->has_user_data<OperatorInfo>()) { | |||
| return false; | |||
| } | |||
| @@ -446,7 +446,7 @@ void StepRedistribution(const CNodePtr &node, const OperatorInfoPtr &distribute_ | |||
| if (node_prim->name() == DEPEND && node_pair.second != 1) { | |||
| continue; | |||
| } | |||
| if (IsParallelCareNode(use_cnode) && use_cnode->HasUserData<OperatorInfo>()) { | |||
| if (IsParallelCareNode(use_cnode) && use_cnode->has_user_data<OperatorInfo>()) { | |||
| Redistribution(node_pair, distribute_operator, insert_node_new, node_pair.second, tensor_redistribution, | |||
| pre_node); | |||
| } else { | |||
| @@ -459,7 +459,7 @@ void StepRedistribution(const CNodePtr &node, const OperatorInfoPtr &distribute_ | |||
| void SplitTensor(const AnfNodePtr &node, const CNodePtr &next_node, int index) { | |||
| MS_EXCEPTION_IF_NULL(node); | |||
| MS_EXCEPTION_IF_NULL(next_node); | |||
| OperatorInfoPtr op_info = next_node->GetUserData<OperatorInfo>(); | |||
| OperatorInfoPtr op_info = next_node->user_data<OperatorInfo>(); | |||
| MS_EXCEPTION_IF_NULL(op_info); | |||
| // If the shape of tensor is [] or [1], no need to split it. | |||
| @@ -584,7 +584,7 @@ void ReplaceOneOp(const Operator &replace_op, const CNodePtr &node) { | |||
| void StepReplaceOp(OperatorVector replace_op, const CNodePtr &node) { | |||
| // step1:get graph manager distribute_operator | |||
| OperatorInfoPtr distribute_operator = node->GetUserData<OperatorInfo>(); | |||
| OperatorInfoPtr distribute_operator = node->user_data<OperatorInfo>(); | |||
| if (distribute_operator == nullptr) { | |||
| MS_LOG(EXCEPTION) << "Failure:AddNode error since distribute_operator is nullptr"; | |||
| } | |||
| @@ -622,7 +622,7 @@ void StepReplaceOp(OperatorVector replace_op, const CNodePtr &node) { | |||
| (void)prim->SetAttrs(attrs); | |||
| } | |||
| if (index == replace_op.size() - 1) { | |||
| replace_node->SetUserData<OperatorInfo>(node->GetUserData<OperatorInfo>()); | |||
| replace_node->set_user_data<OperatorInfo>(node->user_data<OperatorInfo>()); | |||
| } | |||
| replace_node->set_in_forward_flag(true); | |||
| replace_input[0]->set_scope(scope); | |||
| @@ -702,7 +702,7 @@ LossNodeInfo GetLossNodeInfo(const AnfNodePtr &loss_node) { | |||
| auto pre_cnode = pre_node->cast<CNodePtr>(); | |||
| MS_EXCEPTION_IF_NULL(pre_cnode); | |||
| auto pre_prim = GetValueNode<PrimitivePtr>(pre_cnode->input(0)); | |||
| if (pre_prim->name() == CAST && !pre_cnode->HasUserData<OperatorInfo>()) { | |||
| if (pre_prim->name() == CAST && !pre_cnode->has_user_data<OperatorInfo>()) { | |||
| pre_node = pre_cnode->input(1); | |||
| } | |||
| @@ -1198,7 +1198,7 @@ std::pair<AnfNodePtr, int> FindParallelCareNode(const AnfNodePtr &node) { | |||
| if (node_prim->name() == DEPEND && node_pair.second != 1) { | |||
| continue; | |||
| } | |||
| if (IsParallelCareNode(cnode) && cnode->HasUserData<OperatorInfo>()) { | |||
| if (IsParallelCareNode(cnode) && cnode->has_user_data<OperatorInfo>()) { | |||
| return node_pair; | |||
| } else if (FindParallelCareNode(node_pair.first).first != nullptr) { | |||
| return FindParallelCareNode(node_pair.first); | |||
| @@ -1248,7 +1248,7 @@ void SetParallelShape(const AnfNodePtr ¶meter, const std::pair<AnfNodePtr, i | |||
| MS_LOG(DEBUG) << "SetParallelShape " << parameter->ToString() << " shape " << parameter->Shape()->ToString(); | |||
| CNodePtr cnode = res.first->cast<CNodePtr>(); | |||
| MS_EXCEPTION_IF_NULL(cnode); | |||
| OperatorInfoPtr distribute_operator = cnode->GetUserData<OperatorInfo>(); | |||
| OperatorInfoPtr distribute_operator = cnode->user_data<OperatorInfo>(); | |||
| if (distribute_operator == nullptr) { | |||
| MS_LOG(EXCEPTION) << "Failure:node " << cnode->ToString() << " 's OperatorInfoPtr is nullptr"; | |||
| } | |||
| @@ -1271,7 +1271,7 @@ void SetParallelShape(const AnfNodePtr ¶meter, const std::pair<AnfNodePtr, i | |||
| TensorLayout tensor_layout = tensorinfo_in.tensor_layout(); | |||
| ParameterPtr parameter_ptr = parameter->cast<ParameterPtr>(); | |||
| MS_EXCEPTION_IF_NULL(parameter_ptr); | |||
| parameter_ptr->SetUserData<TensorLayout>(std::make_shared<TensorLayout>(tensor_layout)); | |||
| parameter_ptr->set_user_data<TensorLayout>(std::make_shared<TensorLayout>(tensor_layout)); | |||
| } | |||
| void CoverSliceShape(const FuncGraphPtr &root) { | |||
| @@ -1359,7 +1359,7 @@ void SetClonedTensorShapeForOptimizer(const FuncGraphPtr &root) { | |||
| if (found_be_cloned_parameter) { | |||
| // set the shape and tensor layout for cloned parameter | |||
| cloned_parameter->SetUserData<TensorLayout>(cloned_from_parameter->GetUserData<TensorLayout>()); | |||
| cloned_parameter->set_user_data<TensorLayout>(cloned_from_parameter->user_data<TensorLayout>()); | |||
| MS_EXCEPTION_IF_NULL(cloned_parameter_node->abstract()); | |||
| MS_EXCEPTION_IF_NULL(cloned_from_node->abstract()); | |||
| auto cloned_abstract = cloned_parameter_node->abstract()->Clone(); | |||
| @@ -1454,7 +1454,7 @@ void ExtractInformation(const std::vector<AnfNodePtr> &all_nodes) { | |||
| (*operator_).set_outputs_dtype(cnode->Type()); | |||
| (*operator_).set_cnode(cnode); | |||
| if (prim->name() == RESHAPE) { | |||
| cnode->SetUserData<OperatorInfo>(operator_); | |||
| cnode->set_user_data<OperatorInfo>(operator_); | |||
| continue; | |||
| } | |||
| // load strategy checkpoint | |||
| @@ -1489,7 +1489,7 @@ void ExtractInformation(const std::vector<AnfNodePtr> &all_nodes) { | |||
| if (operator_->Init(strategyPtr) == FAILED) { | |||
| MS_LOG(EXCEPTION) << "Failure:operator " << prim->name() << " init failed"; | |||
| } | |||
| cnode->SetUserData<OperatorInfo>(operator_); | |||
| cnode->set_user_data<OperatorInfo>(operator_); | |||
| } else { | |||
| MS_LOG(EXCEPTION) << "ERROR:strategy_ptr is nullptr"; | |||
| } | |||
| @@ -1532,13 +1532,13 @@ std::shared_ptr<TensorLayout> FindNextLayout(const CNodePtr &cnode) { | |||
| if (node_prim->name() == DEPEND && node_pair.second != 1) { | |||
| continue; | |||
| } | |||
| if (IsParallelCareNode(use_apply) && use_apply->HasUserData<OperatorInfo>()) { | |||
| if (IsParallelCareNode(use_apply) && use_apply->has_user_data<OperatorInfo>()) { | |||
| MS_LOG(INFO) << "FindNextLayout success prim " << node_prim->name(); | |||
| auto layout = GetInputLayoutFromCNode(node_pair); | |||
| return std::make_shared<TensorLayout>(layout); | |||
| } | |||
| MS_LOG(DEBUG) << "FindNextLayout failed prim " << node_prim->name() << " " << IsParallelCareNode(use_apply) | |||
| << " " << use_apply->HasUserData<OperatorInfo>(); | |||
| << " " << use_apply->has_user_data<OperatorInfo>(); | |||
| auto layout_ptr = FindNextLayout(use_apply); | |||
| if (layout_ptr) { | |||
| @@ -1570,7 +1570,7 @@ std::shared_ptr<TensorLayout> FindPrevParallelCareNodeLayout(const AnfNodePtr &n | |||
| if (!IsValueNode<Primitive>(cnode->input(0))) { | |||
| return nullptr; | |||
| } | |||
| if (IsParallelCareNode(cnode) && cnode->HasUserData<OperatorInfo>()) { | |||
| if (IsParallelCareNode(cnode) && cnode->has_user_data<OperatorInfo>()) { | |||
| auto layout_ptr = GetOutputLayoutFromCNode(cnode, output_index); | |||
| if (!layout_ptr) { | |||
| MS_LOG(EXCEPTION) << "Failure:GetLayoutFromCNode failed"; | |||
| @@ -1614,7 +1614,7 @@ std::shared_ptr<TensorLayout> FindPrevLayout(const AnfNodePtr &node) { | |||
| if (!IsValueNode<Primitive>(cnode->input(0))) { | |||
| return nullptr; | |||
| } | |||
| if (IsParallelCareNode(cnode) && cnode->HasUserData<OperatorInfo>()) { | |||
| if (IsParallelCareNode(cnode) && cnode->has_user_data<OperatorInfo>()) { | |||
| auto layout_ptr = GetOutputLayoutFromCNode(cnode, 0); | |||
| if (!layout_ptr) { | |||
| MS_LOG(EXCEPTION) << "Failure:GetLayoutFromCNode failed"; | |||
| @@ -1654,12 +1654,12 @@ void ReshapeInit(const std::vector<AnfNodePtr> &all_nodes) { | |||
| continue; | |||
| } | |||
| ValueNodePtr prim_anf_node = cnode->input(0)->cast<ValueNodePtr>(); | |||
| if (!IsParallelCareNode(cnode) || !cnode->HasUserData<OperatorInfo>()) { | |||
| if (!IsParallelCareNode(cnode) || !cnode->has_user_data<OperatorInfo>()) { | |||
| continue; | |||
| } | |||
| PrimitivePtr prim = GetValueNode<PrimitivePtr>(prim_anf_node); | |||
| MS_EXCEPTION_IF_NULL(prim); | |||
| OperatorInfoPtr operator_info = cnode->GetUserData<OperatorInfo>(); | |||
| OperatorInfoPtr operator_info = cnode->user_data<OperatorInfo>(); | |||
| if (operator_info == nullptr) { | |||
| MS_LOG(EXCEPTION) << "Failure:Primitive " << prim->ToString() << " OperatorInstance is nullptr"; | |||
| } | |||
| @@ -1704,7 +1704,7 @@ CNodePtr FindLossCNode(const FuncGraphPtr &func_graph) { | |||
| auto current_prim = GetValueNode<PrimitivePtr>(pre_cnode->input(0)); | |||
| // return -> cast | |||
| if (current_prim->name() == CAST && !pre_cnode->HasUserData<OperatorInfo>()) { | |||
| if (current_prim->name() == CAST && !pre_cnode->has_user_data<OperatorInfo>()) { | |||
| pre_cnode = pre_cnode->input(1)->cast<CNodePtr>(); | |||
| MS_EXCEPTION_IF_NULL(pre_cnode); | |||
| current_prim = GetValueNode<PrimitivePtr>(pre_cnode->input(0)); | |||
| @@ -1761,7 +1761,7 @@ TensorLayouts GetLossNodeGradOutputLayout(const CNodePtr &loss_cnode) { | |||
| return ret; | |||
| } | |||
| OperatorInfoPtr operator_info = loss_cnode->GetUserData<OperatorInfo>(); | |||
| OperatorInfoPtr operator_info = loss_cnode->user_data<OperatorInfo>(); | |||
| MS_EXCEPTION_IF_NULL(operator_info); | |||
| TensorInfo loss_grad_tensor_info; | |||
| size_t op_output_size = operator_info->outputs_tensor_info().size(); | |||
| @@ -1799,7 +1799,7 @@ void SplitSens(const CNodePtr &grad_sens_node, const TensorLayout &loss_grad_lay | |||
| if (sens_tensor_node->isa<Parameter>()) { | |||
| auto sens_tensor_param = sens_tensor_node->cast<ParameterPtr>(); | |||
| MS_LOG(DEBUG) << "loss layout " << loss_grad_layout.ToString(); | |||
| sens_tensor_param->SetUserData<TensorLayout>(std::make_shared<TensorLayout>(loss_grad_layout)); | |||
| sens_tensor_param->set_user_data<TensorLayout>(std::make_shared<TensorLayout>(loss_grad_layout)); | |||
| } | |||
| MS_LOG(INFO) << "The shape of sens is " << ShapeToString(sens_shape) << ", no need to split sens"; | |||
| return; | |||
| @@ -1824,7 +1824,7 @@ void SplitSens(const CNodePtr &grad_sens_node, const TensorLayout &loss_grad_lay | |||
| cloned_abstract->set_shape(parallel_shape); | |||
| sens_tensor_node->set_abstract(cloned_abstract); | |||
| auto sens_tensor_param = sens_tensor_node->cast<ParameterPtr>(); | |||
| sens_tensor_param->SetUserData<TensorLayout>(std::make_shared<TensorLayout>(loss_grad_layout)); | |||
| sens_tensor_param->set_user_data<TensorLayout>(std::make_shared<TensorLayout>(loss_grad_layout)); | |||
| return; | |||
| } | |||
| MS_LOG(EXCEPTION) << "The type of sens node is not Tensor or Parameter, it is unsupported now."; | |||
| @@ -2131,7 +2131,7 @@ void CheckpointStrategy(const FuncGraphPtr &func_graph) { | |||
| } | |||
| PrimitivePtr prim = GetValueNode<PrimitivePtr>(cnode->input(0)); | |||
| MS_EXCEPTION_IF_NULL(prim); | |||
| OperatorInfoPtr operator_info = cnode->GetUserData<OperatorInfo>(); | |||
| OperatorInfoPtr operator_info = cnode->user_data<OperatorInfo>(); | |||
| if (operator_info) { | |||
| if (operator_info->name().find(RESHAPEINFO) != std::string::npos) { | |||
| continue; | |||
| @@ -158,29 +158,29 @@ class AnfNode : public Base { | |||
| size_t seen_{0}; | |||
| template <typename T> | |||
| void SetUserData(const std::string &key, const std::shared_ptr<T> &value) { | |||
| void set_user_data(const std::string &key, const std::shared_ptr<T> &value) { | |||
| user_data_.set<T>(key, value); | |||
| } | |||
| template <typename T> | |||
| void SetUserData(const std::shared_ptr<T> &value) { | |||
| void set_user_data(const std::shared_ptr<T> &value) { | |||
| user_data_.set<T>(T::key, value); | |||
| } | |||
| template <typename T> | |||
| std::shared_ptr<T> GetUserData(const std::string &key) const { | |||
| std::shared_ptr<T> user_data(const std::string &key) const { | |||
| return user_data_.get<T>(key); | |||
| } | |||
| template <typename T> | |||
| std::shared_ptr<T> GetUserData() const { | |||
| std::shared_ptr<T> user_data() const { | |||
| return user_data_.get<T>(T::key); | |||
| } | |||
| bool HasUserData(const std::string &key) const { return user_data_.has(key); } | |||
| bool has_user_data(const std::string &key) const { return user_data_.has(key); } | |||
| template <typename T> | |||
| bool HasUserData() const { | |||
| bool has_user_data() const { | |||
| return user_data_.has(T::key); | |||
| } | |||
| @@ -153,7 +153,7 @@ TEST_F(TestStepAutoParallel, test_create_op_instance) { | |||
| StrategyPtr strategyPtr; | |||
| std::shared_ptr<OperatorInfo> matmul_info = NewOperatorInstance(prim, attrs, shape); | |||
| node->SetUserData<OperatorInfo>(matmul_info); | |||
| node->set_user_data<OperatorInfo>(matmul_info); | |||
| std::string name_expect = "MatMulInfo00"; | |||
| std::string name_test = matmul_info->name(); | |||
| ASSERT_EQ(name_expect, name_test); | |||
| @@ -522,8 +522,8 @@ TEST_F(TestStepParallel, GetTensorInLayout) { | |||
| std::vector<Shapes> shape = {inputs_shape, outputs_shape}; | |||
| OperatorInfoPtr matmul_info = OperatorInstance(prim, attrs, shape); | |||
| matmul_info->Init(strategyPtr); | |||
| node->SetUserData<OperatorInfo>(matmul_info); | |||
| OperatorInfoPtr distribute_operator_pre = node->GetUserData<OperatorInfo>(); | |||
| node->set_user_data<OperatorInfo>(matmul_info); | |||
| OperatorInfoPtr distribute_operator_pre = node->user_data<OperatorInfo>(); | |||
| TensorLayout tensorlayout_e; | |||
| std::vector<int32_t> array = {64, 64}; | |||
| TensorLayout tensorlayout = GetTensorInLayout(node1, prim, distribute_operator_pre); | |||