From c937a22bda7c3fc18f13eb09603cfe6f9dbfc59b Mon Sep 17 00:00:00 2001 From: limingqi107 Date: Sat, 17 Apr 2021 10:17:33 +0800 Subject: [PATCH] add the actor link by auto monad --- mindspore/ccsrc/pipeline/jit/pipeline.cc | 13 +++ .../runtime/framework/actor/kernel_actor.cc | 1 + .../runtime/framework/graph_scheduler.cc | 86 +++++++++++++++++-- .../ccsrc/runtime/framework/graph_scheduler.h | 3 + mindspore/ccsrc/utils/utils.h | 3 + 5 files changed, 98 insertions(+), 8 deletions(-) diff --git a/mindspore/ccsrc/pipeline/jit/pipeline.cc b/mindspore/ccsrc/pipeline/jit/pipeline.cc index 357f7e16ae..d5fed8ac56 100644 --- a/mindspore/ccsrc/pipeline/jit/pipeline.cc +++ b/mindspore/ccsrc/pipeline/jit/pipeline.cc @@ -1051,6 +1051,19 @@ bool InitExecDatasetVm(const std::string &queue_name, int64_t size, int64_t batc auto backend = compile::CreateBackend(); MS_EXCEPTION_IF_NULL(backend); + // The data set graph compiling and running of mindRT. + if (compile::IsMindRTUsed()) { + ConfigManager::GetInstance().set_iter_num(size); + const auto &mindrt_backend = std::dynamic_pointer_cast(backend); + MS_EXCEPTION_IF_NULL(mindrt_backend); + auto graph_id = mindrt_backend->CompileGraph({app_init}); + VectorRef args; + if (need_run) { + (void)mindrt_backend->RunGraph(graph_id, args); + } + return true; + } + auto convert_fn = backend->convert_fn(); MS_EXCEPTION_IF_NULL(convert_fn); // Convert CNodeList to LinConvertResult. diff --git a/mindspore/ccsrc/runtime/framework/actor/kernel_actor.cc b/mindspore/ccsrc/runtime/framework/actor/kernel_actor.cc index 28e64e89c2..8c82adf6f5 100644 --- a/mindspore/ccsrc/runtime/framework/actor/kernel_actor.cc +++ b/mindspore/ccsrc/runtime/framework/actor/kernel_actor.cc @@ -76,6 +76,7 @@ void KernelActor::OnMemoryAllocFinish(OpContext *context) { } SendOutput(context); FreeMemory(context); + EraseInput(context); } bool KernelActor::CheckLaunchCondition(OpContext *context) const { diff --git a/mindspore/ccsrc/runtime/framework/graph_scheduler.cc b/mindspore/ccsrc/runtime/framework/graph_scheduler.cc index 9b09816c37..5ca03291ec 100644 --- a/mindspore/ccsrc/runtime/framework/graph_scheduler.cc +++ b/mindspore/ccsrc/runtime/framework/graph_scheduler.cc @@ -64,8 +64,7 @@ bool IsPersistentDeviceTensor(const AnfNodePtr &node) { return false; } -KernelActor *FindKernelActor(const std::unordered_map &kernel_actors_map, - const std::string &name) { +KernelActor *FindKernelActor(const KernelMapActor &kernel_actors_map, const std::string &name) { auto iter = kernel_actors_map.find(name); if (iter != kernel_actors_map.end()) { return iter->second.get(); @@ -185,6 +184,16 @@ void PrepareDataForWeightNode(const AnfNodePtr &node, const TensorPtr &tensor, c MS_EXCEPTION_IF_NULL(tensor); const auto &device_tensor = AnfAlgo::GetMutableOutputAddr(node, 0); MS_EXCEPTION_IF_NULL(device_tensor); + const auto &host_tensor_address = std::dynamic_pointer_cast(tensor->device_address()); + // If the host tensor has the device address, it indicates that the device address of host tensor is new. + if (host_tensor_address != nullptr) { + if (host_tensor_address != device_tensor) { + AnfAlgo::SetOutputAddr(host_tensor_address, 0, node.get()); + DeviceTensorStore::GetInstance().Insert(node.get(), host_tensor_address); + } + return; + } + // If the ptr of device tensor is not nullptr, it indicates that the device data has been prepared. if (device_tensor->GetPtr() != nullptr) { return; @@ -382,8 +391,9 @@ void GraphScheduler::PrepareRun(const KernelGraphPtr &graph, const std::vectorPushData(host_tensors); + if (host_tensor_queue != nullptr) { + host_tensor_queue->PushData(host_tensors); + } // 3.Prepare the output tensor of graph. for (const auto &output_node : graph->outputs()) { @@ -472,7 +482,7 @@ ActorSetPtr GraphScheduler::Build(const KernelGraphPtr &graph, const DeviceConte void GraphScheduler::Link(ActorSet *actor_set, const KernelGraphPtr &graph, GraphExecutionStrategy strategy) { MS_EXCEPTION_IF_NULL(actor_set); MS_EXCEPTION_IF_NULL(graph); - std::unordered_map kernel_actors_temp_map; + KernelMapActor kernel_actors_temp_map; for (auto &actor : actor_set->kernel_actors_) { MS_EXCEPTION_IF_NULL(actor); kernel_actors_temp_map.emplace(actor->GetAID().Name(), actor); @@ -488,8 +498,15 @@ void GraphScheduler::Link(ActorSet *actor_set, const KernelGraphPtr &graph, Grap // Link the control arrows of kernel actor. LinkControlArrowForKernelActor(kernel_actor, actor_set->loop_count_actor_.get(), graph, strategy); - for (size_t i = 0; i < AnfAlgo::GetInputTensorNum(kernel); ++i) { - KernelWithIndex from_kernel_with_output_idx = AnfAlgo::GetPrevNodeOutput(kernel, i, true); + for (size_t i = 0; i < AnfAlgo::GetInputNum(kernel); ++i) { + auto input_node = AnfAlgo::GetInputNode(kernel, i); + // Link the control arrows of kernel actor by the auto monad, the inputs include monad node. + LinkControlArrowByAutoMonad(kernel_actor, input_node, kernel_actors_temp_map); + if (HasAbstractMonad(input_node)) { + continue; // No data arrow for monad input. + } + + KernelWithIndex from_kernel_with_output_idx = AnfAlgo::VisitKernelWithReturnType(input_node, 0, true); KernelWithIndex to_kernel_with_input_idx = std::make_pair(kernel, i); auto from_kernel = from_kernel_with_output_idx.first; @@ -583,6 +600,8 @@ std::vector GraphScheduler::BuildNoInputKernelActor(const Kernel MS_EXCEPTION_IF_NULL(kernel_actor); if ((kernel_actor->input_datas_num_ == 0) && (kernel_actor->input_controls_num_ == 0)) { no_input_kernel_actors.emplace_back(kernel_actor); + // The no input kernel actor will be triggered by loop count actor, so need set the input_controls_num_. + kernel_actor->input_controls_num_ = 1; } } return no_input_kernel_actors; @@ -689,6 +708,58 @@ void GraphScheduler::LinkControlArrowForKernelActor(KernelActor *from_actor, Loo } } +void GraphScheduler::LinkControlArrowByAutoMonad(KernelActor *to_actor, const AnfNodePtr &from_node, + const KernelMapActor &kernel_actors_map) { + MS_EXCEPTION_IF_NULL(to_actor); + MS_EXCEPTION_IF_NULL(from_node); + if (!from_node->isa()) { + return; + } + // Find the real input node, include the monad node and make tuple node. + const std::vector &return_types = {prim::kPrimUpdateState, prim::kPrimLoad, prim::kPrimMakeTuple}; + const auto &input_kernel_with_output_idx = AnfAlgo::VisitKernelWithReturnType(from_node, 0, true, return_types); + MS_EXCEPTION_IF_NULL(input_kernel_with_output_idx.first); + if (!input_kernel_with_output_idx.first->isa()) { + return; + } + const auto &input_cnode = input_kernel_with_output_idx.first->cast(); + MS_EXCEPTION_IF_NULL(input_cnode); + + // Get the real depend input by monad node which needs to link the control arrow. + AnfNodePtr real_depend_input = nullptr; + if (AnfAlgo::CheckPrimitiveType(input_cnode, prim::kPrimUpdateState)) { + real_depend_input = input_cnode->input(kUpdateStateRealInput); + } else if (AnfAlgo::CheckPrimitiveType(input_cnode, prim::kPrimLoad)) { + real_depend_input = input_cnode->input(kLoadStateInput); + } else if (AnfAlgo::CheckPrimitiveType(input_cnode, prim::kPrimMakeTuple)) { + // Make tuple node needs to be expanded. + for (size_t i = 1; i < input_cnode->inputs().size(); ++i) { + LinkControlArrowByAutoMonad(to_actor, input_cnode->input(i), kernel_actors_map); + } + return; + } else { + return; + } + + MS_EXCEPTION_IF_NULL(real_depend_input); + if (!real_depend_input->isa()) { + return; + } + // The monad node and make tuple node need recursion. + if (AnfAlgo::CheckPrimitiveType(real_depend_input, prim::kPrimUpdateState) || + AnfAlgo::CheckPrimitiveType(real_depend_input, prim::kPrimLoad) || + AnfAlgo::CheckPrimitiveType(real_depend_input, prim::kPrimMakeTuple)) { + LinkControlArrowByAutoMonad(to_actor, real_depend_input, kernel_actors_map); + return; + } + + // Link the control arrow between the kernel actors. + auto from_actor = FindKernelActor(kernel_actors_map, real_depend_input->fullname_with_scope()); + MS_EXCEPTION_IF_NULL(from_actor); + from_actor->output_op_controls_.emplace_back(to_actor->GetAID()); + to_actor->input_controls_num_++; +} + void GraphScheduler::LinkControlArrowForLoopCountActor(LoopCountActor *loop_count_actor, const KernelGraphPtr &graph) { MS_EXCEPTION_IF_NULL(graph); MS_EXCEPTION_IF_NULL(loop_count_actor); @@ -784,7 +855,6 @@ HostTensorQueue *GraphScheduler::FetchHostQueue(const KernelGraphPtr &graph) con if (iter != graph_to_host_queue_.end()) { return iter->second.get(); } else { - MS_LOG(ERROR) << "Can't find the host tensor queue map of graph: " << graph->ToString(); return nullptr; } } diff --git a/mindspore/ccsrc/runtime/framework/graph_scheduler.h b/mindspore/ccsrc/runtime/framework/graph_scheduler.h index f43d67b12f..4ae1223be2 100644 --- a/mindspore/ccsrc/runtime/framework/graph_scheduler.h +++ b/mindspore/ccsrc/runtime/framework/graph_scheduler.h @@ -34,6 +34,7 @@ namespace mindspore { namespace runtime { using mindspore::device::DeviceContext; using mindspore::session::KernelWithIndex; +using KernelMapActor = std::unordered_map; enum class GraphExecutionStrategy { kPipeline, // The actor running is triggered only by data. @@ -118,6 +119,8 @@ class GraphScheduler { void LinkControlArrowForKernelActor(KernelActor *from_actor, LoopCountActor *to_actor, const KernelGraphPtr &graph, GraphExecutionStrategy strategy); void LinkControlArrowForLoopCountActor(LoopCountActor *loop_count_actor, const KernelGraphPtr &graph); + void LinkControlArrowByAutoMonad(KernelActor *to_actor, const AnfNodePtr &from_node, + const KernelMapActor &kernel_actors_map); // Check whether the actor set is valid. bool CheckActorValid(const ActorSet *actor_set) const; diff --git a/mindspore/ccsrc/utils/utils.h b/mindspore/ccsrc/utils/utils.h index 8c5b6a79cd..5ff9341482 100644 --- a/mindspore/ccsrc/utils/utils.h +++ b/mindspore/ccsrc/utils/utils.h @@ -461,6 +461,9 @@ constexpr auto kDependInputSize = 3; // index define of UpdateState constexpr auto kUpdateStateStateInput = 1; constexpr auto kUpdateStateRealInput = 2; +// index define of Load +constexpr auto kLoadRealInput = 1; +constexpr auto kLoadStateInput = 2; // format constexpr auto kOpFormat_DEFAULT = "DefaultFormat"; constexpr auto kOpFormat_NC1KHKWHWC0 = "NC1KHKWHWC0";