Browse Source

try to get address pointer from cache

tags/v1.1.0
tronzhang 5 years ago
parent
commit
1cf2482ba5
5 changed files with 131 additions and 15 deletions
  1. +3
    -0
      mindspore/ccsrc/backend/optimizer/graph_kernel/arithmetic_simplify.cc
  2. +23
    -2
      mindspore/ccsrc/backend/optimizer/mem_reuse/mem_reuse.cc
  3. +8
    -0
      mindspore/ccsrc/backend/optimizer/mem_reuse/mem_reuse.h
  4. +82
    -13
      mindspore/ccsrc/runtime/device/gpu/gpu_kernel_runtime.cc
  5. +15
    -0
      mindspore/ccsrc/runtime/device/gpu/gpu_kernel_runtime.h

+ 3
- 0
mindspore/ccsrc/backend/optimizer/graph_kernel/arithmetic_simplify.cc View File

@@ -15,7 +15,10 @@
*/ */
#include "backend/optimizer/graph_kernel/arithmetic_simplify.h" #include "backend/optimizer/graph_kernel/arithmetic_simplify.h"


#include <algorithm>
#include <list> #include <list>
#include <utility>
#include <vector>
#include "backend/optimizer/graph_kernel/graph_kernel_helper.h" #include "backend/optimizer/graph_kernel/graph_kernel_helper.h"
#include "backend/kernel_compiler/common_utils.h" #include "backend/kernel_compiler/common_utils.h"
#include "backend/session/anf_runtime_algorithm.h" #include "backend/session/anf_runtime_algorithm.h"


+ 23
- 2
mindspore/ccsrc/backend/optimizer/mem_reuse/mem_reuse.cc View File

@@ -268,10 +268,10 @@ KernelRefCountPtr MemReuseUtil::GetKernelInputRef(const CNodePtr &kernel, size_t
session::KernelWithIndex kernel_input; session::KernelWithIndex kernel_input;
if (is_all_nop_node_) { if (is_all_nop_node_) {
// The graph does not remove the nop node. // The graph does not remove the nop node.
kernel_input = AnfAlgo::VisitKernelWithReturnType(input_node, 0, false);
kernel_input = VisitKernelWithReturnType(input_node, 0, false);
} else { } else {
// The graph removes the nop node. // The graph removes the nop node.
kernel_input = AnfAlgo::VisitKernelWithReturnType(input_node, 0, true);
kernel_input = VisitKernelWithReturnType(input_node, 0, true);
} }
if (IsPrimitive(kernel_input.first, prim::kPrimMakeTuple)) { if (IsPrimitive(kernel_input.first, prim::kPrimMakeTuple)) {
MS_LOG(EXCEPTION) << "Input node [" << input_node->DebugString() << "]'s input " << input_idx << " is MakeTuple"; MS_LOG(EXCEPTION) << "Input node [" << input_node->DebugString() << "]'s input " << input_idx << " is MakeTuple";
@@ -459,6 +459,10 @@ void MemReuseUtil::SetAllInfo(const KernelGraph *graph) {
#ifdef MEM_REUSE_DEBUG #ifdef MEM_REUSE_DEBUG
MemReuseChecker::GetInstance().CheckMemReuseIR(total_refs_list_, kernel_def_ptr_list_, graph); MemReuseChecker::GetInstance().CheckMemReuseIR(total_refs_list_, kernel_def_ptr_list_, graph);
#endif #endif

auto context_ptr = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context_ptr);
enable_visit_kernel_cache_ = context_ptr->get_param<bool>(MS_CTX_ENABLE_GRAPH_KERNEL);
} }


uint8_t *MemReuseUtil::GetNodeOutputPtr(const AnfNodePtr &node, size_t index) const { uint8_t *MemReuseUtil::GetNodeOutputPtr(const AnfNodePtr &node, size_t index) const {
@@ -490,5 +494,22 @@ uint8_t *MemReuseUtil::GetNodeWorkSpacePtr(const AnfNodePtr &node, size_t index)
} }
return ptr; return ptr;
} }

