Browse Source

inset cast

tags/v1.4.0
baihuawei 4 years ago
parent
commit
20b3942ca6
1 changed files with 34 additions and 0 deletions
  1. +34
    -0
      mindspore/ccsrc/backend/optimizer/cpu/insert_cast_cpu.cc

+ 34
- 0
mindspore/ccsrc/backend/optimizer/cpu/insert_cast_cpu.cc View File

@@ -20,6 +20,7 @@
#include <string>
#include <vector>
#include <utility>
#include "backend/optimizer/common/helper.h"
#include "backend/kernel_compiler/kernel_build_info.h"
#include "backend/kernel_compiler/cpu/cpu_kernel_factory.h"
#include "backend/session/anf_runtime_algorithm.h"
@@ -89,6 +90,31 @@ void InsertCast(const FuncGraphPtr &func_graph, const CNodePtr &cnode) {
}
}
}
void InsertCastForGraphOutput(const FuncGraphPtr &func_graph, const CNodePtr &cnode) {
MS_EXCEPTION_IF_NULL(cnode);
size_t output_num = AnfAlgo::GetOutputTensorNum(cnode);
for (size_t i = 0; i < output_num; i++) {
auto infer_type = AnfAlgo::GetOutputInferDataType(cnode, i);
auto device_type = AnfAlgo::GetOutputDeviceDataType(cnode, i);
const std::string dev_fmt = AnfAlgo::GetOutputFormat(cnode, i);
if (infer_type != device_type) {
auto used_node_list = GetRealNodeUsedListByOutputIdx(func_graph, cnode, i);
for (size_t j = 0; j < used_node_list->size(); j++) {
auto used_node = used_node_list->at(j).first;
auto used_node_index = static_cast<size_t>(used_node_list->at(j).second - 1);
auto cur_input = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(used_node), used_node_index);
const std::vector<size_t> origin_shape =
AnfAlgo::GetPrevNodeOutputInferShape(utils::cast<CNodePtr>(used_node), i);
auto cast =
AddCastOpNodeToGraph(func_graph, cur_input, dev_fmt, device_type, infer_type, origin_shape, infer_type);
MS_EXCEPTION_IF_NULL(cast);
cast->set_scope(used_node->scope());
utils::cast<CNodePtr>(used_node)->set_input(used_node_index + 1, cast);
}
}
}
}
} // namespace
bool InsertCastCPU::Run(const FuncGraphPtr &func_graph) {
@@ -100,6 +126,14 @@ bool InsertCastCPU::Run(const FuncGraphPtr &func_graph) {
InsertCast(func_graph, cnode);
}
}
AnfNodePtrList outputs;
kernel::GetFuncGraphOutputNodes(func_graph, &outputs);
for (auto node : outputs) {
if (node != nullptr && node->isa<CNode>() && AnfAlgo::IsRealKernel(node)) {
auto cnode = node->cast<CNodePtr>();
InsertCastForGraphOutput(func_graph, cnode);
}
}
return true;
}
} // namespace opt


Loading…
Cancel
Save