Merge pull request !26170 from hewei/fix_ccntags/v1.6.0
| @@ -39,13 +39,10 @@ bool AdjustDependForParallelOptimizerRecomputeAllGatherFusion::Run(const FuncGra | |||
| continue; | |||
| } | |||
| auto cnode = node->cast<CNodePtr>(); | |||
| auto primitive = AnfAlgo::GetCNodePrimitive(cnode); | |||
| auto instance_name = primitive->instance_name(); | |||
| bool is_allgather = AnfAlgo::GetCNodeName(cnode) == kAllGatherOpName; | |||
| bool is_fusion = AnfAlgo::HasNodeAttr(kAttrFusion, cnode) && AnfAlgo::GetNodeAttr<int64_t>(cnode, kAttrFusion) > 0; | |||
| bool is_recompute = cnode->GetAttr(kAttrDuplicated) != nullptr && GetValue<bool>(cnode->GetAttr(kAttrDuplicated)); | |||
| bool is_from_parallel_optimizer = instance_name.find("parallel_optimizer") != std::string::npos; | |||
| if (is_allgather && is_fusion && is_recompute && is_from_parallel_optimizer) { | |||
| if (!AnfAlgo::IsAllgather(cnode) || !AnfAlgo::IsFusion(cnode) || !AnfAlgo::IsFromParallelOptimizer(cnode)) { | |||
| continue; | |||
| } | |||
| if (AnfAlgo::IsRecompute(cnode)) { | |||
| int64_t fusion_id = AnfAlgo::GetNodeAttr<int64_t>(cnode, kAttrFusion); | |||
| if (std::find(parallel_optimizer_recompute_allgather_fusion_ids.begin(), | |||
| parallel_optimizer_recompute_allgather_fusion_ids.end(), | |||
| @@ -58,16 +55,14 @@ bool AdjustDependForParallelOptimizerRecomputeAllGatherFusion::Run(const FuncGra | |||
| } else { | |||
| parallel_optimizer_recompute_allgathers.push_back(node); | |||
| } | |||
| } | |||
| if (!is_recompute && is_fusion && is_allgather && is_from_parallel_optimizer) { | |||
| } else { | |||
| int64_t unrecompute_fusion_id = AnfAlgo::GetNodeAttr<int64_t>(cnode, kAttrFusion); | |||
| unrecompute_max_fusion_id = std::max(unrecompute_fusion_id, unrecompute_max_fusion_id); | |||
| bool would_be_recomputed = | |||
| AnfAlgo::HasNodeAttr(kAttrRecompute, cnode) && AnfAlgo::GetNodeAttr<bool>(cnode, kAttrRecompute); | |||
| if (forward_allgather_recompute_value_in_fusion_group.find(unrecompute_fusion_id) == | |||
| forward_allgather_recompute_value_in_fusion_group.end()) { | |||
| forward_allgather_recompute_value_in_fusion_group[unrecompute_fusion_id] = would_be_recomputed; | |||
| } else if (forward_allgather_recompute_value_in_fusion_group[unrecompute_fusion_id] != would_be_recomputed) { | |||
| auto [iter, inserted] = | |||
| forward_allgather_recompute_value_in_fusion_group.emplace(unrecompute_fusion_id, would_be_recomputed); | |||
| if (!inserted && iter->second != would_be_recomputed) { | |||
| MS_LOG(EXCEPTION) << "In same fusion group, the allgather recompute attribute should be equal. " | |||
| "The normal node is:" | |||
| << cnode->fullname_with_scope(); | |||
| @@ -35,13 +35,10 @@ bool AdjustDependForParallelOptimizerRecomputeAllGather::Run(const FuncGraphPtr | |||
| continue; | |||
| } | |||
| auto cnode = node->cast<CNodePtr>(); | |||
| auto primitive = AnfAlgo::GetCNodePrimitive(cnode); | |||
| auto instance_name = primitive->instance_name(); | |||
| bool is_allgather = AnfAlgo::GetCNodeName(cnode) == kAllGatherOpName; | |||
| bool is_fusion = AnfAlgo::HasNodeAttr(kAttrFusion, cnode) && AnfAlgo::GetNodeAttr<int64_t>(cnode, kAttrFusion) > 0; | |||
| bool is_recompute = cnode->GetAttr(kAttrDuplicated) != nullptr && GetValue<bool>(cnode->GetAttr(kAttrDuplicated)); | |||
| bool is_from_parallel_optimizer = instance_name.find("parallel_optimizer") != std::string::npos; | |||
| if (is_allgather && is_fusion && is_recompute && is_from_parallel_optimizer) { | |||
| if (!AnfAlgo::IsAllgather(cnode) || !AnfAlgo::IsFusion(cnode) || !AnfAlgo::IsFromParallelOptimizer(cnode)) { | |||
| continue; | |||
| } | |||
| if (AnfAlgo::IsRecompute(cnode)) { | |||
| int64_t fusion_id = AnfAlgo::GetNodeAttr<int64_t>(cnode, kAttrFusion); | |||
| if (std::find(parallel_optimizer_recompute_allgather_fusion_ids.begin(), | |||
| parallel_optimizer_recompute_allgather_fusion_ids.end(), | |||
| @@ -54,16 +51,14 @@ bool AdjustDependForParallelOptimizerRecomputeAllGather::Run(const FuncGraphPtr | |||
| } else { | |||
| parallel_optimizer_recompute_allgathers.push_back(node); | |||
| } | |||
| } | |||
| if (!is_recompute && is_fusion && is_allgather && is_from_parallel_optimizer) { | |||
| } else { | |||
| int64_t unrecompute_fusion_id = AnfAlgo::GetNodeAttr<int64_t>(cnode, kAttrFusion); | |||
| unrecompute_max_fusion_id = std::max(unrecompute_fusion_id, unrecompute_max_fusion_id); | |||
| bool would_be_recomputed = | |||
| AnfAlgo::HasNodeAttr(kAttrRecompute, cnode) && AnfAlgo::GetNodeAttr<bool>(cnode, kAttrRecompute); | |||
| if (forward_allgather_recompute_value_in_fusion_group.find(unrecompute_fusion_id) == | |||
| forward_allgather_recompute_value_in_fusion_group.end()) { | |||
| forward_allgather_recompute_value_in_fusion_group[unrecompute_fusion_id] = would_be_recomputed; | |||
| } else if (forward_allgather_recompute_value_in_fusion_group[unrecompute_fusion_id] != would_be_recomputed) { | |||
| auto [iter, inserted] = | |||
| forward_allgather_recompute_value_in_fusion_group.emplace(unrecompute_fusion_id, would_be_recomputed); | |||
| if (!inserted && iter->second != would_be_recomputed) { | |||
| MS_LOG(EXCEPTION) << "In same fusion group, the allgather recompute attribute should be equal. " | |||
| "The normal node is:" | |||
| << cnode->fullname_with_scope(); | |||
| @@ -347,6 +347,22 @@ class AnfRuntimeAlgorithm { | |||
| static size_t GetOutputNumByAbstract(const AbstractBasePtr &node_abstract); | |||
| // Fetch all outputs of call node. | |||
| static std::vector<KernelWithIndex> GetAllOutputByCallNode(const KernelWithIndex &output_with_index); | |||
| static inline bool IsAllgather(const CNodePtr &cnode) { return GetCNodeName(cnode) == kAllGatherOpName; } | |||
| static inline bool IsFusion(const CNodePtr &cnode) { | |||
| return HasNodeAttr(kAttrFusion, cnode) && GetNodeAttr<int64_t>(cnode, kAttrFusion) > 0; | |||
| } | |||
| static inline bool IsFromParallelOptimizer(const CNodePtr &cnode) { | |||
| auto primitive = GetCNodePrimitive(cnode); | |||
| return (primitive != nullptr) && primitive->instance_name().find("parallel_optimizer") != std::string::npos; | |||
| } | |||
| static inline bool IsRecompute(const CNodePtr &cnode) { | |||
| auto attr_dup = cnode->GetAttr(kAttrDuplicated); | |||
| return attr_dup != nullptr && GetValue<bool>(attr_dup); | |||
| } | |||
| }; | |||
| } // namespace session | |||
| using AnfAlgo = session::AnfRuntimeAlgorithm; | |||
| @@ -340,49 +340,42 @@ std::string AnfExporter::GetOtherValueText(const FuncGraphPtr &, const ValuePtr | |||
| return oss.str(); | |||
| } | |||
| static bool CanUseDumpText(const ValuePtr &value) { | |||
| return (value->isa<RefKey>() || value->isa<Scalar>() || value->isa<StringImm>() || value->isa<tensor::Tensor>() || | |||
| value->isa<parse::Symbol>() || value->isa<None>() || value->isa<Null>() || value->isa<ValueSlice>() || | |||
| value->isa<Type>() || value->isa<KeywordArg>()); | |||
| } | |||
| std::string AnfExporter::GetValueText(const FuncGraphPtr &func_graph, const ValuePtr &value) { | |||
| std::ostringstream oss; | |||
| bool is_null_ptr = (func_graph == nullptr || value == nullptr); | |||
| if (is_null_ptr) { | |||
| return oss.str(); | |||
| if (func_graph == nullptr || value == nullptr) { | |||
| return ""; | |||
| } | |||
| if (value->isa<Primitive>()) { | |||
| oss << GetPrimitiveText(value->cast<PrimitivePtr>()); | |||
| } else if (value->isa<MetaFuncGraph>()) { | |||
| return GetPrimitiveText(value->cast<PrimitivePtr>()); | |||
| } | |||
| if (value->isa<MetaFuncGraph>()) { | |||
| MetaFuncGraphPtr meta_func_graph = value->cast<MetaFuncGraphPtr>(); | |||
| oss << GetMetaFuncGraphText(meta_func_graph); | |||
| } else if (value->isa<SymbolicKeyInstance>()) { | |||
| oss << GetSymbolicKeyInstanceText(func_graph, value->cast<SymbolicKeyInstancePtr>()); | |||
| } else if (value->isa<RefKey>()) { | |||
| oss << value->DumpText(); | |||
| } else if (value->isa<Scalar>() || value->isa<StringImm>()) { | |||
| oss << value->DumpText(); | |||
| } else if (value->isa<tensor::Tensor>()) { | |||
| oss << value->DumpText(); | |||
| } else if (value->isa<parse::Symbol>() || value->isa<None>() || value->isa<Null>()) { | |||
| oss << value->DumpText(); | |||
| } else if (value->isa<ValueSequeue>()) { | |||
| oss << GetSequenceText(func_graph, value); | |||
| } else if (value->isa<ValueDictionary>()) { | |||
| oss << GetDictText(func_graph, value); | |||
| } else if (value->isa<ValueSlice>()) { | |||
| ValueSlicePtr slice = value->cast<ValueSlicePtr>(); | |||
| oss << slice->DumpText(); | |||
| } else if (value->isa<Type>()) { | |||
| oss << value->DumpText(); | |||
| } else if (value->isa<parse::NameSpace>()) { | |||
| oss << GetNameSpaceText(value->cast<parse::NameSpacePtr>()); | |||
| } else if (value->isa<parse::PyObjectWrapper>()) { | |||
| oss << value->type_name(); | |||
| } else if (value->isa<KeywordArg>()) { | |||
| KeywordArgPtr keyword_arg = value->cast<KeywordArgPtr>(); | |||
| oss << keyword_arg->DumpText(); | |||
| } else { | |||
| return GetOtherValueText(func_graph, value); | |||
| return GetMetaFuncGraphText(meta_func_graph); | |||
| } | |||
| return oss.str(); | |||
| if (value->isa<SymbolicKeyInstance>()) { | |||
| return GetSymbolicKeyInstanceText(func_graph, value->cast<SymbolicKeyInstancePtr>()); | |||
| } | |||
| if (value->isa<ValueSequeue>()) { | |||
| return GetSequenceText(func_graph, value); | |||
| } | |||
| if (value->isa<ValueDictionary>()) { | |||
| return GetDictText(func_graph, value); | |||
| } | |||
| if (value->isa<parse::NameSpace>()) { | |||
| return GetNameSpaceText(value->cast<parse::NameSpacePtr>()); | |||
| } | |||
| if (value->isa<parse::PyObjectWrapper>()) { | |||
| return value->type_name(); | |||
| } | |||
| if (CanUseDumpText(value)) { | |||
| return value->DumpText(); | |||
| } | |||
| return GetOtherValueText(func_graph, value); | |||
| } | |||
| // This function is used to output node in CNode's inputs | |||
| @@ -1,5 +1,5 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * Copyright 2020-2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| @@ -121,7 +121,10 @@ class IncorporateCall : public AnfVisitor { | |||
| (void)args.insert(args.end(), Xs_.begin(), Xs_.end()); | |||
| } | |||
| } | |||
| return MakeNewNode(node, args); | |||
| } | |||
| AnfNodePtr MakeNewNode(const AnfNodePtr &node, const std::vector<AnfNodePtr> &args) { | |||
| auto new_node = node->func_graph()->NewCNode(args); | |||
| new_node->set_abstract(node->abstract()); | |||
| // Check if the another only usage of {G, Xs} is UpdateState{s, {G, Xs}}, if yes, replace | |||
| @@ -181,43 +181,12 @@ class ChoicePartialEliminater : public AnfVisitor { | |||
| AnfNodePtrList UnifyParameters(const size_t &anchor_index, const AnfNodePtrList &fg_list, | |||
| const std::vector<AnfNodePtrList> args_list) { | |||
| std::vector<size_t> inputs_index_list[args_list.size()]; | |||
| size_t extra_input_counter = 0; | |||
| AnfNodePtrList extra_inputs; | |||
| const auto &anchor_args = args_list[anchor_index]; | |||
| size_t anchor_args_size = anchor_args.size(); | |||
| auto anchor_fg = GetValueNode<FuncGraphPtr>(fg_list[anchor_index]); | |||
| MS_EXCEPTION_IF_NULL(anchor_fg); | |||
| // Find the new location of the old_inputs except Zs; | |||
| for (size_t i = 0; i < args_list.size(); ++i) { | |||
| if (i == anchor_index) { | |||
| continue; | |||
| } | |||
| const auto &another_args = args_list[i]; | |||
| auto &curr_inputs_index = inputs_index_list[i]; | |||
| for (size_t j = 0; j < another_args.size(); ++j) { | |||
| size_t k; | |||
| for (k = 0; k < anchor_args_size; ++k) { | |||
| if (another_args[j] == anchor_args[k]) { | |||
| curr_inputs_index.push_back(k); | |||
| break; | |||
| } | |||
| } | |||
| if (k == anchor_args_size) { | |||
| // check if used by another func_graph; | |||
| for (k = 0; k < extra_input_counter; ++k) { | |||
| if (another_args[j] == extra_inputs[k]) { | |||
| curr_inputs_index.push_back(anchor_args_size + k); | |||
| break; | |||
| } | |||
| } | |||
| if (k == extra_input_counter) { | |||
| extra_inputs.push_back(another_args[j]); | |||
| curr_inputs_index.push_back(anchor_args_size + extra_input_counter); | |||
| extra_input_counter++; | |||
| } | |||
| } | |||
| } | |||
| } | |||
| size_t extra_input_counter = FindNewLocation(args_list, anchor_index, inputs_index_list, &extra_inputs); | |||
| auto manager = anchor_fg->manager(); | |||
| MS_EXCEPTION_IF_NULL(manager); | |||
| @@ -284,6 +253,46 @@ class ChoicePartialEliminater : public AnfVisitor { | |||
| return extra_inputs; | |||
| } | |||
| // Find the new location of the old_inputs except Zs. | |||
| size_t FindNewLocation(const std::vector<AnfNodePtrList> &args_list, size_t anchor_index, | |||
| std::vector<size_t> *inputs_index_list, AnfNodePtrList *extra_inputs_ptr) { | |||
| const auto &anchor_args = args_list[anchor_index]; | |||
| auto &extra_inputs = *extra_inputs_ptr; | |||
| size_t extra_input_counter = 0; | |||
| size_t anchor_args_size = anchor_args.size(); | |||
| for (size_t i = 0; i < args_list.size(); ++i) { | |||
| if (i == anchor_index) { | |||
| continue; | |||
| } | |||
| const auto &another_args = args_list[i]; | |||
| auto &curr_inputs_index = inputs_index_list[i]; | |||
| for (size_t j = 0; j < another_args.size(); ++j) { | |||
| size_t k; | |||
| for (k = 0; k < anchor_args_size; ++k) { | |||
| if (another_args[j] == anchor_args[k]) { | |||
| curr_inputs_index.push_back(k); | |||
| break; | |||
| } | |||
| } | |||
| if (k == anchor_args_size) { | |||
| // check if used by another func_graph; | |||
| for (k = 0; k < extra_input_counter; ++k) { | |||
| if (another_args[j] == extra_inputs[k]) { | |||
| curr_inputs_index.push_back(anchor_args_size + k); | |||
| break; | |||
| } | |||
| } | |||
| if (k == extra_input_counter) { | |||
| extra_inputs.push_back(another_args[j]); | |||
| curr_inputs_index.push_back(anchor_args_size + extra_input_counter); | |||
| extra_input_counter++; | |||
| } | |||
| } | |||
| } | |||
| } | |||
| return extra_input_counter; | |||
| } | |||
| }; | |||
| // {{prim::kPrimSwitch, cond, {prim::kPrimPartial, G1, Xs}, {prim::kPrimPartial, G2, Ys}}, Zs} -> | |||
| @@ -1,5 +1,5 @@ | |||
| /** | |||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||
| * Copyright 2019-2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| @@ -673,6 +673,68 @@ void CheckAndApplyApproximation() { | |||
| } | |||
| } | |||
| static void ConstructCNodeCostGraphEdges(const mindspore::CNodePtr &cnode) { | |||
| auto &inputs = cnode->inputs(); | |||
| ValueNodePtr prim_anf_node = inputs[0]->cast<ValueNodePtr>(); | |||
| PrimitivePtr prim = GetValueNode<PrimitivePtr>(prim_anf_node); | |||
| size_t edge_count = 0; | |||
| auto node_op_info = cnode->user_data<OperatorInfo>(); | |||
| for (size_t i = 1; i < inputs.size(); ++i) { | |||
| auto prev_cnode = inputs[i]->cast<CNodePtr>(); | |||
| bool bool_result_prev_cnode = (prev_cnode == nullptr) || (!IsValueNode<Primitive>(prev_cnode->input(0))); | |||
| if (bool_result_prev_cnode) { | |||
| continue; | |||
| } | |||
| ValueNodePtr prev_prim_anf_node = prev_cnode->input(0)->cast<ValueNodePtr>(); | |||
| PrimitivePtr prev_prim = prev_prim_anf_node->value()->cast<PrimitivePtr>(); | |||
| size_t output_index = 0; | |||
| while ((IsAutoParallelCareNode(prev_cnode)) || (prev_prim->name() == prim::kTupleGetItem) || | |||
| (prev_prim->name() == DEPEND)) { | |||
| if (IsAutoParallelCareNode(prev_cnode)) { | |||
| auto prev_op_info = prev_cnode->user_data<OperatorInfo>(); | |||
| CreateEdgeBetweenTwoOps(prev_op_info, node_op_info, cnode, prev_cnode, prim, prev_prim, output_index, i, | |||
| &edge_count); | |||
| break; | |||
| } else if (prev_prim->name() == prim::kTupleGetItem) { | |||
| // In this case, 'prev_anf_node' is 'tuple_getitem', the actual precursor node is node before | |||
| // this 'tuple_getitem' | |||
| MS_LOG(INFO) << "Jumping the 'tuple_getitem' operator."; | |||
| output_index = LongToSize(GetValue<int64_t>(GetValueNode(prev_cnode->input(2)))); | |||
| prev_cnode = prev_cnode->input(1)->cast<CNodePtr>(); | |||
| bool bool_result_tuple = (prev_cnode == nullptr) || (!IsValueNode<Primitive>(prev_cnode->input(0))); | |||
| if (bool_result_tuple) { | |||
| break; | |||
| } | |||
| prev_prim_anf_node = prev_cnode->input(0)->cast<ValueNodePtr>(); | |||
| prev_prim = prev_prim_anf_node->value()->cast<PrimitivePtr>(); | |||
| if (!IsAutoParallelCareNode(prev_cnode)) { | |||
| MS_LOG(EXCEPTION) << "Did not create OperatorInfo for : " << prev_prim->name(); | |||
| } | |||
| MS_LOG(INFO) << "Jumped the 'tuple_getitem' operator, " | |||
| << "and creating an edge between the Operator before " | |||
| << "'tuple_getitem' and the Operator after 'tuple_getitem'."; | |||
| } else if (prev_prim->name() == DEPEND) { | |||
| // In this case, 'prev_anf_node' is 'depend', the actual precursor node is node before | |||
| // this 'depend' | |||
| MS_LOG(INFO) << "Jumping the 'depend' operator."; | |||
| prev_cnode = prev_cnode->input(1)->cast<CNodePtr>(); | |||
| bool bool_result_depend = (prev_cnode == nullptr) || (!IsValueNode<Primitive>(prev_cnode->input(0))); | |||
| if (bool_result_depend) { | |||
| break; | |||
| } | |||
| prev_prim_anf_node = prev_cnode->input(0)->cast<ValueNodePtr>(); | |||
| prev_prim = prev_prim_anf_node->value()->cast<PrimitivePtr>(); | |||
| MS_LOG(INFO) << "Jumped the 'depend' operator, " | |||
| << "and creating an edge between the Operator before " | |||
| << "'depend' and the Operator after 'depend'."; | |||
| } | |||
| } | |||
| } | |||
| MS_LOG(INFO) << "Successfully created " << edge_count << " edges for: " << node_op_info->name(); | |||
| } | |||
| void ConstructCostGraphEdges(const std::vector<AnfNodePtr> &all_nodes) { | |||
| // Step 2 | |||
| MS_LOG(INFO) << "Constructing edges for cost graph begins."; | |||
| @@ -681,68 +743,10 @@ void ConstructCostGraphEdges(const std::vector<AnfNodePtr> &all_nodes) { | |||
| if ((cnode == nullptr) || !IsValueNode<Primitive>(cnode->input(0))) { | |||
| continue; | |||
| } | |||
| auto &inputs = cnode->inputs(); | |||
| ValueNodePtr prim_anf_node = inputs[0]->cast<ValueNodePtr>(); | |||
| if (!IsAutoParallelCareNode(cnode)) { | |||
| continue; | |||
| } | |||
| PrimitivePtr prim = GetValueNode<PrimitivePtr>(prim_anf_node); | |||
| size_t edge_count = 0; | |||
| auto node_op_info = cnode->user_data<OperatorInfo>(); | |||
| for (size_t i = 1; i < inputs.size(); ++i) { | |||
| auto prev_cnode = inputs[i]->cast<CNodePtr>(); | |||
| bool bool_result_prev_cnode = (prev_cnode == nullptr) || (!IsValueNode<Primitive>(prev_cnode->input(0))); | |||
| if (bool_result_prev_cnode) { | |||
| continue; | |||
| } | |||
| ValueNodePtr prev_prim_anf_node = prev_cnode->input(0)->cast<ValueNodePtr>(); | |||
| PrimitivePtr prev_prim = prev_prim_anf_node->value()->cast<PrimitivePtr>(); | |||
| size_t output_index = 0; | |||
| while ((IsAutoParallelCareNode(prev_cnode)) || (prev_prim->name() == prim::kTupleGetItem) || | |||
| (prev_prim->name() == DEPEND)) { | |||
| if (IsAutoParallelCareNode(prev_cnode)) { | |||
| auto prev_op_info = prev_cnode->user_data<OperatorInfo>(); | |||
| CreateEdgeBetweenTwoOps(prev_op_info, node_op_info, cnode, prev_cnode, prim, prev_prim, output_index, i, | |||
| &edge_count); | |||
| break; | |||
| } else if (prev_prim->name() == prim::kTupleGetItem) { | |||
| // In this case, 'prev_anf_node' is 'tuple_getitem', the actual precursor node is node before | |||
| // this 'tuple_getitem' | |||
| MS_LOG(INFO) << "Jumping the 'tuple_getitem' operator."; | |||
| output_index = LongToSize(GetValue<int64_t>(GetValueNode(prev_cnode->input(2)))); | |||
| prev_cnode = prev_cnode->input(1)->cast<CNodePtr>(); | |||
| bool bool_result_tuple = (prev_cnode == nullptr) || (!IsValueNode<Primitive>(prev_cnode->input(0))); | |||
| if (bool_result_tuple) { | |||
| break; | |||
| } | |||
| prev_prim_anf_node = prev_cnode->input(0)->cast<ValueNodePtr>(); | |||
| prev_prim = prev_prim_anf_node->value()->cast<PrimitivePtr>(); | |||
| if (!IsAutoParallelCareNode(prev_cnode)) { | |||
| MS_LOG(EXCEPTION) << "Did not create OperatorInfo for : " << prev_prim->name(); | |||
| } | |||
| MS_LOG(INFO) << "Jumped the 'tuple_getitem' operator, " | |||
| << "and creating an edge between the Operator before " | |||
| << "'tuple_getitem' and the Operator after 'tuple_getitem'."; | |||
| } else if (prev_prim->name() == DEPEND) { | |||
| // In this case, 'prev_anf_node' is 'depend', the actual precursor node is node before | |||
| // this 'depend' | |||
| MS_LOG(INFO) << "Jumping the 'depend' operator."; | |||
| prev_cnode = prev_cnode->input(1)->cast<CNodePtr>(); | |||
| bool bool_result_depend = (prev_cnode == nullptr) || (!IsValueNode<Primitive>(prev_cnode->input(0))); | |||
| if (bool_result_depend) { | |||
| break; | |||
| } | |||
| prev_prim_anf_node = prev_cnode->input(0)->cast<ValueNodePtr>(); | |||
| prev_prim = prev_prim_anf_node->value()->cast<PrimitivePtr>(); | |||
| MS_LOG(INFO) << "Jumped the 'depend' operator, " | |||
| << "and creating an edge between the Operator before " | |||
| << "'depend' and the Operator after 'depend'."; | |||
| } | |||
| } | |||
| } | |||
| MS_LOG(INFO) << "Successfully created " << edge_count << " edges for: " << node_op_info->name(); | |||
| ConstructCNodeCostGraphEdges(cnode); | |||
| } | |||
| CheckAndApplyApproximation(); | |||
| @@ -1,5 +1,5 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * Copyright 2020-2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| @@ -115,7 +115,9 @@ class IrExportBuilder { | |||
| bool SetValueToAttributeProto(const ValuePtr &value, mind_ir::AttributeProto *const attr_proto); | |||
| bool SetTypeToAttributeProto(const ValuePtr &value, mind_ir::AttributeProto *const attr_proto); | |||
| bool SetScalarToAttributeProto_ir(const ValuePtr &value, mind_ir::AttributeProto *const attr_proto); | |||
| bool SetScalarToAttributeProtoForInt_ir(const ValuePtr &value, mind_ir::AttributeProto *const attr_proto); | |||
| bool SetScalarToAttributeProto_irs(const ValuePtr &value, mind_ir::AttributeProto *const attr_proto); | |||
| bool SetScalarToAttributeProtoForInt_irs(const ValuePtr &value, mind_ir::AttributeProto *const attr_proto); | |||
| bool SetTensorToAttributeProto(const ValuePtr &value, mind_ir::AttributeProto *const attr_proto); | |||
| bool SetSequenceToAttributeProto(const ValueSequeuePtr &value, mind_ir::AttributeProto *const attr_proto, | |||
| std::string *const seq_string); | |||
| @@ -831,7 +833,27 @@ bool IrExportBuilder::SetScalarToAttributeProto_ir(const ValuePtr &value, mind_i | |||
| attr_proto->set_type(mind_ir::AttributeProto_AttributeType_BOOL); | |||
| int64_t attr_value = GetValue<bool>(value) ? 1 : 0; | |||
| attr_proto->set_i(attr_value); | |||
| } else if (value->isa<Int8Imm>()) { | |||
| } else if (SetScalarToAttributeProtoForInt_ir(value, attr_proto)) { | |||
| return true; | |||
| } else if (value->isa<FP32Imm>()) { | |||
| attr_proto->set_type(mind_ir::AttributeProto_AttributeType_FLOAT); | |||
| attr_proto->set_f(GetValue<float>(value)); | |||
| } else if (value->isa<FP64Imm>()) { | |||
| attr_proto->set_type(mind_ir::AttributeProto_AttributeType_DOUBLE); | |||
| attr_proto->set_d(GetValue<double>(value)); | |||
| } else if (value->isa<tensor::Tensor>()) { | |||
| attr_proto->set_type(mind_ir::AttributeProto_AttributeType_TENSOR); | |||
| return SetTensorToAttributeProto(value, attr_proto); | |||
| } else { | |||
| MS_LOG(ERROR) << "Unsupported scalar type: " << value->type_name(); | |||
| return false; | |||
| } | |||
| return true; | |||
| } | |||
| bool IrExportBuilder::SetScalarToAttributeProtoForInt_ir(const ValuePtr &value, | |||
| mind_ir::AttributeProto *const attr_proto) { | |||
| if (value->isa<Int8Imm>()) { | |||
| attr_proto->set_type(mind_ir::AttributeProto_AttributeType_INT8); | |||
| attr_proto->set_i(value->cast<Int8ImmPtr>()->value()); | |||
| } else if (value->isa<Int16Imm>()) { | |||
| @@ -855,17 +877,7 @@ bool IrExportBuilder::SetScalarToAttributeProto_ir(const ValuePtr &value, mind_i | |||
| } else if (value->isa<UInt64Imm>()) { | |||
| attr_proto->set_type(mind_ir::AttributeProto_AttributeType_UINT64); | |||
| attr_proto->set_i(UlongToLong(value->cast<UInt64ImmPtr>()->value())); | |||
| } else if (value->isa<FP32Imm>()) { | |||
| attr_proto->set_type(mind_ir::AttributeProto_AttributeType_FLOAT); | |||
| attr_proto->set_f(GetValue<float>(value)); | |||
| } else if (value->isa<FP64Imm>()) { | |||
| attr_proto->set_type(mind_ir::AttributeProto_AttributeType_DOUBLE); | |||
| attr_proto->set_d(GetValue<double>(value)); | |||
| } else if (value->isa<tensor::Tensor>()) { | |||
| attr_proto->set_type(mind_ir::AttributeProto_AttributeType_TENSOR); | |||
| return SetTensorToAttributeProto(value, attr_proto); | |||
| } else { | |||
| MS_LOG(ERROR) << "Unsupported scalar type: " << value->type_name(); | |||
| return false; | |||
| } | |||
| return true; | |||
| @@ -899,7 +911,27 @@ bool IrExportBuilder::SetScalarToAttributeProto_irs(const ValuePtr &value, mind_ | |||
| } else if (value->isa<BoolImm>()) { | |||
| attr_proto->set_type(mind_ir::AttributeProto_AttributeType_BOOL); | |||
| attr_proto->add_ints(GetValue<bool>(value)); | |||
| } else if (value->isa<Int8Imm>()) { | |||
| } else if (SetScalarToAttributeProtoForInt_irs(value, attr_proto)) { | |||
| return true; | |||
| } else if (value->isa<FP32Imm>()) { | |||
| attr_proto->set_type(mind_ir::AttributeProto_AttributeType_FLOAT); | |||
| attr_proto->add_floats(GetValue<float>(value)); | |||
| } else if (value->isa<FP64Imm>()) { | |||
| attr_proto->set_type(mind_ir::AttributeProto_AttributeType_DOUBLE); | |||
| attr_proto->add_doubles(GetValue<double>(value)); | |||
| } else if (value->isa<tensor::Tensor>()) { | |||
| attr_proto->set_type(mind_ir::AttributeProto_AttributeType_TENSOR); | |||
| return SetTensorToAttributeProto(value, attr_proto); | |||
| } else { | |||
| MS_LOG(ERROR) << "Unsupported scalar type: " << value->type_name(); | |||
| return false; | |||
| } | |||
| return true; | |||
| } | |||
| bool IrExportBuilder::SetScalarToAttributeProtoForInt_irs(const ValuePtr &value, | |||
| mind_ir::AttributeProto *const attr_proto) { | |||
| if (value->isa<Int8Imm>()) { | |||
| attr_proto->set_type(mind_ir::AttributeProto_AttributeType_INT8); | |||
| attr_proto->add_ints(value->cast<Int8ImmPtr>()->value()); | |||
| } else if (value->isa<Int16Imm>()) { | |||
| @@ -923,17 +955,7 @@ bool IrExportBuilder::SetScalarToAttributeProto_irs(const ValuePtr &value, mind_ | |||
| } else if (value->isa<UInt64Imm>()) { | |||
| attr_proto->set_type(mind_ir::AttributeProto_AttributeType_UINT64); | |||
| attr_proto->add_ints(SizeToInt(value->cast<UInt64ImmPtr>()->value())); | |||
| } else if (value->isa<FP32Imm>()) { | |||
| attr_proto->set_type(mind_ir::AttributeProto_AttributeType_FLOAT); | |||
| attr_proto->add_floats(GetValue<float>(value)); | |||
| } else if (value->isa<FP64Imm>()) { | |||
| attr_proto->set_type(mind_ir::AttributeProto_AttributeType_DOUBLE); | |||
| attr_proto->add_doubles(GetValue<double>(value)); | |||
| } else if (value->isa<tensor::Tensor>()) { | |||
| attr_proto->set_type(mind_ir::AttributeProto_AttributeType_TENSOR); | |||
| return SetTensorToAttributeProto(value, attr_proto); | |||
| } else { | |||
| MS_LOG(ERROR) << "Unsupported scalar type: " << value->type_name(); | |||
| return false; | |||
| } | |||
| return true; | |||
| @@ -1,5 +1,5 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * Copyright 2020-2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| @@ -16,9 +16,11 @@ | |||
| #include <map> | |||
| #include <memory> | |||
| #include <vector> | |||
| #include <unordered_map> | |||
| #include <utility> | |||
| #include <functional> | |||
| #include <algorithm> | |||
| #include "ir/tensor.h" | |||
| #include "ir/param_info.h" | |||
| @@ -384,6 +386,7 @@ class OnnxExporter { | |||
| void MatchAndMark(const FuncGraphPtr &func_graph, const std::vector<AnfNodePtr> &nodes, | |||
| std::unordered_map<AnfNodePtr, OpMergedInfo> *op_merged_infos_ptr); | |||
| void MatchAndMarkCNode(const CNodePtr &cnode, std::unordered_map<AnfNodePtr, OpMergedInfo> *op_merged_infos_ptr); | |||
| void ExportNodes(const FuncGraphPtr &func_graph, std::map<AnfNodePtr, size_t> *node_map_ptr, | |||
| onnx::GraphProto *graph_proto); | |||
| @@ -600,7 +603,7 @@ void OnnxExporter::SetTensorProtoInfo(const ParameterPtr ¶m, onnx::TensorPro | |||
| void OnnxExporter::MatchAndMark(const FuncGraphPtr &func_graph, const std::vector<AnfNodePtr> &nodes, | |||
| std::unordered_map<AnfNodePtr, OpMergedInfo> *op_merged_infos_ptr) { | |||
| std::unordered_map<AnfNodePtr, OpMergedInfo> &op_merged_infos = *op_merged_infos_ptr; | |||
| auto &op_merged_infos = *op_merged_infos_ptr; | |||
| for (auto &node : nodes) { | |||
| if (!node->isa<CNode>()) { | |||
| @@ -623,36 +626,41 @@ void OnnxExporter::MatchAndMark(const FuncGraphPtr &func_graph, const std::vecto | |||
| // if the key `input` does not exist, just create a new one | |||
| op_merged_infos[input].referred_count += 1; | |||
| } | |||
| // MindSpore Conv + BiasAdd --> ONNX Conv | |||
| if (cnode->IsApply(std::make_shared<Primitive>("BiasAdd")) && | |||
| IsPrimitiveCNode(cnode->input(1), prim::kPrimConv2D)) { | |||
| op_merged_infos[cnode].mode = OP_MERGE_CONV; | |||
| op_merged_infos[cnode->input(1)].mode = OP_MERGE_IGNORE; | |||
| op_merged_infos[cnode->input(1)].referred_count -= 1; | |||
| } else if (cnode->IsApply(std::make_shared<Primitive>("BiasAdd")) && | |||
| IsPrimitiveCNode(cnode->input(1), prim::kPrimMatMul)) { | |||
| op_merged_infos[cnode].mode = OP_MERGE_GEMM; | |||
| op_merged_infos[cnode->input(1)].mode = OP_MERGE_IGNORE; | |||
| op_merged_infos[cnode->input(1)].referred_count -= 1; | |||
| } else if (cnode->IsApply(prim::kPrimTupleGetItem) && | |||
| IsPrimitiveCNode(cnode->input(1), std::make_shared<Primitive>("BatchNorm")) && | |||
| GetInt64Value(cnode->input(2)) == 0) { | |||
| op_merged_infos[cnode].mode = OP_MERGE_BATCH_NORM; | |||
| op_merged_infos[cnode->input(1)].mode = OP_MERGE_IGNORE; | |||
| op_merged_infos[cnode->input(1)].referred_count -= 1; | |||
| } else if (cnode->IsApply(prim::kPrimTupleGetItem) && | |||
| IsPrimitiveCNode(cnode->input(1), std::make_shared<Primitive>("MaxPoolWithArgmax")) && | |||
| GetInt64Value(cnode->input(2)) == 0) { | |||
| op_merged_infos[cnode].mode = OP_MERGE_MAXPOOL_WITH_ARGMAX; | |||
| op_merged_infos[cnode->input(1)].mode = OP_MERGE_IGNORE; | |||
| op_merged_infos[cnode->input(1)].referred_count -= 1; | |||
| } else if (cnode->IsApply(prim::kPrimTupleGetItem) && | |||
| IsPrimitiveCNode(cnode->input(1), std::make_shared<Primitive>("LayerNorm")) && | |||
| GetInt64Value(cnode->input(2)) == 0) { | |||
| op_merged_infos[cnode].mode = OP_MERGE_LAYER_NORM; | |||
| op_merged_infos[cnode->input(1)].mode = OP_MERGE_IGNORE; | |||
| op_merged_infos[cnode->input(1)].referred_count -= 1; | |||
| } | |||
| MatchAndMarkCNode(cnode, op_merged_infos_ptr); | |||
| } | |||
| } | |||
| void OnnxExporter::MatchAndMarkCNode(const CNodePtr &cnode, | |||
| std::unordered_map<AnfNodePtr, OpMergedInfo> *op_merged_infos_ptr) { | |||
| auto &op_merged_infos = *op_merged_infos_ptr; | |||
| // MindSpore Conv + BiasAdd --> ONNX Conv | |||
| if (cnode->IsApply(std::make_shared<Primitive>("BiasAdd")) && IsPrimitiveCNode(cnode->input(1), prim::kPrimConv2D)) { | |||
| op_merged_infos[cnode].mode = OP_MERGE_CONV; | |||
| op_merged_infos[cnode->input(1)].mode = OP_MERGE_IGNORE; | |||
| op_merged_infos[cnode->input(1)].referred_count -= 1; | |||
| } else if (cnode->IsApply(std::make_shared<Primitive>("BiasAdd")) && | |||
| IsPrimitiveCNode(cnode->input(1), prim::kPrimMatMul)) { | |||
| op_merged_infos[cnode].mode = OP_MERGE_GEMM; | |||
| op_merged_infos[cnode->input(1)].mode = OP_MERGE_IGNORE; | |||
| op_merged_infos[cnode->input(1)].referred_count -= 1; | |||
| } else if (cnode->IsApply(prim::kPrimTupleGetItem) && | |||
| IsPrimitiveCNode(cnode->input(1), std::make_shared<Primitive>("BatchNorm")) && | |||
| GetInt64Value(cnode->input(2)) == 0) { | |||
| op_merged_infos[cnode].mode = OP_MERGE_BATCH_NORM; | |||
| op_merged_infos[cnode->input(1)].mode = OP_MERGE_IGNORE; | |||
| op_merged_infos[cnode->input(1)].referred_count -= 1; | |||
| } else if (cnode->IsApply(prim::kPrimTupleGetItem) && | |||
| IsPrimitiveCNode(cnode->input(1), std::make_shared<Primitive>("MaxPoolWithArgmax")) && | |||
| GetInt64Value(cnode->input(2)) == 0) { | |||
| op_merged_infos[cnode].mode = OP_MERGE_MAXPOOL_WITH_ARGMAX; | |||
| op_merged_infos[cnode->input(1)].mode = OP_MERGE_IGNORE; | |||
| op_merged_infos[cnode->input(1)].referred_count -= 1; | |||
| } else if (cnode->IsApply(prim::kPrimTupleGetItem) && | |||
| IsPrimitiveCNode(cnode->input(1), std::make_shared<Primitive>("LayerNorm")) && | |||
| GetInt64Value(cnode->input(2)) == 0) { | |||
| op_merged_infos[cnode].mode = OP_MERGE_LAYER_NORM; | |||
| op_merged_infos[cnode->input(1)].mode = OP_MERGE_IGNORE; | |||
| op_merged_infos[cnode->input(1)].referred_count -= 1; | |||
| } | |||
| } | |||
| @@ -1571,59 +1579,30 @@ void OnnxExporter::ExportPrimGatherV2(const FuncGraphPtr &, const CNodePtr &node | |||
| void OnnxExporter::ExportCNode(const FuncGraphPtr &func_graph, const CNodePtr &node, | |||
| std::map<AnfNodePtr, size_t> *node_map_ptr, onnx::GraphProto *const graph_proto) { | |||
| // Type of the 2nd input of 'Reshape' of MindSpore is tuple, but ONNX's is tensor, need to do some convert | |||
| if (node->IsApply(prim::kPrimReshape)) { | |||
| return ExportPrimReshape(func_graph, node, node_map_ptr, graph_proto); | |||
| } | |||
| if (node->IsApply(prim::kPrimReduceMean) || node->IsApply(prim::kPrimReduceSum)) { | |||
| return ExportPrimReduce(func_graph, node, node_map_ptr, graph_proto); | |||
| } | |||
| if (node->IsApply(prim::kPrimTranspose)) { | |||
| return ExportPrimTranspose(func_graph, node, node_map_ptr, graph_proto); | |||
| } | |||
| if (node->IsApply(prim::kPrimStridedSlice)) { | |||
| return ExportPrimStridedSlice(func_graph, node, node_map_ptr, graph_proto); | |||
| } | |||
| if (node->IsApply(prim::kPrimResizeNearestNeighbor)) { | |||
| return ExportPrimResizeNearestNeighbor(func_graph, node, node_map_ptr, graph_proto); | |||
| } | |||
| if (node->IsApply(prim::kPrimConcat)) { | |||
| return ExportPrimConcat(func_graph, node, node_map_ptr, graph_proto); | |||
| } | |||
| // MindSpore Cast(x, T) --> ONNX Cast[to=T](x) | |||
| if (node->IsApply(prim::kPrimCast)) { | |||
| return ExportPrimCast(func_graph, node, node_map_ptr, graph_proto); | |||
| } | |||
| // ONNX PRelu requires unidirectional broadcasting, here need some process | |||
| if (node->IsApply(std::make_shared<Primitive>("PReLU"))) { | |||
| return ExportPrimPReLU(func_graph, node, node_map_ptr, graph_proto); | |||
| } | |||
| // MindSpore ReLU6(x) --> ONNX Clip[min=0.f, max=6.f](x) | |||
| if (node->IsApply(std::make_shared<Primitive>("ReLU6"))) { | |||
| return ExportPrimReLU6(func_graph, node, node_map_ptr, graph_proto); | |||
| } | |||
| // MindSpore DepthwiseConv2dNative --> ONNX Conv(x, reshape(w)) | |||
| if (node->IsApply(std::make_shared<Primitive>("DepthwiseConv2dNative"))) { | |||
| return ExportPrimDepthwiseConv2d(func_graph, node, node_map_ptr, graph_proto); | |||
| } | |||
| // MindSpore Tile(x) --> ONNX Tile(x, repeat) | |||
| if (node->IsApply(prim::kPrimTile)) { | |||
| return ExportPrimTile(func_graph, node, node_map_ptr, graph_proto); | |||
| } | |||
| // MindSpore Square(x) --> ONNX Pow(x, 2) | |||
| if (node->IsApply(prim::kPrimSquare)) { | |||
| return ExportPrimSquare(func_graph, node, node_map_ptr, graph_proto); | |||
| } | |||
| using ExportFunc = std::function<void(OnnxExporter *, const FuncGraphPtr &, const CNodePtr &, | |||
| std::map<AnfNodePtr, size_t> *, onnx::GraphProto *const)>; | |||
| static std::vector<std::pair<PrimitivePtr, ExportFunc>> export_table = { | |||
| {prim::kPrimReshape, &OnnxExporter::ExportPrimReshape}, | |||
| {prim::kPrimReduceMean, &OnnxExporter::ExportPrimReduce}, | |||
| {prim::kPrimReduceSum, &OnnxExporter::ExportPrimReduce}, | |||
| {prim::kPrimTranspose, &OnnxExporter::ExportPrimTranspose}, | |||
| {prim::kPrimStridedSlice, &OnnxExporter::ExportPrimStridedSlice}, | |||
| {prim::kPrimResizeNearestNeighbor, &OnnxExporter::ExportPrimResizeNearestNeighbor}, | |||
| {prim::kPrimConcat, &OnnxExporter::ExportPrimConcat}, | |||
| {prim::kPrimCast, &OnnxExporter::ExportPrimCast}, | |||
| {prim::kPrimPRelu, &OnnxExporter::ExportPrimPReLU}, | |||
| {prim::kPrimRelu6, &OnnxExporter::ExportPrimReLU6}, | |||
| {prim::kPrimDepthwiseConv2dNative, &OnnxExporter::ExportPrimDepthwiseConv2d}, | |||
| {prim::kPrimTile, &OnnxExporter::ExportPrimTile}, | |||
| {prim::kPrimSquare, &OnnxExporter::ExportPrimSquare}, | |||
| {prim::kPrimGather, &OnnxExporter::ExportPrimGatherV2}, | |||
| }; | |||
| // MindSpore GatherV2(x, indices, axis) --> ONNX Gather(x, indices) | |||
| if (node->IsApply(prim::kPrimGather)) { | |||
| return ExportPrimGatherV2(func_graph, node, node_map_ptr, graph_proto); | |||
| auto iter = std::find_if(export_table.begin(), export_table.end(), | |||
| [&node](const auto &item) { return node->IsApply(item.first); }); | |||
| if (iter != export_table.end()) { | |||
| iter->second(this, func_graph, node, node_map_ptr, graph_proto); | |||
| return; | |||
| } | |||
| auto inputs = node->inputs(); | |||
| @@ -404,14 +404,48 @@ PrimitivePtr GetPrimitiveFromValueNode(const AnfNodePtr &node) { | |||
| return value->cast<PrimitivePtr>(); | |||
| } | |||
| static std::string GetNodeTargetForVarInputNode(const CNodePtr &cnode) { | |||
| auto &inputs = cnode->inputs(); | |||
| std::vector<AnfNodePtr> real_inputs; | |||
| const size_t update_state_valid_input_index = 2; | |||
| const size_t make_tuple_valid_input_index = 1; | |||
| if (cnode->IsApply(prim::kPrimUpdateState) && inputs.size() > update_state_valid_input_index) { | |||
| (void)std::copy(inputs.begin() + SizeToLong(update_state_valid_input_index), inputs.end(), | |||
| std::back_inserter(real_inputs)); | |||
| } else if (cnode->IsApply(prim::kPrimMakeTuple) && inputs.size() > make_tuple_valid_input_index) { | |||
| (void)std::copy(inputs.begin() + SizeToLong(make_tuple_valid_input_index), inputs.end(), | |||
| std::back_inserter(real_inputs)); | |||
| } | |||
| std::string first_input_target = kTargetUnDefined; | |||
| bool has_diff_target = | |||
| std::any_of(std::rbegin(real_inputs), std::rend(real_inputs), [&first_input_target](const AnfNodePtr &n) { | |||
| auto target = GetOriginNodeTarget(n); | |||
| if (target == kTargetUnDefined) { | |||
| return false; | |||
| } | |||
| if (first_input_target == kTargetUnDefined) { | |||
| first_input_target = target; | |||
| } | |||
| return target != first_input_target; | |||
| }); | |||
| if (!has_diff_target) { | |||
| return first_input_target; | |||
| } | |||
| return kTargetUnDefined; | |||
| } | |||
| static inline bool IsSummaryPrimitiveCNode(const AnfNodePtr &node) { | |||
| return IsPrimitiveCNode(node, prim::kPrimImageSummary) || IsPrimitiveCNode(node, prim::kPrimScalarSummary) || | |||
| IsPrimitiveCNode(node, prim::kPrimTensorSummary) || IsPrimitiveCNode(node, prim::kPrimHistogramSummary); | |||
| } | |||
| std::string GetVirtualNodeTargetFromInputs(const AnfNodePtr &node) { | |||
| MS_EXCEPTION_IF_NULL(node); | |||
| auto cnode = node->cast<CNodePtr>(); | |||
| MS_EXCEPTION_IF_NULL(cnode); | |||
| auto &inputs = cnode->inputs(); | |||
| #ifndef ENABLE_SECURITY | |||
| if (IsPrimitiveCNode(node, prim::kPrimImageSummary) || IsPrimitiveCNode(node, prim::kPrimScalarSummary) || | |||
| IsPrimitiveCNode(node, prim::kPrimTensorSummary) || IsPrimitiveCNode(node, prim::kPrimHistogramSummary)) { | |||
| if (IsSummaryPrimitiveCNode(node)) { | |||
| if (inputs.size() > 1) { | |||
| return GetOriginNodeTarget(inputs[1]); | |||
| } | |||
| @@ -428,31 +462,7 @@ std::string GetVirtualNodeTargetFromInputs(const AnfNodePtr &node) { | |||
| return GetOriginNodeTarget(inputs[use_index]); | |||
| } | |||
| } else if (IsPrimitiveCNode(node, prim::kPrimMakeTuple) || IsPrimitiveCNode(node, prim::kPrimUpdateState)) { | |||
| std::vector<AnfNodePtr> real_inputs; | |||
| const size_t update_state_valid_input_index = 2; | |||
| const size_t make_tuple_valid_input_index = 1; | |||
| if (IsPrimitiveCNode(node, prim::kPrimUpdateState) && inputs.size() > update_state_valid_input_index) { | |||
| (void)std::copy(inputs.begin() + SizeToLong(update_state_valid_input_index), inputs.end(), | |||
| std::back_inserter(real_inputs)); | |||
| } else if (IsPrimitiveCNode(node, prim::kPrimMakeTuple) && inputs.size() > make_tuple_valid_input_index) { | |||
| (void)std::copy(inputs.begin() + SizeToLong(make_tuple_valid_input_index), inputs.end(), | |||
| std::back_inserter(real_inputs)); | |||
| } | |||
| std::string first_input_target = kTargetUnDefined; | |||
| bool has_diff_target = | |||
| std::any_of(std::rbegin(real_inputs), std::rend(real_inputs), [&first_input_target](const AnfNodePtr &n) { | |||
| auto target = GetOriginNodeTarget(n); | |||
| if (target == kTargetUnDefined) { | |||
| return false; | |||
| } | |||
| if (first_input_target == kTargetUnDefined) { | |||
| first_input_target = target; | |||
| } | |||
| return target != first_input_target; | |||
| }); | |||
| if (!has_diff_target) { | |||
| return first_input_target; | |||
| } | |||
| return GetNodeTargetForVarInputNode(node->cast<CNodePtr>()); | |||
| } else if (IsPrimitiveCNode(node, prim::kPrimTupleGetItem)) { | |||
| return GetOriginNodeTarget(cnode->input(1)); | |||
| } | |||