session::KernelWithIndex MemReuseUtil::VisitKernelWithReturnType(const AnfNodePtr &node, size_t i,
bool visit_nop_node) {
if (!enable_visit_kernel_cache_ || i != 0) {
return AnfAlgo::VisitKernelWithReturnType(node, i, visit_nop_node);
}

auto &cache =
visit_nop_node ? visit_kernel_with_return_type_in0pos_cache_ : visit_kernel_with_return_type_in0pos_skip_nop_cache_;
std::unordered_map<AnfNodePtr, session::KernelWithIndex>::iterator tag_iter;
if (auto iter = cache.find(node); iter == cache.end()) {
tag_iter = cache.insert({node, AnfAlgo::VisitKernelWithReturnType(node, i, visit_nop_node)}).first;
} else {
tag_iter = iter;
}
return tag_iter->second;
}
} // namespace memreuse } // namespace memreuse
} // namespace mindspore } // namespace mindspore

+ 8
- 0
mindspore/ccsrc/backend/optimizer/mem_reuse/mem_reuse.h View File

@@ -18,11 +18,13 @@
#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_MEM_REUSE_MEM_REUSE_H_ #define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_MEM_REUSE_MEM_REUSE_H_
#include <map> #include <map>
#include <memory> #include <memory>
#include <unordered_map>
#include <vector> #include <vector>
#include "backend/optimizer/mem_reuse/kernel_refcount.h" #include "backend/optimizer/mem_reuse/kernel_refcount.h"
#include "backend/session/anf_runtime_algorithm.h" #include "backend/session/anf_runtime_algorithm.h"
#include "backend/session/kernel_graph.h" #include "backend/session/kernel_graph.h"
#include "backend/kernel_compiler/tbe/tbe_utils.h" #include "backend/kernel_compiler/tbe/tbe_utils.h"
#include "utils/ms_context.h"
using mindspore::kernel::tbe::TbeUtils; using mindspore::kernel::tbe::TbeUtils;
namespace mindspore { namespace mindspore {
namespace memreuse { namespace memreuse {
@@ -85,6 +87,7 @@ class MemReuseUtil {
uint8_t *GetNodeOutputPtr(const AnfNodePtr &node, size_t index) const; uint8_t *GetNodeOutputPtr(const AnfNodePtr &node, size_t index) const;
uint8_t *GetNodeWorkSpacePtr(const AnfNodePtr &node, size_t index) const; uint8_t *GetNodeWorkSpacePtr(const AnfNodePtr &node, size_t index) const;
bool is_all_nop_node() const { return is_all_nop_node_; } bool is_all_nop_node() const { return is_all_nop_node_; }
session::KernelWithIndex VisitKernelWithReturnType(const AnfNodePtr &node, size_t i, bool visit_nop_node);


private: private:
int util_index_; int util_index_;
@@ -101,6 +104,11 @@ class MemReuseUtil {
uint8_t *mem_base_{nullptr}; uint8_t *mem_base_{nullptr};
// kernel_map_: key is the AnfNodePtr addr, value is the KernelDef // kernel_map_: key is the AnfNodePtr addr, value is the KernelDef
std::map<KernelKey, KernelDefPtr> kernel_map_; std::map<KernelKey, KernelDefPtr> kernel_map_;

bool enable_visit_kernel_cache_{false};

std::unordered_map<AnfNodePtr, session::KernelWithIndex> visit_kernel_with_return_type_in0pos_cache_;
std::unordered_map<AnfNodePtr, session::KernelWithIndex> visit_kernel_with_return_type_in0pos_skip_nop_cache_;
}; };
using MemReuseUtilPtr = std::shared_ptr<MemReuseUtil>; using MemReuseUtilPtr = std::shared_ptr<MemReuseUtil>;
} // namespace memreuse } // namespace memreuse


+ 82
- 13
mindspore/ccsrc/runtime/device/gpu/gpu_kernel_runtime.cc View File

@@ -51,6 +51,10 @@ static const size_t PARAMETER_OUTPUT_INDEX = 0;
bool GPUKernelRuntime::SyncStream() { return GPUDeviceManager::GetInstance().SyncStream(stream_); } bool GPUKernelRuntime::SyncStream() { return GPUDeviceManager::GetInstance().SyncStream(stream_); }


bool GPUKernelRuntime::Init() { bool GPUKernelRuntime::Init() {
auto context_ptr = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context_ptr);
enable_relation_cache_ = context_ptr->get_param<bool>(MS_CTX_ENABLE_GRAPH_KERNEL);

if (device_init_ == true) { if (device_init_ == true) {
GPUMemoryAllocator::GetInstance().CheckMaxDeviceMemory(); GPUMemoryAllocator::GetInstance().CheckMaxDeviceMemory();
return true; return true;
@@ -282,7 +286,7 @@ void GPUKernelRuntime::AllocInplaceNodeMemory(const session::KernelGraph *graph)


auto primitive = AnfAlgo::GetCNodePrimitive(item[0]); auto primitive = AnfAlgo::GetCNodePrimitive(item[0]);
auto output_index = GetValue<uint32_t>(primitive->GetAttr("inplace_output_index")); auto output_index = GetValue<uint32_t>(primitive->GetAttr("inplace_output_index"));
auto device_address = AnfAlgo::GetMutableOutputAddr(item[0], output_index, false);
auto device_address = GetMutableOutputAddr(item[0], output_index, false);
if (device_address->GetPtr() != nullptr) { if (device_address->GetPtr() != nullptr) {
continue; continue;
} }
@@ -670,7 +674,7 @@ bool GPUKernelRuntime::AddMemorySwapTask(const AnfNodePtr &kernel, bool mock, bo
MS_EXCEPTION_IF_NULL(need_swap_kernel); MS_EXCEPTION_IF_NULL(need_swap_kernel);
const HostAddress &host_address = const HostAddress &host_address =
mem_swap_manager_->QueryKernelHostAddr(need_swap_kernel, mem_swap_info.output_idx_); mem_swap_manager_->QueryKernelHostAddr(need_swap_kernel, mem_swap_info.output_idx_);
auto device_address = AnfAlgo::GetMutableOutputAddr(need_swap_kernel, mem_swap_info.output_idx_, false);
auto device_address = GetMutableOutputAddr(need_swap_kernel, mem_swap_info.output_idx_, false);


if (mem_swap_info.swap_kind_ == SwapKind::kDeviceToHost) { if (mem_swap_info.swap_kind_ == SwapKind::kDeviceToHost) {
if (mem_swap_manager_->QueryKernelHostAddrIsDirty(need_swap_kernel, mem_swap_info.output_idx_)) { if (mem_swap_manager_->QueryKernelHostAddrIsDirty(need_swap_kernel, mem_swap_info.output_idx_)) {
@@ -821,10 +825,10 @@ bool GPUKernelRuntime::AllocKernelInputDynamicRes(const mindspore::AnfNodePtr &k
DeviceAddressPtr device_address; DeviceAddressPtr device_address;
if (mem_reuse_util_->is_all_nop_node()) { if (mem_reuse_util_->is_all_nop_node()) {
// Graph may be all nop nodes and not remove nop node, so this can not skip nop node. // Graph may be all nop nodes and not remove nop node, so this can not skip nop node.
device_address = AnfAlgo::GetPrevNodeMutableOutputAddr(kernel, i, false);
device_address = GetPrevNodeMutableOutputAddr(kernel, i, false);
} else { } else {
// Graph may be "nop node + depend + node", the input of node is the depend, so this case need skip nop node. // Graph may be "nop node + depend + node", the input of node is the depend, so this case need skip nop node.
device_address = AnfAlgo::GetPrevNodeMutableOutputAddr(kernel, i, true);
device_address = GetPrevNodeMutableOutputAddr(kernel, i, true);
} }


// Get in-place output_address // Get in-place output_address
@@ -833,7 +837,7 @@ bool GPUKernelRuntime::AllocKernelInputDynamicRes(const mindspore::AnfNodePtr &k
auto input_index = GetValue<uint32_t>(primitive->GetAttr("aggregate_input_index")); auto input_index = GetValue<uint32_t>(primitive->GetAttr("aggregate_input_index"));
if (i == input_index) { if (i == input_index) {
auto skip_node = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(kernel), input_index); auto skip_node = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(kernel), input_index);
device_address = AnfAlgo::GetPrevNodeMutableOutputAddr(skip_node, 0, false);
device_address = GetPrevNodeMutableOutputAddr(skip_node, 0, false);
MS_LOG(INFO) << "[inplace optimizer] aggregate: " << kernel->DebugString() MS_LOG(INFO) << "[inplace optimizer] aggregate: " << kernel->DebugString()
<< ", skip: " << skip_node->DebugString() << ", address: " << device_address->GetMutablePtr(); << ", skip: " << skip_node->DebugString() << ", address: " << device_address->GetMutablePtr();
} }
@@ -859,7 +863,7 @@ bool GPUKernelRuntime::AllocKernelOutputDynamicRes(const mindspore::kernel::Kern
UpdateHostSwapOutQueue(mock); UpdateHostSwapOutQueue(mock);
auto output_sizes = kernel_mod.GetOutputSizeList(); auto output_sizes = kernel_mod.GetOutputSizeList();
for (size_t i = 0; i < output_sizes.size(); ++i) { for (size_t i = 0; i < output_sizes.size(); ++i) {
auto device_address = AnfAlgo::GetMutableOutputAddr(kernel, i, false);
auto device_address = GetMutableOutputAddr(kernel, i, false);
MS_EXCEPTION_IF_NULL(device_address); MS_EXCEPTION_IF_NULL(device_address);
if (device_address->ptr_ == nullptr && !AttemptMallocMem(device_address, output_sizes[i], mock)) { if (device_address->ptr_ == nullptr && !AttemptMallocMem(device_address, output_sizes[i], mock)) {
return false; return false;
@@ -929,10 +933,10 @@ void GPUKernelRuntime::AllocCommunicationOpInputDynamicRes(const mindspore::AnfN
DeviceAddressPtr device_address; DeviceAddressPtr device_address;
if (mem_reuse_util_->is_all_nop_node()) { if (mem_reuse_util_->is_all_nop_node()) {
// Graph may be all nop nodes and not remove nop node, so this can not skip nop node. // Graph may be all nop nodes and not remove nop node, so this can not skip nop node.
device_address = AnfAlgo::GetPrevNodeMutableOutputAddr(kernel, i, false);
device_address = GetPrevNodeMutableOutputAddr(kernel, i, false);
} else { } else {
// Graph may be "nop node + depend + node", the input of node is the depend, so this case need skip nop node. // Graph may be "nop node + depend + node", the input of node is the depend, so this case need skip nop node.
device_address = AnfAlgo::GetPrevNodeMutableOutputAddr(kernel, i, true);
device_address = GetPrevNodeMutableOutputAddr(kernel, i, true);
} }
MS_EXCEPTION_IF_NULL(device_address); MS_EXCEPTION_IF_NULL(device_address);
if (device_address->ptr_ == nullptr) { if (device_address->ptr_ == nullptr) {
@@ -958,7 +962,7 @@ void GPUKernelRuntime::AllocCommunicationOpOutputDynamicRes(const mindspore::Anf
MS_EXCEPTION_IF_NULL(kernel_mod); MS_EXCEPTION_IF_NULL(kernel_mod);
auto output_sizes = kernel_mod->GetOutputSizeList(); auto output_sizes = kernel_mod->GetOutputSizeList();
for (size_t i = 0; i < output_sizes.size(); ++i) { for (size_t i = 0; i < output_sizes.size(); ++i) {
auto device_address = AnfAlgo::GetMutableOutputAddr(kernel, i, false);
auto device_address = GetMutableOutputAddr(kernel, i, false);
MS_EXCEPTION_IF_NULL(device_address); MS_EXCEPTION_IF_NULL(device_address);
if (device_address->ptr_ == nullptr) { if (device_address->ptr_ == nullptr) {
is_need_alloc_memory = true; is_need_alloc_memory = true;
@@ -1013,7 +1017,7 @@ void GPUKernelRuntime::FreeKernelDynamicRes(const mindspore::AnfNodePtr &kernel)
} }
} }


auto kernel_with_index = AnfAlgo::GetPrevNodeOutput(kernel, i);
auto kernel_with_index = GetPrevNodeOutput(kernel, i);
if (AnfAlgo::IsCommunicationOp(kernel_with_index.first)) { if (AnfAlgo::IsCommunicationOp(kernel_with_index.first)) {
continue; continue;
} }
@@ -1030,10 +1034,10 @@ void GPUKernelRuntime::FreeKernelDynamicRes(const mindspore::AnfNodePtr &kernel)
DeviceAddressPtr device_address; DeviceAddressPtr device_address;
if (mem_reuse_util_->is_all_nop_node()) { if (mem_reuse_util_->is_all_nop_node()) {
// Graph may be all nop nodes and not remove nop node, so this can not skip nop node. // Graph may be all nop nodes and not remove nop node, so this can not skip nop node.
device_address = AnfAlgo::GetPrevNodeMutableOutputAddr(kernel, i, false);
device_address = GetPrevNodeMutableOutputAddr(kernel, i, false);
} else { } else {
// Graph may be "nop node + depend + node", the input of node is the depend, so this case need skip nop node. // Graph may be "nop node + depend + node", the input of node is the depend, so this case need skip nop node.
device_address = AnfAlgo::GetPrevNodeMutableOutputAddr(kernel, i, true);
device_address = GetPrevNodeMutableOutputAddr(kernel, i, true);
} }
mem_manager_->FreeMemFromMemPool(device_address); mem_manager_->FreeMemFromMemPool(device_address);
device_address->set_status(DeviceAddressStatus::kInDevice); device_address->set_status(DeviceAddressStatus::kInDevice);
@@ -1046,7 +1050,7 @@ void GPUKernelRuntime::FreeKernelDynamicRes(const mindspore::AnfNodePtr &kernel)
continue; continue;
} }
if (kernel_ref_count_ptr->ref_count_dynamic_use_ == 0) { if (kernel_ref_count_ptr->ref_count_dynamic_use_ == 0) {
auto device_address = AnfAlgo::GetMutableOutputAddr(kernel, i, false);
auto device_address = GetMutableOutputAddr(kernel, i, false);
mem_manager_->FreeMemFromMemPool(device_address); mem_manager_->FreeMemFromMemPool(device_address);
device_address->set_status(DeviceAddressStatus::kInDevice); device_address->set_status(DeviceAddressStatus::kInDevice);
} }
@@ -1062,6 +1066,71 @@ void GPUKernelRuntime::FreeKernelDynamicRes(const mindspore::AnfNodePtr &kernel)
} }
} }
} }

DeviceAddressPtr GPUKernelRuntime::GetPrevNodeMutableOutputAddr(const AnfNodePtr &node, size_t i, bool visit_nop_node) {
if (!enable_relation_cache_) {
return AnfAlgo::GetPrevNodeMutableOutputAddr(node, i, visit_nop_node);
}

auto &addr_cache = visit_nop_node ? prev_node_mut_output_addr_cache_ : prev_node_mut_output_addr_skip_nop_node_cache_;
std::unordered_map<AnfNodePtr, std::vector<DeviceAddressPtr>>::iterator addr_iter;
if (auto iter = addr_cache.find(node); iter == addr_cache.end()) {
addr_iter = addr_cache.insert({node, {AnfAlgo::GetInputTensorNum(node), nullptr}}).first;
} else {
addr_iter = iter;
}

if (addr_iter->second[i] == nullptr) {
auto device_address = AnfAlgo::GetPrevNodeMutableOutputAddr(node, i, visit_nop_node);
addr_iter->second[i] = device_address;
}

return addr_iter->second[i];
}

DeviceAddressPtr GPUKernelRuntime::GetMutableOutputAddr(const AnfNodePtr &node, size_t i, bool visit_nop_node) {
if (!enable_relation_cache_) {
return AnfAlgo::GetMutableOutputAddr(node, i, visit_nop_node);
}

auto &addr_cache = visit_nop_node ? mut_output_addr_cache_ : mut_output_addr_skip_nop_node_cache_;
std::unordered_map<AnfNodePtr, std::vector<DeviceAddressPtr>>::iterator addr_iter;
if (auto iter = addr_cache.find(node); iter == addr_cache.end()) {
auto kernel_mod = AnfAlgo::GetKernelMod(node);
MS_EXCEPTION_IF_NULL(kernel_mod);
auto output_sizes = kernel_mod->GetOutputSizeList();
addr_iter = addr_cache.insert({node, {output_sizes.size(), nullptr}}).first;
} else {
addr_iter = iter;
}

if (addr_iter->second[i] == nullptr) {
auto device_address = AnfAlgo::GetMutableOutputAddr(node, i, visit_nop_node);
addr_iter->second[i] = device_address;
}

return addr_iter->second[i];
}

session::KernelWithIndex GPUKernelRuntime::GetPrevNodeOutput(const AnfNodePtr &node, size_t i) {
if (!enable_relation_cache_) {
return AnfAlgo::GetPrevNodeOutput(node, i);
}

std::unordered_map<AnfNodePtr, std::vector<session::KernelWithIndex>>::iterator addr_iter;
if (auto iter = prev_node_output_cache_.find(node); iter == prev_node_output_cache_.end()) {
addr_iter = prev_node_output_cache_.insert({node, {AnfAlgo::GetInputTensorNum(node), {nullptr, 0}}}).first;
} else {
addr_iter = iter;
}

if (addr_iter->second[i].first == nullptr) {
auto kernel_with_index = AnfAlgo::GetPrevNodeOutput(node, i);
addr_iter->second[i] = kernel_with_index;
}

return addr_iter->second[i];
}
} // namespace gpu } // namespace gpu
} // namespace device } // namespace device
} // namespace mindspore } // namespace mindspore

+ 15
- 0
mindspore/ccsrc/runtime/device/gpu/gpu_kernel_runtime.h View File

@@ -18,12 +18,14 @@
#define MINDSPORE_CCSRC_RUNTIME_DEVICE_GPU_GPU_KERNEL_RUNTIME_H_ #define MINDSPORE_CCSRC_RUNTIME_DEVICE_GPU_GPU_KERNEL_RUNTIME_H_


#include <string> #include <string>
#include <map>
#include <memory> #include <memory>
#include <vector> #include <vector>
#include <set> #include <set>
#include <utility> #include <utility>
#include <unordered_map> #include <unordered_map>
#include <unordered_set> #include <unordered_set>
#include "backend/session/anf_runtime_algorithm.h"
#include "runtime/device/kernel_runtime.h" #include "runtime/device/kernel_runtime.h"
#include "runtime/device/kernel_runtime_manager.h" #include "runtime/device/kernel_runtime_manager.h"
#include "backend/optimizer/mem_reuse/mem_swap_manager.h" #include "backend/optimizer/mem_reuse/mem_swap_manager.h"
@@ -96,6 +98,11 @@ class GPUKernelRuntime : public KernelRuntime {
void UpdateHostSwapOutQueue(bool mock); void UpdateHostSwapOutQueue(bool mock);
void ClearSwapInfo(bool mock); void ClearSwapInfo(bool mock);
void AllocInplaceNodeMemory(const session::KernelGraph *graph); void AllocInplaceNodeMemory(const session::KernelGraph *graph);

DeviceAddressPtr GetPrevNodeMutableOutputAddr(const AnfNodePtr &node, size_t i, bool visit_nop_node);
DeviceAddressPtr GetMutableOutputAddr(const AnfNodePtr &node, size_t i, bool visit_nop_node);
session::KernelWithIndex GetPrevNodeOutput(const AnfNodePtr &node, size_t i);

std::unordered_map<uint32_t, MemReuseUtilPtr> mem_reuse_util_map_; std::unordered_map<uint32_t, MemReuseUtilPtr> mem_reuse_util_map_;
std::unordered_map<uint32_t, MemSwapManagerPtr> mem_swap_map_; std::unordered_map<uint32_t, MemSwapManagerPtr> mem_swap_map_;
std::unordered_map<uint32_t, bool> is_first_step_map_; std::unordered_map<uint32_t, bool> is_first_step_map_;
@@ -105,6 +112,14 @@ class GPUKernelRuntime : public KernelRuntime {


MemReuseUtilPtr mem_reuse_util_{nullptr}; MemReuseUtilPtr mem_reuse_util_{nullptr};
MemSwapManagerPtr mem_swap_manager_{nullptr}; MemSwapManagerPtr mem_swap_manager_{nullptr};

bool enable_relation_cache_{false};

std::unordered_map<AnfNodePtr, std::vector<DeviceAddressPtr>> prev_node_mut_output_addr_cache_;
std::unordered_map<AnfNodePtr, std::vector<DeviceAddressPtr>> prev_node_mut_output_addr_skip_nop_node_cache_;
std::unordered_map<AnfNodePtr, std::vector<DeviceAddressPtr>> mut_output_addr_cache_;
std::unordered_map<AnfNodePtr, std::vector<DeviceAddressPtr>> mut_output_addr_skip_nop_node_cache_;
std::unordered_map<AnfNodePtr, std::vector<session::KernelWithIndex>> prev_node_output_cache_;
}; };
MS_REG_KERNEL_RUNTIME(kGPUDevice, GPUKernelRuntime); MS_REG_KERNEL_RUNTIME(kGPUDevice, GPUKernelRuntime);
} // namespace gpu } // namespace gpu


Loading…
Cancel
Save