Browse Source

!6541 fix bug of inset transdata in pynative mode

Merge pull request !6541 from lianliguang/fix-bug-of-insert-transdata-of-pynative
tags/v1.0.0
mindspore-ci-bot Gitee 5 years ago
parent
commit
564b99e549
3 changed files with 13 additions and 4 deletions
  1. +5
    -4
      mindspore/ccsrc/backend/optimizer/ascend/format_type/insert_trans_op.cc
  2. +6
    -0
      mindspore/ccsrc/backend/optimizer/common/helper.cc
  3. +2
    -0
      mindspore/ccsrc/backend/optimizer/common/helper.h

+ 5
- 4
mindspore/ccsrc/backend/optimizer/ascend/format_type/insert_trans_op.cc View File

@@ -18,6 +18,7 @@
#include <memory>
#include <vector>
#include "utils/utils.h"
#include "backend/optimizer/common/helper.h"
#include "backend/optimizer/ascend/ascend_helper.h"
#include "backend/session/anf_runtime_algorithm.h"
#include "utils/ms_context.h"
@@ -30,12 +31,12 @@ const BaseRef InsertTransOp::DefinePattern() const {
return VectorRef({V, Xs});
}

bool IsGraphOutput(const AnfNodePtr &node, const std::vector<AnfNodePtr> &outputs) {
bool IsGraphOutput(const AnfNodePtr &node, const FuncGraphPtr &func_graph) {
auto outputs = AnfAlgo::GetAllOutput(func_graph->output(), {prim::kPrimTupleGetItem});
auto iter = std::find(outputs.begin(), outputs.end(), node);
if (iter != outputs.end()) {
if (iter != outputs.end() && GetRealNodeNum(func_graph, node) == 1) {
return true;
}

return false;
}

@@ -55,7 +56,7 @@ const AnfNodePtr InsertTransOp::Process(const FuncGraphPtr &func_graph, const An
MS_EXCEPTION_IF_NULL(ms_context);
if (ms_context->get_param<int>(MS_CTX_EXECUTION_MODE) == kPynativeMode &&
!ms_context->get_param<bool>(MS_CTX_ENABLE_PYNATIVE_HOOK)) {
if (IsGraphOutput(node, AnfAlgo::GetAllOutput(func_graph->output(), {prim::kPrimTupleGetItem}))) {
if (IsGraphOutput(node, func_graph)) {
return new_node;
}
}


+ 6
- 0
mindspore/ccsrc/backend/optimizer/common/helper.cc View File

@@ -485,6 +485,12 @@ void RemoveNopNode(session::KernelGraph *const graph) {
}
}

size_t GetRealNodeNum(const FuncGraphPtr &graph, const AnfNodePtr &node) {
auto out_list = GetRealNodeUsedList(graph, node);
MS_EXCEPTION_IF_NULL(out_list);
return out_list->size();
}

std::shared_ptr<std::vector<std::pair<AnfNodePtr, int>>> GetRealNodeUsedList(const FuncGraphPtr &graph,
const AnfNodePtr &node) {
auto output_node_list = std::make_shared<std::vector<std::pair<AnfNodePtr, int>>>();


+ 2
- 0
mindspore/ccsrc/backend/optimizer/common/helper.h View File

@@ -172,6 +172,8 @@ bool IsUsedByOthers(const FuncGraphPtr &graph, const AnfNodePtr &node);
std::shared_ptr<std::vector<std::pair<AnfNodePtr, int>>> GetRealNodeUsedList(const FuncGraphPtr &graph,
const AnfNodePtr &node);

size_t GetRealNodeNum(const FuncGraphPtr &graph, const AnfNodePtr &node);

std::shared_ptr<std::vector<std::pair<AnfNodePtr, int>>> GetRealNodeUsedListByOutputIdx(const FuncGraphPtr &graph,
const AnfNodePtr &node,
size_t output_index);


Loading…
Cancel
Save