Browse Source

!6541 fix bug of inset transdata in pynative mode

Merge pull request !6541 from lianliguang/fix-bug-of-insert-transdata-of-pynative
tags/v1.0.0
mindspore-ci-bot Gitee 5 years ago
parent
commit
564b99e549
3 changed files with 13 additions and 4 deletions
  1. +5
    -4
      mindspore/ccsrc/backend/optimizer/ascend/format_type/insert_trans_op.cc
  2. +6
    -0
      mindspore/ccsrc/backend/optimizer/common/helper.cc
  3. +2
    -0
      mindspore/ccsrc/backend/optimizer/common/helper.h

+ 5
- 4
mindspore/ccsrc/backend/optimizer/ascend/format_type/insert_trans_op.cc View File

@@ -18,6 +18,7 @@
#include <memory> #include <memory>
#include <vector> #include <vector>
#include "utils/utils.h" #include "utils/utils.h"
#include "backend/optimizer/common/helper.h"
#include "backend/optimizer/ascend/ascend_helper.h" #include "backend/optimizer/ascend/ascend_helper.h"
#include "backend/session/anf_runtime_algorithm.h" #include "backend/session/anf_runtime_algorithm.h"
#include "utils/ms_context.h" #include "utils/ms_context.h"
@@ -30,12 +31,12 @@ const BaseRef InsertTransOp::DefinePattern() const {
return VectorRef({V, Xs}); return VectorRef({V, Xs});
} }


bool IsGraphOutput(const AnfNodePtr &node, const std::vector<AnfNodePtr> &outputs) {
bool IsGraphOutput(const AnfNodePtr &node, const FuncGraphPtr &func_graph) {
auto outputs = AnfAlgo::GetAllOutput(func_graph->output(), {prim::kPrimTupleGetItem});
auto iter = std::find(outputs.begin(), outputs.end(), node); auto iter = std::find(outputs.begin(), outputs.end(), node);
if (iter != outputs.end()) {
if (iter != outputs.end() && GetRealNodeNum(func_graph, node) == 1) {
return true; return true;
} }

return false; return false;
} }


@@ -55,7 +56,7 @@ const AnfNodePtr InsertTransOp::Process(const FuncGraphPtr &func_graph, const An
MS_EXCEPTION_IF_NULL(ms_context); MS_EXCEPTION_IF_NULL(ms_context);
if (ms_context->get_param<int>(MS_CTX_EXECUTION_MODE) == kPynativeMode && if (ms_context->get_param<int>(MS_CTX_EXECUTION_MODE) == kPynativeMode &&
!ms_context->get_param<bool>(MS_CTX_ENABLE_PYNATIVE_HOOK)) { !ms_context->get_param<bool>(MS_CTX_ENABLE_PYNATIVE_HOOK)) {
if (IsGraphOutput(node, AnfAlgo::GetAllOutput(func_graph->output(), {prim::kPrimTupleGetItem}))) {
if (IsGraphOutput(node, func_graph)) {
return new_node; return new_node;
} }
} }


+ 6
- 0
mindspore/ccsrc/backend/optimizer/common/helper.cc View File

@@ -485,6 +485,12 @@ void RemoveNopNode(session::KernelGraph *const graph) {
} }
} }


size_t GetRealNodeNum(const FuncGraphPtr &graph, const AnfNodePtr &node) {
auto out_list = GetRealNodeUsedList(graph, node);
MS_EXCEPTION_IF_NULL(out_list);
return out_list->size();
}

std::shared_ptr<std::vector<std::pair<AnfNodePtr, int>>> GetRealNodeUsedList(const FuncGraphPtr &graph, std::shared_ptr<std::vector<std::pair<AnfNodePtr, int>>> GetRealNodeUsedList(const FuncGraphPtr &graph,
const AnfNodePtr &node) { const AnfNodePtr &node) {
auto output_node_list = std::make_shared<std::vector<std::pair<AnfNodePtr, int>>>(); auto output_node_list = std::make_shared<std::vector<std::pair<AnfNodePtr, int>>>();


+ 2
- 0
mindspore/ccsrc/backend/optimizer/common/helper.h View File

@@ -172,6 +172,8 @@ bool IsUsedByOthers(const FuncGraphPtr &graph, const AnfNodePtr &node);
std::shared_ptr<std::vector<std::pair<AnfNodePtr, int>>> GetRealNodeUsedList(const FuncGraphPtr &graph, std::shared_ptr<std::vector<std::pair<AnfNodePtr, int>>> GetRealNodeUsedList(const FuncGraphPtr &graph,
const AnfNodePtr &node); const AnfNodePtr &node);


size_t GetRealNodeNum(const FuncGraphPtr &graph, const AnfNodePtr &node);

std::shared_ptr<std::vector<std::pair<AnfNodePtr, int>>> GetRealNodeUsedListByOutputIdx(const FuncGraphPtr &graph, std::shared_ptr<std::vector<std::pair<AnfNodePtr, int>>> GetRealNodeUsedListByOutputIdx(const FuncGraphPtr &graph,
const AnfNodePtr &node, const AnfNodePtr &node,
size_t output_index); size_t output_index);


Loading…
Cancel
Save