Browse Source

!8224 optimize internal depend output

From: @kisnwang
Reviewed-by: 
Signed-off-by:
tags/v1.1.0
mindspore-ci-bot Gitee 5 years ago
parent
commit
c442ac0f63
4 changed files with 47 additions and 25 deletions
  1. +0
    -14
      mindspore/ccsrc/backend/session/ascend_session.cc
  2. +0
    -2
      mindspore/ccsrc/backend/session/ascend_session.h
  3. +46
    -8
      mindspore/ccsrc/backend/session/session_basic.cc
  4. +1
    -1
      mindspore/ccsrc/backend/session/session_basic.h

+ 0
- 14
mindspore/ccsrc/backend/session/ascend_session.cc View File

@@ -712,20 +712,6 @@ void AscendSession::SetSummaryNodes(KernelGraph *graph) {
MS_LOG(DEBUG) << "Update summary end size: " << summary.size(); MS_LOG(DEBUG) << "Update summary end size: " << summary.size();
} }


GraphId AscendSession::GetGraphIdByNode(const AnfNodePtr &front_anf) const {
for (const auto &graph_item : graphs_) {
auto graph = graph_item.second;
MS_EXCEPTION_IF_NULL(graph);
// if front_anf is a parameter,the backend parameter may have two
if (graph->GetBackendAnfByFrontAnf(front_anf) != nullptr) {
return graph_item.first;
}
}
MS_EXCEPTION_IF_NULL(front_anf);
MS_LOG(DEBUG) << "Front_anf " << front_anf->DebugString() << " is not exist in any graph";
return kInvalidGraphId;
}

void AscendSession::MergeGraphExecOrder() { void AscendSession::MergeGraphExecOrder() {
MS_LOG(INFO) << "Start!"; MS_LOG(INFO) << "Start!";
// merge graph order // merge graph order


+ 0
- 2
mindspore/ccsrc/backend/session/ascend_session.h View File

@@ -62,8 +62,6 @@ class AscendSession : public SessionBasic {
} }
} }


// get graph id in child graphs by ME front anf node pointer
GraphId GetGraphIdByNode(const AnfNodePtr &front_anf) const override;
// get graph id of final graph // get graph id of final graph
GraphId GetFinalRunGraph() const override { return final_graph_id_; } GraphId GetFinalRunGraph() const override { return final_graph_id_; }




+ 46
- 8
mindspore/ccsrc/backend/session/session_basic.cc View File

@@ -341,6 +341,20 @@ void SessionBasic::InitDevice(const std::string &device_name, uint32_t device_id
executor_ = ExecutorManager::Instance().GetExecutor(device_name, device_id); executor_ = ExecutorManager::Instance().GetExecutor(device_name, device_id);
} }


GraphId SessionBasic::GetGraphIdByNode(const AnfNodePtr &front_anf) const {
for (const auto &graph_item : graphs_) {
auto graph = graph_item.second;
MS_EXCEPTION_IF_NULL(graph);
// if front_anf is a parameter,the backend parameter may have two
if (graph->GetBackendAnfByFrontAnf(front_anf) != nullptr) {
return graph_item.first;
}
}
MS_EXCEPTION_IF_NULL(front_anf);
MS_LOG(DEBUG) << "Front_anf " << front_anf->DebugString() << " is not exist in any graph";
return kInvalidGraphId;
}

