Browse Source

feat(mge/data): dpflow dataset, stream sampler and loader

GitOrigin-RevId: cbb4510a13
tags/v1.1.0
Megvii Engine Team 5 years ago
parent
commit
f04e0d777e
3 changed files with 236 additions and 16 deletions
  1. +1
    -0
      imperative/python/megengine/data/__init__.py
  2. +198
    -11
      imperative/python/megengine/data/dataloader.py
  3. +37
    -5
      imperative/python/megengine/data/sampler.py

+ 1
- 0
imperative/python/megengine/data/__init__.py View File

@@ -14,4 +14,5 @@ from .sampler import (
ReplacementSampler, ReplacementSampler,
Sampler, Sampler,
SequentialSampler, SequentialSampler,
StreamSampler,
) )

+ 198
- 11
imperative/python/megengine/data/dataloader.py View File

@@ -19,8 +19,8 @@ import numpy as np
from ..logger import get_logger from ..logger import get_logger
from ..random.rng import _random_seed_generator from ..random.rng import _random_seed_generator
from .collator import Collator from .collator import Collator
from .dataset import Dataset
from .sampler import Sampler, SequentialSampler
from .dataset import Dataset, MapDataset, StreamDataset
from .sampler import Sampler, SequentialSampler, StreamSampler
from .transform import PseudoTransform, Transform from .transform import PseudoTransform, Transform


logger = get_logger(__name__) logger = get_logger(__name__)
@@ -82,13 +82,21 @@ class DataLoader:
raise ValueError("divide should not be set to True when num_workers <= 1") raise ValueError("divide should not be set to True when num_workers <= 1")


self.dataset = dataset self.dataset = dataset

self.num_workers = num_workers self.num_workers = num_workers
self.timeout = timeout self.timeout = timeout


self.divide = divide self.divide = divide


if sampler is None: if sampler is None:
self.sampler = SequentialSampler(dataset, batch_size=1, drop_last=False)
if isinstance(dataset, MapDataset):
self.sampler = SequentialSampler(dataset, batch_size=1, drop_last=False)
elif isinstance(dataset, StreamDataset):
self.sampler = StreamSampler(batch_size=1)
else:
raise TypeError(
"can not recognize this kind of dataset: %s" % type(dataset)
)
else: else:
self.sampler = sampler self.sampler = sampler


@@ -120,16 +128,26 @@ class DataLoader:
"pyarrow.plasma does not support ParallelDataLoader on windows, changing num_workers to be zero" "pyarrow.plasma does not support ParallelDataLoader on windows, changing num_workers to be zero"
) )
self.num_workers = 0 self.num_workers = 0
if self.num_workers == 0:
return _SerialDataLoaderIter(self)
if isinstance(self.dataset, StreamDataset):
if not self.num_workers:
return _SerialStreamDataLoaderIter(self)
else:
return _ParallelStreamDataLoaderIter(self)
elif isinstance(self.dataset, MapDataset):
if not self.num_workers:
return _SerialMapDataLoaderIter(self)
else:
return _ParallelMapDataLoaderIter(self)
else: else:
return _ParallelDataLoaderIter(self)
raise TypeError(
"can not recognize this kind of dataset: %s" % type(self.dataset)
)


def __len__(self): def __len__(self):
return len(self.sampler) return len(self.sampler)




class _BaseDataLoaderIter:
class _BaseMapDataLoaderIter:
def __init__(self, loader): def __init__(self, loader):
self.dataset = loader.dataset self.dataset = loader.dataset
self.sampler = loader.sampler self.sampler = loader.sampler
@@ -158,9 +176,9 @@ class _BaseDataLoaderIter:
return minibatch return minibatch




class _SerialDataLoaderIter(_BaseDataLoaderIter):
class _SerialMapDataLoaderIter(_BaseMapDataLoaderIter):
def __init__(self, loader): def __init__(self, loader):
super(_SerialDataLoaderIter, self).__init__(loader)
super(_SerialMapDataLoaderIter, self).__init__(loader)
self.indices_iter = iter(self.sampler) self.indices_iter = iter(self.sampler)


