diff --git a/mindspore/ccsrc/backend/session/gpu_session.cc b/mindspore/ccsrc/backend/session/gpu_session.cc index 904297749d..fc7f8738f8 100644 --- a/mindspore/ccsrc/backend/session/gpu_session.cc +++ b/mindspore/ccsrc/backend/session/gpu_session.cc @@ -484,6 +484,13 @@ void GPUSession::UpdateOutputTensors(const VectorRef *outputs, if (node->isa()) { auto new_address = std::make_shared(nullptr, address->GetSize()); AnfAlgo::SetOutputAddr(new_address, output_index, node.get()); + if (context::GraphKernelFlags::GetInstance().IsEnableGraphKernel()) { + auto runtime_instance = + device::KernelRuntimeManager::Instance().GetSingleKernelRuntime(kGPUDevice, device_id_); + MS_EXCEPTION_IF_NULL(runtime_instance); + auto gpu_runtime_instance = dynamic_cast(runtime_instance); + gpu_runtime_instance->SetAddrInvalid(address); + } } if (AnfAlgo::IsDynamicShape(node)) { diff --git a/mindspore/ccsrc/runtime/device/gpu/gpu_kernel_runtime.cc b/mindspore/ccsrc/runtime/device/gpu/gpu_kernel_runtime.cc index 9c2116df31..2e93144b54 100644 --- a/mindspore/ccsrc/runtime/device/gpu/gpu_kernel_runtime.cc +++ b/mindspore/ccsrc/runtime/device/gpu/gpu_kernel_runtime.cc @@ -1220,12 +1220,21 @@ DeviceAddressPtr GPUKernelRuntime::GetPrevNodeMutableOutputAddr(const AnfNodePtr addr_iter = iter; } - if (addr_iter->second[i] == nullptr) { + auto &now_addr = addr_iter->second[i]; + if (now_addr == nullptr) { auto device_address = AnfAlgo::GetPrevNodeMutableOutputAddr(node, i, visit_nop_node); - addr_iter->second[i] = device_address; + now_addr = device_address; + addr_state_[now_addr] = true; + } else { + auto addr_state_iter = addr_state_.find(now_addr); + if (addr_state_iter != addr_state_.end() && addr_state_iter->second == false) { + auto device_address = AnfAlgo::GetPrevNodeMutableOutputAddr(node, i, visit_nop_node); + now_addr = device_address; + addr_state_[now_addr] = true; + } } - return addr_iter->second[i]; + return now_addr; } DeviceAddressPtr GPUKernelRuntime::GetMutableOutputAddr(const AnfNodePtr &node, size_t i, bool visit_nop_node) { @@ -1244,12 +1253,21 @@ DeviceAddressPtr GPUKernelRuntime::GetMutableOutputAddr(const AnfNodePtr &node, addr_iter = iter; } - if (addr_iter->second[i] == nullptr) { + auto &now_addr = addr_iter->second[i]; + if (now_addr == nullptr) { auto device_address = AnfAlgo::GetMutableOutputAddr(node, i, visit_nop_node); - addr_iter->second[i] = device_address; + now_addr = device_address; + addr_state_[now_addr] = true; + } else { + auto addr_state_iter = addr_state_.find(now_addr); + if (addr_state_iter != addr_state_.end() && addr_state_iter->second == false) { + auto device_address = AnfAlgo::GetMutableOutputAddr(node, i, visit_nop_node); + now_addr = device_address; + addr_state_[now_addr] = true; + } } - return addr_iter->second[i]; + return now_addr; } session::KernelWithIndex GPUKernelRuntime::GetPrevNodeOutput(const AnfNodePtr &node, size_t i) { diff --git a/mindspore/ccsrc/runtime/device/gpu/gpu_kernel_runtime.h b/mindspore/ccsrc/runtime/device/gpu/gpu_kernel_runtime.h index a7f175d254..e8c77b073e 100644 --- a/mindspore/ccsrc/runtime/device/gpu/gpu_kernel_runtime.h +++ b/mindspore/ccsrc/runtime/device/gpu/gpu_kernel_runtime.h @@ -50,6 +50,7 @@ class GPUKernelRuntime : public KernelRuntime { DeviceAddressType GetTargetDeviceAddressType() const override { return DeviceAddressType::kGPU; }; void *compute_stream() const override { return stream_; } void *communication_stream() const override { return communication_stream_; } + void SetAddrInvalid(const DeviceAddressPtr &addr) { addr_state_[addr] = false; } protected: DeviceAddressPtr CreateDeviceAddress(void *device_ptr, size_t device_size, const string &format, @@ -121,6 +122,7 @@ class GPUKernelRuntime : public KernelRuntime { bool enable_relation_cache_{false}; + std::unordered_map addr_state_; std::unordered_map> prev_node_mut_output_addr_cache_; std::unordered_map> prev_node_mut_output_addr_skip_nop_node_cache_; std::unordered_map> mut_output_addr_cache_;