| @@ -397,7 +397,7 @@ std::shared_ptr<std::vector<std::pair<AnfNodePtr, int>>> GetRealNodeUsedList(con | |||||
| MS_EXCEPTION_IF_NULL(manager); | MS_EXCEPTION_IF_NULL(manager); | ||||
| auto iter = manager->node_users().find(node); | auto iter = manager->node_users().find(node); | ||||
| if (iter == manager->node_users().end()) { | if (iter == manager->node_users().end()) { | ||||
| MS_LOG(EXCEPTION) << "node has no output in manager"; | |||||
| return output_node_list; | |||||
| } | } | ||||
| auto output_info_list = iter->second; | auto output_info_list = iter->second; | ||||
| for (const auto &output_info : output_info_list) { | for (const auto &output_info : output_info_list) { | ||||
| @@ -469,7 +469,8 @@ bool IsNotRealUsedByOthers(const FuncGraphPtr &graph, const AnfNodePtr &node) { | |||||
| auto out_node = output.first; | auto out_node = output.first; | ||||
| auto name = AnfAlgo::GetCNodeName(out_node); | auto name = AnfAlgo::GetCNodeName(out_node); | ||||
| if (name == prim::kPrimDepend->name() || name == prim::kPrimMakeTuple->name() || | if (name == prim::kPrimDepend->name() || name == prim::kPrimMakeTuple->name() || | ||||
| name == prim::kPrimTupleGetItem->name() || name == prim::kPrimLoad->name()) { | |||||
| name == prim::kPrimTupleGetItem->name() || name == prim::kPrimLoad->name() || | |||||
| name == prim::kPrimReturn->name()) { | |||||
| auto result = IsNotRealUsedByOthers(graph, out_node); | auto result = IsNotRealUsedByOthers(graph, out_node); | ||||
| if (!result) { | if (!result) { | ||||
| return result; | return result; | ||||
| @@ -757,6 +757,13 @@ AnfNodePtr KernelGraph::GetBackendAnfByFrontAnf(const AnfNodePtr &front_anf) { | |||||
| return front_backend_anf_map_[front_anf]; | return front_backend_anf_map_[front_anf]; | ||||
| } | } | ||||
| AnfNodePtr KernelGraph::GetFrontAnfByBackendAnf(const AnfNodePtr &backend_anf) { | |||||
| if (backend_front_anf_map_.find(backend_anf) == backend_front_anf_map_.end()) { | |||||
| return nullptr; | |||||
| } | |||||
| return backend_front_anf_map_[backend_anf]; | |||||
| } | |||||
| bool KernelGraph::BackendNodeExistInFrontBackendMap(const AnfNodePtr &backend_anf) { | bool KernelGraph::BackendNodeExistInFrontBackendMap(const AnfNodePtr &backend_anf) { | ||||
| return backend_front_anf_map_.find(backend_anf) != backend_front_anf_map_.end(); | return backend_front_anf_map_.find(backend_anf) != backend_front_anf_map_.end(); | ||||
| } | } | ||||
| @@ -122,6 +122,8 @@ class KernelGraph : public FuncGraph { | |||||
| void FrontBackendlMapUpdate(const AnfNodePtr &old_backend_anf, const AnfNodePtr &new_backend_anf); | void FrontBackendlMapUpdate(const AnfNodePtr &old_backend_anf, const AnfNodePtr &new_backend_anf); | ||||
| // get backend anf by front anf | // get backend anf by front anf | ||||
| AnfNodePtr GetBackendAnfByFrontAnf(const AnfNodePtr &front_anf); | AnfNodePtr GetBackendAnfByFrontAnf(const AnfNodePtr &front_anf); | ||||
| // get front anf by backend anf | |||||
| AnfNodePtr GetFrontAnfByBackendAnf(const AnfNodePtr &backend_anf); | |||||
| // check backend node whether exist in map | // check backend node whether exist in map | ||||
| bool BackendNodeExistInFrontBackendMap(const AnfNodePtr &backend_anf); | bool BackendNodeExistInFrontBackendMap(const AnfNodePtr &backend_anf); | ||||
| // get value node by tensor | // get value node by tensor | ||||
| @@ -0,0 +1,38 @@ | |||||
| /** | |||||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "runtime/framework/actor/actor_common.h" | |||||
| #include <unistd.h> | |||||
| #ifdef __WIN32__ | |||||
| #include <windows.h> | |||||
| #endif | |||||
| namespace mindspore { | |||||
| namespace runtime { | |||||
| int64_t GetMaxThreadNum() { | |||||
| #ifdef __WIN32__ | |||||
| SYSTEM_INFO sys_info; | |||||
| GetSystemInfo(&sys_info); | |||||
| auto max_thread_num = sys_info.dwNumberOfProcessors; | |||||
| #else | |||||
| auto max_thread_num = sysconf(_SC_NPROCESSORS_ONLN); | |||||
| #endif | |||||
| return max_thread_num; | |||||
| } | |||||
| } // namespace runtime | |||||
| } // namespace mindspore | |||||
| @@ -40,6 +40,9 @@ constexpr int kFailure = 1; | |||||
| return; \ | return; \ | ||||
| } | } | ||||
| // Get the max available thread number of system. | |||||
| int64_t GetMaxThreadNum(); | |||||
| } // namespace runtime | } // namespace runtime | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -24,6 +24,7 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace runtime { | namespace runtime { | ||||
| void DataSourceActor::FetchData(OpContext<DeviceTensor> *context) { | void DataSourceActor::FetchData(OpContext<DeviceTensor> *context) { | ||||
| MS_LOG(INFO) << "Data source actor(" << GetAID().Name() << ") fetches data."; | |||||
| MS_EXCEPTION_IF_NULL(context); | MS_EXCEPTION_IF_NULL(context); | ||||
| if (buffers_.size() == buffer_capacity_) { | if (buffers_.size() == buffer_capacity_) { | ||||
| // Send output to trigger computing and free memory. | // Send output to trigger computing and free memory. | ||||
| @@ -54,6 +55,7 @@ void DataSourceActor::FreeMemory(OpContext<DeviceTensor> *context) { | |||||
| } | } | ||||
| void DataSourceActor::SendOutput(OpContext<DeviceTensor> *context) { | void DataSourceActor::SendOutput(OpContext<DeviceTensor> *context) { | ||||
| MS_LOG(INFO) << "Data source actor(" << GetAID().Name() << ") sends output data."; | |||||
| MS_EXCEPTION_IF_NULL(context); | MS_EXCEPTION_IF_NULL(context); | ||||
| if (buffers_.size() == 0) { | if (buffers_.size() == 0) { | ||||
| SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), "The data queue is empty."); | SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), "The data queue is empty."); | ||||
| @@ -57,6 +57,8 @@ class DataSourceActor : public MemoryInterfaceActor { | |||||
| void OnMemoryAllocFinish(OpContext<DeviceTensor> *context) override{}; | void OnMemoryAllocFinish(OpContext<DeviceTensor> *context) override{}; | ||||
| protected: | protected: | ||||
| friend class GraphScheduler; | |||||
| // Construct the device tensors and fill to device tensor buffer from the member nodes during the data fetching. | // Construct the device tensors and fill to device tensor buffer from the member nodes during the data fetching. | ||||
| virtual void FillDataBuffer() = 0; | virtual void FillDataBuffer() = 0; | ||||
| @@ -78,7 +78,7 @@ void KernelActor::OnMemoryAllocFinish(OpContext<DeviceTensor> *context) { | |||||
| FreeMemory(context); | FreeMemory(context); | ||||
| } | } | ||||
| bool KernelActor::CheckLaunchCondition(OpContext<DeviceTensor> *context) { | |||||
| bool KernelActor::CheckLaunchCondition(OpContext<DeviceTensor> *context) const { | |||||
| MS_EXCEPTION_IF_NULL(context); | MS_EXCEPTION_IF_NULL(context); | ||||
| if (input_datas_num_ != 0) { | if (input_datas_num_ != 0) { | ||||
| auto data_iter = input_op_datas_.find(context->sequential_num_); | auto data_iter = input_op_datas_.find(context->sequential_num_); | ||||
| @@ -136,16 +136,14 @@ void KernelActor::FetchWorkspaceDeviceTensor() { | |||||
| MS_EXCEPTION_IF_NULL(kernel_mod); | MS_EXCEPTION_IF_NULL(kernel_mod); | ||||
| auto workspace_sizes = kernel_mod->GetWorkspaceSizeList(); | auto workspace_sizes = kernel_mod->GetWorkspaceSizeList(); | ||||
| for (size_t i = 0; i < workspace_sizes.size(); ++i) { | for (size_t i = 0; i < workspace_sizes.size(); ++i) { | ||||
| if (workspace_sizes[i] != 0) { | |||||
| auto device_address = AnfAlgo::GetMutableWorkspaceAddr(kernel_, i); | |||||
| MS_EXCEPTION_IF_NULL(device_address); | |||||
| workspace_device_tensors_.emplace_back(device_address.get()); | |||||
| } | |||||
| auto device_address = AnfAlgo::GetMutableWorkspaceAddr(kernel_, i); | |||||
| MS_EXCEPTION_IF_NULL(device_address); | |||||
| workspace_device_tensors_.emplace_back(device_address.get()); | |||||
| } | } | ||||
| } | } | ||||
| void KernelActor::FetchLaunchArgs(std::vector<AddressPtr> *kernel_inputs, std::vector<AddressPtr> *kernel_outputs, | void KernelActor::FetchLaunchArgs(std::vector<AddressPtr> *kernel_inputs, std::vector<AddressPtr> *kernel_outputs, | ||||
| std::vector<AddressPtr> *kernel_workspaces) { | |||||
| std::vector<AddressPtr> *kernel_workspaces) const { | |||||
| MS_EXCEPTION_IF_NULL(kernel_inputs); | MS_EXCEPTION_IF_NULL(kernel_inputs); | ||||
| MS_EXCEPTION_IF_NULL(kernel_outputs); | MS_EXCEPTION_IF_NULL(kernel_outputs); | ||||
| MS_EXCEPTION_IF_NULL(kernel_workspaces); | MS_EXCEPTION_IF_NULL(kernel_workspaces); | ||||
| @@ -165,7 +163,7 @@ void KernelActor::FetchLaunchArgs(std::vector<AddressPtr> *kernel_inputs, std::v | |||||
| } | } | ||||
| } | } | ||||
| void KernelActor::SendOutput(OpContext<DeviceTensor> *context) { | |||||
| void KernelActor::SendOutput(OpContext<DeviceTensor> *context) const { | |||||
| MS_EXCEPTION_IF_NULL(context); | MS_EXCEPTION_IF_NULL(context); | ||||
| // Send output data. | // Send output data. | ||||
| for (auto &op_arrow : output_op_arrows_) { | for (auto &op_arrow : output_op_arrows_) { | ||||
| @@ -186,5 +184,16 @@ void KernelActor::SendOutput(OpContext<DeviceTensor> *context) { | |||||
| } | } | ||||
| } | } | ||||
| void KernelActor::EraseInput(OpContext<DeviceTensor> *context) { | |||||
| MS_EXCEPTION_IF_NULL(context); | |||||
| if (input_datas_num_ != 0) { | |||||
| (void)input_op_datas_.erase(context->sequential_num_); | |||||
| } | |||||
| if (input_controls_num_ != 0) { | |||||
| (void)input_op_controls_.erase(context->sequential_num_); | |||||
| } | |||||
| } | |||||
| } // namespace runtime | } // namespace runtime | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -64,12 +64,14 @@ class KernelActor : public MemoryInterfaceActor { | |||||
| friend class GraphScheduler; | friend class GraphScheduler; | ||||
| // Check whether satisfy the condition for launch. | // Check whether satisfy the condition for launch. | ||||
| bool CheckLaunchCondition(OpContext<DeviceTensor> *context); | |||||
| bool CheckLaunchCondition(OpContext<DeviceTensor> *context) const; | |||||
| // Fetch the args of kernel launch. | // Fetch the args of kernel launch. | ||||
| void FetchLaunchArgs(std::vector<AddressPtr> *kernel_inputs, std::vector<AddressPtr> *kernel_outputs, | void FetchLaunchArgs(std::vector<AddressPtr> *kernel_inputs, std::vector<AddressPtr> *kernel_outputs, | ||||
| std::vector<AddressPtr> *kernel_workspaces); | |||||
| std::vector<AddressPtr> *kernel_workspaces) const; | |||||
| // Send output data and output controls when finish kernel launch. | // Send output data and output controls when finish kernel launch. | ||||
| void SendOutput(OpContext<DeviceTensor> *context); | |||||
| void SendOutput(OpContext<DeviceTensor> *context) const; | |||||
| // Erase input data and input controls when finish kernel launch. | |||||
| void EraseInput(OpContext<DeviceTensor> *context); | |||||
| // Fetch the device tensor for launch. | // Fetch the device tensor for launch. | ||||
| void FetchInputDeviceTensor(OpContext<DeviceTensor> *context); | void FetchInputDeviceTensor(OpContext<DeviceTensor> *context); | ||||
| @@ -28,6 +28,9 @@ void LoopCountActor::RunOpControl(AID *input_control, OpContext<DeviceTensor> *c | |||||
| input_op_controls_[sequential_num].emplace_back(input_control); | input_op_controls_[sequential_num].emplace_back(input_control); | ||||
| if (input_op_controls_[sequential_num].size() == input_controls_num_) { | if (input_op_controls_[sequential_num].size() == input_controls_num_) { | ||||
| current_count_++; | current_count_++; | ||||
| (void)input_op_controls_.erase(sequential_num); | |||||
| MS_LOG(INFO) << "Loop count actor(" << GetAID().Name() << ") runs op control, loop count: " << loop_count_ | |||||
| << ", current count: " << current_count_; | |||||
| if (current_count_ == loop_count_) { | if (current_count_ == loop_count_) { | ||||
| current_count_ = 0; | current_count_ = 0; | ||||
| SET_OPCONTEXT_SUCCESS_RET((*context)); | SET_OPCONTEXT_SUCCESS_RET((*context)); | ||||
| @@ -29,7 +29,7 @@ void MemoryManagerActor::AllocateMemory(std::vector<DeviceTensor *> alloc_list, | |||||
| for (auto &device_tensor : alloc_list) { | for (auto &device_tensor : alloc_list) { | ||||
| MS_EXCEPTION_IF_NULL(device_tensor); | MS_EXCEPTION_IF_NULL(device_tensor); | ||||
| if (device_tensor->GetPtr() != nullptr) { | |||||
| if ((device_tensor->GetPtr() != nullptr) || (device_tensor->GetSize() == 0)) { | |||||
| continue; | continue; | ||||
| } | } | ||||
| // Allocate memory through the device context. | // Allocate memory through the device context. | ||||
| @@ -53,7 +53,9 @@ void MemoryManagerActor::FreeMemory(std::vector<DeviceTensor *> free_list, const | |||||
| device_tensor->DecreaseRefCountUsed(); | device_tensor->DecreaseRefCountUsed(); | ||||
| if (device_tensor->ref_count_dynamic_used() == 0) { | if (device_tensor->ref_count_dynamic_used() == 0) { | ||||
| // Free memory through the device context. | // Free memory through the device context. | ||||
| device_context->FreeMemory(device_tensor); | |||||
| if (device_tensor->GetPtr() != nullptr) { | |||||
| device_context->FreeMemory(device_tensor); | |||||
| } | |||||
| device_tensor->ResetRefCountUsed(); | device_tensor->ResetRefCountUsed(); | ||||
| } | } | ||||
| } | } | ||||
| @@ -209,6 +209,7 @@ GraphId GraphCompiler::CompileGraph(const AnfNodePtrList &nodes, const AnfNodePt | |||||
| } | } | ||||
| GraphId GraphCompiler::CompileGraphImpl(const KernelGraphPtr &graph) const { | GraphId GraphCompiler::CompileGraphImpl(const KernelGraphPtr &graph) const { | ||||
| MS_EXCEPTION_IF_NULL(graph); | |||||
| MS_EXCEPTION_IF_NULL(device_context_); | MS_EXCEPTION_IF_NULL(device_context_); | ||||
| // Optimization pass which is irrelevant to device type or format. | // Optimization pass which is irrelevant to device type or format. | ||||
| device_context_->OptimizeGraphWithoutDeviceInfo(graph); | device_context_->OptimizeGraphWithoutDeviceInfo(graph); | ||||
| @@ -224,8 +225,11 @@ GraphId GraphCompiler::CompileGraphImpl(const KernelGraphPtr &graph) const { | |||||
| // Create device address for all anf nodes of graph. | // Create device address for all anf nodes of graph. | ||||
| CreateDeviceAddress(graph); | CreateDeviceAddress(graph); | ||||
| // Transform graph to actor DAG, contains build and link. | // Transform graph to actor DAG, contains build and link. | ||||
| GraphScheduler::GetInstance().Transform(graph, device_context_); | |||||
| const auto &actor_set = GraphScheduler::GetInstance().Transform(graph, device_context_); | |||||
| GraphScheduler::GetInstance().Schedule(actor_set); | |||||
| return graph->graph_id(); | return graph->graph_id(); | ||||
| } | } | ||||
| @@ -243,6 +243,8 @@ BaseRef CreateOutputTensor(const session::KernelWithIndex &node_output_pair, con | |||||
| const auto &device_tensor = AnfAlgo::GetMutableOutputAddr(node, output_index); | const auto &device_tensor = AnfAlgo::GetMutableOutputAddr(node, output_index); | ||||
| MS_EXCEPTION_IF_NULL(device_tensor); | MS_EXCEPTION_IF_NULL(device_tensor); | ||||
| tensor->set_device_address(device_tensor); | tensor->set_device_address(device_tensor); | ||||
| device_tensor->set_ref_count(SIZE_MAX); | |||||
| device_tensor->ResetRefCountUsed(); | |||||
| return tensor; | return tensor; | ||||
| } | } | ||||
| } | } | ||||
| @@ -280,24 +282,40 @@ void GraphScheduler::Initialize() { | |||||
| } | } | ||||
| init_ = true; | init_ = true; | ||||
| auto actorMgr = ActorMgr::GetActorMgrRef(); | |||||
| MS_EXCEPTION_IF_NULL(actorMgr); | |||||
| // Create the thread pool of actor runtime. | |||||
| auto max_thread_num = GetMaxThreadNum(); | |||||
| MS_LOG(INFO) << "Max available thread number: " << max_thread_num; | |||||
| actorMgr->Initialize(max_thread_num); | |||||
| // Create memory manager actor. | // Create memory manager actor. | ||||
| auto memory_manager_actor = std::make_shared<MemoryManagerActor>(); | auto memory_manager_actor = std::make_shared<MemoryManagerActor>(); | ||||
| MS_EXCEPTION_IF_NULL(memory_manager_actor); | MS_EXCEPTION_IF_NULL(memory_manager_actor); | ||||
| memory_manager_aid_ = memory_manager_actor->GetAID(); | memory_manager_aid_ = memory_manager_actor->GetAID(); | ||||
| // Schedule memory manager actor, bind single thread to response to memory alloc and free quickly. | // Schedule memory manager actor, bind single thread to response to memory alloc and free quickly. | ||||
| auto base_actor = static_cast<ActorReference>(memory_manager_actor); | auto base_actor = static_cast<ActorReference>(memory_manager_actor); | ||||
| auto actorMgr = ActorMgr::GetActorMgrRef(); | |||||
| MS_EXCEPTION_IF_NULL(actorMgr); | |||||
| (void)actorMgr->Spawn(base_actor, false); | (void)actorMgr->Spawn(base_actor, false); | ||||
| } | } | ||||
| ActorSet *GraphScheduler::Transform(const KernelGraphPtr &graph, const DeviceContext *device_context, | ActorSet *GraphScheduler::Transform(const KernelGraphPtr &graph, const DeviceContext *device_context, | ||||
| const std::vector<tensor::TensorPtr> *input_tensors, | const std::vector<tensor::TensorPtr> *input_tensors, | ||||
| GraphExecutionStrategy strategy) { | GraphExecutionStrategy strategy) { | ||||
| MS_EXCEPTION_IF_NULL(graph); | |||||
| MS_LOG(INFO) << "Graph(" << graph->ToString() << ") transforms actor begin."; | |||||
| Initialize(); | |||||
| PersistDeviceTensor(graph); | PersistDeviceTensor(graph); | ||||
| auto actor_set = Build(graph, device_context); | auto actor_set = Build(graph, device_context); | ||||
| graph_to_actors_.emplace(graph, actor_set); | graph_to_actors_.emplace(graph, actor_set); | ||||
| Link(actor_set.get(), graph, strategy); | Link(actor_set.get(), graph, strategy); | ||||
| if (!CheckActorValid(actor_set.get())) { | |||||
| MS_LOG(EXCEPTION) << "The actor set of " << graph->ToString() << " is invalid."; | |||||
| } | |||||
| MS_LOG(INFO) << "Graph(" << graph->ToString() << ") transforms actor end."; | |||||
| return actor_set.get(); | return actor_set.get(); | ||||
| } | } | ||||
| @@ -327,16 +345,23 @@ void GraphScheduler::Schedule(const ActorSet *actor_set) { | |||||
| } | } | ||||
| } | } | ||||
| void GraphScheduler::PrepareRun(const KernelGraphPtr &graph, const DeviceContext *device_context, | |||||
| const std::vector<TensorPtr> *input_tensors, VectorRef *const &outputs) { | |||||
| void GraphScheduler::PrepareRun(const KernelGraphPtr &graph, const std::vector<TensorPtr> *input_tensors, | |||||
| VectorRef *const &outputs) { | |||||
| MS_EXCEPTION_IF_NULL(graph); | MS_EXCEPTION_IF_NULL(graph); | ||||
| MS_EXCEPTION_IF_NULL(device_context); | |||||
| MS_EXCEPTION_IF_NULL(input_tensors); | MS_EXCEPTION_IF_NULL(input_tensors); | ||||
| MS_EXCEPTION_IF_NULL(outputs); | MS_EXCEPTION_IF_NULL(outputs); | ||||
| // Get the device context for the first kernel actor. | |||||
| const auto &actor_set = Fetch(graph); | |||||
| MS_EXCEPTION_IF_NULL(actor_set); | |||||
| const auto &first_kernel_actor = actor_set->kernel_actors_[0]; | |||||
| MS_EXCEPTION_IF_NULL(first_kernel_actor); | |||||
| const auto &device_context = first_kernel_actor->device_context_; | |||||
| // 1.Prepare the data of device tensor store(value nodes of graph). | // 1.Prepare the data of device tensor store(value nodes of graph). | ||||
| for (const auto &value_node : graph->graph_value_nodes()) { | for (const auto &value_node : graph->graph_value_nodes()) { | ||||
| PrepareDataForValueNode(value_node, device_context); | |||||
| if (AnfAlgo::OutputAddrExist(value_node, 0)) { | |||||
| PrepareDataForValueNode(value_node, device_context); | |||||
| } | |||||
| } | } | ||||
| // 1.Prepare the data of device tensor store(weights of graph), and fill the host tensors for non weighted parameters. | // 1.Prepare the data of device tensor store(weights of graph), and fill the host tensors for non weighted parameters. | ||||
| @@ -372,10 +397,10 @@ bool GraphScheduler::Run(const ActorSet *actor_set, GraphExecutionStrategy strat | |||||
| MS_EXCEPTION_IF_NULL(actor_set); | MS_EXCEPTION_IF_NULL(actor_set); | ||||
| // Construct OpContext. | // Construct OpContext. | ||||
| OpContext<DeviceTensor> op_context; | OpContext<DeviceTensor> op_context; | ||||
| auto sequential_num = uuids::RandomBasedGenerator::GenerateRandomUuid(); | |||||
| uuids::uuid sequential_num; | |||||
| std::vector<Promise<int>> result(1); | |||||
| op_context.sequential_num_ = &sequential_num; | op_context.sequential_num_ = &sequential_num; | ||||
| Promise<int> result; | |||||
| op_context.results_->push_back(result); | |||||
| op_context.results_ = &result; | |||||
| // Trigger no input kernel actor running. | // Trigger no input kernel actor running. | ||||
| for (auto &no_input_kernel_actor : actor_set->no_input_kernel_actors_) { | for (auto &no_input_kernel_actor : actor_set->no_input_kernel_actors_) { | ||||
| @@ -398,11 +423,22 @@ bool GraphScheduler::Run(const ActorSet *actor_set, GraphExecutionStrategy strat | |||||
| } | } | ||||
| // Get the run result. | // Get the run result. | ||||
| auto result_future = result.GetFuture(); | |||||
| auto result_future = result[0].GetFuture(); | |||||
| result_future.Wait(); | result_future.Wait(); | ||||
| if (!result_future.IsOK()) { | if (!result_future.IsOK()) { | ||||
| return false; | return false; | ||||
| } | } | ||||
| // Sync device stream. | |||||
| const auto &first_kernel_actor = actor_set->kernel_actors_[0]; | |||||
| MS_EXCEPTION_IF_NULL(first_kernel_actor); | |||||
| const auto &device_context = first_kernel_actor->device_context_; | |||||
| MS_EXCEPTION_IF_NULL(device_context); | |||||
| if (!device_context->SyncStream()) { | |||||
| MS_LOG(ERROR) << "Sync stream failed."; | |||||
| return false; | |||||
| } | |||||
| return true; | return true; | ||||
| } | } | ||||
| @@ -453,7 +489,7 @@ void GraphScheduler::Link(ActorSet *actor_set, const KernelGraphPtr &graph, Grap | |||||
| LinkControlArrowForKernelActor(kernel_actor, actor_set->loop_count_actor_.get(), graph, strategy); | LinkControlArrowForKernelActor(kernel_actor, actor_set->loop_count_actor_.get(), graph, strategy); | ||||
| for (size_t i = 0; i < AnfAlgo::GetInputTensorNum(kernel); ++i) { | for (size_t i = 0; i < AnfAlgo::GetInputTensorNum(kernel); ++i) { | ||||
| KernelWithIndex from_kernel_with_output_idx = AnfAlgo::GetPrevNodeOutput(kernel, i); | |||||
| KernelWithIndex from_kernel_with_output_idx = AnfAlgo::GetPrevNodeOutput(kernel, i, true); | |||||
| KernelWithIndex to_kernel_with_input_idx = std::make_pair(kernel, i); | KernelWithIndex to_kernel_with_input_idx = std::make_pair(kernel, i); | ||||
| auto from_kernel = from_kernel_with_output_idx.first; | auto from_kernel = from_kernel_with_output_idx.first; | ||||
| @@ -557,6 +593,7 @@ LoopCountActorPtr GraphScheduler::BuildLoopCountActor(const KernelGraphPtr &grap | |||||
| auto loop_count = ConfigManager::GetInstance().iter_num(); | auto loop_count = ConfigManager::GetInstance().iter_num(); | ||||
| auto actor_name = graph->ToString() + "_" + "LoopCountActor"; | auto actor_name = graph->ToString() + "_" + "LoopCountActor"; | ||||
| auto loop_count_actor = std::make_shared<LoopCountActor>(actor_name, loop_count); | auto loop_count_actor = std::make_shared<LoopCountActor>(actor_name, loop_count); | ||||
| MS_LOG(INFO) << "Create loop count actor: " << actor_name; | |||||
| MS_EXCEPTION_IF_NULL(loop_count_actor); | MS_EXCEPTION_IF_NULL(loop_count_actor); | ||||
| return loop_count_actor; | return loop_count_actor; | ||||
| } | } | ||||
| @@ -640,7 +677,12 @@ void GraphScheduler::LinkControlArrowForKernelActor(KernelActor *from_actor, Loo | |||||
| from_actor->input_controls_num_++; | from_actor->input_controls_num_++; | ||||
| } | } | ||||
| // The manager of graph member is weak ptr, so need created and used in the function IsNotRealUsedByOthers. | |||||
| const auto &manager = Manage(graph, true); | |||||
| MS_EXCEPTION_IF_NULL(manager); | |||||
| if (opt::IsNotRealUsedByOthers(graph, from_actor->kernel_)) { | if (opt::IsNotRealUsedByOthers(graph, from_actor->kernel_)) { | ||||
| MS_EXCEPTION_IF_NULL(from_actor->kernel_); | |||||
| MS_LOG(INFO) << from_actor->kernel_->fullname_with_scope() << " is not real used by other nodes."; | |||||
| auto to_aid = to_actor->GetAID(); | auto to_aid = to_actor->GetAID(); | ||||
| from_actor->output_op_controls_.emplace_back(to_aid); | from_actor->output_op_controls_.emplace_back(to_aid); | ||||
| to_actor->input_controls_num_++; | to_actor->input_controls_num_++; | ||||
| @@ -667,11 +709,57 @@ void GraphScheduler::LinkControlArrowForLoopCountActor(LoopCountActor *loop_coun | |||||
| } | } | ||||
| } | } | ||||
| bool GraphScheduler::CheckActorValid(const ActorSet *actor_set) const { | |||||
| MS_EXCEPTION_IF_NULL(actor_set); | |||||
| // Check the data source actors. | |||||
| for (const auto &data_source_actor : actor_set->data_source_actors_) { | |||||
| MS_EXCEPTION_IF_NULL(data_source_actor); | |||||
| if (data_source_actor->output_op_arrows_.size() == 0) { | |||||
| MS_LOG(ERROR) << data_source_actor->GetAID().Name() << " has no user."; | |||||
| return false; | |||||
| } | |||||
| } | |||||
| // Check the kernel actors. | |||||
| for (const auto &kernel_actor : actor_set->kernel_actors_) { | |||||
| MS_EXCEPTION_IF_NULL(kernel_actor); | |||||
| if (kernel_actor->output_op_arrows_.size() + kernel_actor->output_op_controls_.size() == 0) { | |||||
| MS_LOG(ERROR) << kernel_actor->GetAID().Name() << " has no user."; | |||||
| return false; | |||||
| } | |||||
| auto input_num = AnfAlgo::GetInputTensorNum(kernel_actor->kernel_); | |||||
| auto input_data_num = kernel_actor->input_datas_num_; | |||||
| auto device_tensor_store_num = kernel_actor->device_tensor_store_keys_.size(); | |||||
| if (input_data_num + device_tensor_store_num != input_num) { | |||||
| MS_LOG(ERROR) << "The input building of " << kernel_actor->GetAID().Name() | |||||
| << " is wrong, input data num: " << input_data_num | |||||
| << ", device tensor store num: " << device_tensor_store_num << ", total input num: " << input_num; | |||||
| return false; | |||||
| } | |||||
| } | |||||
| // Check the loop count actor. | |||||
| const auto &loop_count_actor = actor_set->loop_count_actor_; | |||||
| if (loop_count_actor != nullptr) { | |||||
| if (loop_count_actor->input_controls_num_ == 0) { | |||||
| MS_LOG(ERROR) << loop_count_actor->GetAID().Name() << " has no source."; | |||||
| return false; | |||||
| } | |||||
| } | |||||
| return true; | |||||
| } | |||||
| void GraphScheduler::PersistDeviceTensor(const KernelGraphPtr &graph) { | void GraphScheduler::PersistDeviceTensor(const KernelGraphPtr &graph) { | ||||
| MS_EXCEPTION_IF_NULL(graph); | MS_EXCEPTION_IF_NULL(graph); | ||||
| for (auto &value_node : graph->graph_value_nodes()) { | for (auto &value_node : graph->graph_value_nodes()) { | ||||
| MS_EXCEPTION_IF_NULL(value_node); | MS_EXCEPTION_IF_NULL(value_node); | ||||
| if (!AnfAlgo::OutputAddrExist(value_node, 0)) { | |||||
| MS_LOG(INFO) << "The device address is not exist: " << value_node->ToString(); | |||||
| continue; | |||||
| } | |||||
| auto device_tensor = AnfAlgo::GetMutableOutputAddr(value_node, 0); | auto device_tensor = AnfAlgo::GetMutableOutputAddr(value_node, 0); | ||||
| DeviceTensorStore::GetInstance().Insert(value_node.get(), device_tensor); | DeviceTensorStore::GetInstance().Insert(value_node.get(), device_tensor); | ||||
| device_tensor->set_ref_count(SIZE_MAX); | device_tensor->set_ref_count(SIZE_MAX); | ||||
| @@ -682,6 +770,7 @@ void GraphScheduler::PersistDeviceTensor(const KernelGraphPtr &graph) { | |||||
| MS_EXCEPTION_IF_NULL(input_node); | MS_EXCEPTION_IF_NULL(input_node); | ||||
| if (IsPersistentDeviceTensor(input_node)) { | if (IsPersistentDeviceTensor(input_node)) { | ||||
| auto device_tensor = AnfAlgo::GetMutableOutputAddr(input_node, 0); | auto device_tensor = AnfAlgo::GetMutableOutputAddr(input_node, 0); | ||||
| MS_EXCEPTION_IF_NULL(device_tensor); | |||||
| DeviceTensorStore::GetInstance().Insert(input_node.get(), device_tensor); | DeviceTensorStore::GetInstance().Insert(input_node.get(), device_tensor); | ||||
| device_tensor->set_ref_count(SIZE_MAX); | device_tensor->set_ref_count(SIZE_MAX); | ||||
| device_tensor->ResetRefCountUsed(); | device_tensor->ResetRefCountUsed(); | ||||
| @@ -700,5 +789,144 @@ HostTensorQueue *GraphScheduler::FetchHostQueue(const KernelGraphPtr &graph) con | |||||
| } | } | ||||
| } | } | ||||
| void GraphScheduler::DumpActor(const KernelGraphPtr &graph) const { | |||||
| MS_EXCEPTION_IF_NULL(graph); | |||||
| const auto &actor_set = Fetch(graph); | |||||
| MS_EXCEPTION_IF_NULL(actor_set); | |||||
| std::string filename = "./actor_set_" + graph->ToString() + ".ir"; | |||||
| std::ofstream ofs(filename); | |||||
| if (!ofs.is_open()) { | |||||
| MS_LOG(ERROR) << "Open file [" << filename << "] failed!"; | |||||
| return; | |||||
| } | |||||
| ofs << "[Data source actors]\n"; | |||||
| for (const auto &data_source_actor : actor_set->data_source_actors_) { | |||||
| DumpDSActor(data_source_actor.get(), ofs); | |||||
| ofs << "\n"; | |||||
| } | |||||
| ofs << "\n[Kernel actors]\n"; | |||||
| for (const auto &kernel_actor : actor_set->kernel_actors_) { | |||||
| DumpKernelActor(kernel_actor.get(), ofs); | |||||
| ofs << "\n"; | |||||
| } | |||||
| ofs << "\n[No input kernel actors]\n"; | |||||
| for (const auto &no_input_kernel_actor : actor_set->no_input_kernel_actors_) { | |||||
| DumpKernelActor(no_input_kernel_actor.get(), ofs); | |||||
| ofs << "\n"; | |||||
| } | |||||
| ofs << "\n[Loop count actor]\n"; | |||||
| const auto &loop_count_actor = actor_set->loop_count_actor_; | |||||
| if (loop_count_actor != nullptr) { | |||||
| DumpLoopCountActor(loop_count_actor.get(), ofs); | |||||
| ofs << "\n"; | |||||
| } | |||||
| } | |||||
| void GraphScheduler::DumpDSActor(const DataSourceActor *actor, std::ofstream &ofs) const { | |||||
| MS_EXCEPTION_IF_NULL(actor); | |||||
| const auto &actor_name = actor->GetAID().Name(); | |||||
| MS_EXCEPTION_IF_NULL(actor->device_context_); | |||||
| ofs << "\tactor_name:" << actor_name << "\tdevice_context:" << actor->device_context_->device_context_key().ToString() | |||||
| << "\n"; | |||||
| if (actor_name.find("_DeviceQueueDataSourceActor") != string::npos) { | |||||
| // Dump the member info of device queue data source actor. | |||||
| const auto &device_queue_ds_actor = dynamic_cast<const DeviceQueueDataSourceActor *>(actor); | |||||
| const auto &data_kernel = device_queue_ds_actor->data_kernel_; | |||||
| MS_EXCEPTION_IF_NULL(data_kernel); | |||||
| ofs << "\t\tdata_kernel_name:" << data_kernel->fullname_with_scope() | |||||
| << "\tinput_number:" << AnfAlgo::GetInputTensorNum(data_kernel) | |||||
| << "\toutput_number:" << AnfAlgo::GetOutputTensorNum(data_kernel) << "\n"; | |||||
| for (size_t i = 0; i < AnfAlgo::GetOutputTensorNum(data_kernel); ++i) { | |||||
| const auto &device_tensor = AnfAlgo::GetMutableOutputAddr(data_kernel, i, false); | |||||
| MS_EXCEPTION_IF_NULL(device_tensor); | |||||
| ofs << "\t\t\toutput_index:" << i << "\tptr:" << device_tensor->GetPtr() << "\tsize:" << device_tensor->GetSize() | |||||
| << "\tref_count:" << device_tensor->ref_count_dynamic_used() << "\n "; | |||||
| } | |||||
| } else if (actor_name.find("_HostQueueDataSourceActor") != string::npos) { | |||||
| // Dump the member info of host queue data source actor. | |||||
| const auto &host_queue_ds_actor = dynamic_cast<const HostQueueDataSourceActor *>(actor); | |||||
| ofs << "\t\tdata_nodes:" << host_queue_ds_actor->data_nodes_.size() << "\n"; | |||||
| for (size_t i = 0; i < host_queue_ds_actor->data_nodes_.size(); ++i) { | |||||
| const auto &data_node = host_queue_ds_actor->data_nodes_[i]; | |||||
| MS_EXCEPTION_IF_NULL(data_node); | |||||
| const auto &device_tensor = AnfAlgo::GetMutableOutputAddr(data_node, 0, false); | |||||
| MS_EXCEPTION_IF_NULL(device_tensor); | |||||
| ofs << "\t\t\tnode_order_number:" << i << "\tnode_name:" << data_node->fullname_with_scope() | |||||
| << "\tptr:" << device_tensor->GetPtr() << "\tsize:" << device_tensor->GetSize() | |||||
| << "\tref_count:" << device_tensor->ref_count_dynamic_used() << "\n "; | |||||
| } | |||||
| } | |||||
| ofs << "\t\toutput_data_arrows:" << actor->output_op_arrows_.size() << "\n "; | |||||
| for (const auto &data_arrow : actor->output_op_arrows_) { | |||||
| MS_EXCEPTION_IF_NULL(data_arrow); | |||||
| ofs << "\t\t\tfrom_output_index:" << data_arrow->from_output_index_ | |||||
| << "\tto_actor_name:" << data_arrow->to_op_id_.Name() << "\tto_input_index:" << data_arrow->to_input_index_ | |||||
| << "\n"; | |||||
| } | |||||
| } | |||||
| void GraphScheduler::DumpLoopCountActor(const LoopCountActor *actor, std::ofstream &ofs) const { | |||||
| MS_EXCEPTION_IF_NULL(actor); | |||||
| ofs << "\tactor_name:" << actor->GetAID().Name() << "\tloop_count:" << actor->loop_count_ | |||||
| << "\tinput_controls_num:" << actor->input_controls_num_ << "\n"; | |||||
| ofs << "\t\toutput_control_arrows:" << (actor->data_source_aids_.size() + actor->no_input_kernel_aids_.size()) | |||||
| << "\n "; | |||||
| for (const auto &aid : actor->data_source_aids_) { | |||||
| ofs << "\t\t\tto_actor_name:" << aid.Name() << "\n"; | |||||
| } | |||||
| for (const auto &aid : actor->no_input_kernel_aids_) { | |||||
| ofs << "\t\t\tto_actor_name:" << aid.Name() << "\n"; | |||||
| } | |||||
| } | |||||
| void GraphScheduler::DumpKernelActor(const KernelActor *actor, std::ofstream &ofs) const { | |||||
| MS_EXCEPTION_IF_NULL(actor); | |||||
| MS_EXCEPTION_IF_NULL(actor->device_context_); | |||||
| ofs << "\tactor_name:" << actor->GetAID().Name() | |||||
| << "\tdevice_context:" << actor->device_context_->device_context_key().ToString() | |||||
| << "\tinput_data_num:" << actor->input_datas_num_ << "\tinput_controls_num:" << actor->input_controls_num_ | |||||
| << "\n"; | |||||
| const auto &kernel = actor->kernel_; | |||||
| MS_EXCEPTION_IF_NULL(kernel); | |||||
| ofs << "\t\tkernel_name:" << kernel->fullname_with_scope() << "\tinput_number:" << AnfAlgo::GetInputTensorNum(kernel) | |||||
| << "\toutput_number:" << AnfAlgo::GetOutputTensorNum(kernel) << "\n"; | |||||
| for (size_t i = 0; i < AnfAlgo::GetOutputTensorNum(kernel); ++i) { | |||||
| const auto &device_tensor = AnfAlgo::GetMutableOutputAddr(kernel, i, false); | |||||
| MS_EXCEPTION_IF_NULL(device_tensor); | |||||
| ofs << "\t\t\toutput_index:" << i << "\tptr:" << device_tensor->GetPtr() << "\tsize:" << device_tensor->GetSize() | |||||
| << "\tref_count:" << device_tensor->ref_count_dynamic_used() << "\n "; | |||||
| } | |||||
| ofs << "\t\tdevice_tensor_stores:" << actor->device_tensor_store_keys_.size() << "\n "; | |||||
| for (const auto &device_tensor_store_key : actor->device_tensor_store_keys_) { | |||||
| const auto &node = reinterpret_cast<AnfNode *>(device_tensor_store_key.second); | |||||
| MS_EXCEPTION_IF_NULL(node); | |||||
| ofs << "\t\t\tto_input_index:" << device_tensor_store_key.first | |||||
| << "\tfrom_node_name:" << node->fullname_with_scope() << "\n"; | |||||
| } | |||||
| ofs << "\t\toutput_data_arrows:" << actor->output_op_arrows_.size() << "\n "; | |||||
| for (const auto &data_arrow : actor->output_op_arrows_) { | |||||
| MS_EXCEPTION_IF_NULL(data_arrow); | |||||
| ofs << "\t\t\tfrom_output_index:" << data_arrow->from_output_index_ | |||||
| << "\tto_actor_name:" << data_arrow->to_op_id_.Name() << "\tto_input_index:" << data_arrow->to_input_index_ | |||||
| << "\n"; | |||||
| } | |||||
| ofs << "\t\toutput_control_arrows:" << actor->output_op_controls_.size() << "\n "; | |||||
| for (const auto &aid : actor->output_op_controls_) { | |||||
| ofs << "\t\t\tto_actor_name:" << aid.Name() << "\n"; | |||||
| } | |||||
| } | |||||
| } // namespace runtime | } // namespace runtime | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -23,6 +23,7 @@ | |||||
| #include <utility> | #include <utility> | ||||
| #include <unordered_map> | #include <unordered_map> | ||||
| #include <algorithm> | #include <algorithm> | ||||
| #include <fstream> | |||||
| #include "runtime/framework/actor/data_source_actor.h" | #include "runtime/framework/actor/data_source_actor.h" | ||||
| #include "runtime/framework/actor/loop_count_actor.h" | #include "runtime/framework/actor/loop_count_actor.h" | ||||
| #include "runtime/framework/actor/kernel_actor.h" | #include "runtime/framework/actor/kernel_actor.h" | ||||
| @@ -62,7 +63,8 @@ class GraphScheduler { | |||||
| return instance; | return instance; | ||||
| } | } | ||||
| // The memory manager creating and scheduling. | |||||
| // 1. Thread pool creating. | |||||
| // 2. The memory manager creating and scheduling. | |||||
| void Initialize(); | void Initialize(); | ||||
| // Transform graph to actor DAG, contains build and link. | // Transform graph to actor DAG, contains build and link. | ||||
| @@ -78,8 +80,7 @@ class GraphScheduler { | |||||
| // 1. Prepare the data of device tensor store(such as weights and value nodes of graph). | // 1. Prepare the data of device tensor store(such as weights and value nodes of graph). | ||||
| // 2. Prepare the data of host tensor queue(such as non weighted parameters of graph). | // 2. Prepare the data of host tensor queue(such as non weighted parameters of graph). | ||||
| // 3. Prepare the output tensor of graph. | // 3. Prepare the output tensor of graph. | ||||
| void PrepareRun(const KernelGraphPtr &graph, const DeviceContext *device_context, | |||||
| const std::vector<TensorPtr> *input_tensors, VectorRef *const &outputs); | |||||
| void PrepareRun(const KernelGraphPtr &graph, const std::vector<TensorPtr> *input_tensors, VectorRef *const &outputs); | |||||
| // The processing entry of actors running. | // The processing entry of actors running. | ||||
| bool Run(const ActorSet *actor_set, GraphExecutionStrategy strategy = GraphExecutionStrategy::kPipeline); | bool Run(const ActorSet *actor_set, GraphExecutionStrategy strategy = GraphExecutionStrategy::kPipeline); | ||||
| @@ -118,12 +119,21 @@ class GraphScheduler { | |||||
| GraphExecutionStrategy strategy); | GraphExecutionStrategy strategy); | ||||
| void LinkControlArrowForLoopCountActor(LoopCountActor *loop_count_actor, const KernelGraphPtr &graph); | void LinkControlArrowForLoopCountActor(LoopCountActor *loop_count_actor, const KernelGraphPtr &graph); | ||||
| // Check whether the actor set is valid. | |||||
| bool CheckActorValid(const ActorSet *actor_set) const; | |||||
| // Persist device tensors of graph's some nodes(such as weights and value nodes). | // Persist device tensors of graph's some nodes(such as weights and value nodes). | ||||
| void PersistDeviceTensor(const KernelGraphPtr &graph); | void PersistDeviceTensor(const KernelGraphPtr &graph); | ||||
| // Fetch the hsot tensor queue by kernel graph. | // Fetch the hsot tensor queue by kernel graph. | ||||
| HostTensorQueue *FetchHostQueue(const KernelGraphPtr &graph) const; | HostTensorQueue *FetchHostQueue(const KernelGraphPtr &graph) const; | ||||
| // Display the actor information of corresponding kernel graph. | |||||
| void DumpActor(const KernelGraphPtr &graph) const; | |||||
| void DumpDSActor(const DataSourceActor *actor, std::ofstream &ofs) const; | |||||
| void DumpLoopCountActor(const LoopCountActor *actor, std::ofstream &ofs) const; | |||||
| void DumpKernelActor(const KernelActor *actor, std::ofstream &ofs) const; | |||||
| std::unordered_map<KernelGraphPtr, ActorSetPtr> graph_to_actors_; | std::unordered_map<KernelGraphPtr, ActorSetPtr> graph_to_actors_; | ||||
| std::unordered_map<KernelGraphPtr, HostTensorQueuePtr> graph_to_host_queue_; | std::unordered_map<KernelGraphPtr, HostTensorQueuePtr> graph_to_host_queue_; | ||||
| @@ -95,7 +95,7 @@ class DeviceContext { | |||||
| // Synchronize stream, device such as GPU and Ascend need stream to launch kernel asynchronously, | // Synchronize stream, device such as GPU and Ascend need stream to launch kernel asynchronously, | ||||
| // using 'SyncStream' to block thread and wait for completing all tasks in stream. | // using 'SyncStream' to block thread and wait for completing all tasks in stream. | ||||
| // Devices that do not need stream could ignore the implementation of this function. | // Devices that do not need stream could ignore the implementation of this function. | ||||
| virtual bool SyncStream(size_t stream_id = 0) { return true; } | |||||
| virtual bool SyncStream(size_t stream_id = 0) const { return true; } | |||||
| // Get device_context_key_ to obtain device name and device id. | // Get device_context_key_ to obtain device name and device id. | ||||
| const DeviceContextKey &device_context_key() const { return device_context_key_; } | const DeviceContextKey &device_context_key() const { return device_context_key_; } | ||||
| @@ -270,7 +270,7 @@ bool GPUDeviceContext::LaunchKernel(KernelMod *kernel_mod, const std::vector<Add | |||||
| return kernel_mod->Launch(inputs, workspace, outputs, streams_.front()); | return kernel_mod->Launch(inputs, workspace, outputs, streams_.front()); | ||||
| } | } | ||||
| bool GPUDeviceContext::SyncStream(size_t stream_id) { | |||||
| bool GPUDeviceContext::SyncStream(size_t stream_id) const { | |||||
| if (stream_id >= streams_.size()) { | if (stream_id >= streams_.size()) { | ||||
| MS_LOG(EXCEPTION) << "The stream_id: " << stream_id << " is greater than stream array size: " << streams_.size(); | MS_LOG(EXCEPTION) << "The stream_id: " << stream_id << " is greater than stream array size: " << streams_.size(); | ||||
| } | } | ||||
| @@ -61,7 +61,7 @@ class GPUDeviceContext : public DeviceContext { | |||||
| bool LaunchKernel(KernelMod *kernel_mod, const std::vector<AddressPtr> &inputs, | bool LaunchKernel(KernelMod *kernel_mod, const std::vector<AddressPtr> &inputs, | ||||
| const std::vector<AddressPtr> &workspace, const std::vector<AddressPtr> &outputs) const override; | const std::vector<AddressPtr> &workspace, const std::vector<AddressPtr> &outputs) const override; | ||||
| bool SyncStream(size_t stream_id = 0) override; | |||||
| bool SyncStream(size_t stream_id = 0) const override; | |||||
| private: | private: | ||||
| DISABLE_COPY_AND_ASSIGN(GPUDeviceContext); | DISABLE_COPY_AND_ASSIGN(GPUDeviceContext); | ||||
| @@ -233,10 +233,12 @@ MindRTBackend::MindRTBackend(const std::string &backend_name, const std::string | |||||
| : Backend(backend_name), device_name_(device_name), device_id_(device_id) {} | : Backend(backend_name), device_name_(device_name), device_id_(device_id) {} | ||||
| GraphId MindRTBackend::CompileGraph(const AnfNodePtrList &nodes) { | GraphId MindRTBackend::CompileGraph(const AnfNodePtrList &nodes) { | ||||
| MS_LOG(INFO) << "Compile graph begin."; | |||||
| // Get and set the device context. | // Get and set the device context. | ||||
| const auto &cur_device_name = GetCNodeTarget(nodes[0]); | const auto &cur_device_name = GetCNodeTarget(nodes[0]); | ||||
| const auto &device_context = | const auto &device_context = | ||||
| device::DeviceContextManager::GetInstance().GetOrCreateDeviceContext({cur_device_name, device_id_}); | device::DeviceContextManager::GetInstance().GetOrCreateDeviceContext({cur_device_name, device_id_}); | ||||
| device_context->Initialize(); | |||||
| runtime::GraphCompiler::GetInstance().set_device_context(device_context); | runtime::GraphCompiler::GetInstance().set_device_context(device_context); | ||||
| // Transform nodes to inputs and outputs. | // Transform nodes to inputs and outputs. | ||||
| @@ -246,10 +248,13 @@ GraphId MindRTBackend::CompileGraph(const AnfNodePtrList &nodes) { | |||||
| std::tie(fg, inputs, outputs) = TransformSegmentToAnfGraph(nodes); | std::tie(fg, inputs, outputs) = TransformSegmentToAnfGraph(nodes); | ||||
| // Compile graph. | // Compile graph. | ||||
| return runtime::GraphCompiler::GetInstance().CompileGraph(inputs, outputs); | |||||
| auto graph_id = runtime::GraphCompiler::GetInstance().CompileGraph(nodes, outputs); | |||||
| MS_LOG(INFO) << "Compile graph end, graph id: " << graph_id; | |||||
| return graph_id; | |||||
| } | } | ||||
| VectorRef MindRTBackend::RunGraph(GraphId graph_id, const VectorRef &args) { | VectorRef MindRTBackend::RunGraph(GraphId graph_id, const VectorRef &args) { | ||||
| MS_LOG(INFO) << "Run graph begin, graph id: " << graph_id; | |||||
| const auto &context_ptr = MsContext::GetInstance(); | const auto &context_ptr = MsContext::GetInstance(); | ||||
| MS_EXCEPTION_IF_NULL(context_ptr); | MS_EXCEPTION_IF_NULL(context_ptr); | ||||
| if (context_ptr->get_param<bool>(MS_CTX_PRECOMPILE_ONLY)) { | if (context_ptr->get_param<bool>(MS_CTX_PRECOMPILE_ONLY)) { | ||||
| @@ -257,24 +262,38 @@ VectorRef MindRTBackend::RunGraph(GraphId graph_id, const VectorRef &args) { | |||||
| return VectorRef(); | return VectorRef(); | ||||
| } | } | ||||
| // Transform args to input tensors. | |||||
| std::vector<tensor::TensorPtr> inputs; | |||||
| for (const auto &arg : args) { | |||||
| PushInputTensor(arg, &inputs); | |||||
| } | |||||
| // Fetch the kernel graph. | // Fetch the kernel graph. | ||||
| const auto &kernel_graph = runtime::GraphCompiler::GetInstance().Fetch(graph_id); | const auto &kernel_graph = runtime::GraphCompiler::GetInstance().Fetch(graph_id); | ||||
| MS_EXCEPTION_IF_NULL(kernel_graph); | MS_EXCEPTION_IF_NULL(kernel_graph); | ||||
| // Transform args to input tensors. | |||||
| std::vector<tensor::TensorPtr> inputs; | |||||
| for (const auto &input_node : kernel_graph->input_nodes()) { | |||||
| const auto &front_node = kernel_graph->GetFrontAnfByBackendAnf(input_node); | |||||
| MS_EXCEPTION_IF_NULL(front_node); | |||||
| MS_EXCEPTION_IF_NULL(front_node->func_graph()); | |||||
| const auto &origin_parameters = front_node->func_graph()->parameters(); | |||||
| const auto &iter = std::find(origin_parameters.begin(), origin_parameters.end(), front_node); | |||||
| if (iter == origin_parameters.end()) { | |||||
| MS_LOG(EXCEPTION) << "Parameter node: " << front_node->fullname_with_scope() << " is not exist."; | |||||
| } | |||||
| auto position = IntToSize(std::distance(origin_parameters.begin(), iter)); | |||||
| PushInputTensor(args[position], &inputs); | |||||
| } | |||||
| // Fetch the actor DAG. | // Fetch the actor DAG. | ||||
| const auto &actor_set = runtime::GraphScheduler::GetInstance().Fetch(kernel_graph); | const auto &actor_set = runtime::GraphScheduler::GetInstance().Fetch(kernel_graph); | ||||
| MS_EXCEPTION_IF_NULL(actor_set); | MS_EXCEPTION_IF_NULL(actor_set); | ||||
| // Run actor DAG, wait interface of GraphScheduler to create outputs. | |||||
| // Run actor DAG. | |||||
| VectorRef outputs; | VectorRef outputs; | ||||
| runtime::GraphScheduler::GetInstance().Run(actor_set); | |||||
| runtime::GraphScheduler::GetInstance().PrepareRun(kernel_graph, &inputs, &outputs); | |||||
| if (!runtime::GraphScheduler::GetInstance().Run(actor_set)) { | |||||
| MS_LOG(EXCEPTION) << "The graph runs failed, graph id: " << graph_id | |||||
| << ", graph name: " << kernel_graph->ToString(); | |||||
| } | |||||
| MS_LOG(INFO) << "Run graph end, graph id: " << graph_id; | |||||
| return outputs; | return outputs; | ||||
| } | } | ||||
| } // namespace compile | } // namespace compile | ||||
| @@ -114,33 +114,34 @@ static int GetSlogLevel(MsLogLevel level) { | |||||
| static const char *GetSubModuleName(SubModuleId module_id) { | static const char *GetSubModuleName(SubModuleId module_id) { | ||||
| static const char *sub_module_names[NUM_SUBMODUES] = { | static const char *sub_module_names[NUM_SUBMODUES] = { | ||||
| "UNKNOWN", // SM_UNKNOWN | |||||
| "CORE", // SM_CORE | |||||
| "ANALYZER", // SM_ANALYZER | |||||
| "COMMON", // SM_COMMON | |||||
| "DEBUG", // SM_DEBUG | |||||
| "OFFLINE_DEBUG", // SM_OFFLINE_DEBUG | |||||
| "DEVICE", // SM_DEVICE | |||||
| "GE_ADPT", // SM_GE_ADPT | |||||
| "IR", // SM_IR | |||||
| "KERNEL", // SM_KERNEL | |||||
| "MD", // SM_MD | |||||
| "ME", // SM_ME | |||||
| "EXPRESS", // SM_EXPRESS | |||||
| "OPTIMIZER", // SM_OPTIMIZER | |||||
| "PARALLEL", // SM_PARALLEL | |||||
| "PARSER", // SM_PARSER | |||||
| "PIPELINE", // SM_PIPELINE | |||||
| "PRE_ACT", // SM_PRE_ACT | |||||
| "PYNATIVE", // SM_PYNATIVE | |||||
| "SESSION", // SM_SESSION | |||||
| "UTILS", // SM_UTILS | |||||
| "VM", // SM_VM | |||||
| "PROFILER", // SM_PROFILER | |||||
| "PS", // SM_PS | |||||
| "LITE", // SM_LITE | |||||
| "HCCL_ADPT", // SM_HCCL_ADPT | |||||
| "MINDQUANTUM" // SM_MINDQUANTUM | |||||
| "UNKNOWN", // SM_UNKNOWN | |||||
| "CORE", // SM_CORE | |||||
| "ANALYZER", // SM_ANALYZER | |||||
| "COMMON", // SM_COMMON | |||||
| "DEBUG", // SM_DEBUG | |||||
| "OFFLINE_DEBUG", // SM_OFFLINE_DEBUG | |||||
| "DEVICE", // SM_DEVICE | |||||
| "GE_ADPT", // SM_GE_ADPT | |||||
| "IR", // SM_IR | |||||
| "KERNEL", // SM_KERNEL | |||||
| "MD", // SM_MD | |||||
| "ME", // SM_ME | |||||
| "EXPRESS", // SM_EXPRESS | |||||
| "OPTIMIZER", // SM_OPTIMIZER | |||||
| "PARALLEL", // SM_PARALLEL | |||||
| "PARSER", // SM_PARSER | |||||
| "PIPELINE", // SM_PIPELINE | |||||
| "PRE_ACT", // SM_PRE_ACT | |||||
| "PYNATIVE", // SM_PYNATIVE | |||||
| "SESSION", // SM_SESSION | |||||
| "UTILS", // SM_UTILS | |||||
| "VM", // SM_VM | |||||
| "PROFILER", // SM_PROFILER | |||||
| "PS", // SM_PS | |||||
| "LITE", // SM_LITE | |||||
| "HCCL_ADPT", // SM_HCCL_ADPT | |||||
| "MINDQUANTUM", // SM_MINDQUANTUM | |||||
| "RUNTIME_FRAMEWORK" // SM_RUNTIME_FRAMEWORK | |||||
| }; | }; | ||||
| return sub_module_names[module_id % NUM_SUBMODUES]; | return sub_module_names[module_id % NUM_SUBMODUES]; | ||||