|
|
|
@@ -341,6 +341,20 @@ void SessionBasic::InitDevice(const std::string &device_name, uint32_t device_id |
|
|
|
executor_ = ExecutorManager::Instance().GetExecutor(device_name, device_id); |
|
|
|
} |
|
|
|
|
|
|
|
GraphId SessionBasic::GetGraphIdByNode(const AnfNodePtr &front_anf) const { |
|
|
|
for (const auto &graph_item : graphs_) { |
|
|
|
auto graph = graph_item.second; |
|
|
|
MS_EXCEPTION_IF_NULL(graph); |
|
|
|
// if front_anf is a parameter,the backend parameter may have two |
|
|
|
if (graph->GetBackendAnfByFrontAnf(front_anf) != nullptr) { |
|
|
|
return graph_item.first; |
|
|
|
} |
|
|
|
} |
|
|
|
MS_EXCEPTION_IF_NULL(front_anf); |
|
|
|
MS_LOG(DEBUG) << "Front_anf " << front_anf->DebugString() << " is not exist in any graph"; |
|
|
|
return kInvalidGraphId; |
|
|
|
} |
|
|
|
|
|
|
|
KernelGraphPtr SessionBasic::GetGraph(mindspore::GraphId graph_id) const { |
|
|
|
auto it = graphs_.find(graph_id); |
|
|
|
if (it == graphs_.end()) { |
|
|
|
@@ -1216,16 +1230,41 @@ std::vector<AnfNodePtr> ExtendNodeUsers(const FuncGraphManagerPtr &front_func_gr |
|
|
|
return result; |
|
|
|
} |
|
|
|
|
|
|
|
void HandleInternalOutput(const AnfNodePtr &front_node, const AnfNodePtr &backend_node, |
|
|
|
AnfNodePtr GetSupportedInternalNode(const AnfNodePtr &front_node) { |
|
|
|
MS_EXCEPTION_IF_NULL(front_node); |
|
|
|
if (!front_node->isa<CNode>()) { |
|
|
|
return nullptr; |
|
|
|
} |
|
|
|
if (AnfAlgo::IsRealKernel(front_node)) { |
|
|
|
return front_node; |
|
|
|
} |
|
|
|
if (AnfAlgo::CheckPrimitiveType(front_node, prim::kPrimTupleGetItem)) { |
|
|
|
return front_node; |
|
|
|
} |
|
|
|
if (AnfAlgo::CheckPrimitiveType(front_node, prim::kPrimDepend)) { |
|
|
|
auto cnode = front_node->cast<CNodePtr>(); |
|
|
|
MS_EXCEPTION_IF_NULL(cnode); |
|
|
|
auto &inputs = cnode->inputs(); |
|
|
|
if (inputs.size() > 2) { |
|
|
|
return GetSupportedInternalNode(inputs[1]); |
|
|
|
} |
|
|
|
} |
|
|
|
return nullptr; |
|
|
|
} |
|
|
|
|
|
|
|
void HandleInternalOutput(const AnfNodePtr &input_front_node, const AnfNodePtr &backend_node, |
|
|
|
const FuncGraphManagerPtr &front_func_graph_manager, |
|
|
|
const std::shared_ptr<KernelGraph> &backend_graph) { |
|
|
|
// When init parameter from cnode of other graphs, the cnode will not be real kernel except for tuple_getitem. |
|
|
|
if (!AnfAlgo::IsRealKernel(front_node) && !AnfAlgo::CheckPrimitiveType(front_node, prim::kPrimTupleGetItem)) { |
|
|
|
auto front_node = GetSupportedInternalNode(input_front_node); |
|
|
|
if (front_node == nullptr) { |
|
|
|
return; |
|
|
|
} |
|
|
|
auto front_real_kernel_pair = AnfAlgo::VisitKernel(front_node, 0); |
|
|
|
auto backend_real_kernel_pair = AnfAlgo::VisitKernel(backend_node, 0); |
|
|
|
|
|
|
|
auto backend_real_kernel = backend_real_kernel_pair.first; |
|
|
|
if (backend_real_kernel == nullptr || !backend_real_kernel->isa<CNode>()) { |
|
|
|
return; |
|
|
|
} |
|
|
|
auto front_real_kernel = front_real_kernel_pair.first; |
|
|
|
std::string kernel_target = GetCNodeTarget(front_real_kernel); |
|
|
|
bool internal_output = CNodeFirstInputIsPrimitive(front_real_kernel); |
|
|
|
@@ -1254,10 +1293,9 @@ void HandleInternalOutput(const AnfNodePtr &front_node, const AnfNodePtr &backen |
|
|
|
} |
|
|
|
} |
|
|
|
if (internal_output) { |
|
|
|
MS_LOG(INFO) << "Internal output: " << front_node->DebugString() << " To " |
|
|
|
<< backend_real_kernel_pair.first->DebugString() << ", unique_target: " << unique_target; |
|
|
|
backend_graph->AddInternalOutput(front_node, backend_real_kernel_pair.first, backend_real_kernel_pair.second, |
|
|
|
unique_target); |
|
|
|
MS_LOG(INFO) << "AddInternalOutput: " << front_node->DebugString() << " To " << backend_real_kernel->DebugString() |
|
|
|
<< ", unique_target: " << unique_target; |
|
|
|
backend_graph->AddInternalOutput(front_node, backend_real_kernel, backend_real_kernel_pair.second, unique_target); |
|
|
|
} |
|
|
|
} |
|
|
|
} // namespace |
|
|
|
|