|
|
|
@@ -91,7 +91,7 @@ void InsertCast(const FuncGraphPtr &func_graph, const CNodePtr &cnode) { |
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
void InsertCastForGraphOutput(const FuncGraphPtr &func_graph, const CNodePtr &cnode) {
|
|
|
|
void InsertCastForGraphOutput(const FuncGraphPtr &func_graph, const CNodePtr &cnode, const AnfNodePtr &func_output) {
|
|
|
|
MS_EXCEPTION_IF_NULL(cnode);
|
|
|
|
size_t output_num = AnfAlgo::GetOutputTensorNum(cnode);
|
|
|
|
for (size_t i = 0; i < output_num; i++) {
|
|
|
|
@@ -102,6 +102,9 @@ void InsertCastForGraphOutput(const FuncGraphPtr &func_graph, const CNodePtr &cn |
|
|
|
auto used_node_list = GetRealNodeUsedListByOutputIdx(func_graph, cnode, i);
|
|
|
|
for (size_t j = 0; j < used_node_list->size(); j++) {
|
|
|
|
auto used_node = used_node_list->at(j).first;
|
|
|
|
if (used_node != func_output) {
|
|
|
|
continue;
|
|
|
|
}
|
|
|
|
auto used_node_index = static_cast<size_t>(used_node_list->at(j).second - 1);
|
|
|
|
auto cur_input = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(used_node), used_node_index);
|
|
|
|
const std::vector<size_t> origin_shape =
|
|
|
|
@@ -128,10 +131,11 @@ bool InsertCastCPU::Run(const FuncGraphPtr &func_graph) { |
|
|
|
}
|
|
|
|
AnfNodePtrList outputs;
|
|
|
|
kernel::GetFuncGraphOutputNodes(func_graph, &outputs);
|
|
|
|
auto func_output = func_graph->output();
|
|
|
|
for (auto node : outputs) {
|
|
|
|
if (node != nullptr && node->isa<CNode>() && AnfAlgo::IsRealKernel(node)) {
|
|
|
|
auto cnode = node->cast<CNodePtr>();
|
|
|
|
InsertCastForGraphOutput(func_graph, cnode);
|
|
|
|
InsertCastForGraphOutput(func_graph, cnode, func_output);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
return true;
|
|
|
|
|