|
|
|
@@ -26,6 +26,7 @@ |
|
|
|
#include "backend/session/anf_runtime_algorithm.h"
|
|
|
|
#include "backend/session/kernel_graph.h"
|
|
|
|
#include "utils/utils.h"
|
|
|
|
#include "utils/ms_context.h"
|
|
|
|
#include "backend/kernel_compiler/common_utils.h"
|
|
|
|
|
|
|
|
namespace mindspore {
|
|
|
|
@@ -129,13 +130,17 @@ bool InsertCastCPU::Run(const FuncGraphPtr &func_graph) { |
|
|
|
InsertCast(func_graph, cnode);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
AnfNodePtrList outputs;
|
|
|
|
kernel::GetFuncGraphOutputNodes(func_graph, &outputs);
|
|
|
|
auto func_output = func_graph->output();
|
|
|
|
for (auto node : outputs) {
|
|
|
|
if (node != nullptr && node->isa<CNode>() && AnfAlgo::IsRealKernel(node)) {
|
|
|
|
auto cnode = node->cast<CNodePtr>();
|
|
|
|
InsertCastForGraphOutput(func_graph, cnode, func_output);
|
|
|
|
auto ms_context = MsContext::GetInstance();
|
|
|
|
MS_EXCEPTION_IF_NULL(ms_context);
|
|
|
|
if (ms_context->get_param<int>(MS_CTX_EXECUTION_MODE) != kPynativeMode) {
|
|
|
|
AnfNodePtrList outputs;
|
|
|
|
kernel::GetFuncGraphOutputNodes(func_graph, &outputs);
|
|
|
|
auto func_output = func_graph->output();
|
|
|
|
for (auto node : outputs) {
|
|
|
|
if (node != nullptr && node->isa<CNode>() && AnfAlgo::IsRealKernel(node)) {
|
|
|
|
auto cnode = node->cast<CNodePtr>();
|
|
|
|
InsertCastForGraphOutput(func_graph, cnode, func_output);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
return true;
|
|
|
|
|