|
|
|
@@ -192,8 +192,8 @@ void PrepareDataForWeightNode(const AnfNodePtr &backend_node, const AnfNodePtr & |
|
|
|
if (host_tensor_address->DeviceType() == device_tensor->DeviceType()) { |
|
|
|
AnfAlgo::SetOutputAddr(host_tensor_address, 0, backend_node.get()); |
|
|
|
} else { |
|
|
|
MS_LOG(ERROR) << "The device type is not equal, host tensor type:" << host_tensor_address->DeviceType() |
|
|
|
<< ", device tensor type:" << device_tensor->DeviceType(); |
|
|
|
MS_LOG(INFO) << "The device type is not equal, host tensor type:" << host_tensor_address->DeviceType() |
|
|
|
<< ", device tensor type:" << device_tensor->DeviceType(); |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
@@ -845,6 +845,7 @@ void GraphScheduler::CacheGraphOutputToActor(const GraphCompilerInfo &graph_comp |
|
|
|
void GraphScheduler::Link(ActorSet *actor_set, const GraphCompilerInfo &graph_compiler_info) { |
|
|
|
MS_EXCEPTION_IF_NULL(actor_set); |
|
|
|
std::vector<KernelActor *> auto_monad_actors; |
|
|
|
std::vector<CNodePtr> communication_nodes; |
|
|
|
const std::unordered_set<PrimitivePtr, PrimitiveHasher, PrimitiveEqual> auto_monad_prims = { |
|
|
|
prim::kPrimDepend, prim::kPrimUpdateState, prim::kPrimLoad}; |
|
|
|
|
|
|
|
@@ -854,6 +855,9 @@ void GraphScheduler::Link(ActorSet *actor_set, const GraphCompilerInfo &graph_co |
|
|
|
MS_EXCEPTION_IF_NULL(graph); |
|
|
|
auto execution_order = graph->execution_order(); |
|
|
|
for (auto &kernel : execution_order) { |
|
|
|
if (AnfAlgo::IsCommunicationOp(kernel)) { |
|
|
|
communication_nodes.emplace_back(kernel); |
|
|
|
} |
|
|
|
if (IsSkippedKernelActor(kernel) || (!IsKernelActor(kernel, graph_compiler_info.strategy_))) { |
|
|
|
continue; |
|
|
|
} |
|
|
|
@@ -881,10 +885,11 @@ void GraphScheduler::Link(ActorSet *actor_set, const GraphCompilerInfo &graph_co |
|
|
|
} |
|
|
|
// Link the control arrows for allreduce kernel by the send/recv nodes in the kernel graph. |
|
|
|
LinkControlArrowBySendRecvNodes(graph); |
|
|
|
// Link the control arrows by the communication nodes to ensure communication nodes running order. |
|
|
|
LinkControlArrowByCommunicationNode(graph); |
|
|
|
} |
|
|
|
|
|
|
|
// Link the control arrows by the communication nodes to ensure communication nodes running order. |
|
|
|
LinkControlArrowByCommunicationNode(communication_nodes); |
|
|
|
|
|
|
|
if (graph_compiler_info.strategy_ == GraphExecutionStrategy::kPipeline) { |
|
|
|
// Link the arrow by control node. |
|
|
|
LinkArrowByControlNode(graph_compiler_info, actor_set); |
|
|
|
@@ -1676,8 +1681,10 @@ void GraphScheduler::LinkControlArrowBySendRecvNodes(const KernelGraphPtr &graph |
|
|
|
// inputs of to_allreduce_actor --> from_send_actor |
|
|
|
for (auto &input_aid : to_allreduce_actor->input_data_arrow_aids_) { |
|
|
|
auto input_actor = dynamic_cast<KernelActor *>(FetchActor(input_aid.Name())); |
|
|
|
input_actor->output_control_arrows_.emplace_back(from_send_actor->GetAID()); |
|
|
|
from_send_actor->input_controls_num_++; |
|
|
|
if (input_actor != nullptr) { |
|
|
|
input_actor->output_control_arrows_.emplace_back(from_send_actor->GetAID()); |
|
|
|
from_send_actor->input_controls_num_++; |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
// from_send_actor --> from_recv_actor |
|
|
|
@@ -1709,8 +1716,10 @@ void GraphScheduler::LinkControlArrowBySendRecvNodes(const KernelGraphPtr &graph |
|
|
|
// to_recv_actor --> outputs of from_allreduce_actor |
|
|
|
for (auto &output_data_arrow : from_allreduce_actor->output_data_arrows_) { |
|
|
|
auto output_actor = dynamic_cast<KernelActor *>(FetchActor(output_data_arrow->to_op_id_.Name())); |
|
|
|
to_recv_actor->output_control_arrows_.emplace_back(output_actor->GetAID()); |
|
|
|
output_actor->input_controls_num_++; |
|
|
|
if (output_actor != nullptr) { |
|
|
|
to_recv_actor->output_control_arrows_.emplace_back(output_actor->GetAID()); |
|
|
|
output_actor->input_controls_num_++; |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
// In the scene of allreduce op and computing op parallel multi stream, the input memory of allreduce can be |
|
|
|
@@ -1724,22 +1733,26 @@ void GraphScheduler::LinkControlArrowBySendRecvNodes(const KernelGraphPtr &graph |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
void GraphScheduler::LinkControlArrowByCommunicationNode(const KernelGraphPtr &graph) { |
|
|
|
std::vector<CNodePtr> communication_nodes; |
|
|
|
auto execution_order = graph->execution_order(); |
|
|
|
for (auto &kernel : execution_order) { |
|
|
|
if (AnfAlgo::IsCommunicationOp(kernel)) { |
|
|
|
communication_nodes.emplace_back(kernel); |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
void GraphScheduler::LinkControlArrowByCommunicationNode(const std::vector<CNodePtr> &communication_nodes) { |
|
|
|
for (size_t i = 1; i < communication_nodes.size(); ++i) { |
|
|
|
auto from_actor = dynamic_cast<KernelActor *>(FetchActor(communication_nodes[i - 1]->fullname_with_scope())); |
|
|
|
auto to_actor = dynamic_cast<KernelActor *>(FetchActor(communication_nodes[i]->fullname_with_scope())); |
|
|
|
MS_EXCEPTION_IF_NULL(from_actor); |
|
|
|
MS_EXCEPTION_IF_NULL(to_actor); |
|
|
|
// Ensure communication node to execute orderly. |
|
|
|
from_actor->output_control_arrows_.emplace_back(to_actor->GetAID()); |
|
|
|
to_actor->input_controls_num_++; |
|
|
|
|
|
|
|
// Ensure the input actor of next communication actor is after the previous communication actor to optimize the |
|
|
|
// execution performance in the multi device scenario. |
|
|
|
// Using the multi stream to optimize the performance in the future. |
|
|
|
for (auto &input_aid : to_actor->input_data_arrow_aids_) { |
|
|
|
auto input_actor = dynamic_cast<KernelActor *>(FetchActor(input_aid.Name())); |
|
|
|
if ((input_actor != nullptr) && (from_actor != input_actor)) { |
|
|
|
from_actor->output_control_arrows_.emplace_back(input_actor->GetAID()); |
|
|
|
input_actor->input_controls_num_++; |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
|