Browse Source

!10522 Add check LabelSwitch op in if_by_if case

From: @liangzelang
Reviewed-by: @zhoufeng54,@kisnwang
Signed-off-by: @kisnwang
tags/v1.2.0-rc1
mindspore-ci-bot Gitee 5 years ago
parent
commit
e52173f631
3 changed files with 23 additions and 4 deletions
  1. +19
    -0
      mindspore/ccsrc/backend/session/anf_runtime_algorithm.cc
  2. +2
    -0
      mindspore/ccsrc/backend/session/anf_runtime_algorithm.h
  3. +2
    -4
      mindspore/ccsrc/backend/session/ascend_session.cc

+ 19
- 0
mindspore/ccsrc/backend/session/anf_runtime_algorithm.cc View File

@@ -1006,6 +1006,25 @@ bool AnfRuntimeAlgorithm::IsParameterWeight(const ParameterPtr &node) {
return node->has_default(); return node->has_default();
} }


bool AnfRuntimeAlgorithm::IsLabelIndexInNode(const AnfNodePtr &node, size_t label_index) {
MS_EXCEPTION_IF_NULL(node);
if (!node->isa<CNode>()) {
return false;
}
auto cnode = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
if (AnfAlgo::GetCNodeName(cnode) == kLabelGotoOpName &&
(AnfAlgo::GetNodeAttr<uint32_t>(cnode, kAttrLabelIndex) == label_index)) {
return true;
} else if (AnfAlgo::GetCNodeName(cnode) == kLabelSwitchOpName) {
auto label_list = AnfAlgo::GetNodeAttr<std::vector<uint32_t>>(cnode, kAttrLabelSwitchList);
if (std::find(label_list.begin(), label_list.end(), label_index) != label_list.end()) {
return true;
}
}
return false;
}

void AnfRuntimeAlgorithm::SetStreamId(uint32_t stream_id, AnfNode *node) { void AnfRuntimeAlgorithm::SetStreamId(uint32_t stream_id, AnfNode *node) {
MS_EXCEPTION_IF_NULL(node); MS_EXCEPTION_IF_NULL(node);
auto kernel_info = static_cast<device::KernelInfo *>(node->kernel_info()); auto kernel_info = static_cast<device::KernelInfo *>(node->kernel_info());


+ 2
- 0
mindspore/ccsrc/backend/session/anf_runtime_algorithm.h View File

@@ -188,6 +188,8 @@ class AnfRuntimeAlgorithm {
static bool IsNodeInGraphKernel(const AnfNodePtr &node); static bool IsNodeInGraphKernel(const AnfNodePtr &node);
// check parameter is weight or data // check parameter is weight or data
static bool IsParameterWeight(const ParameterPtr &node); static bool IsParameterWeight(const ParameterPtr &node);
// checkout whether the anf node is include the label_index.
static bool IsLabelIndexInNode(const AnfNodePtr &node, size_t label_index);
// set stream id of kernel,which will be set in stream assign and be used in stream generate // set stream id of kernel,which will be set in stream assign and be used in stream generate
static void SetStreamId(uint32_t stream_id, AnfNode *node); static void SetStreamId(uint32_t stream_id, AnfNode *node);
// get stream id // get stream id


+ 2
- 4
mindspore/ccsrc/backend/session/ascend_session.cc View File

@@ -1233,7 +1233,7 @@ void AscendSession::MultiCallGraphOptimize(NotNull<KernelGraphPtr> root_graph) {
MS_LOG(INFO) << "graph: " << graph->graph_id() << " has been called by more than two graphs"; MS_LOG(INFO) << "graph: " << graph->graph_id() << " has been called by more than two graphs";
int32_t index = 0; int32_t index = 0;
std::vector<KernelGraphPtr> child_graphs; std::vector<KernelGraphPtr> child_graphs;
auto start_label = graph->get_start_label();
auto start_label_id = AnfAlgo::GetNodeAttr<uint32_t>(graph->get_start_label(), kAttrLabelIndex);
auto end_node = graph->get_end_goto(); auto end_node = graph->get_end_goto();
ParameterPtr post_label_param = graph->AddExtraParamAndTensor("label_param", 0); ParameterPtr post_label_param = graph->AddExtraParamAndTensor("label_param", 0);
std::vector<AnfNodePtr> new_inputs = {std::make_shared<ValueNode>(std::make_shared<Primitive>(kLabelSwitchOpName)), std::vector<AnfNodePtr> new_inputs = {std::make_shared<ValueNode>(std::make_shared<Primitive>(kLabelSwitchOpName)),
@@ -1242,9 +1242,7 @@ void AscendSession::MultiCallGraphOptimize(NotNull<KernelGraphPtr> root_graph) {
auto kg = graphs_[graph_id]; auto kg = graphs_[graph_id];
auto nodes = kg->execution_order(); auto nodes = kg->execution_order();
for (uint32_t i = 0; i < nodes.size(); i++) { for (uint32_t i = 0; i < nodes.size(); i++) {
if (AnfAlgo::GetCNodeName(nodes[i]) == kLabelGotoOpName &&
(AnfAlgo::GetNodeAttr<uint32_t>(nodes[i], kAttrLabelIndex) ==
AnfAlgo::GetNodeAttr<uint32_t>(start_label, kAttrLabelIndex))) {
if (AnfAlgo::IsLabelIndexInNode(nodes[i], start_label_id)) {
if (i < (nodes.size() - 1)) { if (i < (nodes.size() - 1)) {
new_inputs.push_back(nodes[i + 1]); new_inputs.push_back(nodes[i + 1]);
} else { } else {


Loading…
Cancel
Save