| @@ -320,6 +320,7 @@ void DataPrepareActor::OnMemoryAllocFinish(OpContext<DeviceTensor> *const contex | |||
| void DataPrepareActor::PrepareDataForDeviceTensorStore(const std::vector<std::vector<TensorPtr>> &input_tensors, | |||
| OpContext<DeviceTensor> *const context) { | |||
| const auto &parser = graph_compiler_info_->control_node_parser_; | |||
| for (size_t i = 0; i < graph_compiler_info_->graphs_.size(); ++i) { | |||
| const auto &graph = graph_compiler_info_->graphs_[i]; | |||
| const auto &device_context = graph_compiler_info_->device_contexts_[i]; | |||
| @@ -338,15 +339,16 @@ void DataPrepareActor::PrepareDataForDeviceTensorStore(const std::vector<std::ve | |||
| const auto &input_node = input_nodes[j]; | |||
| const auto &input_tensor = tensors[j]; | |||
| MS_EXCEPTION_IF_NULL(input_node); | |||
| if (!IsPersistentDeviceTensor(input_node)) { | |||
| const auto front_node = FetchFrontNodeByBackendNode(input_node, graph); | |||
| if (!IsPersistentDeviceTensor(input_node) || | |||
| (parser != nullptr && parser->IsInited() && (!parser->IsRootGraphParameter(front_node)))) { | |||
| continue; | |||
| } | |||
| const auto front_node = FetchFrontNodeByBackendNode(input_node, graph); | |||
| PrepareDataForWeightNode(input_node, front_node, input_tensor, device_context, context); | |||
| } | |||
| } | |||
| PrepareDeviceTensorStoreForControlNode(graph_compiler_info_->control_node_parser_, input_tensors.back(), context); | |||
| PrepareDataForControlNode(graph_compiler_info_->control_node_parser_, input_tensors.back(), context); | |||
| } | |||
| void DataPrepareActor::PrepareDataForHostTensorQueue(const std::vector<std::vector<TensorPtr>> &input_tensors, | |||
| @@ -699,51 +701,14 @@ void DataPrepareActor::PrepareDataForWeightNode(const AnfNodePtr &backend_node, | |||
| } | |||
| } | |||
| // In control flow, all weight nodes associated with the host weight parameter need to use the same device tensor. | |||
| void DataPrepareActor::PrepareDataForControlWeightNode(const AnfNodePtr &node, const AnfNodePtr &front_node, | |||
| const TensorPtr &tensor, const DeviceContext *device_context, | |||
| const HostParameterToWeight &host_parameter_to_weights, | |||
| OpContext<DeviceTensor> *const context) { | |||
| MS_EXCEPTION_IF_NULL(node); | |||
| MS_EXCEPTION_IF_NULL(front_node); | |||
| MS_EXCEPTION_IF_NULL(tensor); | |||
| MS_EXCEPTION_IF_NULL(device_context); | |||
| auto device_tensors = DeviceTensorStore::GetInstance().Fetch(front_node.get()); | |||
| bool need_update_device_tensor_store = (device_tensors.size() == 0) ? true : false; | |||
| for (auto &device_tensor : device_tensors) { | |||
| MS_EXCEPTION_IF_NULL(device_tensor); | |||
| // Different from CPU、GPU platform, the subgraph weight params device addr of Ascend platform | |||
| // has already been allocated during the compilation, so these weight params still need to be updated. | |||
| if (device_tensor->GetPtr() == nullptr || device_tensor->is_ptr_persisted()) { | |||
| need_update_device_tensor_store = true; | |||
| break; | |||
| } | |||
| } | |||
| if (need_update_device_tensor_store) { | |||
| PrepareDataForWeightNode(node, front_node, tensor, device_context, context); | |||
| } | |||
| const auto iter = host_parameter_to_weights.find(front_node); | |||
| if (iter == host_parameter_to_weights.end()) { | |||
| void DataPrepareActor::PrepareDataForControlNode(const ControlNodeParserPtr &control_node_parser, | |||
| const std::vector<TensorPtr> &tensors, | |||
| OpContext<DeviceTensor> *const context) { | |||
| MS_EXCEPTION_IF_NULL(control_node_parser); | |||
| if (!control_node_parser->IsInited()) { | |||
| return; | |||
| } | |||
| // Fetch all the device tensors of host weight node and insert as the weight of other nodes. | |||
| const auto &sub_front_nodes = host_parameter_to_weights.at(front_node); | |||
| device_tensors = DeviceTensorStore::GetInstance().Fetch(front_node.get()); | |||
| for (const auto &sub_front_node : sub_front_nodes) { | |||
| for (const auto &device_tensor : device_tensors) { | |||
| MS_EXCEPTION_IF_NULL(sub_front_node); | |||
| DeviceTensorStore::GetInstance().Insert(sub_front_node.get(), device_tensor); | |||
| } | |||
| } | |||
| } | |||
| void DataPrepareActor::PrepareDeviceTensorStoreForControlNode(const ControlNodeParserPtr &control_node_parser, | |||
| const std::vector<TensorPtr> &tensors, | |||
| OpContext<DeviceTensor> *const context) { | |||
| MS_EXCEPTION_IF_NULL(control_node_parser); | |||
| for (const auto &value_node_with_context : control_node_parser->front_value_nodes()) { | |||
| MS_EXCEPTION_IF_NULL(value_node_with_context.first.first); | |||
| if (AnfAlgo::OutputAddrExist(value_node_with_context.first.first, 0)) { | |||
| @@ -753,19 +718,34 @@ void DataPrepareActor::PrepareDeviceTensorStoreForControlNode(const ControlNodeP | |||
| const auto &control_node_parameters = control_node_parser->control_node_parameters(); | |||
| for (size_t i = 0; i < control_node_parameters.size(); ++i) { | |||
| const auto &input_node = control_node_parameters[i]; | |||
| const auto &input_tensor = tensors[i]; | |||
| MS_EXCEPTION_IF_NULL(input_node); | |||
| if (IsPersistentDeviceTensor(input_node)) { | |||
| const auto &front_to_backend_parameters = control_node_parser->front_to_backend_parameters(); | |||
| const auto &iter = front_to_backend_parameters.find({input_node, 0}); | |||
| if (iter == front_to_backend_parameters.end() || iter->second.empty()) { | |||
| MS_LOG(EXCEPTION) << "Cannot find backend node for weight parameter:" | |||
| << AnfAlgo::GetNodeDebugString(input_node); | |||
| } | |||
| const auto &node_with_context = iter->second.begin(); | |||
| PrepareDataForControlWeightNode(node_with_context->first, input_node, input_tensor, node_with_context->second, | |||
| control_node_parser->host_parameter_to_weights(), context); | |||
| const auto &front_node = control_node_parameters[i]; | |||
| MS_EXCEPTION_IF_NULL(front_node); | |||
| if ((!IsPersistentDeviceTensor(front_node)) || (!control_node_parser->IsRootGraphParameter(front_node))) { | |||
| continue; | |||
| } | |||
| const auto &front_to_backend_parameters = control_node_parser->front_to_backend_parameters(); | |||
| const auto &iter = front_to_backend_parameters.find({front_node, 0}); | |||
| if (iter == front_to_backend_parameters.end() || iter->second.empty()) { | |||
| MS_LOG(EXCEPTION) << "Cannot find backend node for weight parameter:" << AnfAlgo::GetNodeDebugString(front_node); | |||
| } | |||
| const auto &node_with_context = iter->second.begin(); | |||
| const auto &backend_node = node_with_context->first; | |||
| const auto &device_context = node_with_context->second; | |||
| MS_EXCEPTION_IF_NULL(backend_node); | |||
| MS_EXCEPTION_IF_NULL(device_context); | |||
| auto device_tensors = DeviceTensorStore::GetInstance().Fetch(front_node.get()); | |||
| if (device_tensors.empty()) { | |||
| std::string error_info = "Failed to get device tensor for front node:" + front_node->DebugString(); | |||
| SET_OPCONTEXT_FAIL_RET_WITH_ERROR_BY_STRATEGY(real_strategy_, (*context), error_info); | |||
| } | |||
| // Different from CPU, GPU platform, the subgraph weight params device addr of Ascend platform has already been | |||
| // allocated during the compilation, so these weight params still need to be updated. | |||
| if (device_tensors[0] != nullptr && | |||
| (device_tensors[0]->GetPtr() == nullptr || device_tensors[0]->is_ptr_persisted())) { | |||
| PrepareDataForWeightNode(backend_node, front_node, tensors[i], device_context, context); | |||
| } | |||
| } | |||
| } | |||
| @@ -90,17 +90,13 @@ class DataPrepareActor : public DebugAwareActor { | |||
| const DeviceContext *device_context, OpContext<DeviceTensor> *const context); | |||
| // The data prepare in the control flow scene. | |||
| void PrepareDeviceTensorStoreForControlNode(const ControlNodeParserPtr &control_node_parser, | |||
| const std::vector<TensorPtr> &tensors, | |||
| OpContext<DeviceTensor> *const context); | |||
| // If the parameters in the root graph are only used by the control node, these parameters will not be initialized | |||
| // by the kernel graph, and addresses need to be specially allocated for these parameters. | |||
| void PrepareDataForControlNode(const ControlNodeParserPtr &control_node_parser, const std::vector<TensorPtr> &tensors, | |||
| OpContext<DeviceTensor> *const context); | |||
| void PrepareHostTensorQueueForControlNode(const std::vector<TensorPtr> &tensors, | |||
| std::vector<TensorPtr> *const host_tensors, | |||
| OpContext<DeviceTensor> *const context); | |||
| // In control flow, all weight nodes associated with the host weight parameter need to use the same device tensor. | |||
| void PrepareDataForControlWeightNode(const AnfNodePtr &node, const AnfNodePtr &front_node, const TensorPtr &tensor, | |||
| const DeviceContext *device_context, | |||
| const HostParameterToWeight &host_parameter_to_weights, | |||
| OpContext<DeviceTensor> *const context); | |||
| void PrepareDataForControlValueNode(const KernelWithIndex &node_with_index, const DeviceContext *device_context, | |||
| OpContext<DeviceTensor> *const context); | |||
| @@ -945,7 +945,7 @@ void ControlNodeParser::Parse(const std::vector<AnfNodePtr> &control_nodes, cons | |||
| ParseFirstControlNodeForFuncGraph(control_nodes); | |||
| } | |||
| bool ControlNodeParser::IsControlFlowDataArrow(const KernelGraphPtr &graph, const AnfNodePtr &node) { | |||
| bool ControlNodeParser::IsControlFlowDataArrow(const KernelGraphPtr &graph, const AnfNodePtr &backend_node) { | |||
| MS_EXCEPTION_IF_NULL(graph); | |||
| // Has no control flow node. | |||
| if (!IsInited()) { | |||
| @@ -956,26 +956,24 @@ bool ControlNodeParser::IsControlFlowDataArrow(const KernelGraphPtr &graph, cons | |||
| return true; | |||
| } | |||
| MS_EXCEPTION_IF_NULL(node); | |||
| if (!node->isa<Parameter>()) { | |||
| MS_EXCEPTION_IF_NULL(backend_node); | |||
| if (!backend_node->isa<Parameter>()) { | |||
| return false; | |||
| } | |||
| auto parameter_node = node->cast<ParameterPtr>(); | |||
| auto parameter_node = backend_node->cast<ParameterPtr>(); | |||
| MS_EXCEPTION_IF_NULL(parameter_node); | |||
| // Parameter input should be linked to its entrance actor. | |||
| auto front_node = graph->GetFrontAnfByBackendAnf(node); | |||
| auto internal_node_with_index = graph->GetFrontNodeByInternalParameter(node); | |||
| auto front_node = graph->GetFrontAnfByBackendAnf(backend_node); | |||
| auto internal_node_with_index = graph->GetFrontNodeByInternalParameter(backend_node); | |||
| front_node = (front_node != nullptr ? front_node : internal_node_with_index.first); | |||
| if (front_node == nullptr) { | |||
| auto front_node_with_index = graph->GetElementInTupleBackendFrontIndexMap(node); | |||
| auto front_node_with_index = graph->GetElementInTupleBackendFrontIndexMap(backend_node); | |||
| front_node = front_node_with_index.first; | |||
| } | |||
| // If parameter is a weight node, it should be set to kernel actor directly. | |||
| if (AnfAlgo::IsParameterWeight(node->cast<ParameterPtr>()) || | |||
| (front_node != nullptr && front_node->isa<Parameter>() && | |||
| AnfAlgo::IsParameterWeight(front_node->cast<ParameterPtr>()))) { | |||
| MS_EXCEPTION_IF_NULL(front_node); | |||
| // If parameter is a weight node in root funcgraph, it should be set to kernel actor directly. | |||
| if (IsRootGraphParameter(front_node) && AnfAlgo::IsParameterWeight(backend_node->cast<ParameterPtr>())) { | |||
| return false; | |||
| } | |||
| @@ -125,7 +125,7 @@ class ControlNodeParser { | |||
| // There are two situations: | |||
| // 1. In control flow, the parameter input needs to be connected to the entrance actor of the funcgraph. | |||
| // 2. In the kernel graph with call node input, the data arrow needs to be connected to the stack actor. | |||
| bool IsControlFlowDataArrow(const KernelGraphPtr &graph, const AnfNodePtr &node); | |||
| bool IsControlFlowDataArrow(const KernelGraphPtr &graph, const AnfNodePtr &backend_node); | |||
| bool IsRootGraphParameter(const AnfNodePtr &node); | |||
| bool IsRecursionCallNode(const AnfNodePtr &node); | |||
| // If there is a recursive call node in the input of the kernel graph, the graph is recursive. | |||
| @@ -1862,6 +1862,7 @@ void GraphScheduler::CheckActorValid(const ActorSet *actor_set) const { | |||
| } | |||
| void GraphScheduler::PersistDeviceTensor(const GraphCompilerInfo &graph_compiler_info) { | |||
| const auto &parser = graph_compiler_info.control_node_parser_; | |||
| for (size_t i = 0; i < graph_compiler_info.graphs_.size(); ++i) { | |||
| const auto &graph = graph_compiler_info.graphs_[i]; | |||
| const auto &device_context = graph_compiler_info.device_contexts_[i]; | |||
| @@ -1882,24 +1883,18 @@ void GraphScheduler::PersistDeviceTensor(const GraphCompilerInfo &graph_compiler | |||
| for (auto &input_node : graph->input_nodes()) { | |||
| MS_EXCEPTION_IF_NULL(input_node); | |||
| AnfNodePtr sub_front_node = nullptr; | |||
| AnfNodePtr front_node = nullptr; | |||
| if (IsInternalParameter(input_node, graph)) { | |||
| auto front_output_with_index = graph->GetFrontNodeByInternalParameter(input_node); | |||
| sub_front_node = front_output_with_index.first; | |||
| } else if (IsPersistentDeviceTensor(input_node) || HasAbstractRef(input_node)) { | |||
| sub_front_node = FetchFrontNodeByBackendNode(input_node, graph); | |||
| front_node = front_output_with_index.first; | |||
| } else if (IsPersistentDeviceTensor(input_node)) { | |||
| front_node = FetchFrontNodeByBackendNode(input_node, graph); | |||
| } | |||
| if (sub_front_node == nullptr) { | |||
| if (front_node == nullptr || (!IsPersistentDeviceTensor(front_node)) || | |||
| (parser != nullptr && parser->IsInited() && (!parser->IsRootGraphParameter(front_node)))) { | |||
| continue; | |||
| } | |||
| // The sub front nodes share the device tensor store with the root front node. | |||
| MS_EXCEPTION_IF_NULL(graph_compiler_info.control_node_parser_); | |||
| auto front_node = graph_compiler_info.control_node_parser_->FetchRootGraphFrontNodeBySubFrontNode(sub_front_node); | |||
| MS_EXCEPTION_IF_NULL(front_node); | |||
| MS_LOG(DEBUG) << "Graph id:" << graph->graph_id() << ", sub front node:" << sub_front_node->DebugString() | |||
| << ", root front node:" << front_node->DebugString(); | |||
| auto device_tensor = AnfAlgo::GetMutableOutputAddr(input_node, 0, false); | |||
| MS_EXCEPTION_IF_NULL(device_tensor); | |||
| if (IsPersistentDeviceTensor(input_node)) { | |||
| @@ -1907,14 +1902,11 @@ void GraphScheduler::PersistDeviceTensor(const GraphCompilerInfo &graph_compiler | |||
| AddDeviceTensorStore(front_node.get(), device_tensor); | |||
| } | |||
| // Share the weight in the host and device, then input_node is internal parameter and front_node is weight. | |||
| if (!IsPersistentDeviceTensor(front_node)) { | |||
| continue; | |||
| } | |||
| if (device_tensor->is_ptr_persisted()) { | |||
| device_tensor->SetNodeIndex(input_node, 0); | |||
| AddDeviceTensorStore(front_node.get(), device_tensor); | |||
| } | |||
| // If the device tensor store of this device type is not exist, then create the new device tensor of this type. | |||
| if (DeviceTensorStore::GetInstance().Fetch(front_node.get(), device_context->GetDeviceAddressType()) == nullptr) { | |||
| MS_LOG(INFO) << "Fetch no device tensor store by:" << front_node->fullname_with_scope() | |||
| @@ -1927,17 +1919,38 @@ void GraphScheduler::PersistDeviceTensor(const GraphCompilerInfo &graph_compiler | |||
| } | |||
| } | |||
| } | |||
| PersistDeviceTensorForControlNode(graph_compiler_info); | |||
| } | |||
| void GraphScheduler::PersistDeviceTensorForControlNode(const GraphCompilerInfo &graph_compiler_info) { | |||
| const auto &parser = graph_compiler_info.control_node_parser_; | |||
| if (parser == nullptr) { | |||
| if (parser == nullptr || (!parser->IsInited())) { | |||
| return; | |||
| } | |||
| for (const auto &sub_front_node_to_root_front_node : parser->sub_front_node_to_root_front_node_) { | |||
| auto device_tensors = DeviceTensorStore::GetInstance().Fetch(sub_front_node_to_root_front_node.second.get()); | |||
| for (const auto &device_tensor : device_tensors) { | |||
| MS_EXCEPTION_IF_NULL(device_tensor); | |||
| AddDeviceTensorStore(sub_front_node_to_root_front_node.first.get(), device_tensor); | |||
| const auto &control_node_parameters = parser->control_node_parameters(); | |||
| for (size_t i = 0; i < control_node_parameters.size(); ++i) { | |||
| const auto &input_node = control_node_parameters[i]; | |||
| MS_EXCEPTION_IF_NULL(input_node); | |||
| if ((!IsPersistentDeviceTensor(input_node)) || (!parser->IsRootGraphParameter(input_node))) { | |||
| continue; | |||
| } | |||
| const auto &front_to_backend_parameters = parser->front_to_backend_parameters(); | |||
| const auto &iter = front_to_backend_parameters.find({input_node, 0}); | |||
| if (iter == front_to_backend_parameters.end() || iter->second.empty()) { | |||
| MS_LOG(EXCEPTION) << "Cannot find backend node for weight parameter:" << input_node->DebugString(); | |||
| } | |||
| const auto &node_with_context = iter->second.begin(); | |||
| const auto &backend_node = node_with_context->first; | |||
| const auto &device_context = node_with_context->second; | |||
| MS_EXCEPTION_IF_NULL(backend_node); | |||
| MS_EXCEPTION_IF_NULL(device_context); | |||
| if (!DeviceTensorStore::GetInstance().Fetch(input_node.get()).empty()) { | |||
| continue; | |||
| } | |||
| auto device_tensor = AnfAlgo::GetMutableOutputAddr(backend_node, 0, false); | |||
| MS_EXCEPTION_IF_NULL(device_tensor); | |||
| AddDeviceTensorStore(input_node.get(), device_tensor); | |||
| } | |||
| } | |||
| @@ -179,6 +179,7 @@ class GraphScheduler { | |||
| // Persist device tensors of graph's some nodes(such as weights and value nodes). | |||
| void PersistDeviceTensor(const GraphCompilerInfo &graph_compiler_info); | |||
| void PersistDeviceTensorForControlNode(const GraphCompilerInfo &graph_compiler_info); | |||
| // Display the actor information of corresponding kernel graph. | |||
| void DumpActor(const ActorSet *actor_set, const GraphCompilerInfo &graph_compiler_info) const; | |||