Browse Source

actor runtime supports communication nodes running order

tags/v1.3.0
limingqi107 4 years ago
parent
commit
79f1ab6724
3 changed files with 44 additions and 9 deletions
  1. +21
    -9
      mindspore/ccsrc/frontend/parallel/group_manager.cc
  2. +21
    -0
      mindspore/ccsrc/runtime/framework/graph_scheduler.cc
  3. +2
    -0
      mindspore/ccsrc/runtime/framework/graph_scheduler.h

+ 21
- 9
mindspore/ccsrc/frontend/parallel/group_manager.cc View File

@@ -20,6 +20,7 @@
#include <utility>
#if !defined(NO_DLIB) || defined(ENABLE_GPU)
#include "backend/session/executor_manager.h"
#include "runtime/framework/actor/actor_common.h"
#else
#include "frontend/parallel/parallel_stub/executor_manager_stub.h"
#endif
@@ -73,18 +74,24 @@ GroupManager::GroupManager() { groups_.clear(); }
#if !defined(NO_DLIB) || defined(ENABLE_GPU)
bool GroupManager::CreateGroupByExecutor(const std::string &device_name, const std::string &group_name,
const std::vector<uint32_t> ranks, int device_id) {
auto executor = session::ExecutorManager::Instance().GetExecutor(device_name, device_id);
MS_EXCEPTION_IF_NULL(executor);
bool ret = executor->CreateCommGroup(group_name, ranks);
return ret;
if (IsMindRTUsed()) {
return CommManager::GetInstance().CreateGroupSync(group_name, ranks);
} else {
auto executor = session::ExecutorManager::Instance().GetExecutor(device_name, device_id);
MS_EXCEPTION_IF_NULL(executor);
return executor->CreateCommGroup(group_name, ranks);
}
}

bool GroupManager::DestroyGroupByExecutor(const std::string &device_name, const std::string &group_name,
int device_id) {
auto executor = session::ExecutorManager::Instance().GetExecutor(device_name, device_id);
MS_EXCEPTION_IF_NULL(executor);
bool ret = executor->DestroyCommGroup(group_name);
return ret;
if (IsMindRTUsed()) {
return CommManager::GetInstance().DestroyGroup(group_name);
} else {
auto executor = session::ExecutorManager::Instance().GetExecutor(device_name, device_id);
MS_EXCEPTION_IF_NULL(executor);
return executor->DestroyCommGroup(group_name);
}
}

Status CreateGroups(const std::vector<std::pair<std::string, std::vector<uint32_t>>> &group_info) {
@@ -96,7 +103,12 @@ Status CreateGroups(const std::vector<std::pair<std::string, std::vector<uint32_
auto executor = session::ExecutorManager::Instance().GetExecutor(device_name, device_id);
MS_EXCEPTION_IF_NULL(executor);
for (auto &group : group_info) {
bool ret = executor->CreateCommGroup(group.first, group.second);
bool ret = true;
if (IsMindRTUsed()) {
ret = CommManager::GetInstance().CreateGroupSync(group.first, group.second);
} else {
ret = executor->CreateCommGroup(group.first, group.second);
}
if (!ret) {
MS_LOG(ERROR) << "Create group failed, group name is " << group.first << ", ranks is " << group.second;
return FAILED;


+ 21
- 0
mindspore/ccsrc/runtime/framework/graph_scheduler.cc View File

@@ -688,6 +688,8 @@ void GraphScheduler::Link(ActorSet *actor_set, const GraphCompilerInfo &graph_co
}
// Link the control arrows for allreduce kernel by the send/recv nodes in the kernel graph.
LinkControlArrowBySendRecvNodes(graph);
// Link the control arrows by the communication nodes to ensure communication nodes running order.
LinkControlArrowByCommunicationNode(graph);
}

// Link the arrow by control node.
@@ -1436,6 +1438,25 @@ void GraphScheduler::LinkControlArrowBySendRecvNodes(const KernelGraphPtr &graph
}
}

void GraphScheduler::LinkControlArrowByCommunicationNode(const KernelGraphPtr &graph) {
std::vector<CNodePtr> communication_nodes;
auto execution_order = graph->execution_order();
for (auto &kernel : execution_order) {
if (AnfAlgo::IsCommunicationOp(kernel)) {
communication_nodes.emplace_back(kernel);
}
}

for (size_t i = 1; i < communication_nodes.size(); ++i) {
auto from_actor = dynamic_cast<KernelActor *>(FetchActor(communication_nodes[i - 1]->fullname_with_scope()));
auto to_actor = dynamic_cast<KernelActor *>(FetchActor(communication_nodes[i]->fullname_with_scope()));
MS_EXCEPTION_IF_NULL(from_actor);
MS_EXCEPTION_IF_NULL(to_actor);
from_actor->output_control_arrows_.emplace_back(to_actor->GetAID());
to_actor->input_controls_num_++;
}
}

void GraphScheduler::LinkControlArrowForLoopCountActor(LoopCountActor *loop_count_actor, const ActorSet *actor_set,
GraphExecutionStrategy strategy) {
MS_EXCEPTION_IF_NULL(actor_set);


+ 2
- 0
mindspore/ccsrc/runtime/framework/graph_scheduler.h View File

@@ -217,6 +217,8 @@ class GraphScheduler {
void LinkControlArrowBySkippedNode(KernelActor *to_actor, const AnfNodePtr &skipped_node);
// Link the control arrows for allreduce kernel by the send/recv nodes in the kernel graph.
void LinkControlArrowBySendRecvNodes(const KernelGraphPtr &graph);
// Link the control arrows by the communication nodes in the kernel graph to ensure communication nodes running order.
void LinkControlArrowByCommunicationNode(const KernelGraphPtr &graph);
void LinkOutputResultArrowForOutputActor(OutputActor *to_actor, const GraphCompilerInfo &graph_compiler_info);
void LinkDeviceTensorStoreForAutoMonadActor(const std::vector<KernelActor *> &auto_monad_actors);



Loading…
Cancel
Save