| @@ -34,6 +34,7 @@ void KernelActor::Init() { | |||
| is_dynamic_shape_ = AnfAlgo::IsDynamicShape(kernel_); | |||
| // Init the device tensors and kernel launch info. | |||
| copy_input_device_tensors_.resize(real_input_num_); | |||
| input_device_tensors_.resize(real_input_num_); | |||
| for (auto &input_address : input_device_tensors_) { | |||
| memory_free_list_.emplace_back(input_address); | |||
| @@ -277,6 +278,42 @@ void KernelActor::PushInputDeviceTensor(const std::vector<TensorPtr> *input_tens | |||
| } | |||
| } | |||
| void KernelActor::CopyInputDeviceTensor(const OpData<DeviceTensor> *input_data, OpContext<DeviceTensor> *context) { | |||
| MS_EXCEPTION_IF_NULL(input_data); | |||
| if ((input_data->data_ == nullptr) || (input_data->data_->DeviceType() == device_context_->GetDeviceAddressType())) { | |||
| return; | |||
| } | |||
| MS_LOG(DEBUG) << "Copy from device type: " << input_data->data_->DeviceType() | |||
| << " to device type: " << device_context_->GetDeviceAddressType() << " in " << GetAID().Name(); | |||
| if (copy_input_device_tensors_[input_data->index_] == nullptr) { | |||
| copy_input_device_tensors_[input_data->index_] = device_context_->CreateDeviceAddress( | |||
| nullptr, input_data->data_->GetSize(), input_data->data_->format(), input_data->data_->type_id()); | |||
| } | |||
| // Dynamic shape need update size. | |||
| copy_input_device_tensors_[input_data->index_]->SetSize(input_data->data_->GetSize()); | |||
| if (copy_input_device_tensors_[input_data->index_]->GetPtr() == nullptr) { | |||
| if (!device_context_->AllocateMemory(copy_input_device_tensors_[input_data->index_].get(), | |||
| copy_input_device_tensors_[input_data->index_]->GetSize())) { | |||
| std::string error_info = | |||
| "Device(id:" + std::to_string(device_context_->device_context_key().device_id_) + | |||
| ") memory isn't enough and alloc failed, actor name: " + GetAID().Name() + | |||
| ", alloc size: " + std::to_string(copy_input_device_tensors_[input_data->index_]->GetSize()); | |||
| SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), error_info); | |||
| } | |||
| } | |||
| if (!Copy(copy_input_device_tensors_[input_data->index_].get(), input_data->data_)) { | |||
| std::string error_info = "Copy device tensor failed: " + GetAID().Name(); | |||
| SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), error_info); | |||
| } | |||
| // Update by the copy input device tensor. | |||
| input_device_tensors_[input_data->index_] = copy_input_device_tensors_[input_data->index_].get(); | |||
| memory_free_list_[input_data->index_] = copy_input_device_tensors_[input_data->index_].get(); | |||
| } | |||
| void KernelActor::FetchInputDeviceTensor(OpContext<DeviceTensor> *context) { | |||
| MS_EXCEPTION_IF_NULL(context); | |||
| MS_EXCEPTION_IF_NULL(device_context_); | |||
| @@ -289,6 +326,7 @@ void KernelActor::FetchInputDeviceTensor(OpContext<DeviceTensor> *context) { | |||
| input_device_tensors_[input_data->index_] = input_data->data_; | |||
| memory_free_list_[input_data->index_] = input_data->data_; | |||
| } | |||
| CopyInputDeviceTensor(input_data, context); | |||
| } | |||
| } | |||
| @@ -91,6 +91,7 @@ class KernelActor : public DebugAwareActor { | |||
| // Fetch the device tensor for launch. | |||
| void FetchInputDeviceTensor(OpContext<DeviceTensor> *context); | |||
| void FetchOutputDeviceTensor(); | |||
| void CopyInputDeviceTensor(const OpData<DeviceTensor> *input_data, OpContext<DeviceTensor> *context); | |||
| // In step mode, push the input tensors which contain valid device address into input_device_tensors_ directly. | |||
| void PushInputDeviceTensor(const std::vector<TensorPtr> *input_tensors); | |||
| @@ -144,6 +145,9 @@ class KernelActor : public DebugAwareActor { | |||
| std::vector<DeviceTensor *> input_device_tensors_; | |||
| std::vector<DeviceTensor *> output_device_tensors_; | |||
| std::vector<DeviceTensor *> workspace_device_tensors_; | |||
| // The received input device type may be different from the device context type in the control flow and host device | |||
| // scenarios, so it needs to be copied from the input device type to the device context type. | |||
| std::vector<DeviceTensorPtr> copy_input_device_tensors_; | |||
| // The device tensors for memory alloc and free. | |||
| // output + workspace | |||
| @@ -1287,6 +1287,13 @@ void ControlNodeParser::FetchHostParameterToWeight(const RealToFormalNode &front | |||
| std::vector<AnfNodePtr> dest_nodes; | |||
| FetchWeightbyHostParameter(pair.first, &dest_nodes, front_to_front_parameters); | |||
| host_parameter_to_weights_[pair.first] = dest_nodes; | |||
| if (std::find(root_graph_parameters_.begin(), root_graph_parameters_.end(), pair.first) != | |||
| root_graph_parameters_.end()) { | |||
| for (auto &sub_front_node : dest_nodes) { | |||
| sub_front_node_to_root_front_node_[sub_front_node] = pair.first; | |||
| } | |||
| } | |||
| } | |||
| } | |||
| @@ -1584,5 +1591,12 @@ void ControlNodeParser::FetchAutoMonadNode(const std::vector<AnfNodePtr> &contro | |||
| } | |||
| } | |||
| } | |||
| AnfNodePtr ControlNodeParser::FetchRootGraphFrontNodeBySubFrontNode(const AnfNodePtr &sub_front_node) { | |||
| if (sub_front_node_to_root_front_node_.count(sub_front_node) == 0) { | |||
| return sub_front_node; | |||
| } | |||
| return sub_front_node_to_root_front_node_[sub_front_node]; | |||
| } | |||
| } // namespace runtime | |||
| } // namespace mindspore | |||
| @@ -126,6 +126,8 @@ class ControlNodeParser { | |||
| return front_to_backend_kernels_[front_node_with_index].first; | |||
| } | |||
| AnfNodePtr FetchRootGraphFrontNodeBySubFrontNode(const AnfNodePtr &sub_front_node); | |||
| private: | |||
| friend class GraphScheduler; | |||
| @@ -221,6 +223,7 @@ class ControlNodeParser { | |||
| // host parameter to weights records the weights in the subgraph corresponding to the node in the root funcgraph. | |||
| // When initializing the weights, all related weights need to be recorded as the same device tensor. | |||
| HostParameterToWeight host_parameter_to_weights_; | |||
| std::unordered_map<AnfNodePtr, AnfNodePtr> sub_front_node_to_root_front_node_; | |||
| // The front value node saves all value nodes that are not in the kernel graph. These nodes are generally the | |||
| // input of the control node. | |||
| @@ -248,7 +248,14 @@ void PrepareDataForControlWeightNode( | |||
| MS_EXCEPTION_IF_NULL(device_context); | |||
| auto device_tensors = DeviceTensorStore::GetInstance().Fetch(front_node.get()); | |||
| if (device_tensors.empty()) { | |||
| bool need_update_device_tensor_store = (device_tensors.size() == 0) ? true : false; | |||
| for (auto &device_tensor : device_tensors) { | |||
| if (device_tensor->GetPtr() == nullptr) { | |||
| need_update_device_tensor_store = true; | |||
| break; | |||
| } | |||
| } | |||
| if (need_update_device_tensor_store) { | |||
| PrepareDataForWeightNode(node, front_node, tensor, device_context); | |||
| } | |||
| @@ -455,7 +462,7 @@ void GraphScheduler::Initialize() { | |||
| auto OMP_thread_num_used = common::GetEnv("OMP_NUM_THREADS"); | |||
| MS_LOG(INFO) << "The actor thread number: " << actor_thread_num | |||
| << ", the computed OMP thread number : " << OMP_thread_num | |||
| << ", the used OMP thread number : " << stoi(OMP_thread_num_used); | |||
| << ", the used OMP thread number : " << OMP_thread_num_used; | |||
| BuildAndScheduleGlobalActor(); | |||
| } | |||
| @@ -2719,20 +2726,27 @@ void GraphScheduler::PersistDeviceTensor(const GraphCompilerInfo &graph_compiler | |||
| for (auto &input_node : graph->input_nodes()) { | |||
| MS_EXCEPTION_IF_NULL(input_node); | |||
| AnfNodePtr front_node = nullptr; | |||
| AnfNodePtr sub_front_node = nullptr; | |||
| if (IsInternalParameter(input_node, graph)) { | |||
| auto front_node_with_index = graph->GetFrontNodeByInternalParameter(input_node); | |||
| MS_EXCEPTION_IF_NULL(front_node_with_index.first); | |||
| const auto &front_output_with_index = | |||
| AnfAlgo::VisitKernelWithReturnType(front_node_with_index.first, front_node_with_index.second, false); | |||
| front_node = front_output_with_index.first; | |||
| } else if (IsPersistentDeviceTensor(input_node)) { | |||
| front_node = FetchFrontNodeByBackendNode(input_node, graph); | |||
| sub_front_node = front_output_with_index.first; | |||
| } else if (IsPersistentDeviceTensor(input_node) || HasAbstractRef(input_node)) { | |||
| sub_front_node = FetchFrontNodeByBackendNode(input_node, graph); | |||
| } | |||
| if (front_node == nullptr) { | |||
| if (sub_front_node == nullptr) { | |||
| continue; | |||
| } | |||
| // The sub front nodes share the device tensor store with the root front node. | |||
| auto front_node = sub_front_node; | |||
| if (graph_compiler_info.control_node_parser_ != nullptr) { | |||
| front_node = graph_compiler_info.control_node_parser_->FetchRootGraphFrontNodeBySubFrontNode(sub_front_node); | |||
| } | |||
| MS_LOG(DEBUG) << "Graph id:" << graph->graph_id() << ", sub front node:" << sub_front_node->DebugString() | |||
| << ", root front node:" << front_node->DebugString(); | |||
| auto device_tensor = AnfAlgo::GetMutableOutputAddr(input_node, 0, false); | |||
| MS_EXCEPTION_IF_NULL(device_tensor); | |||
| if (IsPersistentDeviceTensor(input_node)) { | |||
| @@ -3091,7 +3105,12 @@ void GraphScheduler::DumpDeviceTensorStore(const GraphCompilerInfo &graph_compil | |||
| if (!IsPersistentDeviceTensor(input_node)) { | |||
| continue; | |||
| } | |||
| const auto &front_node = FetchFrontNodeByBackendNode(input_node, graph); | |||
| const auto &sub_front_node = FetchFrontNodeByBackendNode(input_node, graph); | |||
| // The sub front nodes share the device tensor store with the root front node. | |||
| auto front_node = sub_front_node; | |||
| if (graph_compiler_info.control_node_parser_ != nullptr) { | |||
| front_node = graph_compiler_info.control_node_parser_->FetchRootGraphFrontNodeBySubFrontNode(sub_front_node); | |||
| } | |||
| const auto device_tensors = DeviceTensorStore::GetInstance().Fetch(front_node.get()); | |||
| ofs << "\t\tdevcie tensor key:" << front_node->fullname_with_scope() << "\tvalue size:" << device_tensors.size() | |||
| << "\n"; | |||