Browse Source

!3216 Add op mapping attr for those opt pass worked in LeNet

Merge pull request !3216 from huanghui/add-op-mapping-attrs-for-low-level-opt-pass
tags/v0.6.0-beta
mindspore-ci-bot Gitee 5 years ago
parent
commit
11145b0987
3 changed files with 12 additions and 0 deletions
  1. +3
    -0
      mindspore/ccsrc/backend/optimizer/ascend/ascend_helper.cc
  2. +8
    -0
      mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/derelu_fusion.cc
  3. +1
    -0
      mindspore/ccsrc/utils/utils.h

+ 3
- 0
mindspore/ccsrc/backend/optimizer/ascend/ascend_helper.cc View File

@@ -212,6 +212,7 @@ CNodePtr NewTransOpNode(const FuncGraphPtr &func_graph, const AnfNodePtr &input,
MS_EXCEPTION_IF_NULL(kernel_select);
kernel_select->SelectKernel(trans_node);
AnfAlgo::SetNodeAttr(kAttrVisited, MakeValue(true), trans_node);
AnfAlgo::SetNodeAttr(kAttrDatadumpOriginalNames, MakeValue<std::vector<std::string>>({}), trans_node);
MS_EXCEPTION_IF_NULL(trans_node);
trans_node->set_scope(input->scope());
return trans_node;
@@ -250,6 +251,7 @@ AnfNodePtr AddCastOpNodeToGraph(const FuncGraphPtr &func_graph, const AnfNodePtr
AnfAlgo::SetSelectKernelBuildInfo(builder.Build(), cast.get());
AnfAlgo::SetOutputInferTypeAndShape({origin_type}, {origin_shape}, cast.get());
AnfAlgo::SetNodeAttr(kIsBackendCast, MakeValue(true), cast);
AnfAlgo::SetNodeAttr(kAttrDatadumpOriginalNames, MakeValue<std::vector<std::string>>({}), cast);
return cast;
}

@@ -354,6 +356,7 @@ AnfNodePtr CreateMemcpyAsyncOp(const FuncGraphPtr &graph, const AnfNodePtr &node
MS_EXCEPTION_IF_NULL(new_node);
new_node->set_abstract(node->abstract());
new_node->set_scope(node->scope());
AnfAlgo::SetNodeAttr(kAttrDatadumpOriginalNames, MakeValue<std::vector<std::string>>({}), new_node);
return new_node;
}
} // namespace opt


+ 8
- 0
mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/derelu_fusion.cc View File

@@ -15,6 +15,7 @@
*/
#include "backend/optimizer/ascend/ir_fusion/derelu_fusion.h"
#include <memory>
#include <string>
#include <vector>
#include "backend/session/anf_runtime_algorithm.h"
#include "ir/primitive.h"
@@ -111,6 +112,13 @@ const AnfNodePtr DereluFusion::Process(const FuncGraphPtr &graph, const AnfNodeP
CreateMultipleOutputsOfAnfNode(graph, relu_v2, kReluV2OutputNum, &relu_v2_node_outputs);

auto relu_grad_v2 = CreateReluGradV2(graph, relu_grad, relu_v2_node_outputs[1]);
// Add attr mapping from original nodes to fusion nodes
auto original_names =
MakeValue<std::vector<std::string>>({relu->fullname_with_scope(), relu_grad->fullname_with_scope()});
AnfAlgo::SetNodeAttr(kAttrDatadumpOriginalNames, original_names, relu_v2);
AnfAlgo::SetNodeAttr(kAttrDatadumpOriginalNames, original_names, relu_grad_v2);
AnfAlgo::SetNodeAttr(kAttrDatadumpIsMultiop, MakeValue(true), relu_v2);
AnfAlgo::SetNodeAttr(kAttrDatadumpIsMultiop, MakeValue(true), relu_grad_v2);

auto manage = graph->manager();
MS_EXCEPTION_IF_NULL(manage);


+ 1
- 0
mindspore/ccsrc/utils/utils.h View File

@@ -228,6 +228,7 @@ constexpr auto kAttrLabelSwitchList = "label_switch_list";
constexpr auto kAttrNewAxisMask = "new_axis_mask";
constexpr auto kAttrShrinkAxisMask = "shrink_axis_mask";
constexpr auto kAttrDatadumpOriginalNames = "_datadump_original_names";
constexpr auto kAttrDatadumpIsMultiop = "_datadump_is_multiop";
constexpr auto kAttrStreamId = "stream_id";
constexpr auto kAttrRecordEvent = "record_event";
constexpr auto kAttrWaitEvent = "wait_event";


Loading…
Cancel
Save