Browse Source

gpu iterator weak ref opt

tags/v0.5.0-beta
panfengfeng 5 years ago
parent
commit
636d419af3
4 changed files with 15 additions and 47 deletions
  1. +9
    -7
      mindspore/ccsrc/dataset/engine/datasetops/device_queue_op.cc
  2. +2
    -14
      mindspore/ccsrc/device/gpu/gpu_buffer_mgr.cc
  3. +0
    -1
      mindspore/ccsrc/device/gpu/gpu_buffer_mgr.h
  4. +4
    -25
      mindspore/dataset/engine/iterators.py

+ 9
- 7
mindspore/ccsrc/dataset/engine/datasetops/device_queue_op.cc View File

@@ -26,10 +26,6 @@
#include "dataset/util/task_manager.h"
#include "dataset/engine/opt/pass.h"

#ifdef ENABLE_TDTQUE
#include "tdt/tsd_client.h"
#endif

namespace mindspore {
namespace dataset {
DeviceQueueOp::DeviceQueueOp(std::string channel_name, DeviceType device_type, int32_t device_id, int32_t prefetch_size,
@@ -167,9 +163,15 @@ Status DeviceQueueOp::SendDataToGPU() {
is_break_loop = true;
}
}
RETURN_IF_NOT_OK(GetNextInput(&current_buffer));
if (!TaskManager::FindMe()->Interrupted())
RETURN_IF_NOT_OK(GetNextInput(&current_buffer));
else
is_break_loop = true;
}
RETURN_IF_NOT_OK(GetNextInput(&current_buffer));
if (!TaskManager::FindMe()->Interrupted())
RETURN_IF_NOT_OK(GetNextInput(&current_buffer));
else
is_break_loop = true;
}

MS_LOG(INFO) << "Device queue total batch is " << total_batch << ", number of batches is " << num_batch_ << ".";
@@ -191,7 +193,7 @@ Status DeviceQueueOp::RetryPushGPUData(const std::vector<size_t> &data_size, con
items.push_back(data_item);
}

while (!GpuBufferMgr::GetInstance().IsClosed()) {
while (!GpuBufferMgr::GetInstance().IsClosed() && !TaskManager::FindMe()->Interrupted()) {
RETURN_IF_NOT_OK(MallocForGPUData(&items, curr_row));
auto ret = GpuBufferMgr::GetInstance().Push(handle, items, WAIT_TIME);
if (ret) {


+ 2
- 14
mindspore/ccsrc/device/gpu/gpu_buffer_mgr.cc View File

@@ -172,9 +172,7 @@ bool GpuBufferMgr::CloseNotify() {
{
std::lock_guard<std::mutex> lk(close_mutex_);
// set closed_ to be true, all the dataset retry can be jumped out of the while
closed_ = true; // set closed_ to be true, all the dataset retry can be jumped out of the while
// notify all the waiting dataset threads
close_confirm_cond_.notify_all(); // notify all the waiting dataset threads
closed_ = true;
}

// wati for the dataset threads' ack
@@ -188,16 +186,6 @@ bool GpuBufferMgr::CloseNotify() {
return result;
}

void GpuBufferMgr::CloseConfirm() {
// lock scope
{
std::unique_lock<std::mutex> lk(close_mutex_);
// dataset threads wait for the closed_ flag from false to true
close_confirm_cond_.wait(
lk, [this] { return closed_; }); // dataset threads wait for the closed_ flag from false to true
}

sema.Signal();
}
void GpuBufferMgr::CloseConfirm() { sema.Signal(); }
} // namespace device
} // namespace mindspore

+ 0
- 1
mindspore/ccsrc/device/gpu/gpu_buffer_mgr.h View File

@@ -119,7 +119,6 @@ class GpuBufferMgr {
bool closed_;
std::mutex mutex_;
std::mutex close_mutex_;
std::condition_variable close_confirm_cond_;
// how many queues opened by dataset
int open_by_dataset_;
Semaphore sema;


+ 4
- 25
mindspore/dataset/engine/iterators.py View File

@@ -17,7 +17,6 @@
from abc import abstractmethod
import copy
import weakref
from importlib import import_module

from mindspore._c_dataengine import DEPipeline
from mindspore._c_dataengine import OpName
@@ -25,10 +24,6 @@ from mindspore._c_dataengine import OpName
from mindspore import log as logger
from . import datasets as de

try:
context = import_module("mindspore.context")
except ModuleNotFoundError:
context = None

ITERATORS_LIST = list()

@@ -36,18 +31,9 @@ ITERATORS_LIST = list()
def _cleanup():
"""Release all the Iterator."""
for itr_ref in ITERATORS_LIST:
if context:
device_type = context.get_context("device_target")
if device_type == "GPU":
itr_ref.release()
else:
itr = itr_ref()
if itr is not None:
itr.release()
else:
itr = itr_ref()
if itr is not None:
itr.release()
itr = itr_ref()
if itr is not None:
itr.release()


def alter_tree(node):
@@ -101,14 +87,7 @@ class Iterator:
"""

def __init__(self, dataset):
if context:
device_type = context.get_context("device_target")
if device_type == "GPU":
ITERATORS_LIST.append(self)
else:
ITERATORS_LIST.append(weakref.ref(self))
else:
ITERATORS_LIST.append(weakref.ref(self))
ITERATORS_LIST.append(weakref.ref(self))
# create a copy of tree and work on it.
self.dataset = copy.deepcopy(dataset)
self.dataset = alter_tree(self.dataset)


Loading…
Cancel
Save