|
|
|
@@ -16,11 +16,13 @@ |
|
|
|
|
|
|
|
#include "pre_activate/ascend/format_type/insert_trans_op.h" |
|
|
|
#include <memory> |
|
|
|
#include <vector> |
|
|
|
#include "utils/utils.h" |
|
|
|
#include "pre_activate/ascend/ascend_helper.h" |
|
|
|
#include "session/anf_runtime_algorithm.h" |
|
|
|
#include "device/kernel_info.h" |
|
|
|
#include "kernel/oplib/oplib.h" |
|
|
|
#include "utils/context/ms_context.h" |
|
|
|
|
|
|
|
namespace mindspore { |
|
|
|
namespace opt { |
|
|
|
@@ -30,6 +32,15 @@ const BaseRef InsertTransOp::DefinePattern() const { |
|
|
|
return VectorRef({V, Xs}); |
|
|
|
} |
|
|
|
|
|
|
|
bool IsGraphOutput(const AnfNodePtr &node, const std::vector<AnfNodePtr> &outputs) { |
|
|
|
auto iter = std::find(outputs.begin(), outputs.end(), node); |
|
|
|
if (iter != outputs.end()) { |
|
|
|
return true; |
|
|
|
} |
|
|
|
|
|
|
|
return false; |
|
|
|
} |
|
|
|
|
|
|
|
const AnfNodePtr InsertTransOp::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, |
|
|
|
const EquivPtr &) const { |
|
|
|
if (node == nullptr || !AnfAlgo::IsRealKernel(node)) { |
|
|
|
@@ -38,6 +49,13 @@ const AnfNodePtr InsertTransOp::Process(const FuncGraphPtr &func_graph, const An |
|
|
|
AnfAlgo::SetNodeAttr(kAttrVisited, MakeValue(true), node); |
|
|
|
MS_LOG(DEBUG) << "====process op: " << node->DebugString(); |
|
|
|
AnfNodePtr new_node = InsertTransOpForInput(func_graph, node, kernel_select_); |
|
|
|
auto ms_context = MsContext::GetInstance(); |
|
|
|
MS_EXCEPTION_IF_NULL(ms_context); |
|
|
|
if (ms_context->execution_mode() == kPynativeMode) { |
|
|
|
if (IsGraphOutput(node, AnfAlgo::GetAllOutput(func_graph->output(), {prim::kPrimTupleGetItem}))) { |
|
|
|
return new_node; |
|
|
|
} |
|
|
|
} |
|
|
|
return InsertTransOpForOutput(func_graph, new_node, kernel_select_); |
|
|
|
} |
|
|
|
} // namespace opt |
|
|
|
|