Merge pull request !26595 from caifubi/master-pynative-mindrt-lazy-build-with-ascendtags/v1.6.0
| @@ -21,6 +21,7 @@ | |||
| #include <memory> | |||
| #include "runtime/device/ascend/ge_runtime/task_info.h" | |||
| #include "backend/kernel_compiler/kernel.h" | |||
| #include "runtime/device/executor/dynamic_kernel.h" | |||
| #ifndef ENABLE_SECURITY | |||
| #include "debug/data_dump/dump_json_parser.h" | |||
| #endif | |||
| @@ -44,9 +45,19 @@ class AscendKernelMod : public KernelMod { | |||
| #endif | |||
| } | |||
| void InitDynamicKernel(const CNodePtr &cnode_ptr, void *stream) { | |||
| if (dynamic_kernel_ == nullptr) { | |||
| stream_ = stream; | |||
| dynamic_kernel_ = GenDynamicKernel(cnode_ptr, stream); | |||
| dynamic_kernel_->Initialize(); | |||
| } | |||
| } | |||
| device::DynamicKernelPtr DynamicKernel() const { return dynamic_kernel_; } | |||
| protected: | |||
| uint32_t block_dim_{1}; | |||
| uint32_t stream_id_{0}; | |||
| device::DynamicKernelPtr dynamic_kernel_{nullptr}; | |||
| }; | |||
| } // namespace kernel | |||
| } // namespace mindspore | |||
| @@ -2190,7 +2190,6 @@ std::shared_ptr<KernelGraph> SessionBasic::ConstructSingleOpGraph(const OpRunInf | |||
| // set execution order | |||
| std::vector<CNodePtr> exe_order = {cnode}; | |||
| graph->set_execution_order(exe_order); | |||
| // set output | |||
| if (is_ascend) { | |||
| graph->set_output(cnode); | |||
| } else { | |||
| @@ -96,7 +96,8 @@ void ResetMindRTEnable(const ResourcePtr &res) { | |||
| auto manager = func_graph->manager(); | |||
| size_t graph_nums = manager->func_graphs().size(); | |||
| // Heterogeneous scenario | |||
| if (graph_nums == 1 && context_ptr->get_param<std::string>(MS_CTX_DEVICE_TARGET) != kAscendDevice) { | |||
| if (graph_nums == 1 && (context_ptr->get_param<std::string>(MS_CTX_DEVICE_TARGET) != kAscendDevice || | |||
| context_ptr->get_param<int>(MS_CTX_EXECUTION_MODE) == kPynativeMode)) { | |||
| return; | |||
| } | |||
| if (common::GetEnv("ENABLE_ASCEND_MINDRT") == "1" || common::kEnableAscendMindRT) { | |||
| @@ -2116,7 +2116,7 @@ py::object ForwardExecutor::RunOpInMs(const OpExecInfoPtr &op_exec_info, Pynativ | |||
| #endif | |||
| VectorRef outputs; | |||
| if (!enable_mind_rt || cur_target == "Ascend") { | |||
| if (!enable_mind_rt) { | |||
| auto cur_session = GetCurrentSession(cur_target, device_id); | |||
| MS_EXCEPTION_IF_NULL(cur_session); | |||
| cur_session->RunOp(&op_run_info, &outputs); | |||
| @@ -178,9 +178,7 @@ void AscendDeviceAddress::BindDevice() const { | |||
| device::DeviceContextManager::GetInstance().GetOrCreateDeviceContext({device_name_, device_id_}); | |||
| auto ascend_device_context = dynamic_cast<AscendDeviceContext *>(device_context); | |||
| MS_EXCEPTION_IF_NULL(ascend_device_context); | |||
| if (!ascend_device_context->BindDeviceToCurrentThread()) { | |||
| MS_LOG(EXCEPTION) << "BindDeviceToCurrentThread failed."; | |||
| } | |||
| ascend_device_context->BindDeviceToCurrentThread(); | |||
| } else { | |||
| MS_LOG(WARNING) << "device name is null."; | |||
| } | |||
| @@ -160,6 +160,29 @@ void PrepareDataForValue(const ValuePtr &value, const KernelWithIndex &node_with | |||
| SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), error_info); | |||
| } | |||
| } | |||
| void UpdateRefNodeOutputDeviceAddress(const KernelGraphPtr &graph) { | |||
| MS_EXCEPTION_IF_NULL(graph); | |||
| auto ref_node_map = graph->GetRefMap(); | |||
| for (auto iter : ref_node_map) { | |||
| auto &output_pair = iter.first; | |||
| auto &input_pair = iter.second; | |||
| auto &ref_node = output_pair.first; | |||
| auto output_index = output_pair.second; | |||
| auto &input_node = input_pair.first; | |||
| auto input_node_output_index = input_pair.second; | |||
| auto input_addr = AnfAlgo::GetMutableOutputAddr(input_node, input_node_output_index); | |||
| auto ref_node_output_addr = AnfAlgo::GetMutableOutputAddr(ref_node, output_index); | |||
| // Just compare shared_ptr of two DeviceAddress. | |||
| // The ptr of DeviceAddress may still be nullptr. | |||
| if (input_addr != ref_node_output_addr) { | |||
| // The output of RefNode will not be used by subsequent Node. | |||
| // So update the DeviceAddress of the kernel directly instead of updating the ptr of the DeviceAddress. | |||
| AnfAlgo::SetOutputAddr(input_addr, output_index, ref_node.get()); | |||
| } | |||
| } | |||
| } | |||
| } // namespace | |||
| void DataPrepareActor::Init() { | |||
| MS_EXCEPTION_IF_NULL(graph_compiler_info_); | |||
| @@ -295,6 +318,9 @@ void DataPrepareActor::PrepareDataForDeviceTensorStore(const std::vector<std::ve | |||
| const auto front_node = FetchFrontNodeByBackendNode(input_node, graph); | |||
| PrepareDataForWeightNode(input_node, front_node, input_tensor, device_context, context); | |||
| } | |||
| // The DeviceAddress of the graph parameter has been updated. | |||
| // The output address of RefNode needs to be consistent with the address of parameter. | |||
| UpdateRefNodeOutputDeviceAddress(graph); | |||
| } | |||
| PrepareDeviceTensorStoreForControlNode(graph_compiler_info_->control_node_parser_, input_tensors.back(), context); | |||
| @@ -335,7 +335,11 @@ void KernelActor::FetchOutputDeviceTensor() { | |||
| MS_EXCEPTION_IF_NULL(output_address); | |||
| if (output_size_list[i] != output_address->GetSize()) { | |||
| // The size of output address may be changed in dynamic shape scenario. | |||
| output_address->SetSize(output_size_list[i]); | |||
| // If the format of the DeviceAddress is different, then the size is originally different. | |||
| // Such as NCHW(1,1,1,3) and NC1HWC0(1,1,1,1,16). So we don't need to update the size. | |||
| if (AnfAlgo::GetOutputFormat(kernel_, i) == output_address->format()) { | |||
| output_address->SetSize(output_size_list[i]); | |||
| } | |||
| } | |||
| // When the tensor is the output of graph or in dynamic shape scenario, the output tensor may be changed. | |||
| @@ -402,6 +402,10 @@ GraphId GraphCompiler::CompileGraphImpl(const KernelGraphPtr &graph, const Devic | |||
| MS_EXCEPTION_IF_NULL(device_context); | |||
| const auto &ms_context = MsContext::GetInstance(); | |||
| MS_EXCEPTION_IF_NULL(ms_context); | |||
| if (ms_context->get_param<int>(MS_CTX_EXECUTION_MODE) == kPynativeMode) { | |||
| return graph->graph_id(); | |||
| } | |||
| #ifdef ENABLE_DUMP_IR | |||
| bool save_graphs = ms_context->get_param<bool>(MS_CTX_SAVE_GRAPHS_FLAG); | |||
| // Dump .pb graph before graph optimization. | |||
| @@ -426,10 +430,8 @@ GraphId GraphCompiler::CompileGraphImpl(const KernelGraphPtr &graph, const Devic | |||
| // Adjust kernel graph before run graph. | |||
| device_context->PreprocessBeforeRunGraph(graph); | |||
| if (ms_context->get_param<int>(MS_CTX_EXECUTION_MODE) == kGraphMode) { | |||
| // Create device address for all anf nodes of graph. | |||
| CreateDeviceAddress(graph, device_context); | |||
| } | |||
| // Create device address for all anf nodes of graph. | |||
| CreateDeviceAddress(graph, device_context); | |||
| graph->set_is_all_nop_node(opt::IsAllNopNode(graph.get())); | |||
| @@ -482,14 +484,16 @@ GraphId GraphCompiler::CompileGraph(const session::OpRunInfo &op_run_info, bool | |||
| // Generate kernel graph. | |||
| MS_EXCEPTION_IF_NULL(session_); | |||
| KernelGraphPtr graph = | |||
| session_->ConstructSingleOpGraph(op_run_info, op_run_info.input_tensors, op_run_info.tensor_mask); | |||
| session_->ConstructSingleOpGraph(op_run_info, op_run_info.input_tensors, op_run_info.tensor_mask, | |||
| device_context->GetDeviceAddressType() == device::DeviceAddressType::kAscend); | |||
| MS_EXCEPTION_IF_NULL(graph); | |||
| // session_ is SessionBasic, AscendUnifyMindIR has not been executed. | |||
| device_context->UnifyMindIR(graph); | |||
| MS_EXCEPTION_IF_NULL(device_context); | |||
| device_context->OptimizeSingleOpGraph(graph); | |||
| device_context->PreprocessBeforeRunSingleOpGraph(graph); | |||
| // Create device address for all anf nodes of graph. | |||
| CreateDeviceAddressWithoutWorkspace(graph, device_context); | |||
| @@ -520,6 +524,7 @@ void GraphCompiler::BuildSingleOpGraphs(const std::vector<KernelGraphPtr> &graph | |||
| device_context->CreateKernel(node_to_build); | |||
| for (const auto &graph : graphs) { | |||
| device_context->PreprocessBeforeRunSingleOpGraph(graph); | |||
| CreateKernelWorkspaceDeviceAddress(device_context, graph); | |||
| } | |||
| } | |||
| @@ -553,6 +558,7 @@ void GraphCompiler::CreateDeviceAddressWithoutWorkspace(const KernelGraphPtr &gr | |||
| CreateValueNodeDeviceAddress(device_context, graph); | |||
| CreateKernelOutputDeviceAddress(device_context, graph); | |||
| UpdateDeviceAddressForInplaceNode(graph); | |||
| UpdateDeviceAddressForRefNode(graph); | |||
| } | |||
| void GraphCompiler::GetParamAndOutputIndex( | |||
| @@ -26,6 +26,8 @@ | |||
| #include "runtime/device/ascend/ascend_stream_assign.h" | |||
| #include "runtime/device/ascend/kernel_build_ascend.h" | |||
| #include "runtime/hardware/ascend/ascend_graph_optimization.h" | |||
| #include "backend/kernel_compiler/ascend_kernel_mod.h" | |||
| #include "runtime/device/ascend/ascend_bucket.h" | |||
| #ifndef ENABLE_SECURITY | |||
| #include "debug/data_dump/dump_json_parser.h" | |||
| @@ -284,10 +286,18 @@ void AscendDeviceContext::Initialize() { | |||
| DumpInit(rank_id_); | |||
| #endif | |||
| compute_stream_ = runtime_instance_->compute_stream(); | |||
| communication_stream_ = runtime_instance_->communication_stream(); | |||
| initialized_ = true; | |||
| MS_LOG(INFO) << "Status record: Initialize success."; | |||
| } | |||
| bool AscendDeviceContext::IsGraphMode() { | |||
| auto context = MsContext::GetInstance(); | |||
| MS_EXCEPTION_IF_NULL(context); | |||
| return context->get_param<int>(MS_CTX_EXECUTION_MODE) == kGraphMode; | |||
| } | |||
| void AscendDeviceContext::Destroy() { | |||
| MS_LOG(INFO) << "Status record: Enter Destroy..."; | |||
| if (!initialized_) { | |||
| @@ -306,7 +316,7 @@ void AscendDeviceContext::Destroy() { | |||
| std::vector<GraphSegmentPtr> AscendDeviceContext::PartitionGraph( | |||
| const FuncGraphPtr &func_graph, const std::vector<GraphSegmentPtr> &default_partition_segments) { | |||
| return std::vector<GraphSegmentPtr>(); | |||
| return IsGraphMode() ? std::vector<GraphSegmentPtr>() : default_partition_segments; | |||
| } | |||
| void AscendDeviceContext::UnifyMindIR(const KernelGraphPtr &graph) const { | |||
| @@ -544,27 +554,71 @@ bool AscendDeviceContext::SyncStream(size_t stream_id) const { | |||
| bool AscendDeviceContext::IsExecutingSink(const KernelGraphPtr &graph) const { | |||
| auto ms_context = MsContext::GetInstance(); | |||
| MS_EXCEPTION_IF_NULL(ms_context); | |||
| return ms_context->get_param<bool>(MS_CTX_ENABLE_TASK_SINK); | |||
| return ms_context->get_param<bool>(MS_CTX_ENABLE_TASK_SINK) && IsGraphMode(); | |||
| } | |||
| bool AscendDeviceContext::IsLoopCountSink(const KernelGraphPtr &graph) const { return true; } | |||
| bool AscendDeviceContext::IsLoopCountSink(const KernelGraphPtr &graph) const { return IsGraphMode(); } | |||
| // kernel by kernel mode interface | |||
| void AscendDeviceContext::OptimizeSingleOpGraph(const KernelGraphPtr &graph) const { | |||
| MS_LOG(ERROR) << "!!! Ascend with MindRT not support kernel by kernel mode. !!! "; | |||
| AscendGraphOptimization::GetInstance().OptimizeSingleOpGraph(graph); | |||
| } | |||
| void AscendDeviceContext::PreprocessBeforeRunSingleOpGraph(const KernelGraphPtr &graph) const { | |||
| MS_LOG(ERROR) << "!!! Ascend with MindRT not support kernel by kernel mode. !!! "; | |||
| } | |||
| MS_EXCEPTION_IF_NULL(graph); | |||
| const auto &nodes = graph->execution_order(); | |||
| // Remove placeholder | |||
| for (const auto &node : nodes) { | |||
| auto op_name = AnfAlgo::GetCNodeName(node); | |||
| static const std::set<std::string> place_holder_nodes = {kDynamicRNNOpName, kDynamicGRUV2OpName}; | |||
| auto iter = place_holder_nodes.find(op_name); | |||
| if (iter != place_holder_nodes.end()) { | |||
| auto none_index = AnfAlgo::GetNodeAttr<std::vector<int64_t>>(node, "placeholder_index"); | |||
| // Remove seq_length | |||
| auto input_num = AnfAlgo::GetInputTensorNum(node); | |||
| std::vector<AnfNodePtr> new_inputs = {AnfAlgo::GetCNodePrimitiveNode(node)}; | |||
| for (size_t i = 0; i < input_num; ++i) { | |||
| auto item = std::find(none_index.begin(), none_index.end(), i); | |||
| if (item == none_index.end()) { | |||
| auto input_node = AnfAlgo::GetInputNode(node, i); | |||
| new_inputs.emplace_back(input_node); | |||
| } | |||
| } | |||
| node->set_inputs(new_inputs); | |||
| } | |||
| } | |||
| device::ascend::InsertAtomicCleanOps(nodes, &node_atomics_); | |||
| std::vector<CNodePtr> atomic_nodes; | |||
| for (const auto &node : nodes) { | |||
| auto iter = node_atomics_.find(node); | |||
| if (iter != node_atomics_.end()) { | |||
| const auto &atomics = iter->second; | |||
| std::copy(atomics.begin(), atomics.end(), std::back_inserter(atomic_nodes)); | |||
| } | |||
| } | |||
| void AscendDeviceContext::UpdateDynamicShape(const CNodePtr &kernel) const { | |||
| MS_LOG(ERROR) << "!!! Ascend with MindRT not support function UpdateDynamicShape. !!! "; | |||
| CreateKernel(atomic_nodes); | |||
| } | |||
| void AscendDeviceContext::UpdateDynamicShape(const CNodePtr &kernel) const {} | |||
| std::shared_ptr<Bucket> AscendDeviceContext::CreateBucket(uint32_t bucket_id, uint32_t bucket_size) const { | |||
| MS_LOG(ERROR) << "!!! Ascend with MindRT not support function CreateBucket. !!! "; | |||
| return DeviceContext::CreateBucket(bucket_id, bucket_size); | |||
| auto bucket = std::make_shared<AscendBucket>(bucket_id, bucket_size); | |||
| MS_EXCEPTION_IF_NULL(bucket); | |||
| bucket->Init({compute_stream_}, {communication_stream_}); | |||
| return bucket; | |||
| } | |||
| bool AscendDeviceContext::SyncRuning() const { | |||
| auto ms_context = MsContext::GetInstance(); | |||
| MS_EXCEPTION_IF_NULL(ms_context); | |||
| if ((ms_context->get_param<int>(MS_CTX_EXECUTION_MODE) == kPynativeMode) && | |||
| ms_context->get_param<bool>(MS_CTX_ENABLE_PYNATIVE_SYNCHRONIZE) && !SyncStream()) { | |||
| return false; | |||
| } | |||
| return true; | |||
| } | |||
| bool AscendDeviceContext::LaunchKernel(const CNodePtr &kernel, const vector<AddressPtr> &inputs, | |||
| @@ -582,6 +636,19 @@ bool AscendDeviceContext::LaunchKernel(const CNodePtr &kernel, const vector<Addr | |||
| auto kernel_mod = AnfAlgo::GetKernelMod(kernel); | |||
| MS_EXCEPTION_IF_NULL(kernel_mod); | |||
| if (is_dynamic_shape) { | |||
| kernel::AscendKernelMod *ascend_kernel = dynamic_cast<kernel::AscendKernelMod *>(kernel_mod); | |||
| MS_EXCEPTION_IF_NULL(ascend_kernel); | |||
| ascend_kernel->InitDynamicKernel(kernel, compute_stream_); | |||
| auto dynamic_kernel = ascend_kernel->DynamicKernel(); | |||
| MS_EXCEPTION_IF_NULL(dynamic_kernel); | |||
| dynamic_kernel->InferShape(); | |||
| dynamic_kernel->UpdateArgs(); | |||
| dynamic_kernel->Execute(); | |||
| dynamic_kernel->PostExecute(); | |||
| return SyncRuning(); | |||
| } | |||
| std::vector<AddressPtr> real_inputs; | |||
| auto input_num = AnfAlgo::GetInputTensorNum(kernel); | |||
| if (input_num != inputs.size()) { | |||
| @@ -605,21 +672,13 @@ bool AscendDeviceContext::LaunchKernel(const CNodePtr &kernel, const vector<Addr | |||
| return false; | |||
| } | |||
| // Sync running. | |||
| auto ms_context = MsContext::GetInstance(); | |||
| MS_EXCEPTION_IF_NULL(ms_context); | |||
| if ((ms_context->get_param<int>(MS_CTX_EXECUTION_MODE) == kPynativeMode) && | |||
| ms_context->get_param<bool>(MS_CTX_ENABLE_PYNATIVE_SYNCHRONIZE) && !SyncStream()) { | |||
| return false; | |||
| } | |||
| return true; | |||
| return SyncRuning(); | |||
| } | |||
| bool AscendDeviceContext::BindDeviceToCurrentThread() const { | |||
| void AscendDeviceContext::BindDeviceToCurrentThread() const { | |||
| if (initialized_) { | |||
| runtime_instance_->SetContext(); | |||
| } | |||
| return true; | |||
| } | |||
| bool AscendDeviceContext::LaunchAtomicClean(const CNodePtr &node, const std::vector<AddressPtr> &workspace, | |||
| @@ -124,7 +124,7 @@ class AscendDeviceContext : public DeviceContext { | |||
| bool IsLoopCountSink(const KernelGraphPtr &graph) const override; | |||
| // set rt_context_ to this thread to control device | |||
| bool BindDeviceToCurrentThread() const; | |||
| void BindDeviceToCurrentThread() const; | |||
| // dump all graphs. | |||
| void DumpAllGraphs(const std::vector<KernelGraphPtr> &all_graphs) const override; | |||
| @@ -135,6 +135,8 @@ class AscendDeviceContext : public DeviceContext { | |||
| void AssignInputMemory(const NotNull<KernelGraphPtr> &graph, NotNull<std::set<KernelGraphPtr> *> memo) const; | |||
| void LoadModel(const NotNull<KernelGraphPtr> &root_graph) const; | |||
| void UpdateExecOrder(const KernelGraphPtr &graph) const; | |||
| static bool IsGraphMode(); | |||
| bool SyncRuning() const; | |||
| // Kernel Runtime --- only for task sink | |||
| AscendKernelRuntime *runtime_instance_{nullptr}; | |||
| @@ -157,6 +159,7 @@ class AscendDeviceContext : public DeviceContext { | |||
| bool LaunchAtomicClean(const CNodePtr &node, const std::vector<AddressPtr> &workspace, | |||
| const std::vector<AddressPtr> &outputs) const; | |||
| void *compute_stream_; | |||
| void *communication_stream_; | |||
| }; | |||
| } // namespace ascend | |||
| } // namespace device | |||
| @@ -56,6 +56,20 @@ void AscendGraphOptimization::OptimizeGraph(const KernelGraphPtr &graph) { | |||
| MS_LOG(INFO) << "Status record: end optimize graph. graph id: " << graph->graph_id(); | |||
| } | |||
| void AscendGraphOptimization::OptimizeSingleOpGraph(const KernelGraphPtr &graph) { | |||
| MS_EXCEPTION_IF_NULL(graph); | |||
| opt::RunOpAscendBackendIRFusionOptimization(graph); | |||
| SelectKernel(graph); | |||
| opt::RunOpAscendBackendOptimization(graph); | |||
| auto ms_context = MsContext::GetInstance(); | |||
| MS_EXCEPTION_IF_NULL(ms_context); | |||
| // Cannot Hide nop node in PyNative mode. | |||
| // If there is more than one node in the graph, | |||
| // and one of the nodes is a nop node, the node will be hidden. | |||
| // The DAG of Actors will be invalid(lack an input edge). | |||
| } | |||
| void AscendGraphOptimization::OptimizeGraphWithoutDeviceInfo(const KernelGraphPtr &graph) { | |||
| MS_EXCEPTION_IF_NULL(graph); | |||
| auto context_ptr = MsContext::GetInstance(); | |||
| @@ -97,6 +111,7 @@ void AscendGraphOptimization::OptimizeExecutionOrder(const KernelGraphPtr &graph | |||
| DumpIRProto(graph, "before_removeNop_" + std::to_string(graph->graph_id())); | |||
| } | |||
| #endif | |||
| // TODO(sida): do not hide nop op in kernel_by_kernel mode | |||
| if (graph->is_executing_sink()) { | |||
| opt::HideNopNode(graph.get()); | |||
| @@ -43,6 +43,7 @@ class AscendGraphOptimization { | |||
| AscendGraphOptimization &operator=(const AscendGraphOptimization &) = delete; | |||
| void OptimizeGraph(const KernelGraphPtr &graph); | |||
| void OptimizeSingleOpGraph(const KernelGraphPtr &graph); | |||
| void SetOperatorInfo(const std::vector<CNodePtr> &nodes); | |||
| void UnifyMindIR(const KernelGraphPtr &graph); | |||
| @@ -234,28 +234,25 @@ void UpdateOutput(const std::vector<session::KernelWithIndex> &output_nodes, Vec | |||
| } | |||
| } | |||
| void UpdateOutputDeviceAddress(const std::vector<session::KernelWithIndex> &output_nodes, | |||
| const DeviceContext *device_context) { | |||
| for (auto &item_with_index : output_nodes) { | |||
| auto &output_node = item_with_index.first; | |||
| auto output_index = item_with_index.second; | |||
| if (output_node != nullptr) { | |||
| if (!AnfAlgo::OutputAddrExist(output_node, output_index, false)) { | |||
| void ClearGraphDeviceAddress(const KernelGraphPtr &graph, const DeviceContext *device_context) { | |||
| MS_EXCEPTION_IF_NULL(graph); | |||
| for (const auto &node : graph->execution_order()) { | |||
| auto output_address_num = AnfAlgo::GetOutputAddressNum(node); | |||
| for (size_t i = 0; i < output_address_num; ++i) { | |||
| if (!AnfAlgo::OutputAddrExist(node, i, false)) { | |||
| continue; | |||
| } | |||
| const auto &device_tensor = AnfAlgo::GetMutableOutputAddr(output_node, output_index, false); | |||
| if ((device_tensor == nullptr) || (device_tensor->GetPtr() == nullptr)) { | |||
| const auto &device_address = AnfAlgo::GetMutableOutputAddr(node, i, false); | |||
| if (device_address == nullptr) { | |||
| continue; | |||
| } | |||
| MS_EXCEPTION_IF_NULL(device_context); | |||
| auto new_device_tensor = device_context->CreateDeviceAddress(nullptr, device_tensor->GetSize(), | |||
| device_tensor->format(), device_tensor->type_id()); | |||
| MS_EXCEPTION_IF_NULL(new_device_tensor); | |||
| new_device_tensor->set_original_ref_count(device_tensor->original_ref_count()); | |||
| new_device_tensor->ResetRefCount(); | |||
| AnfAlgo::SetOutputAddr(new_device_tensor, output_index, output_node.get()); | |||
| auto new_device_address = device_context->CreateDeviceAddress( | |||
| nullptr, device_address->GetSize(), device_address->format(), device_address->type_id()); | |||
| MS_EXCEPTION_IF_NULL(new_device_address); | |||
| new_device_address->set_original_ref_count(device_address->original_ref_count()); | |||
| new_device_address->ResetRefCount(); | |||
| AnfAlgo::SetOutputAddr(new_device_address, i, node.get()); | |||
| } | |||
| } | |||
| } | |||
| @@ -269,6 +266,51 @@ void UpdateInputDeviceAddress(const KernelGraphPtr &graph) { | |||
| } | |||
| } | |||
| } | |||
| std::vector<tensor::TensorPtr> GetRealValueNodeTensorFromGraph( | |||
| const KernelGraphPtr &graph, size_t input_tensors_size, | |||
| const std::vector<tensor::TensorPtr> &tensors_without_value_node) { | |||
| std::vector<tensor::TensorPtr> new_input_tensors; | |||
| if (graph->execution_order().size() != 1) { | |||
| return new_input_tensors; | |||
| } | |||
| const auto &node = graph->execution_order().back(); | |||
| auto input_num = AnfAlgo::GetInputTensorNum(node); | |||
| // In most scenarios, input_num and input_tensors_size are equal. | |||
| // Except for special procedures, new ValueNode will be added to Graph in GraphOptimize. | |||
| if (input_num == input_tensors_size) { | |||
| return new_input_tensors; | |||
| } | |||
| MS_LOG(INFO) << "CNode input num:" << input_num << " input_tensors size:" << input_tensors_size; | |||
| std::map<size_t, tensor::TensorPtr> value_node_pos; | |||
| for (size_t i = 0; i < input_num; ++i) { | |||
| auto input = AnfAlgo::GetInputNode(node, i); | |||
| MS_EXCEPTION_IF_NULL(input); | |||
| if (input->isa<ValueNode>()) { | |||
| auto value_node = input->cast<ValueNodePtr>(); | |||
| MS_EXCEPTION_IF_NULL(value_node); | |||
| auto value = value_node->value(); | |||
| MS_EXCEPTION_IF_NULL(value); | |||
| auto tensor = value->cast<tensor::TensorPtr>(); | |||
| value_node_pos.emplace(i, tensor); | |||
| } | |||
| } | |||
| size_t cur_input_tensor_index = 0; | |||
| for (size_t i = 0; i < input_num; ++i) { | |||
| auto iter = value_node_pos.find(i); | |||
| if (iter == value_node_pos.end()) { | |||
| new_input_tensors.emplace_back(tensors_without_value_node[cur_input_tensor_index]); | |||
| cur_input_tensor_index++; | |||
| } else { | |||
| new_input_tensors.emplace_back(iter->second); | |||
| } | |||
| } | |||
| MS_LOG(INFO) << "new input tensor size:" << new_input_tensors.size(); | |||
| return new_input_tensors; | |||
| } | |||
| } // namespace | |||
| VectorRef MsBackend::MsRunGraph(const GraphId &g, const VectorRef &args, const std::string &target) { | |||
| @@ -1125,6 +1167,9 @@ void MindRTBackend::RunSingleOpGraph(const KernelGraphPtr &graph, | |||
| } | |||
| } | |||
| std::vector<tensor::TensorPtr> new_input_tensors = | |||
| GetRealValueNodeTensorFromGraph(graph, input_tensors.size(), tensors_without_value_node); | |||
| for (auto &tensor : tensors_without_value_node) { | |||
| MS_EXCEPTION_IF_NULL(tensor); | |||
| if (tensor->NeedWaitDevice()) { | |||
| @@ -1135,7 +1180,8 @@ void MindRTBackend::RunSingleOpGraph(const KernelGraphPtr &graph, | |||
| // Run actor DAG. | |||
| const auto &actor_set = runtime::GraphScheduler::GetInstance().Fetch(graph_compiler_info->name_); | |||
| MS_EXCEPTION_IF_NULL(actor_set); | |||
| runtime::GraphScheduler::GetInstance().Run(actor_set, {}, {tensors_without_value_node}, input_tensors, | |||
| runtime::GraphScheduler::GetInstance().Run(actor_set, {}, {tensors_without_value_node}, | |||
| new_input_tensors.empty() ? input_tensors : new_input_tensors, | |||
| runtime::GraphExecutionStrategy::kStep); | |||
| // Release the kernel resource. | |||
| @@ -1200,7 +1246,7 @@ void MindRTBackend::LazyExecuteTaskCallback() { | |||
| const auto &context = op_run_task->context(); | |||
| RunSingleOpGraph(context->graph(), context->output_nodes(), context->op_run_info(), | |||
| context->graph_compiler_info(), context->device_context()); | |||
| UpdateOutputDeviceAddress(context->output_nodes(), context->device_context()); | |||
| ClearGraphDeviceAddress(context->graph(), context->device_context()); | |||
| UpdateInputDeviceAddress(context->graph()); | |||
| op_lazy_builder.PopOpRunTask(); | |||
| @@ -1258,7 +1304,7 @@ void MindRTBackend::RunOpInternal(bool single_op_cache_hit, GraphCompilerInfo *g | |||
| } | |||
| RunSingleOpGraph(graph, output_nodes, *op_run_info, graph_compiler_info, device_context); | |||
| UpdateOutput(output_nodes, outputs); | |||
| UpdateOutputDeviceAddress(output_nodes, device_context); | |||
| ClearGraphDeviceAddress(graph, device_context); | |||
| UpdateInputDeviceAddress(graph); | |||
| if (op_run_info->is_dynamic_shape) { | |||
| UpdateOutputAbstract(graph, op_run_info); | |||