Merge pull request !3216 from huanghui/add-op-mapping-attrs-for-low-level-opt-passtags/v0.6.0-beta
| @@ -212,6 +212,7 @@ CNodePtr NewTransOpNode(const FuncGraphPtr &func_graph, const AnfNodePtr &input, | |||||
| MS_EXCEPTION_IF_NULL(kernel_select); | MS_EXCEPTION_IF_NULL(kernel_select); | ||||
| kernel_select->SelectKernel(trans_node); | kernel_select->SelectKernel(trans_node); | ||||
| AnfAlgo::SetNodeAttr(kAttrVisited, MakeValue(true), trans_node); | AnfAlgo::SetNodeAttr(kAttrVisited, MakeValue(true), trans_node); | ||||
| AnfAlgo::SetNodeAttr(kAttrDatadumpOriginalNames, MakeValue<std::vector<std::string>>({}), trans_node); | |||||
| MS_EXCEPTION_IF_NULL(trans_node); | MS_EXCEPTION_IF_NULL(trans_node); | ||||
| trans_node->set_scope(input->scope()); | trans_node->set_scope(input->scope()); | ||||
| return trans_node; | return trans_node; | ||||
| @@ -250,6 +251,7 @@ AnfNodePtr AddCastOpNodeToGraph(const FuncGraphPtr &func_graph, const AnfNodePtr | |||||
| AnfAlgo::SetSelectKernelBuildInfo(builder.Build(), cast.get()); | AnfAlgo::SetSelectKernelBuildInfo(builder.Build(), cast.get()); | ||||
| AnfAlgo::SetOutputInferTypeAndShape({origin_type}, {origin_shape}, cast.get()); | AnfAlgo::SetOutputInferTypeAndShape({origin_type}, {origin_shape}, cast.get()); | ||||
| AnfAlgo::SetNodeAttr(kIsBackendCast, MakeValue(true), cast); | AnfAlgo::SetNodeAttr(kIsBackendCast, MakeValue(true), cast); | ||||
| AnfAlgo::SetNodeAttr(kAttrDatadumpOriginalNames, MakeValue<std::vector<std::string>>({}), cast); | |||||
| return cast; | return cast; | ||||
| } | } | ||||
| @@ -354,6 +356,7 @@ AnfNodePtr CreateMemcpyAsyncOp(const FuncGraphPtr &graph, const AnfNodePtr &node | |||||
| MS_EXCEPTION_IF_NULL(new_node); | MS_EXCEPTION_IF_NULL(new_node); | ||||
| new_node->set_abstract(node->abstract()); | new_node->set_abstract(node->abstract()); | ||||
| new_node->set_scope(node->scope()); | new_node->set_scope(node->scope()); | ||||
| AnfAlgo::SetNodeAttr(kAttrDatadumpOriginalNames, MakeValue<std::vector<std::string>>({}), new_node); | |||||
| return new_node; | return new_node; | ||||
| } | } | ||||
| } // namespace opt | } // namespace opt | ||||
| @@ -15,6 +15,7 @@ | |||||
| */ | */ | ||||
| #include "backend/optimizer/ascend/ir_fusion/derelu_fusion.h" | #include "backend/optimizer/ascend/ir_fusion/derelu_fusion.h" | ||||
| #include <memory> | #include <memory> | ||||
| #include <string> | |||||
| #include <vector> | #include <vector> | ||||
| #include "backend/session/anf_runtime_algorithm.h" | #include "backend/session/anf_runtime_algorithm.h" | ||||
| #include "ir/primitive.h" | #include "ir/primitive.h" | ||||
| @@ -111,6 +112,13 @@ const AnfNodePtr DereluFusion::Process(const FuncGraphPtr &graph, const AnfNodeP | |||||
| CreateMultipleOutputsOfAnfNode(graph, relu_v2, kReluV2OutputNum, &relu_v2_node_outputs); | CreateMultipleOutputsOfAnfNode(graph, relu_v2, kReluV2OutputNum, &relu_v2_node_outputs); | ||||
| auto relu_grad_v2 = CreateReluGradV2(graph, relu_grad, relu_v2_node_outputs[1]); | auto relu_grad_v2 = CreateReluGradV2(graph, relu_grad, relu_v2_node_outputs[1]); | ||||
| // Add attr mapping from original nodes to fusion nodes | |||||
| auto original_names = | |||||
| MakeValue<std::vector<std::string>>({relu->fullname_with_scope(), relu_grad->fullname_with_scope()}); | |||||
| AnfAlgo::SetNodeAttr(kAttrDatadumpOriginalNames, original_names, relu_v2); | |||||
| AnfAlgo::SetNodeAttr(kAttrDatadumpOriginalNames, original_names, relu_grad_v2); | |||||
| AnfAlgo::SetNodeAttr(kAttrDatadumpIsMultiop, MakeValue(true), relu_v2); | |||||
| AnfAlgo::SetNodeAttr(kAttrDatadumpIsMultiop, MakeValue(true), relu_grad_v2); | |||||
| auto manage = graph->manager(); | auto manage = graph->manager(); | ||||
| MS_EXCEPTION_IF_NULL(manage); | MS_EXCEPTION_IF_NULL(manage); | ||||
| @@ -228,6 +228,7 @@ constexpr auto kAttrLabelSwitchList = "label_switch_list"; | |||||
| constexpr auto kAttrNewAxisMask = "new_axis_mask"; | constexpr auto kAttrNewAxisMask = "new_axis_mask"; | ||||
| constexpr auto kAttrShrinkAxisMask = "shrink_axis_mask"; | constexpr auto kAttrShrinkAxisMask = "shrink_axis_mask"; | ||||
| constexpr auto kAttrDatadumpOriginalNames = "_datadump_original_names"; | constexpr auto kAttrDatadumpOriginalNames = "_datadump_original_names"; | ||||
| constexpr auto kAttrDatadumpIsMultiop = "_datadump_is_multiop"; | |||||
| constexpr auto kAttrStreamId = "stream_id"; | constexpr auto kAttrStreamId = "stream_id"; | ||||
| constexpr auto kAttrRecordEvent = "record_event"; | constexpr auto kAttrRecordEvent = "record_event"; | ||||
| constexpr auto kAttrWaitEvent = "wait_event"; | constexpr auto kAttrWaitEvent = "wait_event"; | ||||