def _get_next_batch(self): def _get_next_batch(self):
@@ -170,11 +188,11 @@ class _SerialDataLoaderIter(_BaseDataLoaderIter):
return self.collator.apply(trans_items) return self.collator.apply(trans_items)




class _ParallelDataLoaderIter(_BaseDataLoaderIter):
class _ParallelMapDataLoaderIter(_BaseMapDataLoaderIter):
__initialized = False __initialized = False


def __init__(self, loader): def __init__(self, loader):
super(_ParallelDataLoaderIter, self).__init__(loader)
super(_ParallelMapDataLoaderIter, self).__init__(loader)


self.task_queues = [ self.task_queues = [
multiprocessing.Queue(maxsize=2) for _ in range(self.num_workers) multiprocessing.Queue(maxsize=2) for _ in range(self.num_workers)
@@ -326,6 +344,175 @@ class _ParallelDataLoaderIter(_BaseDataLoaderIter):
self._shutdown() self._shutdown()




class _BaseStreamDataLoaderIter:
def __init__(self, loader):
self.dataset = loader.dataset
self.sampler = loader.sampler
self.transform = loader.transform
self.collator = loader.collator
self.num_workers = loader.num_workers
self.timeout = loader.timeout
self.post_process = self.dataset.post_process

def _get_next_batch(self):
raise NotImplementedError

def __iter__(self):
return self

def __next__(self):
return self.post_process(self._get_next_batch())


class _SerialStreamDataLoaderIter(_BaseStreamDataLoaderIter):
def __init__(self, loader):
super().__init__(loader)
self.dataset_iter = iter(self.dataset)

def _get_next_batch(self):
ret = []
start_time = time.time()
while len(ret) != self.sampler.batch_size:
waited_time = time.time() - start_time
if self.timeout > 0 and waited_time > self.timeout:
raise RuntimeError("get_next_batch timeout!")
item = next(self.dataset_iter)
for idx in range(len(item[0])):
trans_item = self.transform.apply(tuple(e[idx] for e in item))
ret.append(trans_item)
if len(ret) == self.sampler.batch_size:
break
return self.collator.apply(ret)


class _ParallelStreamDataLoaderIter(_BaseStreamDataLoaderIter):
__initialized = False

def __init__(self, loader):
super().__init__(loader)

self.shutdown_flag = multiprocessing.Value("i", 0)

# shared-memory queue implemented by pyarrow plasma store
from ._queue import PlasmaShmQueue

self.batch_queue = PlasmaShmQueue(maxsize=2)
self.workers = []
self.worker_queues = [
multiprocessing.Queue(maxsize=1) for _ in range(self.num_workers)
]
for worker_id in range(self.num_workers):
worker = multiprocessing.Process(
target=self._gen_data, args=(worker_id,), daemon=True
)
worker.start()
self.workers.append(worker)
self.collator_worker = multiprocessing.Process(
target=self._gen_batch, daemon=True
)
self.collator_worker.start()

self.__initialized = True

def _gen_data(self, worker_id):
dataset_iter = iter(self.dataset)
while True:
if self.shutdown_flag.value == 1:
break
item = next(dataset_iter)
for idx in range(len(item[0])):
trans_item = self.transform.apply(tuple(e[idx] for e in item))
while True:
try:
self.worker_queues[worker_id].put(trans_item)
break
except queue.Full:
if self.shutdown_flag.value == 1:
break
logger.debug("batch part queue is full")

def _gen_batch(self):
cnt = -1
trans_items = []
while True:
if self.shutdown_flag.value == 1:
break
cnt += 1
queue_id = cnt % self.num_workers
try:
trans_item = self.worker_queues[queue_id].get(
timeout=MP_QUEUE_GET_TIMEOUT
)
except queue.Empty:
continue
trans_items.append(trans_item)
if len(trans_items) == self.sampler.batch_size:
batch_data = self.collator.apply(trans_items)
while True:
try:
self.batch_queue.put(batch_data, timeout=1)
break
except queue.Full:
if self.shutdown_flag.value == 1:
break
logger.debug("batch queue is full")
trans_items = []

def _check_workers(self):
if not self.collator_worker.is_alive():
exitcode = self.collator_worker.exitcode
if exitcode != 0:
raise RuntimeError("collator worker died. {}".format(exitcode))

for worker_id, worker in enumerate(self.workers):
if not worker.is_alive():
exitcode = worker.exitcode
if exitcode != 0:
raise RuntimeError(
"worker: {} died. {}".format(worker_id, exitcode)
)

def _try_get_next_batch(self):
start_time = time.time()
while True:
self._check_workers()
try:
return self.batch_queue.get(timeout=1)
except queue.Empty:
logger.debug("batch queue empty!")
waited_time = time.time() - start_time
if self.timeout > 0 and waited_time > self.timeout:
raise RuntimeError("get_next_batch timeout!")

def _get_next_batch(self):
batch_data = self._try_get_next_batch()
return batch_data

def _shutdown(self):
with self.shutdown_flag.get_lock():
self.shutdown_flag.value = 1

if self.collator_worker.is_alive():
self.collator_worker.terminate()
self.collator_worker.join()

for worker in self.workers:
if worker.is_alive():
worker.terminate()
worker.join()

for q in self.worker_queues:
q.cancel_join_thread()
q.close()

self.batch_queue.cancel_join_thread()
self.batch_queue.close()

def __del__(self):
if self.__initialized:
self._shutdown()


def _task_feeding_loop( def _task_feeding_loop(
indices_iter, task_queues, num_workers, divide, shutdown_flag, feed_batch_idx indices_iter, task_queues, num_workers, divide, shutdown_flag, feed_batch_idx
): ):


+ 37
- 5
imperative/python/megengine/data/sampler.py View File

@@ -8,7 +8,7 @@
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import collections.abc import collections.abc
import math import math
from abc import ABC
from abc import ABC, abstractmethod
from typing import Any, Generator, Iterator, List, Union from typing import Any, Generator, Iterator, List, Union


import numpy as np import numpy as np
@@ -17,6 +17,16 @@ import megengine.distributed as dist




class Sampler(ABC): class Sampler(ABC):
r"""
An abstract class for all Sampler
"""

@abstractmethod
def __init__(self):
pass


class MapSampler(Sampler):
def __init__( def __init__(
self, self,
dataset, dataset,
@@ -145,7 +155,29 @@ class Sampler(ABC):
return iter(batch_index) return iter(batch_index)




class SequentialSampler(Sampler):
class StreamSampler(Sampler):
"""
Sampler for stream dataset.

.. warning::

In the case of multiple workers, sampler should ensure that each worker gets
different data. But this class cannot do it yet, please build your own
dataset and sampler to achieve this goal.

"""

def __init__(self, batch_size=1):
self.batch_size = batch_size

def __iter__(self):
return self

def __next__(self):
return range(self.batch_size)


class SequentialSampler(MapSampler):
def __init__( def __init__(
self, self,
dataset, dataset,
@@ -176,7 +208,7 @@ class SequentialSampler(Sampler):
return self.indices return self.indices




class RandomSampler(Sampler):
class RandomSampler(MapSampler):
def __init__( def __init__(
self, self,
dataset, dataset,
@@ -205,7 +237,7 @@ class RandomSampler(Sampler):
return self.rng.permutation(self.indices).tolist() return self.rng.permutation(self.indices).tolist()




class ReplacementSampler(Sampler):
class ReplacementSampler(MapSampler):
def __init__( def __init__(
self, self,
dataset, dataset,
@@ -249,7 +281,7 @@ class ReplacementSampler(Sampler):
return self.rng.multinomial(n, self.weights, self.num_samples).tolist() return self.rng.multinomial(n, self.weights, self.num_samples).tolist()




class Infinite(Sampler):
class Infinite(MapSampler):
r"""Infinite Sampler warper for basic sampler.""" r"""Infinite Sampler warper for basic sampler."""


def sample(self): def sample(self):


Loading…
Cancel
Save