diff --git a/ge/graph/manager/graph_caching_allocator.cc b/ge/graph/manager/graph_caching_allocator.cc index bfef4001..dd46e670 100644 --- a/ge/graph/manager/graph_caching_allocator.cc +++ b/ge/graph/manager/graph_caching_allocator.cc @@ -28,10 +28,9 @@ const size_t bin_ranges[kNumBins] = {kRoundBlockSize * kKByteSize, kBinSizeUnit8 * kMByteSize, kBinSizeUnit32 * kMByteSize, kBinSizeUnit128 * kMByteSize, - kGByteSize, - kBinSizeUnit4 * kGByteSize, - kBinSizeUnit16 * kGByteSize, - kBinSizeUnit26 * kGByteSize}; + kBinSizeUnit256 * kMByteSize, + kBinSizeUnit512 * kMByteSize, + kGByteSize}; static bool BlockComparator(const Block *left, const Block *right) { if (left->size != right->size) { @@ -63,7 +62,10 @@ size_t GetBinIndex(size_t size) { size_t GetAllocationSize(size_t size) { size_t index = GetBinIndex(size); - return bin_ranges[index]; + if (bin_ranges[index] >= size) { + return bin_ranges[index]; + } + return kGByteSize * ((size + kGByteSize - 1) / kGByteSize); } /// diff --git a/ge/graph/manager/graph_caching_allocator.h b/ge/graph/manager/graph_caching_allocator.h index e024d5cd..42d0952d 100644 --- a/ge/graph/manager/graph_caching_allocator.h +++ b/ge/graph/manager/graph_caching_allocator.h @@ -36,17 +36,17 @@ namespace ge { constexpr size_t kRoundBlockSize = 512; // all block sizes are rounded to at least 512 bytes constexpr size_t kBinSizeUnit4 = 4; constexpr size_t kBinSizeUnit8 = 8; -constexpr size_t kBinSizeUnit16 = 16; -constexpr size_t kBinSizeUnit26 = 26; constexpr size_t kBinSizeUnit32 = 32; constexpr size_t kBinSizeUnit128 = 128; +constexpr size_t kBinSizeUnit256 = 256; +constexpr size_t kBinSizeUnit512 = 512; -constexpr double kSplitThreshold = 0.75; // split when malloc size <= small block size * kSpliThreshold +constexpr double kSplitThreshold = 0.5; // split when malloc size <= small block size * kSpliThreshold constexpr size_t kKByteSize = 1024; constexpr size_t kMByteSize = 1048576; // 1024 * 1024 constexpr size_t kGByteSize = 1073741824; // 1024 * 1024 * 1024 -static const uint32_t kNumBins = 8; +static const uint32_t kNumBins = 7; class MemoryAllocator; diff --git a/ge/hybrid/executor/worker/execution_engine.cc b/ge/hybrid/executor/worker/execution_engine.cc index 7f4fa78c..bd53bec4 100755 --- a/ge/hybrid/executor/worker/execution_engine.cc +++ b/ge/hybrid/executor/worker/execution_engine.cc @@ -322,7 +322,8 @@ Status NodeDoneCallback::OnNodeDone() { GE_CHK_STATUS_RET(ProfilingReport(), "Report node[%s] to profiling failed.", node_item.NodeName().c_str()); } - + // release workspace + context_->ReleaseWorkspace(); // release inputs for (int i = 0; i < context_->NumInputs(); ++i) { context_->ReleaseInput(i); diff --git a/ge/hybrid/node_executor/task_context.cc b/ge/hybrid/node_executor/task_context.cc index e3cf5ae1..085970e0 100644 --- a/ge/hybrid/node_executor/task_context.cc +++ b/ge/hybrid/node_executor/task_context.cc @@ -36,10 +36,6 @@ TaskContext::TaskContext(GraphExecutionContext *execution_context, TaskContext::~TaskContext() { GELOGD("[%s] TaskContext destroyed.", node_item_->NodeName().c_str()); - for (auto ws_addr : workspaces_) { - execution_context_->allocator->Deallocate(ws_addr); - } - // release output for (int i = 0; i < NumOutputs(); ++i) { auto output_tensor = MutableOutput(i); @@ -49,6 +45,13 @@ TaskContext::~TaskContext() { } } +void TaskContext::ReleaseWorkspace() { + GELOGD("[%s] Start ReleaseWorkspace.", node_item_->NodeName().c_str()); + for (auto ws_addr : workspaces_) { + execution_context_->allocator->Deallocate(ws_addr); + } +} + std::unique_ptr TaskContext::Create(NodeState *node_state, GraphExecutionContext *execution_context, SubgraphContext *subgraph_context) { diff --git a/ge/hybrid/node_executor/task_context.h b/ge/hybrid/node_executor/task_context.h index dc4ff058..6bdf8014 100644 --- a/ge/hybrid/node_executor/task_context.h +++ b/ge/hybrid/node_executor/task_context.h @@ -55,6 +55,7 @@ class TaskContext { GeTensorDescPtr MutableOutputDesc(int index) const; void ReleaseInputsAndOutputs(); bool NeedCallback(); + void ReleaseWorkspace(); void ReleaseInput(int index); const TensorValue *GetInput(int index) const; const TensorValue *GetOutput(int index) const;