|
|
|
@@ -844,7 +844,7 @@ void ControlNodeParser::ParseDeviceContextForReturnNode(const DeviceContext *def |
|
|
|
const auto &cnode = return_node->cast<CNodePtr>(); |
|
|
|
MS_EXCEPTION_IF_NULL(cnode); |
|
|
|
const auto &inputs = cnode->inputs(); |
|
|
|
const auto output_nodes = FetchAllOutputWithIndex(inputs[kReturnInputPos]); |
|
|
|
const auto output_nodes = FetchInputNodeByNode(inputs[kReturnInputPos]); |
|
|
|
std::vector<const DeviceContext *> return_device_contexts; |
|
|
|
|
|
|
|
for (const auto &output_node : output_nodes) { |
|
|
|
@@ -909,19 +909,13 @@ void ControlNodeParser::ParseDeviceContextForReturnNode(const DeviceContext *def |
|
|
|
|
|
|
|
void ControlNodeParser::FetchFrontNodeToKernelGraph(const std::vector<KernelGraphPtr> &graphs) { |
|
|
|
for (const auto &graph : graphs) { |
|
|
|
MS_EXCEPTION_IF_NULL(graph); |
|
|
|
if (graph->execution_order().empty()) { |
|
|
|
continue; |
|
|
|
} |
|
|
|
|
|
|
|
for (auto &kernel : graph->execution_order()) { |
|
|
|
auto front_node = graph->GetFrontAnfByBackendAnf(kernel); |
|
|
|
if (front_node != nullptr) { |
|
|
|
front_node_to_kernel_graph_[front_node] = graph; |
|
|
|
} |
|
|
|
} |
|
|
|
const auto &graph_outputs = graph->graph_output_map(); |
|
|
|
for (const auto &backend_to_front : graph_outputs) { |
|
|
|
front_node_to_kernel_graph_[backend_to_front.second.first] = graph; |
|
|
|
const auto &front_to_backend_nodes = graph->front_backend_anf_map(); |
|
|
|
for (const auto &front_to_backend_node : front_to_backend_nodes) { |
|
|
|
front_node_to_kernel_graph_[front_to_backend_node.first] = graph; |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
|