|
|
|
@@ -194,7 +194,7 @@ void GraphScheduler::Clear() { |
|
|
|
actor_name_to_actor_.clear(); |
|
|
|
} |
|
|
|
|
|
|
|
using DataArrowLinkFunc = void (GraphScheduler::*)(AbstractActor *const, KernelActor *const, const KernelWithIndex &, |
|
|
|
using DataArrowLinkFunc = void (GraphScheduler::*)(AbstractActor *const, AbstractActor *const, const KernelWithIndex &, |
|
|
|
const KernelWithIndex &, const KernelGraphPtr &); |
|
|
|
static std::map<KernelTransformType, DataArrowLinkFunc> kKernelTypeToLinkFunc; |
|
|
|
|
|
|
|
@@ -364,6 +364,7 @@ ActorSetPtr GraphScheduler::Build(const GraphCompilerInfo &graph_compiler_info) |
|
|
|
auto host_queue = std::make_shared<HostTensorQueue>(); |
|
|
|
actor_set->data_source_actors_ = BuildDataSourceActor(graph_compiler_info, host_queue); |
|
|
|
actor_set->kernel_actors_ = BuildKernelActor(graph_compiler_info); |
|
|
|
actor_set->super_kernel_actors_ = BuildSuperKernelActor(graph_compiler_info); |
|
|
|
actor_set->loop_count_actor_ = BuildLoopCountActor(graph_compiler_info); |
|
|
|
actor_set->output_actor_ = BuildOutputActor(graph_compiler_info); |
|
|
|
actor_set->data_prepare_actor_ = |
|
|
|
@@ -386,41 +387,27 @@ void GraphScheduler::CacheGraphOutputToActor(const GraphCompilerInfo &graph_comp |
|
|
|
auto origin_output_with_index = graph->GetFrontNodeWithIndexByGraphOutput(output_with_index); |
|
|
|
if (origin_output_with_index.first == nullptr) { |
|
|
|
MS_LOG(WARNING) << "The graph " << graph->graph_id() << " output node:" << output_kernel->fullname_with_scope() |
|
|
|
<< " with index: " << output_with_index.second << " has no actor."; |
|
|
|
<< " with index: " << output_with_index.second << " has no front node."; |
|
|
|
continue; |
|
|
|
} |
|
|
|
|
|
|
|
auto actor_output_index = output_with_index.second; |
|
|
|
OpActor<DeviceTensor> *actor = nullptr; |
|
|
|
if (IsKernelActor(output_kernel, graph_compiler_info.strategy_)) { |
|
|
|
actor = FetchActor(output_kernel->fullname_with_scope()); |
|
|
|
} else if (IsDeviceQueueDSActor(output_kernel, graph_compiler_info.strategy_)) { |
|
|
|
std::string actor_name = graph_compiler_info.name_ + "_DeviceDSActor" + "_" + std::to_string(graph->graph_id()); |
|
|
|
actor = FetchActor(actor_name); |
|
|
|
} else if (IsHostQueueDSActor(output_kernel, graph, graph_compiler_info.origin_parameters_order_, |
|
|
|
graph_compiler_info.strategy_)) { |
|
|
|
actor = FetchActor(graph_compiler_info.name_ + "_HostDSActor"); |
|
|
|
const auto &host_ds_actor = dynamic_cast<HostQueueDataSourceActor *>(actor); |
|
|
|
MS_EXCEPTION_IF_NULL(host_ds_actor); |
|
|
|
// Get the position of output kernel in the data source actor. |
|
|
|
actor_output_index = host_ds_actor->FetchNodePosition(output_kernel); |
|
|
|
} else if (IsPersistentDeviceTensor(output_kernel)) { |
|
|
|
auto kernel_type = KernelTransformType::kUnknown; |
|
|
|
std::string kernel_name = ""; |
|
|
|
FetchKernelTransformTypeAndName(output_kernel, graph, graph_compiler_info, &kernel_type, &kernel_name); |
|
|
|
if (kernel_name == "") { |
|
|
|
MS_LOG(INFO) << "The graph " << graph->graph_id() << " output node:" << output_kernel->fullname_with_scope() |
|
|
|
<< " is device tensor store."; |
|
|
|
continue; |
|
|
|
} else { |
|
|
|
MS_LOG(INFO) << "Ignore the internal parameter node:" << output_kernel->DebugString(); |
|
|
|
<< " with index:" << output_with_index.second |
|
|
|
<< " is not actor, and the kernel type is:" << kernel_type; |
|
|
|
continue; |
|
|
|
} |
|
|
|
|
|
|
|
MS_EXCEPTION_IF_NULL(actor); |
|
|
|
auto output_actor = dynamic_cast<AbstractActor *>(FetchActor(kernel_name)); |
|
|
|
MS_EXCEPTION_IF_NULL(output_actor); |
|
|
|
(void)graph_output_to_actor_.emplace(origin_output_with_index, GraphOutputPair(output_actor, output_with_index)); |
|
|
|
MS_LOG(INFO) << "Cache the graph " << graph->graph_id() << " output node:" << output_kernel->fullname_with_scope() |
|
|
|
<< " with index: " << output_with_index.second << " to actor:" << actor->GetAID().Name() |
|
|
|
<< " with index:" << actor_output_index |
|
|
|
<< " with index: " << output_with_index.second << " to actor:" << output_actor->GetAID().Name() |
|
|
|
<< ", from front node:" << origin_output_with_index.first->fullname_with_scope() |
|
|
|
<< " with index: " << origin_output_with_index.second; |
|
|
|
(void)graph_output_to_actor_.emplace(origin_output_with_index, |
|
|
|
GraphOutputPair(dynamic_cast<AbstractActor *>(actor), actor_output_index)); |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
@@ -429,44 +416,14 @@ void GraphScheduler::Link(ActorSet *actor_set, const GraphCompilerInfo &graph_co |
|
|
|
MS_EXCEPTION_IF_NULL(actor_set); |
|
|
|
std::vector<KernelActor *> auto_monad_actors; |
|
|
|
std::vector<CNodePtr> communication_nodes; |
|
|
|
const std::unordered_set<PrimitivePtr, PrimitiveHasher, PrimitiveEqual> auto_monad_prims = { |
|
|
|
prim::kPrimDepend, prim::kPrimUpdateState, prim::kPrimLoad}; |
|
|
|
|
|
|
|
// Foreach the execution order to link the actors. |
|
|
|
for (size_t index = 0; index < graph_compiler_info.graphs_.size(); ++index) { |
|
|
|
const auto &graph = graph_compiler_info.graphs_[index]; |
|
|
|
for (const auto &graph : graph_compiler_info.graphs_) { |
|
|
|
MS_EXCEPTION_IF_NULL(graph); |
|
|
|
auto execution_order = graph->execution_order(); |
|
|
|
for (auto &kernel : execution_order) { |
|
|
|
MS_EXCEPTION_IF_NULL(kernel); |
|
|
|
if (AnfAlgo::IsCommunicationOp(kernel)) { |
|
|
|
(void)communication_nodes.emplace_back(kernel); |
|
|
|
} |
|
|
|
if (IsSkippedKernelActor(kernel) || (!IsKernelActor(kernel, graph_compiler_info.strategy_))) { |
|
|
|
continue; |
|
|
|
} |
|
|
|
const auto &kernel_actor = dynamic_cast<KernelActor *>(FetchActor(kernel->fullname_with_scope())); |
|
|
|
MS_EXCEPTION_IF_NULL(kernel_actor); |
|
|
|
|
|
|
|
for (size_t i = 0; i < AnfAlgo::GetInputNum(kernel); ++i) { |
|
|
|
auto input_node = AnfAlgo::GetInputNode(kernel, i); |
|
|
|
// Link the control arrows of kernel actor by the auto monad, the inputs include monad node. |
|
|
|
if (AnfAlgo::IsOneOfPrimitiveCNode(input_node, auto_monad_prims)) { |
|
|
|
LinkControlArrowByAutoMonad(kernel_actor, input_node, graph); |
|
|
|
} |
|
|
|
if (HasAbstractMonad(input_node)) { |
|
|
|
(void)auto_monad_actors.emplace_back(kernel_actor); |
|
|
|
continue; // No data arrow for monad input. |
|
|
|
} |
|
|
|
|
|
|
|
KernelWithIndex from_kernel_with_output_idx = AnfAlgo::VisitKernelWithReturnType(input_node, 0, false); |
|
|
|
KernelWithIndex to_kernel_with_input_idx = std::make_pair(kernel, i); |
|
|
|
// The gather of linking data arrows of kernel by the different from kernel type. |
|
|
|
LinkDataArrow(kernel_actor, graph_compiler_info, graph, from_kernel_with_output_idx, to_kernel_with_input_idx); |
|
|
|
} |
|
|
|
if (graph->is_sink()) { |
|
|
|
LinkDataArrowInSinkMode(graph, graph_compiler_info); |
|
|
|
} else { |
|
|
|
LinkDataArrowInNonSinkMode(graph, graph_compiler_info, &auto_monad_actors, &communication_nodes); |
|
|
|
} |
|
|
|
// Link the control arrows for allreduce kernel by the send/recv nodes in the kernel graph. |
|
|
|
LinkControlArrowBySendRecvNodes(graph); |
|
|
|
} |
|
|
|
|
|
|
|
// Link the arrow in the control flow scene. |
|
|
|
@@ -523,22 +480,25 @@ std::vector<DataSourceActorPtr> GraphScheduler::BuildDataSourceActor(const Graph |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
// Build device queue data source actor. |
|
|
|
const auto &execution_order = graph->execution_order(); |
|
|
|
const auto &iter = |
|
|
|
std::find_if(execution_order.begin(), execution_order.end(), [&graph_compiler_info](const CNodePtr &node) { |
|
|
|
return IsDeviceQueueDSActor(node, graph_compiler_info.strategy_); |
|
|
|
}); |
|
|
|
if (iter != execution_order.end()) { |
|
|
|
auto actor_name = graph_compiler_info.name_ + "_DeviceDSActor" + "_" + std::to_string(graph->graph_id()); |
|
|
|
MS_LOG(INFO) << "Create queue data source actor: " << actor_name; |
|
|
|
auto device_queue_ds_actor = std::make_shared<DeviceQueueDataSourceActor>( |
|
|
|
actor_name, 1, device_context, memory_manager_aid_, debug_aid_, recorder_aid_); |
|
|
|
MS_EXCEPTION_IF_NULL(device_queue_ds_actor); |
|
|
|
InsertActor(device_queue_ds_actor.get()); |
|
|
|
(void)data_source_actors.emplace_back(device_queue_ds_actor); |
|
|
|
device_queue_ds_actor->data_kernel_ = *iter; |
|
|
|
device_queue_ds_actor->kernel_info_ = dynamic_cast<device::KernelInfo *>((*iter)->kernel_info()); |
|
|
|
// The graph sink mode has no device queue data source actor. |
|
|
|
if (!graph->is_sink()) { |
|
|
|
// Build device queue data source actor. |
|
|
|
const auto &execution_order = graph->execution_order(); |
|
|
|
const auto &iter = |
|
|
|
std::find_if(execution_order.begin(), execution_order.end(), [&graph_compiler_info](const CNodePtr &node) { |
|
|
|
return IsDeviceQueueDSActor(node, graph_compiler_info.strategy_); |
|
|
|
}); |
|
|
|
if (iter != execution_order.end()) { |
|
|
|
auto actor_name = graph_compiler_info.name_ + "_DeviceDSActor" + "_" + std::to_string(graph->graph_id()); |
|
|
|
MS_LOG(INFO) << "Create queue data source actor: " << actor_name; |
|
|
|
auto device_queue_ds_actor = std::make_shared<DeviceQueueDataSourceActor>( |
|
|
|
actor_name, 1, device_context, memory_manager_aid_, debug_aid_, recorder_aid_); |
|
|
|
MS_EXCEPTION_IF_NULL(device_queue_ds_actor); |
|
|
|
InsertActor(device_queue_ds_actor.get()); |
|
|
|
(void)data_source_actors.emplace_back(device_queue_ds_actor); |
|
|
|
device_queue_ds_actor->data_kernel_ = *iter; |
|
|
|
device_queue_ds_actor->kernel_info_ = dynamic_cast<device::KernelInfo *>((*iter)->kernel_info()); |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
@@ -588,8 +548,11 @@ std::vector<KernelActorPtr> GraphScheduler::BuildKernelActor(const GraphCompiler |
|
|
|
const auto &graph = graph_compiler_info.graphs_[i]; |
|
|
|
const auto &device_context = graph_compiler_info.device_contexts_[i]; |
|
|
|
MS_EXCEPTION_IF_NULL(graph); |
|
|
|
auto execution_order = graph->execution_order(); |
|
|
|
if (graph->is_sink()) { |
|
|
|
continue; |
|
|
|
} |
|
|
|
|
|
|
|
auto execution_order = graph->execution_order(); |
|
|
|
// Single op graph in step mode, kernel actor executes synchronously. |
|
|
|
bool is_single_op_graph = execution_order.size() == 1; |
|
|
|
GraphExecutionStrategy strategy = graph_compiler_info.strategy_; |
|
|
|
@@ -615,6 +578,27 @@ std::vector<KernelActorPtr> GraphScheduler::BuildKernelActor(const GraphCompiler |
|
|
|
return kernel_actors; |
|
|
|
} |
|
|
|
|
|
|
|
std::vector<SuperKernelActorPtr> GraphScheduler::BuildSuperKernelActor(const GraphCompilerInfo &graph_compiler_info) { |
|
|
|
std::vector<SuperKernelActorPtr> super_kernel_actors; |
|
|
|
|
|
|
|
for (size_t i = 0; i < graph_compiler_info.graphs_.size(); ++i) { |
|
|
|
const auto &graph = graph_compiler_info.graphs_[i]; |
|
|
|
const auto &device_context = graph_compiler_info.device_contexts_[i]; |
|
|
|
MS_EXCEPTION_IF_NULL(graph); |
|
|
|
if (!graph->is_sink()) { |
|
|
|
continue; |
|
|
|
} |
|
|
|
|
|
|
|
auto actor_name = graph->ToString() + "_SuperKernelActor"; |
|
|
|
auto super_kernel_actor = |
|
|
|
std::make_shared<SuperKernelActor>(actor_name, graph, device_context, memory_manager_aid_, nullptr, nullptr); |
|
|
|
MS_EXCEPTION_IF_NULL(super_kernel_actor); |
|
|
|
InsertActor(super_kernel_actor.get()); |
|
|
|
(void)super_kernel_actors.emplace_back(super_kernel_actor); |
|
|
|
} |
|
|
|
return super_kernel_actors; |
|
|
|
} |
|
|
|
|
|
|
|
LoopCountActorPtr GraphScheduler::BuildLoopCountActor(const GraphCompilerInfo &graph_compiler_info) { |
|
|
|
if (graph_compiler_info.strategy_ == GraphExecutionStrategy::kStep) { |
|
|
|
return nullptr; |
|
|
|
@@ -658,7 +642,6 @@ DataPrepareActorPtr GraphScheduler::BuildDataPrepareActor(const GraphCompilerInf |
|
|
|
if (iter != data_source_actors.end()) { |
|
|
|
host_queue_ds_actor = std::dynamic_pointer_cast<HostQueueDataSourceActor>(*iter); |
|
|
|
} |
|
|
|
|
|
|
|
auto actor_name = graph_compiler_info.name_ + "_DataPrepareActor"; |
|
|
|
auto data_prepare_actor = std::make_shared<DataPrepareActor>(actor_name, memory_manager_aid_, debug_aid_, |
|
|
|
&graph_compiler_info, host_queue_ds_actor, host_queue); |
|
|
|
@@ -670,12 +653,15 @@ DataPrepareActorPtr GraphScheduler::BuildDataPrepareActor(const GraphCompilerInf |
|
|
|
for (size_t index = 0; index < graph_compiler_info.graphs_.size(); ++index) { |
|
|
|
const auto &graph = graph_compiler_info.graphs_[index]; |
|
|
|
MS_EXCEPTION_IF_NULL(graph); |
|
|
|
if (graph->is_sink()) { |
|
|
|
continue; |
|
|
|
} |
|
|
|
|
|
|
|
auto &execution_order = graph->execution_order(); |
|
|
|
for (auto &kernel : execution_order) { |
|
|
|
if (!AnfAlgo::IsCommunicationOp(kernel)) { |
|
|
|
continue; |
|
|
|
} |
|
|
|
|
|
|
|
auto key = std::make_pair(kernel, graph_compiler_info.device_contexts_[index]); |
|
|
|
auto value = std::make_pair(false, false); |
|
|
|
if (AnfAlgo::GetInputTensorNum(kernel) > 1) { |
|
|
|
@@ -695,10 +681,17 @@ DataPrepareActorPtr GraphScheduler::BuildDataPrepareActor(const GraphCompilerInf |
|
|
|
return data_prepare_actor; |
|
|
|
} |
|
|
|
|
|
|
|
std::vector<KernelActorPtr> GraphScheduler::BuildNoInputKernelActor(const ActorSet *actor_set, |
|
|
|
GraphExecutionStrategy strategy) { |
|
|
|
std::vector<AbstractActorPtr> GraphScheduler::BuildNoInputKernelActor(const ActorSet *actor_set, |
|
|
|
GraphExecutionStrategy strategy) { |
|
|
|
MS_EXCEPTION_IF_NULL(actor_set); |
|
|
|
std::vector<KernelActorPtr> no_input_kernel_actors; |
|
|
|
std::vector<AbstractActorPtr> no_input_kernel_actors; |
|
|
|
|
|
|
|
for (auto &super_kernel_actor : actor_set->super_kernel_actors_) { |
|
|
|
MS_EXCEPTION_IF_NULL(super_kernel_actor); |
|
|
|
if ((super_kernel_actor->input_datas_num_ == 0) && (super_kernel_actor->input_controls_num_ == 0)) { |
|
|
|
(void)no_input_kernel_actors.emplace_back(super_kernel_actor); |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
for (auto &kernel_actor : actor_set->kernel_actors_) { |
|
|
|
MS_EXCEPTION_IF_NULL(kernel_actor); |
|
|
|
@@ -730,6 +723,78 @@ std::vector<KernelActorPtr> GraphScheduler::BuildNoInputKernelActor(const ActorS |
|
|
|
return no_input_kernel_actors; |
|
|
|
} |
|
|
|
|
|
|
|
void GraphScheduler::LinkDataArrowInSinkMode(const KernelGraphPtr &graph, |
|
|
|
const GraphCompilerInfo &graph_compiler_info) { |
|
|
|
MS_EXCEPTION_IF_NULL(graph); |
|
|
|
auto to_actor_name = graph->ToString() + "_SuperKernelActor"; |
|
|
|
auto to_actor = dynamic_cast<SuperKernelActor *>(FetchActor(to_actor_name)); |
|
|
|
MS_EXCEPTION_IF_NULL(to_actor); |
|
|
|
|
|
|
|
for (const auto &input_node : graph->input_nodes()) { |
|
|
|
MS_EXCEPTION_IF_NULL(input_node); |
|
|
|
auto kernel_type = KernelTransformType::kUnknown; |
|
|
|
std::string kernel_name = ""; |
|
|
|
FetchKernelTransformTypeAndName(input_node, graph, graph_compiler_info, &kernel_type, &kernel_name); |
|
|
|
|
|
|
|
KernelWithIndex from_kernel_with_output_idx = std::make_pair(input_node, 0); |
|
|
|
KernelWithIndex to_kernel_with_input_idx = std::make_pair(nullptr, 0); |
|
|
|
AbstractActor *from_actor = nullptr; |
|
|
|
if (kernel_type == KernelTransformType::kHostDataSourceActor) { |
|
|
|
from_actor = dynamic_cast<AbstractActor *>(FetchActor(kernel_name)); |
|
|
|
} |
|
|
|
|
|
|
|
if ((from_actor != nullptr) && (kKernelTypeToLinkFunc.count(kernel_type) > 0)) { |
|
|
|
(this->*kKernelTypeToLinkFunc[kernel_type])(from_actor, to_actor, from_kernel_with_output_idx, |
|
|
|
to_kernel_with_input_idx, graph); |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
void GraphScheduler::LinkDataArrowInNonSinkMode(const KernelGraphPtr &graph, |
|
|
|
const GraphCompilerInfo &graph_compiler_info, |
|
|
|
std::vector<KernelActor *> *const auto_monad_actors, |
|
|
|
std::vector<CNodePtr> *const communication_nodes) { |
|
|
|
MS_EXCEPTION_IF_NULL(graph); |
|
|
|
MS_EXCEPTION_IF_NULL(auto_monad_actors); |
|
|
|
MS_EXCEPTION_IF_NULL(communication_nodes); |
|
|
|
|
|
|
|
const std::unordered_set<PrimitivePtr, PrimitiveHasher, PrimitiveEqual> auto_monad_prims = { |
|
|
|
prim::kPrimDepend, prim::kPrimUpdateState, prim::kPrimLoad}; |
|
|
|
auto &execution_order = graph->execution_order(); |
|
|
|
// Foreach the execution order to link the actors. |
|
|
|
for (const auto &kernel : execution_order) { |
|
|
|
MS_EXCEPTION_IF_NULL(kernel); |
|
|
|
if (AnfAlgo::IsCommunicationOp(kernel)) { |
|
|
|
(void)communication_nodes->emplace_back(kernel); |
|
|
|
} |
|
|
|
if (IsSkippedKernelActor(kernel) || (!IsKernelActor(kernel, graph_compiler_info.strategy_))) { |
|
|
|
continue; |
|
|
|
} |
|
|
|
const auto &kernel_actor = dynamic_cast<KernelActor *>(FetchActor(kernel->fullname_with_scope())); |
|
|
|
MS_EXCEPTION_IF_NULL(kernel_actor); |
|
|
|
|
|
|
|
for (size_t i = 0; i < AnfAlgo::GetInputNum(kernel); ++i) { |
|
|
|
auto input_node = AnfAlgo::GetInputNode(kernel, i); |
|
|
|
// Link the control arrows of kernel actor by the auto monad, the inputs include monad node. |
|
|
|
if (AnfAlgo::IsOneOfPrimitiveCNode(input_node, auto_monad_prims)) { |
|
|
|
LinkControlArrowByAutoMonad(kernel_actor, input_node, graph); |
|
|
|
} |
|
|
|
if (HasAbstractMonad(input_node)) { |
|
|
|
(void)auto_monad_actors->emplace_back(kernel_actor); |
|
|
|
continue; // No data arrow for monad input. |
|
|
|
} |
|
|
|
|
|
|
|
KernelWithIndex from_kernel_with_output_idx = AnfAlgo::VisitKernelWithReturnType(input_node, 0, false); |
|
|
|
KernelWithIndex to_kernel_with_input_idx = std::make_pair(kernel, i); |
|
|
|
// The gather of linking data arrows of kernel by the different from kernel type. |
|
|
|
LinkDataArrow(kernel_actor, graph_compiler_info, graph, from_kernel_with_output_idx, to_kernel_with_input_idx); |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
// Link the control arrows for allreduce kernel by the send/recv nodes in the kernel graph. |
|
|
|
LinkControlArrowBySendRecvNodes(graph); |
|
|
|
} |
|
|
|
|
|
|
|
void GraphScheduler::LinkDataArrow(KernelActor *const to_actor, const GraphCompilerInfo &graph_compiler_info, |
|
|
|
const KernelGraphPtr &graph, const KernelWithIndex &from_kernel_with_output_idx, |
|
|
|
const KernelWithIndex &to_kernel_with_input_idx) { |
|
|
|
@@ -784,7 +849,7 @@ void GraphScheduler::LinkDataArrow(KernelActor *const to_actor, const GraphCompi |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
void GraphScheduler::LinkDataArrowForDeviceTensorStore(AbstractActor *const, KernelActor *const to_actor, |
|
|
|
void GraphScheduler::LinkDataArrowForDeviceTensorStore(AbstractActor *const, AbstractActor *const to_actor, |
|
|
|
const KernelWithIndex &from_kernel_with_output_idx, |
|
|
|
const KernelWithIndex &to_kernel_with_input_idx, |
|
|
|
const KernelGraphPtr &graph) { |
|
|
|
@@ -797,7 +862,7 @@ void GraphScheduler::LinkDataArrowForDeviceTensorStore(AbstractActor *const, Ker |
|
|
|
(void)to_actor->device_tensor_store_keys_.emplace_back(to_kernel_with_input_idx.second, device_tensor_store_key); |
|
|
|
} |
|
|
|
|
|
|
|
void GraphScheduler::LinkDataArrowForInternalParameter(AbstractActor *const, KernelActor *to_actor, |
|
|
|
void GraphScheduler::LinkDataArrowForInternalParameter(AbstractActor *const, AbstractActor *to_actor, |
|
|
|
const KernelWithIndex &from_kernel_with_output_idx, |
|
|
|
const KernelWithIndex &to_kernel_with_input_idx, |
|
|
|
const KernelGraphPtr &graph) { |
|
|
|
@@ -831,13 +896,16 @@ void GraphScheduler::LinkDataArrowForInternalParameter(AbstractActor *const, Ker |
|
|
|
} |
|
|
|
auto actor_pair = graph_output_to_actor_[front_output_with_index]; |
|
|
|
MS_EXCEPTION_IF_NULL(actor_pair.first); |
|
|
|
MS_EXCEPTION_IF_NULL(actor_pair.second.first); |
|
|
|
MS_LOG(INFO) << "Graph " << graph->graph_id() << " internal parameter:" << internal_parameter->DebugString() |
|
|
|
<< ", corresponding front node:" << front_output_node->fullname_with_scope() |
|
|
|
<< " with index:" << front_output_with_index.second |
|
|
|
<< ", from actor:" << actor_pair.first->GetAID().Name() << " with index:" << actor_pair.second |
|
|
|
<< ", to actor:" << to_actor->GetAID().Name() << " with index:" << to_kernel_with_input_idx.second; |
|
|
|
<< ", from actor:" << actor_pair.first->GetAID().Name() |
|
|
|
<< " node:" << actor_pair.second.first->fullname_with_scope() |
|
|
|
<< " with index:" << actor_pair.second.second << ", to actor:" << to_actor->GetAID().Name() |
|
|
|
<< " with index:" << to_kernel_with_input_idx.second; |
|
|
|
real_from_actor = actor_pair.first; |
|
|
|
real_from_kernel_with_output_idx = KernelWithIndex(nullptr, actor_pair.second); |
|
|
|
real_from_kernel_with_output_idx = actor_pair.second; |
|
|
|
kernel_type = actor_pair.first->type_; |
|
|
|
} |
|
|
|
|
|
|
|
@@ -848,7 +916,7 @@ void GraphScheduler::LinkDataArrowForInternalParameter(AbstractActor *const, Ker |
|
|
|
to_kernel_with_input_idx, graph); |
|
|
|
} |
|
|
|
|
|
|
|
void GraphScheduler::LinkDataArrowForBaseActor(AbstractActor *const from_actor, KernelActor *const to_actor, |
|
|
|
void GraphScheduler::LinkDataArrowForBaseActor(AbstractActor *const from_actor, AbstractActor *const to_actor, |
|
|
|
const KernelWithIndex &from_kernel_with_output_idx, |
|
|
|
const KernelWithIndex &to_kernel_with_input_idx) { |
|
|
|
MS_EXCEPTION_IF_NULL(from_actor); |
|
|
|
@@ -884,7 +952,7 @@ void GraphScheduler::LinkDataArrowForBaseActor(AbstractActor *const from_actor, |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
void GraphScheduler::LinkDataArrowForDeviceDSActor(AbstractActor *const from_actor, KernelActor *const to_actor, |
|
|
|
void GraphScheduler::LinkDataArrowForDeviceDSActor(AbstractActor *const from_actor, AbstractActor *const to_actor, |
|
|
|
const KernelWithIndex &from_kernel_with_output_idx, |
|
|
|
const KernelWithIndex &to_kernel_with_input_idx, |
|
|
|
const KernelGraphPtr &) { |
|
|
|
@@ -898,7 +966,7 @@ void GraphScheduler::LinkDataArrowForDeviceDSActor(AbstractActor *const from_act |
|
|
|
LinkDataArrowForBaseActor(from_actor, to_actor, real_from_kernel_with_output_idx, to_kernel_with_input_idx); |
|
|
|
} |
|
|
|
|
|
|
|
void GraphScheduler::LinkDataArrowForHostDSActor(AbstractActor *const from_actor, KernelActor *const to_actor, |
|
|
|
void GraphScheduler::LinkDataArrowForHostDSActor(AbstractActor *const from_actor, AbstractActor *const to_actor, |
|
|
|
const KernelWithIndex &from_kernel_with_output_idx, |
|
|
|
const KernelWithIndex &to_kernel_with_input_idx, |
|
|
|
const KernelGraphPtr &) { |
|
|
|
@@ -919,7 +987,7 @@ void GraphScheduler::LinkDataArrowForHostDSActor(AbstractActor *const from_actor |
|
|
|
LinkDataArrowForBaseActor(from_actor, to_actor, real_from_kernel_with_output_idx, to_kernel_with_input_idx); |
|
|
|
} |
|
|
|
|
|
|
|
void GraphScheduler::LinkDataArrowForKernelActor(AbstractActor *const from_actor, KernelActor *const to_actor, |
|
|
|
void GraphScheduler::LinkDataArrowForKernelActor(AbstractActor *const from_actor, AbstractActor *const to_actor, |
|
|
|
const KernelWithIndex &from_kernel_with_output_idx, |
|
|
|
const KernelWithIndex &to_kernel_with_input_idx, |
|
|
|
const KernelGraphPtr &) { |
|
|
|
@@ -954,7 +1022,7 @@ void GraphScheduler::LinkDataArrowForKernelActor(AbstractActor *const from_actor |
|
|
|
LinkDataArrowForBaseActor(real_from_actor, to_actor, real_from_kernel_with_output_idx, to_kernel_with_input_idx); |
|
|
|
} |
|
|
|
|
|
|
|
void GraphScheduler::LinkDataArrowForCopyActor(AbstractActor *const from_actor, KernelActor *const to_actor, |
|
|
|
void GraphScheduler::LinkDataArrowForCopyActor(AbstractActor *const from_actor, AbstractActor *const to_actor, |
|
|
|
const KernelWithIndex &from_kernel_with_output_idx, |
|
|
|
const KernelWithIndex &to_kernel_with_input_idx) { |
|
|
|
MS_EXCEPTION_IF_NULL(from_actor); |
|
|
|
@@ -1020,7 +1088,7 @@ void GraphScheduler::LinkDataArrowForCopyActor(AbstractActor *const from_actor, |
|
|
|
UpdateRefCount(copy_actor->output_.get()); |
|
|
|
} |
|
|
|
|
|
|
|
void GraphScheduler::LinkControlArrowByAutoMonad(KernelActor *to_actor, const AnfNodePtr &from_node, |
|
|
|
void GraphScheduler::LinkControlArrowByAutoMonad(AbstractActor *to_actor, const AnfNodePtr &from_node, |
|
|
|
const KernelGraphPtr &graph) { |
|
|
|
MS_EXCEPTION_IF_NULL(to_actor); |
|
|
|
MS_EXCEPTION_IF_NULL(from_node); |
|
|
|
@@ -1098,7 +1166,7 @@ void GraphScheduler::LinkControlArrowByAutoMonad(KernelActor *to_actor, const An |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
void GraphScheduler::LinkControlArrowBySkippedNode(KernelActor *to_actor, const AnfNodePtr &skipped_node) { |
|
|
|
void GraphScheduler::LinkControlArrowBySkippedNode(AbstractActor *to_actor, const AnfNodePtr &skipped_node) { |
|
|
|
MS_EXCEPTION_IF_NULL(to_actor); |
|
|
|
MS_EXCEPTION_IF_NULL(skipped_node); |
|
|
|
auto to_aid = to_actor->GetAID(); |
|
|
|
@@ -1287,6 +1355,11 @@ void GraphScheduler::LinkControlArrowForLoopCountActor(LoopCountActor *loop_coun |
|
|
|
|
|
|
|
// Collect the actors which have no output. |
|
|
|
std::vector<MemoryAwareActor *> no_output_actors; |
|
|
|
for (auto &super_actor : actor_set->super_kernel_actors_) { |
|
|
|
if ((super_actor->output_data_arrows_.size() == 0) && (super_actor->output_control_arrows_.size() == 0)) { |
|
|
|
(void)no_output_actors.emplace_back(super_actor.get()); |
|
|
|
} |
|
|
|
} |
|
|
|
for (auto &kernel_actor : actor_set->kernel_actors_) { |
|
|
|
// The no output kernel control side in subgraph needs to be connected to the corresponding output switch actor. |
|
|
|
if ((kernel_actor->output_data_arrows_.size() == 0) && (kernel_actor->output_control_arrows_.size() == 0) && |
|
|
|
@@ -1380,16 +1453,17 @@ void GraphScheduler::LinkOutputResultArrowForOutputActor(OutputActor *to_actor, |
|
|
|
continue; |
|
|
|
} |
|
|
|
auto op_arrow = std::make_shared<DataArrow>(output_with_index.second, to_actor->GetAID(), output_position); |
|
|
|
auto position = from_actor->FetchNodePosition(output_with_index.first); |
|
|
|
// If the from actor has the multi nodes, then use the real output position. |
|
|
|
if (position != 0) { |
|
|
|
op_arrow->from_output_index_ = SizeToInt(position); |
|
|
|
} |
|
|
|
(void)from_actor->output_result_arrows_.emplace_back(op_arrow); |
|
|
|
(void)from_actor->output_nodes_.emplace_back(output_with_index.first); |
|
|
|
|
|
|
|
// Update the real compute node in the host data source actor. |
|
|
|
if (kernel_type == KernelTransformType::kHostDataSourceActor) { |
|
|
|
auto host_queue_ds_actor = dynamic_cast<HostQueueDataSourceActor *>(from_actor); |
|
|
|
MS_EXCEPTION_IF_NULL(host_queue_ds_actor); |
|
|
|
UpdateRefCount(host_queue_ds_actor->data_nodes_[position], output_with_index.second, true); |
|
|
|
auto position = host_queue_ds_actor->FetchNodePosition(output_with_index.first); |
|
|
|
auto real_node = host_queue_ds_actor->FetchNode(position); |
|
|
|
from_actor->output_nodes_[from_actor->output_nodes_.size() - 1] = real_node; |
|
|
|
UpdateRefCount(real_node, output_with_index.second, true); |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
@@ -1474,6 +1548,15 @@ bool GraphScheduler::CheckActorValid(const ActorSet *actor_set, GraphExecutionSt |
|
|
|
return true; |
|
|
|
} |
|
|
|
|
|
|
|
// Check the super kernel actors. |
|
|
|
for (const auto &super_kernel_actor : actor_set->super_kernel_actors_) { |
|
|
|
MS_EXCEPTION_IF_NULL(super_kernel_actor); |
|
|
|
if (super_kernel_actor->output_data_arrows_.size() + super_kernel_actor->output_control_arrows_.size() == 0) { |
|
|
|
MS_LOG(ERROR) << super_kernel_actor->GetAID().Name() << " has no user."; |
|
|
|
return false; |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
// Check the kernel actors. |
|
|
|
for (const auto &kernel_actor : actor_set->kernel_actors_) { |
|
|
|
MS_EXCEPTION_IF_NULL(kernel_actor); |
|
|
|
@@ -1602,11 +1685,17 @@ void GraphScheduler::FetchKernelTransformTypeAndName(const AnfNodePtr &node, con |
|
|
|
const GraphCompilerInfo &graph_compiler_info, |
|
|
|
KernelTransformType *const kernel_type, |
|
|
|
std::string *const kernel_name) { |
|
|
|
MS_EXCEPTION_IF_NULL(node); |
|
|
|
MS_EXCEPTION_IF_NULL(graph); |
|
|
|
MS_EXCEPTION_IF_NULL(kernel_type); |
|
|
|
MS_EXCEPTION_IF_NULL(kernel_name); |
|
|
|
|
|
|
|
if (graph->is_sink() && ((node == nullptr) || node->isa<CNode>())) { |
|
|
|
*kernel_type = KernelTransformType::kSuperKernelActor; |
|
|
|
*kernel_name = graph->ToString() + "_SuperKernelActor"; |
|
|
|
return; |
|
|
|
} |
|
|
|
|
|
|
|
MS_EXCEPTION_IF_NULL(node); |
|
|
|
if (IsDeviceQueueDSActor(node, graph_compiler_info.strategy_)) { |
|
|
|
*kernel_type = KernelTransformType::kDeviceDataSourceActor; |
|
|
|
*kernel_name = graph_compiler_info.name_ + "_DeviceDSActor" + "_" + std::to_string(graph->graph_id()); |
|
|
|
@@ -1682,9 +1771,14 @@ void GraphScheduler::DumpActor(const ActorSet *actor_set, const GraphCompilerInf |
|
|
|
DumpKernelActor(kernel_actor.get(), ofs); |
|
|
|
} |
|
|
|
|
|
|
|
ofs << "\n\n[Super kernel actors:" << actor_set->super_kernel_actors_.size() << "]\n"; |
|
|
|
for (const auto &super_kernel_actor : actor_set->super_kernel_actors_) { |
|
|
|
DumpSuperKernelActor(super_kernel_actor.get(), ofs); |
|
|
|
} |
|
|
|
|
|
|
|
ofs << "\n\n[No input kernel actors:" << actor_set->no_input_kernel_actors_.size() << "]\n"; |
|
|
|
for (const auto &no_input_kernel_actor : actor_set->no_input_kernel_actors_) { |
|
|
|
DumpKernelActor(no_input_kernel_actor.get(), ofs); |
|
|
|
DumpNoInputKernelActor(no_input_kernel_actor.get(), ofs); |
|
|
|
} |
|
|
|
|
|
|
|
ofs << "\n\n[Copy actors:" << actor_set->copy_actors_.size() << "]\n"; |
|
|
|
@@ -1778,11 +1872,18 @@ void GraphScheduler::DumpAbstractActor(const AbstractActor *actor, std::ofstream |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
if (actor->output_result_arrows_.size() != actor->output_nodes_.size()) { |
|
|
|
MS_LOG(EXCEPTION) << "The size of output result arrows is not equal to the output nodes."; |
|
|
|
} |
|
|
|
if (actor->output_result_arrows_.size() > 0) { |
|
|
|
ofs << "\t\toutput_result_arrows:" << actor->output_result_arrows_.size() << "\n "; |
|
|
|
for (const auto &result_arrow : actor->output_result_arrows_) { |
|
|
|
for (size_t i = 0; i < actor->output_result_arrows_.size(); ++i) { |
|
|
|
auto result_arrow = actor->output_result_arrows_[i]; |
|
|
|
auto output_node = actor->output_nodes_[i]; |
|
|
|
MS_EXCEPTION_IF_NULL(result_arrow); |
|
|
|
ofs << "\t\t\tfrom_output_index:" << result_arrow->from_output_index_ |
|
|
|
MS_EXCEPTION_IF_NULL(output_node); |
|
|
|
ofs << "\t\t\tfrom_output_node:" << output_node->fullname_with_scope() |
|
|
|
<< "tfrom_output_index:" << result_arrow->from_output_index_ |
|
|
|
<< "\tto_actor_name:" << result_arrow->to_op_id_.Name() |
|
|
|
<< "\toutput_node_position:" << result_arrow->to_input_index_ << "\n"; |
|
|
|
} |
|
|
|
@@ -1882,6 +1983,34 @@ void GraphScheduler::DumpKernelActor(const KernelActor *actor, std::ofstream &of |
|
|
|
ofs << "\n"; |
|
|
|
} |
|
|
|
|
|
|
|
void GraphScheduler::DumpSuperKernelActor(const SuperKernelActor *actor, std::ofstream &ofs) const { |
|
|
|
MS_EXCEPTION_IF_NULL(actor); |
|
|
|
ofs << "\tactor_name:" << actor->GetAID().Name() << "\n"; |
|
|
|
|
|
|
|
const auto &graph = actor->graph_; |
|
|
|
MS_EXCEPTION_IF_NULL(graph); |
|
|
|
|
|
|
|
ofs << "\t\tgraph id:" << graph->graph_id() << "\tgraphl_name:" << graph->ToString() |
|
|
|
<< "\tis_sink:" << graph->is_sink() << "\tinputs_num:" << (graph->input_nodes()).size() |
|
|
|
<< "\tkernels_num:" << (graph->execution_order()).size() << "\n"; |
|
|
|
|
|
|
|
DumpAbstractActor(actor, ofs); |
|
|
|
ofs << "\n"; |
|
|
|
} |
|
|
|
|
|
|
|
void GraphScheduler::DumpNoInputKernelActor(const AbstractActor *actor, std::ofstream &ofs) const { |
|
|
|
MS_EXCEPTION_IF_NULL(actor); |
|
|
|
if (actor->type_ == KernelTransformType::kKernelActor) { |
|
|
|
auto kernel_actor = dynamic_cast<const KernelActor *>(actor); |
|
|
|
MS_EXCEPTION_IF_NULL(kernel_actor); |
|
|
|
DumpKernelActor(kernel_actor, ofs); |
|
|
|
} else if (actor->type_ == KernelTransformType::kSuperKernelActor) { |
|
|
|
auto super_kernel_actor = dynamic_cast<const SuperKernelActor *>(actor); |
|
|
|
MS_EXCEPTION_IF_NULL(super_kernel_actor); |
|
|
|
DumpSuperKernelActor(super_kernel_actor, ofs); |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
void GraphScheduler::DumpOutputActor(const OutputActor *actor, std::ofstream &ofs) const { |
|
|
|
MS_EXCEPTION_IF_NULL(actor); |
|
|
|
ofs << "\tactor_name:" << actor->GetAID().Name() << "\tloop_count:" << actor->loop_count_ |
|
|
|
@@ -1906,7 +2035,7 @@ void GraphScheduler::DumpCopyActor(const CopyActor *actor, std::ofstream &ofs) c |
|
|
|
void GraphScheduler::DumpDeviceTensorStore(const GraphCompilerInfo &graph_compiler_info, std::ofstream &ofs) const { |
|
|
|
for (const auto &graph : graph_compiler_info.graphs_) { |
|
|
|
MS_EXCEPTION_IF_NULL(graph); |
|
|
|
ofs << "\tgraph id:" << graph->graph_id() << "\n"; |
|
|
|
ofs << "\tgraph id:" << graph->graph_id() << "\tis_sink:" << graph->is_sink() << "\n"; |
|
|
|
|
|
|
|
for (auto &value_node : graph->graph_value_nodes()) { |
|
|
|
MS_EXCEPTION_IF_NULL(value_node); |
|
|
|
|