| @@ -484,6 +484,13 @@ void GPUSession::UpdateOutputTensors(const VectorRef *outputs, | |||||
| if (node->isa<CNode>()) { | if (node->isa<CNode>()) { | ||||
| auto new_address = std::make_shared<device::gpu::GPUDeviceAddress>(nullptr, address->GetSize()); | auto new_address = std::make_shared<device::gpu::GPUDeviceAddress>(nullptr, address->GetSize()); | ||||
| AnfAlgo::SetOutputAddr(new_address, output_index, node.get()); | AnfAlgo::SetOutputAddr(new_address, output_index, node.get()); | ||||
| if (context::GraphKernelFlags::GetInstance().IsEnableGraphKernel()) { | |||||
| auto runtime_instance = | |||||
| device::KernelRuntimeManager::Instance().GetSingleKernelRuntime(kGPUDevice, device_id_); | |||||
| MS_EXCEPTION_IF_NULL(runtime_instance); | |||||
| auto gpu_runtime_instance = dynamic_cast<device::gpu::GPUKernelRuntime *>(runtime_instance); | |||||
| gpu_runtime_instance->SetAddrInvalid(address); | |||||
| } | |||||
| } | } | ||||
| if (AnfAlgo::IsDynamicShape(node)) { | if (AnfAlgo::IsDynamicShape(node)) { | ||||
| @@ -1220,12 +1220,21 @@ DeviceAddressPtr GPUKernelRuntime::GetPrevNodeMutableOutputAddr(const AnfNodePtr | |||||
| addr_iter = iter; | addr_iter = iter; | ||||
| } | } | ||||
| if (addr_iter->second[i] == nullptr) { | |||||
| auto &now_addr = addr_iter->second[i]; | |||||
| if (now_addr == nullptr) { | |||||
| auto device_address = AnfAlgo::GetPrevNodeMutableOutputAddr(node, i, visit_nop_node); | auto device_address = AnfAlgo::GetPrevNodeMutableOutputAddr(node, i, visit_nop_node); | ||||
| addr_iter->second[i] = device_address; | |||||
| now_addr = device_address; | |||||
| addr_state_[now_addr] = true; | |||||
| } else { | |||||
| auto addr_state_iter = addr_state_.find(now_addr); | |||||
| if (addr_state_iter != addr_state_.end() && addr_state_iter->second == false) { | |||||
| auto device_address = AnfAlgo::GetPrevNodeMutableOutputAddr(node, i, visit_nop_node); | |||||
| now_addr = device_address; | |||||
| addr_state_[now_addr] = true; | |||||
| } | |||||
| } | } | ||||
| return addr_iter->second[i]; | |||||
| return now_addr; | |||||
| } | } | ||||
| DeviceAddressPtr GPUKernelRuntime::GetMutableOutputAddr(const AnfNodePtr &node, size_t i, bool visit_nop_node) { | DeviceAddressPtr GPUKernelRuntime::GetMutableOutputAddr(const AnfNodePtr &node, size_t i, bool visit_nop_node) { | ||||
| @@ -1244,12 +1253,21 @@ DeviceAddressPtr GPUKernelRuntime::GetMutableOutputAddr(const AnfNodePtr &node, | |||||
| addr_iter = iter; | addr_iter = iter; | ||||
| } | } | ||||
| if (addr_iter->second[i] == nullptr) { | |||||
| auto &now_addr = addr_iter->second[i]; | |||||
| if (now_addr == nullptr) { | |||||
| auto device_address = AnfAlgo::GetMutableOutputAddr(node, i, visit_nop_node); | auto device_address = AnfAlgo::GetMutableOutputAddr(node, i, visit_nop_node); | ||||
| addr_iter->second[i] = device_address; | |||||
| now_addr = device_address; | |||||
| addr_state_[now_addr] = true; | |||||
| } else { | |||||
| auto addr_state_iter = addr_state_.find(now_addr); | |||||
| if (addr_state_iter != addr_state_.end() && addr_state_iter->second == false) { | |||||
| auto device_address = AnfAlgo::GetMutableOutputAddr(node, i, visit_nop_node); | |||||
| now_addr = device_address; | |||||
| addr_state_[now_addr] = true; | |||||
| } | |||||
| } | } | ||||
| return addr_iter->second[i]; | |||||
| return now_addr; | |||||
| } | } | ||||
| session::KernelWithIndex GPUKernelRuntime::GetPrevNodeOutput(const AnfNodePtr &node, size_t i) { | session::KernelWithIndex GPUKernelRuntime::GetPrevNodeOutput(const AnfNodePtr &node, size_t i) { | ||||
| @@ -50,6 +50,7 @@ class GPUKernelRuntime : public KernelRuntime { | |||||
| DeviceAddressType GetTargetDeviceAddressType() const override { return DeviceAddressType::kGPU; }; | DeviceAddressType GetTargetDeviceAddressType() const override { return DeviceAddressType::kGPU; }; | ||||
| void *compute_stream() const override { return stream_; } | void *compute_stream() const override { return stream_; } | ||||
| void *communication_stream() const override { return communication_stream_; } | void *communication_stream() const override { return communication_stream_; } | ||||
| void SetAddrInvalid(const DeviceAddressPtr &addr) { addr_state_[addr] = false; } | |||||
| protected: | protected: | ||||
| DeviceAddressPtr CreateDeviceAddress(void *device_ptr, size_t device_size, const string &format, | DeviceAddressPtr CreateDeviceAddress(void *device_ptr, size_t device_size, const string &format, | ||||
| @@ -121,6 +122,7 @@ class GPUKernelRuntime : public KernelRuntime { | |||||
| bool enable_relation_cache_{false}; | bool enable_relation_cache_{false}; | ||||
| std::unordered_map<DeviceAddressPtr, bool> addr_state_; | |||||
| std::unordered_map<AnfNodePtr, std::vector<DeviceAddressPtr>> prev_node_mut_output_addr_cache_; | std::unordered_map<AnfNodePtr, std::vector<DeviceAddressPtr>> prev_node_mut_output_addr_cache_; | ||||
| std::unordered_map<AnfNodePtr, std::vector<DeviceAddressPtr>> prev_node_mut_output_addr_skip_nop_node_cache_; | std::unordered_map<AnfNodePtr, std::vector<DeviceAddressPtr>> prev_node_mut_output_addr_skip_nop_node_cache_; | ||||
| std::unordered_map<AnfNodePtr, std::vector<DeviceAddressPtr>> mut_output_addr_cache_; | std::unordered_map<AnfNodePtr, std::vector<DeviceAddressPtr>> mut_output_addr_cache_; | ||||