KernelGraphPtr SessionBasic::GetGraph(mindspore::GraphId graph_id) const { KernelGraphPtr SessionBasic::GetGraph(mindspore::GraphId graph_id) const {
auto it = graphs_.find(graph_id); auto it = graphs_.find(graph_id);
if (it == graphs_.end()) { if (it == graphs_.end()) {
@@ -1216,16 +1230,41 @@ std::vector<AnfNodePtr> ExtendNodeUsers(const FuncGraphManagerPtr &front_func_gr
return result; return result;
} }


void HandleInternalOutput(const AnfNodePtr &front_node, const AnfNodePtr &backend_node,
AnfNodePtr GetSupportedInternalNode(const AnfNodePtr &front_node) {
MS_EXCEPTION_IF_NULL(front_node);
if (!front_node->isa<CNode>()) {
return nullptr;
}
if (AnfAlgo::IsRealKernel(front_node)) {
return front_node;
}
if (AnfAlgo::CheckPrimitiveType(front_node, prim::kPrimTupleGetItem)) {
return front_node;
}
if (AnfAlgo::CheckPrimitiveType(front_node, prim::kPrimDepend)) {
auto cnode = front_node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
auto &inputs = cnode->inputs();
if (inputs.size() > 2) {
return GetSupportedInternalNode(inputs[1]);
}
}
return nullptr;
}

void HandleInternalOutput(const AnfNodePtr &input_front_node, const AnfNodePtr &backend_node,
const FuncGraphManagerPtr &front_func_graph_manager, const FuncGraphManagerPtr &front_func_graph_manager,
const std::shared_ptr<KernelGraph> &backend_graph) { const std::shared_ptr<KernelGraph> &backend_graph) {
// When init parameter from cnode of other graphs, the cnode will not be real kernel except for tuple_getitem.
if (!AnfAlgo::IsRealKernel(front_node) && !AnfAlgo::CheckPrimitiveType(front_node, prim::kPrimTupleGetItem)) {
auto front_node = GetSupportedInternalNode(input_front_node);
if (front_node == nullptr) {
return; return;
} }
auto front_real_kernel_pair = AnfAlgo::VisitKernel(front_node, 0); auto front_real_kernel_pair = AnfAlgo::VisitKernel(front_node, 0);
auto backend_real_kernel_pair = AnfAlgo::VisitKernel(backend_node, 0); auto backend_real_kernel_pair = AnfAlgo::VisitKernel(backend_node, 0);

auto backend_real_kernel = backend_real_kernel_pair.first;
if (backend_real_kernel == nullptr || !backend_real_kernel->isa<CNode>()) {
return;
}
auto front_real_kernel = front_real_kernel_pair.first; auto front_real_kernel = front_real_kernel_pair.first;
std::string kernel_target = GetCNodeTarget(front_real_kernel); std::string kernel_target = GetCNodeTarget(front_real_kernel);
bool internal_output = CNodeFirstInputIsPrimitive(front_real_kernel); bool internal_output = CNodeFirstInputIsPrimitive(front_real_kernel);
@@ -1254,10 +1293,9 @@ void HandleInternalOutput(const AnfNodePtr &front_node, const AnfNodePtr &backen
} }
} }
if (internal_output) { if (internal_output) {
MS_LOG(INFO) << "Internal output: " << front_node->DebugString() << " To "
<< backend_real_kernel_pair.first->DebugString() << ", unique_target: " << unique_target;
backend_graph->AddInternalOutput(front_node, backend_real_kernel_pair.first, backend_real_kernel_pair.second,
unique_target);
MS_LOG(INFO) << "AddInternalOutput: " << front_node->DebugString() << " To " << backend_real_kernel->DebugString()
<< ", unique_target: " << unique_target;
backend_graph->AddInternalOutput(front_node, backend_real_kernel, backend_real_kernel_pair.second, unique_target);
} }
} }
} // namespace } // namespace


+ 1
- 1
mindspore/ccsrc/backend/session/session_basic.h View File

@@ -90,7 +90,7 @@ class SessionBasic : public std::enable_shared_from_this<SessionBasic> {
CNodePtr CreateNewCNode(CNodePtr cnode, KernelGraph *graph); CNodePtr CreateNewCNode(CNodePtr cnode, KernelGraph *graph);


// get graph id in child graphs by ME front anf node pointer // get graph id in child graphs by ME front anf node pointer
virtual GraphId GetGraphIdByNode(const AnfNodePtr &) const { return kInvalidGraphId; }
virtual GraphId GetGraphIdByNode(const AnfNodePtr &) const;
virtual GraphId GetFinalRunGraph() const { return kInvalidGraphId; } virtual GraphId GetFinalRunGraph() const { return kInvalidGraphId; }
void CheckPSModeConsistence(const KernelGraphPtr &Kernel_graph); void CheckPSModeConsistence(const KernelGraphPtr &Kernel_graph);
void AssignParamKey(const KernelGraphPtr &kernel_graph); void AssignParamKey(const KernelGraphPtr &kernel_graph);


Loading…
Cancel
Save