Browse Source

!22978 fix insert cast in cpu pynative mode

Merge pull request !22978 from baihuawei/fix_pynative_insert_cast
tags/v1.5.0-rc1
i-robot Gitee 4 years ago
parent
commit
cbce050af6
1 changed files with 12 additions and 7 deletions
  1. +12
    -7
      mindspore/ccsrc/backend/optimizer/cpu/insert_cast_cpu.cc

+ 12
- 7
mindspore/ccsrc/backend/optimizer/cpu/insert_cast_cpu.cc View File

@@ -26,6 +26,7 @@
#include "backend/session/anf_runtime_algorithm.h"
#include "backend/session/kernel_graph.h"
#include "utils/utils.h"
#include "utils/ms_context.h"
#include "backend/kernel_compiler/common_utils.h"
namespace mindspore {
@@ -129,13 +130,17 @@ bool InsertCastCPU::Run(const FuncGraphPtr &func_graph) {
InsertCast(func_graph, cnode);
}
}
AnfNodePtrList outputs;
kernel::GetFuncGraphOutputNodes(func_graph, &outputs);
auto func_output = func_graph->output();
for (auto node : outputs) {
if (node != nullptr && node->isa<CNode>() && AnfAlgo::IsRealKernel(node)) {
auto cnode = node->cast<CNodePtr>();
InsertCastForGraphOutput(func_graph, cnode, func_output);
auto ms_context = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(ms_context);
if (ms_context->get_param<int>(MS_CTX_EXECUTION_MODE) != kPynativeMode) {
AnfNodePtrList outputs;
kernel::GetFuncGraphOutputNodes(func_graph, &outputs);
auto func_output = func_graph->output();
for (auto node : outputs) {
if (node != nullptr && node->isa<CNode>() && AnfAlgo::IsRealKernel(node)) {
auto cnode = node->cast<CNodePtr>();
InsertCastForGraphOutput(func_graph, cnode, func_output);
}
}
}
return true;


Loading…
Cancel
Save