|
|
|
@@ -20,6 +20,7 @@ |
|
|
|
#include <string>
|
|
|
|
#include <vector>
|
|
|
|
#include <utility>
|
|
|
|
#include "backend/optimizer/common/helper.h"
|
|
|
|
#include "backend/kernel_compiler/kernel_build_info.h"
|
|
|
|
#include "backend/kernel_compiler/cpu/cpu_kernel_factory.h"
|
|
|
|
#include "backend/session/anf_runtime_algorithm.h"
|
|
|
|
@@ -89,6 +90,31 @@ void InsertCast(const FuncGraphPtr &func_graph, const CNodePtr &cnode) { |
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
void InsertCastForGraphOutput(const FuncGraphPtr &func_graph, const CNodePtr &cnode) {
|
|
|
|
MS_EXCEPTION_IF_NULL(cnode);
|
|
|
|
size_t output_num = AnfAlgo::GetOutputTensorNum(cnode);
|
|
|
|
for (size_t i = 0; i < output_num; i++) {
|
|
|
|
auto infer_type = AnfAlgo::GetOutputInferDataType(cnode, i);
|
|
|
|
auto device_type = AnfAlgo::GetOutputDeviceDataType(cnode, i);
|
|
|
|
const std::string dev_fmt = AnfAlgo::GetOutputFormat(cnode, i);
|
|
|
|
if (infer_type != device_type) {
|
|
|
|
auto used_node_list = GetRealNodeUsedListByOutputIdx(func_graph, cnode, i);
|
|
|
|
for (size_t j = 0; j < used_node_list->size(); j++) {
|
|
|
|
auto used_node = used_node_list->at(j).first;
|
|
|
|
auto used_node_index = static_cast<size_t>(used_node_list->at(j).second - 1);
|
|
|
|
auto cur_input = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(used_node), used_node_index);
|
|
|
|
const std::vector<size_t> origin_shape =
|
|
|
|
AnfAlgo::GetPrevNodeOutputInferShape(utils::cast<CNodePtr>(used_node), i);
|
|
|
|
auto cast =
|
|
|
|
AddCastOpNodeToGraph(func_graph, cur_input, dev_fmt, device_type, infer_type, origin_shape, infer_type);
|
|
|
|
MS_EXCEPTION_IF_NULL(cast);
|
|
|
|
cast->set_scope(used_node->scope());
|
|
|
|
utils::cast<CNodePtr>(used_node)->set_input(used_node_index + 1, cast);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
} // namespace
|
|
|
|
|
|
|
|
bool InsertCastCPU::Run(const FuncGraphPtr &func_graph) {
|
|
|
|
@@ -100,6 +126,14 @@ bool InsertCastCPU::Run(const FuncGraphPtr &func_graph) { |
|
|
|
InsertCast(func_graph, cnode);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
AnfNodePtrList outputs;
|
|
|
|
kernel::GetFuncGraphOutputNodes(func_graph, &outputs);
|
|
|
|
for (auto node : outputs) {
|
|
|
|
if (node != nullptr && node->isa<CNode>() && AnfAlgo::IsRealKernel(node)) {
|
|
|
|
auto cnode = node->cast<CNodePtr>();
|
|
|
|
InsertCastForGraphOutput(func_graph, cnode);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
return true;
|
|
|
|
}
|
|
|
|
} // namespace opt
|
|
|
|
|