|
|
|
@@ -18,6 +18,7 @@ |
|
|
|
#include <memory> |
|
|
|
#include <vector> |
|
|
|
#include "utils/utils.h" |
|
|
|
#include "backend/optimizer/common/helper.h" |
|
|
|
#include "backend/optimizer/ascend/ascend_helper.h" |
|
|
|
#include "backend/session/anf_runtime_algorithm.h" |
|
|
|
#include "utils/ms_context.h" |
|
|
|
@@ -30,12 +31,12 @@ const BaseRef InsertTransOp::DefinePattern() const { |
|
|
|
return VectorRef({V, Xs}); |
|
|
|
} |
|
|
|
|
|
|
|
bool IsGraphOutput(const AnfNodePtr &node, const std::vector<AnfNodePtr> &outputs) { |
|
|
|
bool IsGraphOutput(const AnfNodePtr &node, const FuncGraphPtr &func_graph) { |
|
|
|
auto outputs = AnfAlgo::GetAllOutput(func_graph->output(), {prim::kPrimTupleGetItem}); |
|
|
|
auto iter = std::find(outputs.begin(), outputs.end(), node); |
|
|
|
if (iter != outputs.end()) { |
|
|
|
if (iter != outputs.end() && GetRealNodeNum(func_graph, node) == 1) { |
|
|
|
return true; |
|
|
|
} |
|
|
|
|
|
|
|
return false; |
|
|
|
} |
|
|
|
|
|
|
|
@@ -55,7 +56,7 @@ const AnfNodePtr InsertTransOp::Process(const FuncGraphPtr &func_graph, const An |
|
|
|
MS_EXCEPTION_IF_NULL(ms_context); |
|
|
|
if (ms_context->get_param<int>(MS_CTX_EXECUTION_MODE) == kPynativeMode && |
|
|
|
!ms_context->get_param<bool>(MS_CTX_ENABLE_PYNATIVE_HOOK)) { |
|
|
|
if (IsGraphOutput(node, AnfAlgo::GetAllOutput(func_graph->output(), {prim::kPrimTupleGetItem}))) { |
|
|
|
if (IsGraphOutput(node, func_graph)) { |
|
|
|
return new_node; |
|
|
|
} |
|
|
|
} |
|
|
|
|