fix(mge/distributed): change api name with preload
fix(mge/distributed): fix recursive model in preload tensor
fix(mge/distributed): fix recursive when cache contain None
GitOrigin-RevId: 80e2a6dd70
tags/v1.6.0
| @@ -15,12 +15,15 @@ import queue | |||||
| import random | import random | ||||
| import threading | import threading | ||||
| import time | import time | ||||
| from typing import Callable | |||||
| from typing import Callable, Union | |||||
| import numpy as np | import numpy as np | ||||
| from ..device import _sh, get_default_device | |||||
| from ..functional.tensor import copy | |||||
| from ..logger import get_logger | from ..logger import get_logger | ||||
| from ..random.rng import _random_seed_generator | from ..random.rng import _random_seed_generator | ||||
| from ..tensor import Tensor | |||||
| from .collator import Collator | from .collator import Collator | ||||
| from .dataset import Dataset, StreamDataset | from .dataset import Dataset, StreamDataset | ||||
| from .sampler import MapSampler, Sampler, SequentialSampler, StreamSampler | from .sampler import MapSampler, Sampler, SequentialSampler, StreamSampler | ||||
| @@ -44,7 +47,7 @@ def raise_timeout_error(): | |||||
| class DataLoader: | class DataLoader: | ||||
| r"""Provides a convenient way to iterate on a given dataset. | r"""Provides a convenient way to iterate on a given dataset. | ||||
| DataLoader combines a dataset with | DataLoader combines a dataset with | ||||
| :class:`~.Sampler`, :class:`~.Transform` and :class:`~.Collator`, | :class:`~.Sampler`, :class:`~.Transform` and :class:`~.Collator`, | ||||
| make it flexible to get minibatch continually from a dataset. | make it flexible to get minibatch continually from a dataset. | ||||
| @@ -66,6 +69,8 @@ class DataLoader: | |||||
| ``True`` means one batch is divided into :attr:`num_workers` pieces, and | ``True`` means one batch is divided into :attr:`num_workers` pieces, and | ||||
| the workers will process these pieces parallelly. ``False`` means | the workers will process these pieces parallelly. ``False`` means | ||||
| different sub-process will process different batch. Default: False | different sub-process will process different batch. Default: False | ||||
| preload: Defines whether to apply the preloading strategy of dataloader, and parallelize the copy of host2device while kernal is executed to improve the loading speed. default is seted False | |||||
| the output will change from np.ndarry to dtype tensor. the support dtypes for preload are int,float,list[int,float],tuple[int,float],and another type is not supported. | |||||
| """ | """ | ||||
| __initialized = False | __initialized = False | ||||
| @@ -79,6 +84,7 @@ class DataLoader: | |||||
| timeout: int = 0, | timeout: int = 0, | ||||
| timeout_event: Callable = raise_timeout_error, | timeout_event: Callable = raise_timeout_error, | ||||
| divide: bool = False, | divide: bool = False, | ||||
| preload: bool = False, | |||||
| ): | ): | ||||
| if num_workers < 0: | if num_workers < 0: | ||||
| raise ValueError("num_workers should not be negative") | raise ValueError("num_workers should not be negative") | ||||
| @@ -96,6 +102,7 @@ class DataLoader: | |||||
| self.timeout_event = timeout_event | self.timeout_event = timeout_event | ||||
| self.divide = divide | self.divide = divide | ||||
| self.preload = preload | |||||
| if isinstance(dataset, StreamDataset): | if isinstance(dataset, StreamDataset): | ||||
| self.sampler = sampler if sampler else StreamSampler(batch_size=1) | self.sampler = sampler if sampler else StreamSampler(batch_size=1) | ||||
| @@ -145,24 +152,74 @@ class DataLoader: | |||||
| self.num_workers = 0 | self.num_workers = 0 | ||||
| if isinstance(self.dataset, StreamDataset): | if isinstance(self.dataset, StreamDataset): | ||||
| if not self.num_workers: | if not self.num_workers: | ||||
| return _SerialStreamDataLoaderIter(self) | |||||
| return _SerialStreamDataLoaderIter(self, self.preload) | |||||
| else: | else: | ||||
| return _ParallelStreamDataLoaderIter(self) | |||||
| return _ParallelStreamDataLoaderIter(self, self.preload) | |||||
| else: | else: | ||||
| assert isinstance( | assert isinstance( | ||||
| self.dataset, Dataset | self.dataset, Dataset | ||||
| ), "Can not recognize this kind of dataset: %s" % type(self.dataset) | ), "Can not recognize this kind of dataset: %s" % type(self.dataset) | ||||
| if not self.num_workers: | if not self.num_workers: | ||||
| return _SerialMapDataLoaderIter(self) | |||||
| return _SerialMapDataLoaderIter(self, self.preload) | |||||
| else: | else: | ||||
| return _ParallelMapDataLoaderIter(self) | |||||
| return _ParallelMapDataLoaderIter(self, self.preload) | |||||
| def __len__(self): | def __len__(self): | ||||
| return len(self.sampler) | return len(self.sampler) | ||||
| class _BaseMapDataLoaderIter: | |||||
| def __init__(self, loader): | |||||
| class PreLoader: | |||||
| def __init__(self, preload): | |||||
| if preload: | |||||
| self.default_device = get_default_device() | |||||
| self.pre_load_device = self.default_device + ":" + str(_sh.get_next()) | |||||
| self.pre_load_device_cache = None | |||||
| self.preload = preload | |||||
| """ | |||||
| strategy one: load from numpy data, and generate dtype tensor | |||||
| """ | |||||
| def _load_tensor(self, batch, cached=True): | |||||
| if isinstance(batch, np.ndarray): | |||||
| device = self.pre_load_device if cached else self.default_device | |||||
| return Tensor(batch, device=device) | |||||
| elif isinstance(batch, collections.abc.Mapping): | |||||
| return {k: self._load_tensor(v, cached) for k, v in batch.items()} | |||||
| elif isinstance(batch, tuple) and hasattr(batch, "_fields"): # namedtuple | |||||
| return type(batch)(*(self._load_tensor(value, cached) for value in batch)) | |||||
| elif isinstance(batch, collections.abc.Sequence): | |||||
| return [self._load_tensor(value, cached) for value in batch] | |||||
| else: | |||||
| return batch | |||||
| """ | |||||
| strategy two: load from cache that is already tensor just do d2d copy | |||||
| """ | |||||
| def _load_cache(self, data): | |||||
| if isinstance(data, Tensor): | |||||
| if data.device == self.default_device: | |||||
| return data | |||||
| return copy(data, device=self.default_device) | |||||
| elif isinstance(data, collections.abc.Mapping): | |||||
| return {k: self._load_cache(v) for k, v in data.items()} | |||||
| elif isinstance(data, tuple) and hasattr(data, "_fields"): # namedtuple | |||||
| return type(data)(*(self._load_cache(value) for value in data)) | |||||
| elif isinstance(data, collections.abc.Sequence): | |||||
| return [self._load_cache(value) for value in data] | |||||
| else: | |||||
| return data | |||||
| def _swap_out_cache(self): | |||||
| out = self._load_cache(self.pre_load_device_cache) | |||||
| self.pre_load_device_cache = None # clean cache | |||||
| return out | |||||
| class _BaseMapDataLoaderIter(PreLoader): | |||||
| def __init__(self, loader, preload): | |||||
| super().__init__(preload) | |||||
| self.dataset = loader.dataset | self.dataset = loader.dataset | ||||
| self.sampler = loader.sampler | self.sampler = loader.sampler | ||||
| self.seed = _random_seed_generator().__next__() | self.seed = _random_seed_generator().__next__() | ||||
| @@ -184,16 +241,35 @@ class _BaseMapDataLoaderIter: | |||||
| return self | return self | ||||
| def __next__(self): | def __next__(self): | ||||
| if self.preload: | |||||
| cached = self.pre_load_device_cache | |||||
| if cached is None: # first and last | |||||
| if self.num_processed >= len(self): # last | |||||
| raise StopIteration | |||||
| elif self.num_processed == 0: # first | |||||
| self._try_load_tensor(cached=False) # first do the h2d | |||||
| out = self._swap_out_cache() | |||||
| self._try_load_tensor() | |||||
| return out | |||||
| else: | |||||
| if self.num_processed >= len(self): | |||||
| raise StopIteration | |||||
| minibatch = self._get_next_batch() | |||||
| self.num_processed += 1 | |||||
| return minibatch | |||||
| def _try_load_tensor(self, cached=True): | |||||
| if self.num_processed >= len(self): | if self.num_processed >= len(self): | ||||
| raise StopIteration | |||||
| minibatch = self._get_next_batch() | |||||
| self.num_processed += 1 | |||||
| return minibatch | |||||
| return | |||||
| else: | |||||
| self.num_processed += 1 | |||||
| batch = self._get_next_batch() | |||||
| self.pre_load_device_cache = self._load_tensor(batch, cached) | |||||
| class _SerialMapDataLoaderIter(_BaseMapDataLoaderIter): | class _SerialMapDataLoaderIter(_BaseMapDataLoaderIter): | ||||
| def __init__(self, loader): | |||||
| super(_SerialMapDataLoaderIter, self).__init__(loader) | |||||
| def __init__(self, loader, preload): | |||||
| super(_SerialMapDataLoaderIter, self).__init__(loader, preload) | |||||
| self.indices_iter = iter(self.sampler) | self.indices_iter = iter(self.sampler) | ||||
| def _get_next_batch(self): | def _get_next_batch(self): | ||||
| @@ -206,8 +282,8 @@ class _SerialMapDataLoaderIter(_BaseMapDataLoaderIter): | |||||
| class _ParallelMapDataLoaderIter(_BaseMapDataLoaderIter): | class _ParallelMapDataLoaderIter(_BaseMapDataLoaderIter): | ||||
| __initialized = False | __initialized = False | ||||
| def __init__(self, loader): | |||||
| super(_ParallelMapDataLoaderIter, self).__init__(loader) | |||||
| def __init__(self, loader, preload): | |||||
| super(_ParallelMapDataLoaderIter, self).__init__(loader, preload) | |||||
| self.task_queues = [ | self.task_queues = [ | ||||
| multiprocessing.Queue(maxsize=2) for _ in range(self.num_workers) | multiprocessing.Queue(maxsize=2) for _ in range(self.num_workers) | ||||
| @@ -358,8 +434,9 @@ class _ParallelMapDataLoaderIter(_BaseMapDataLoaderIter): | |||||
| self._shutdown() | self._shutdown() | ||||
| class _BaseStreamDataLoaderIter: | |||||
| def __init__(self, loader): | |||||
| class _BaseStreamDataLoaderIter(PreLoader): | |||||
| def __init__(self, loader, preload): | |||||
| super().__init__(preload) | |||||
| self.dataset = loader.dataset | self.dataset = loader.dataset | ||||
| self.sampler = loader.sampler | self.sampler = loader.sampler | ||||
| self.transform = loader.transform | self.transform = loader.transform | ||||
| @@ -388,12 +465,23 @@ class _BaseStreamDataLoaderIter: | |||||
| return self | return self | ||||
| def __next__(self): | def __next__(self): | ||||
| return self._get_next_batch() | |||||
| if self.preload: | |||||
| if self.pre_load_device_cache is None: | |||||
| self._try_load_tensor(cached=False) # load in current | |||||
| out = self._swap_out_cache() | |||||
| self._try_load_tensor() # load in cached | |||||
| return out | |||||
| else: | |||||
| return self._get_next_batch() | |||||
| def _try_load_tensor(self, cached=True): | |||||
| batch = self._get_next_batch() | |||||
| self.pre_load_device_cache = self._load_tensor(batch, cached) | |||||
| class _SerialStreamDataLoaderIter(_BaseStreamDataLoaderIter): | class _SerialStreamDataLoaderIter(_BaseStreamDataLoaderIter): | ||||
| def __init__(self, loader): | |||||
| super().__init__(loader) | |||||
| def __init__(self, loader, preload): | |||||
| super().__init__(loader, preload) | |||||
| self.dataset_iter = iter(self.dataset) | self.dataset_iter = iter(self.dataset) | ||||
| self.idx = 0 | self.idx = 0 | ||||
| self.unused = [] | self.unused = [] | ||||
| @@ -439,8 +527,8 @@ class _SerialStreamDataLoaderIter(_BaseStreamDataLoaderIter): | |||||
| class _ParallelStreamDataLoaderIter(_BaseStreamDataLoaderIter): | class _ParallelStreamDataLoaderIter(_BaseStreamDataLoaderIter): | ||||
| __initialized = False | __initialized = False | ||||
| def __init__(self, loader): | |||||
| super().__init__(loader) | |||||
| def __init__(self, loader, preload): | |||||
| super().__init__(loader, preload) | |||||
| self.shutdown_flag = multiprocessing.Value("i", 0) | self.shutdown_flag = multiprocessing.Value("i", 0) | ||||
| @@ -29,6 +29,19 @@ __all__ = [ | |||||
| ] | ] | ||||
| class _stream_helper: | |||||
| def __init__(self): | |||||
| self.stream = 1 | |||||
| def get_next(self): | |||||
| out = self.stream | |||||
| self.stream = self.stream + 1 | |||||
| return out | |||||
| _sh = _stream_helper() | |||||
| def _valid_device(inp): | def _valid_device(inp): | ||||
| if isinstance(inp, str) and re.match("^([cxg]pu|rocm)(\d+|\d+:\d+|x)$", inp): | if isinstance(inp, str) and re.match("^([cxg]pu|rocm)(\d+|\d+:\d+|x)$", inp): | ||||
| return True | return True | ||||
| @@ -12,7 +12,7 @@ from typing import List, Optional, Tuple | |||||
| from mprop import mproperty | from mprop import mproperty | ||||
| from ..device import set_default_device, what_is_xpu | |||||
| from ..device import _sh, set_default_device, what_is_xpu | |||||
| from ..random import seed | from ..random import seed | ||||
| from .server import Client, Server | from .server import Client, Server | ||||
| @@ -27,7 +27,6 @@ class StaticData: | |||||
| proc_rank = None | proc_rank = None | ||||
| device = None | device = None | ||||
| backend = None | backend = None | ||||
| next_stream = None | |||||
| device_type = None | device_type = None | ||||
| machine_ranks = None | machine_ranks = None | ||||
| @@ -43,6 +42,8 @@ class Group: | |||||
| Args: | Args: | ||||
| proc_ranks: rank list of the group, the first one is root rank. | proc_ranks: rank list of the group, the first one is root rank. | ||||
| """ | """ | ||||
| def __init__(self, proc_ranks): | def __init__(self, proc_ranks): | ||||
| @@ -55,9 +56,7 @@ class Group: | |||||
| def reset(self, proc_ranks): | def reset(self, proc_ranks): | ||||
| self.check(proc_ranks) | self.check(proc_ranks) | ||||
| self.proc_ranks = proc_ranks | self.proc_ranks = proc_ranks | ||||
| self.stream = _sd.next_stream | |||||
| _sd.next_stream += 1 | |||||
| self.is_single_machine_cache = None | |||||
| self.stream = _sh.get_next() | |||||
| def check(self, proc_ranks): | def check(self, proc_ranks): | ||||
| assert _sd is not None, "please call init_process_group first" | assert _sd is not None, "please call init_process_group first" | ||||
| @@ -160,7 +159,6 @@ def init_process_group( | |||||
| _sd.proc_rank = rank | _sd.proc_rank = rank | ||||
| _sd.device = device | _sd.device = device | ||||
| _sd.backend = backend | _sd.backend = backend | ||||
| _sd.next_stream = 1 | |||||
| _sd.device_type = device_type | _sd.device_type = device_type | ||||
| WORLD.reset(list(range(world_size))) | WORLD.reset(list(range(world_size))) | ||||
| @@ -0,0 +1,308 @@ | |||||
| # -*- coding: utf-8 -*- | |||||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
| # | |||||
| # Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, | |||||
| # software distributed under the License is distributed on an | |||||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| import gc | |||||
| import os | |||||
| import platform | |||||
| import time | |||||
| import numpy as np | |||||
| import pytest | |||||
| from megengine.data.collator import Collator | |||||
| from megengine.data.dataloader import DataLoader | |||||
| from megengine.data.dataset import ArrayDataset, StreamDataset | |||||
| from megengine.data.sampler import RandomSampler, SequentialSampler, StreamSampler | |||||
| from megengine.data.transform import ( | |||||
| Compose, | |||||
| Normalize, | |||||
| PseudoTransform, | |||||
| ToMode, | |||||
| Transform, | |||||
| ) | |||||
| def init_dataset(): | |||||
| sample_num = 100 | |||||
| rand_data = np.random.randint(0, 255, size=(sample_num, 1, 32, 32), dtype=np.uint8) | |||||
| label = np.random.randint(0, 10, size=(sample_num,), dtype=int) | |||||
| dataset = ArrayDataset(rand_data, label) | |||||
| return dataset | |||||
| def test_dataloader_init(): | |||||
| dataset = init_dataset() | |||||
| with pytest.raises(ValueError): | |||||
| dataloader = DataLoader(dataset, num_workers=2, divide=True) | |||||
| with pytest.raises(ValueError): | |||||
| dataloader = DataLoader(dataset, num_workers=-1) | |||||
| with pytest.raises(ValueError): | |||||
| dataloader = DataLoader(dataset, timeout=-1) | |||||
| with pytest.raises(ValueError): | |||||
| dataloader = DataLoader(dataset, num_workers=0, divide=True) | |||||
| dataloader = DataLoader(dataset, preload=True) | |||||
| assert isinstance(dataloader.sampler, SequentialSampler) | |||||
| assert isinstance(dataloader.transform, PseudoTransform) | |||||
| assert isinstance(dataloader.collator, Collator) | |||||
| dataloader = DataLoader( | |||||
| dataset, | |||||
| sampler=RandomSampler(dataset, batch_size=6, drop_last=False), | |||||
| preload=True, | |||||
| ) | |||||
| assert len(dataloader) == 17 | |||||
| dataloader = DataLoader( | |||||
| dataset, | |||||
| sampler=RandomSampler(dataset, batch_size=6, drop_last=True), | |||||
| preload=True, | |||||
| ) | |||||
| assert len(dataloader) == 16 | |||||
| class MyStream(StreamDataset): | |||||
| def __init__(self, number, batch=False, error_foramt=False, block=False): | |||||
| self.number = number | |||||
| self.batch = batch | |||||
| self.error_format = error_foramt | |||||
| self.block = block | |||||
| def __iter__(self): | |||||
| for cnt in range(self.number): | |||||
| if self.block: | |||||
| for _ in range(10): | |||||
| time.sleep(1) | |||||
| if self.batch: | |||||
| data = np.random.randint(0, 256, (2, 2, 2, 3), dtype="uint8") | |||||
| yield (True, (data, [cnt, cnt - self.number])) | |||||
| else: | |||||
| data = np.random.randint(0, 256, (2, 2, 3), dtype="uint8") | |||||
| if self.error_format: | |||||
| yield (data, cnt) | |||||
| else: | |||||
| yield (False, (data, cnt)) | |||||
| raise StopIteration | |||||
| @pytest.mark.parametrize("batch", [True, False]) | |||||
| @pytest.mark.parametrize("num_workers", [0, 2]) | |||||
| def test_stream_dataloader(batch, num_workers): | |||||
| dataset = MyStream(100, batch=batch) | |||||
| sampler = StreamSampler(batch_size=4) | |||||
| dataloader = DataLoader( | |||||
| dataset, | |||||
| sampler, | |||||
| Compose([Normalize(mean=(103, 116, 123), std=(57, 57, 58)), ToMode("CHW")]), | |||||
| num_workers=num_workers, | |||||
| preload=True, | |||||
| ) | |||||
| check_set = set() | |||||
| for step, data in enumerate(dataloader): | |||||
| if step == 10: | |||||
| break | |||||
| assert data[0]._tuple_shape == (4, 3, 2, 2) | |||||
| assert data[1]._tuple_shape == (4,) | |||||
| for i in data[1]: | |||||
| assert i not in check_set | |||||
| check_set.add(i) | |||||
| def test_stream_dataloader_error(): | |||||
| dataset = MyStream(100, error_foramt=True) | |||||
| sampler = StreamSampler(batch_size=4) | |||||
| dataloader = DataLoader(dataset, sampler, preload=True) | |||||
| with pytest.raises(AssertionError, match=r".*tuple.*"): | |||||
| data_iter = iter(dataloader) | |||||
| next(data_iter) | |||||
| @pytest.mark.parametrize("num_workers", [0, 2]) | |||||
| def test_stream_dataloader_timeout(num_workers): | |||||
| dataset = MyStream(100, False, block=True) | |||||
| sampler = StreamSampler(batch_size=4) | |||||
| dataloader = DataLoader( | |||||
| dataset, sampler, num_workers=num_workers, timeout=2, preload=True | |||||
| ) | |||||
| with pytest.raises(RuntimeError, match=r".*timeout.*"): | |||||
| data_iter = iter(dataloader) | |||||
| next(data_iter) | |||||
| def test_dataloader_serial(): | |||||
| dataset = init_dataset() | |||||
| dataloader = DataLoader( | |||||
| dataset, | |||||
| sampler=RandomSampler(dataset, batch_size=4, drop_last=False), | |||||
| preload=True, | |||||
| ) | |||||
| for (data, label) in dataloader: | |||||
| assert data._tuple_shape == (4, 1, 32, 32) | |||||
| assert label._tuple_shape == (4,) | |||||
| def test_dataloader_parallel(): | |||||
| # set max shared memory to 100M | |||||
| os.environ["MGE_PLASMA_MEMORY"] = "100000000" | |||||
| dataset = init_dataset() | |||||
| dataloader = DataLoader( | |||||
| dataset, | |||||
| sampler=RandomSampler(dataset, batch_size=4, drop_last=False), | |||||
| num_workers=2, | |||||
| divide=False, | |||||
| preload=True, | |||||
| ) | |||||
| for (data, label) in dataloader: | |||||
| assert data._tuple_shape == (4, 1, 32, 32) | |||||
| assert label._tuple_shape == (4,) | |||||
| dataloader = DataLoader( | |||||
| dataset, | |||||
| sampler=RandomSampler(dataset, batch_size=4, drop_last=False), | |||||
| num_workers=2, | |||||
| divide=True, | |||||
| preload=True, | |||||
| ) | |||||
| for (data, label) in dataloader: | |||||
| assert data._tuple_shape == (4, 1, 32, 32) | |||||
| assert label._tuple_shape == (4,) | |||||
| @pytest.mark.skipif( | |||||
| platform.system() == "Windows", | |||||
| reason="dataloader do not support parallel on windows", | |||||
| ) | |||||
| def test_dataloader_parallel_timeout(): | |||||
| dataset = init_dataset() | |||||
| class TimeoutTransform(Transform): | |||||
| def __init__(self): | |||||
| pass | |||||
| def apply(self, input): | |||||
| time.sleep(10) | |||||
| return input | |||||
| dataloader = DataLoader( | |||||
| dataset, | |||||
| sampler=RandomSampler(dataset, batch_size=4, drop_last=False), | |||||
| transform=TimeoutTransform(), | |||||
| num_workers=2, | |||||
| timeout=2, | |||||
| preload=True, | |||||
| ) | |||||
| with pytest.raises(RuntimeError, match=r".*timeout.*"): | |||||
| data_iter = iter(dataloader) | |||||
| batch_data = next(data_iter) | |||||
| @pytest.mark.skipif( | |||||
| platform.system() == "Windows", | |||||
| reason="dataloader do not support parallel on windows", | |||||
| ) | |||||
| def test_dataloader_parallel_worker_exception(): | |||||
| print("in target") | |||||
| dataset = init_dataset() | |||||
| class FakeErrorTransform(Transform): | |||||
| def __init__(self): | |||||
| pass | |||||
| def apply(self, input): | |||||
| y = x + 1 | |||||
| return input | |||||
| dataloader = DataLoader( | |||||
| dataset, | |||||
| sampler=RandomSampler(dataset, batch_size=4, drop_last=False), | |||||
| transform=FakeErrorTransform(), | |||||
| num_workers=2, | |||||
| preload=True, | |||||
| ) | |||||
| with pytest.raises(RuntimeError, match=r"worker.*died"): | |||||
| data_iter = iter(dataloader) | |||||
| batch_data = next(data_iter) | |||||
| def _multi_instances_parallel_dataloader_worker(): | |||||
| dataset = init_dataset() | |||||
| for divide_flag in [True, False]: | |||||
| train_dataloader = DataLoader( | |||||
| dataset, | |||||
| sampler=RandomSampler(dataset, batch_size=4, drop_last=False), | |||||
| num_workers=2, | |||||
| divide=divide_flag, | |||||
| preload=True, | |||||
| ) | |||||
| val_dataloader = DataLoader( | |||||
| dataset, | |||||
| sampler=RandomSampler(dataset, batch_size=10, drop_last=False), | |||||
| num_workers=2, | |||||
| divide=divide_flag, | |||||
| preload=True, | |||||
| ) | |||||
| for idx, (data, label) in enumerate(train_dataloader): | |||||
| assert data._tuple_shape == (4, 1, 32, 32) | |||||
| assert label._tuple_shape == (4,) | |||||
| if idx % 5 == 0: | |||||
| for val_data, val_label in val_dataloader: | |||||
| assert val_data._tuple_shape == (10, 1, 32, 32) | |||||
| assert val_label._tuple_shape == (10,) | |||||
| def test_dataloader_parallel_multi_instances(): | |||||
| # set max shared memory to 100M | |||||
| os.environ["MGE_PLASMA_MEMORY"] = "100000000" | |||||
| _multi_instances_parallel_dataloader_worker() | |||||
| @pytest.mark.isolated_distributed | |||||
| def test_dataloader_parallel_multi_instances_multiprocessing(): | |||||
| gc.collect() | |||||
| # set max shared memory to 100M | |||||
| os.environ["MGE_PLASMA_MEMORY"] = "100000000" | |||||
| import multiprocessing as mp | |||||
| # mp.set_start_method("spawn") | |||||
| processes = [] | |||||
| for i in range(4): | |||||
| p = mp.Process(target=_multi_instances_parallel_dataloader_worker) | |||||
| p.start() | |||||
| processes.append(p) | |||||
| for p in processes: | |||||
| p.join() | |||||
| assert p.exitcode == 0 | |||||
| @pytest.mark.parametrize("num_workers", [0, 2]) | |||||
| def test_timeout_event(num_workers): | |||||
| def cb(): | |||||
| return (True, (np.zeros(shape=(2, 2, 2, 3)), np.ones(shape=(2,)))) | |||||
| dataset = MyStream(100, block=True) | |||||
| sampler = StreamSampler(batch_size=4) | |||||
| dataloader = DataLoader( | |||||
| dataset, | |||||
| sampler, | |||||
| num_workers=num_workers, | |||||
| timeout=2, | |||||
| timeout_event=cb, | |||||
| preload=True, | |||||
| ) | |||||
| for _, data in enumerate(dataloader): | |||||
| np.testing.assert_equal(data[0], np.zeros(shape=(4, 2, 2, 3))) | |||||
| np.testing.assert_equal(data[1], np.ones(shape=(4,))) | |||||
| break | |||||