fix(mge/distributed): change api name with preload
fix(mge/distributed): fix recursive model in preload tensor
fix(mge/distributed): fix recursive when cache contain None
GitOrigin-RevId: 80e2a6dd70
tags/v1.6.0
| @@ -15,12 +15,15 @@ import queue | |||
| import random | |||
| import threading | |||
| import time | |||
| from typing import Callable | |||
| from typing import Callable, Union | |||
| import numpy as np | |||
| from ..device import _sh, get_default_device | |||
| from ..functional.tensor import copy | |||
| from ..logger import get_logger | |||
| from ..random.rng import _random_seed_generator | |||
| from ..tensor import Tensor | |||
| from .collator import Collator | |||
| from .dataset import Dataset, StreamDataset | |||
| from .sampler import MapSampler, Sampler, SequentialSampler, StreamSampler | |||
| @@ -44,7 +47,7 @@ def raise_timeout_error(): | |||
| class DataLoader: | |||
| r"""Provides a convenient way to iterate on a given dataset. | |||
| DataLoader combines a dataset with | |||
| :class:`~.Sampler`, :class:`~.Transform` and :class:`~.Collator`, | |||
| make it flexible to get minibatch continually from a dataset. | |||
| @@ -66,6 +69,8 @@ class DataLoader: | |||
| ``True`` means one batch is divided into :attr:`num_workers` pieces, and | |||
| the workers will process these pieces parallelly. ``False`` means | |||
| different sub-process will process different batch. Default: False | |||
| preload: Defines whether to apply the preloading strategy of dataloader, and parallelize the copy of host2device while kernal is executed to improve the loading speed. default is seted False | |||
| the output will change from np.ndarry to dtype tensor. the support dtypes for preload are int,float,list[int,float],tuple[int,float],and another type is not supported. | |||
| """ | |||
| __initialized = False | |||
| @@ -79,6 +84,7 @@ class DataLoader: | |||
| timeout: int = 0, | |||
| timeout_event: Callable = raise_timeout_error, | |||
| divide: bool = False, | |||
| preload: bool = False, | |||
| ): | |||
| if num_workers < 0: | |||
| raise ValueError("num_workers should not be negative") | |||
| @@ -96,6 +102,7 @@ class DataLoader: | |||
| self.timeout_event = timeout_event | |||
| self.divide = divide | |||
| self.preload = preload | |||
| if isinstance(dataset, StreamDataset): | |||
| self.sampler = sampler if sampler else StreamSampler(batch_size=1) | |||
| @@ -145,24 +152,74 @@ class DataLoader: | |||
| self.num_workers = 0 | |||
| if isinstance(self.dataset, StreamDataset): | |||
| if not self.num_workers: | |||
| return _SerialStreamDataLoaderIter(self) | |||
| return _SerialStreamDataLoaderIter(self, self.preload) | |||
| else: | |||
| return _ParallelStreamDataLoaderIter(self) | |||
| return _ParallelStreamDataLoaderIter(self, self.preload) | |||
| else: | |||
| assert isinstance( | |||
| self.dataset, Dataset | |||
| ), "Can not recognize this kind of dataset: %s" % type(self.dataset) | |||
| if not self.num_workers: | |||
| return _SerialMapDataLoaderIter(self) | |||
| return _SerialMapDataLoaderIter(self, self.preload) | |||
| else: | |||
| return _ParallelMapDataLoaderIter(self) | |||
| return _ParallelMapDataLoaderIter(self, self.preload) | |||
| def __len__(self): | |||
| return len(self.sampler) | |||
| class _BaseMapDataLoaderIter: | |||
| def __init__(self, loader): | |||
| class PreLoader: | |||
| def __init__(self, preload): | |||
| if preload: | |||
| self.default_device = get_default_device() | |||
| self.pre_load_device = self.default_device + ":" + str(_sh.get_next()) | |||
| self.pre_load_device_cache = None | |||
| self.preload = preload | |||
| """ | |||
| strategy one: load from numpy data, and generate dtype tensor | |||
| """ | |||
| def _load_tensor(self, batch, cached=True): | |||
| if isinstance(batch, np.ndarray): | |||
| device = self.pre_load_device if cached else self.default_device | |||
| return Tensor(batch, device=device) | |||
| elif isinstance(batch, collections.abc.Mapping): | |||
| return {k: self._load_tensor(v, cached) for k, v in batch.items()} | |||
| elif isinstance(batch, tuple) and hasattr(batch, "_fields"): # namedtuple | |||
| return type(batch)(*(self._load_tensor(value, cached) for value in batch)) | |||
| elif isinstance(batch, collections.abc.Sequence): | |||
| return [self._load_tensor(value, cached) for value in batch] | |||
| else: | |||
| return batch | |||
| """ | |||
| strategy two: load from cache that is already tensor just do d2d copy | |||
| """ | |||
| def _load_cache(self, data): | |||
| if isinstance(data, Tensor): | |||
| if data.device == self.default_device: | |||
| return data | |||
| return copy(data, device=self.default_device) | |||
| elif isinstance(data, collections.abc.Mapping): | |||
| return {k: self._load_cache(v) for k, v in data.items()} | |||
| elif isinstance(data, tuple) and hasattr(data, "_fields"): # namedtuple | |||
| return type(data)(*(self._load_cache(value) for value in data)) | |||
| elif isinstance(data, collections.abc.Sequence): | |||
| return [self._load_cache(value) for value in data] | |||
| else: | |||
| return data | |||
| def _swap_out_cache(self): | |||
| out = self._load_cache(self.pre_load_device_cache) | |||
| self.pre_load_device_cache = None # clean cache | |||
| return out | |||
| class _BaseMapDataLoaderIter(PreLoader): | |||
| def __init__(self, loader, preload): | |||
| super().__init__(preload) | |||
| self.dataset = loader.dataset | |||
| self.sampler = loader.sampler | |||
| self.seed = _random_seed_generator().__next__() | |||
| @@ -184,16 +241,35 @@ class _BaseMapDataLoaderIter: | |||
| return self | |||
| def __next__(self): | |||
| if self.preload: | |||
| cached = self.pre_load_device_cache | |||
| if cached is None: # first and last | |||
| if self.num_processed >= len(self): # last | |||
| raise StopIteration | |||
| elif self.num_processed == 0: # first | |||
| self._try_load_tensor(cached=False) # first do the h2d | |||
| out = self._swap_out_cache() | |||
| self._try_load_tensor() | |||
| return out | |||
| else: | |||
| if self.num_processed >= len(self): | |||
| raise StopIteration | |||
| minibatch = self._get_next_batch() | |||
| self.num_processed += 1 | |||
| return minibatch | |||
| def _try_load_tensor(self, cached=True): | |||
| if self.num_processed >= len(self): | |||
| raise StopIteration | |||
| minibatch = self._get_next_batch() | |||
| self.num_processed += 1 | |||
| return minibatch | |||
| return | |||
| else: | |||
| self.num_processed += 1 | |||
| batch = self._get_next_batch() | |||
| self.pre_load_device_cache = self._load_tensor(batch, cached) | |||
| class _SerialMapDataLoaderIter(_BaseMapDataLoaderIter): | |||
| def __init__(self, loader): | |||
| super(_SerialMapDataLoaderIter, self).__init__(loader) | |||
| def __init__(self, loader, preload): | |||
| super(_SerialMapDataLoaderIter, self).__init__(loader, preload) | |||
| self.indices_iter = iter(self.sampler) | |||
| def _get_next_batch(self): | |||
| @@ -206,8 +282,8 @@ class _SerialMapDataLoaderIter(_BaseMapDataLoaderIter): | |||
| class _ParallelMapDataLoaderIter(_BaseMapDataLoaderIter): | |||
| __initialized = False | |||
| def __init__(self, loader): | |||
| super(_ParallelMapDataLoaderIter, self).__init__(loader) | |||
| def __init__(self, loader, preload): | |||
| super(_ParallelMapDataLoaderIter, self).__init__(loader, preload) | |||
| self.task_queues = [ | |||
| multiprocessing.Queue(maxsize=2) for _ in range(self.num_workers) | |||
| @@ -358,8 +434,9 @@ class _ParallelMapDataLoaderIter(_BaseMapDataLoaderIter): | |||
| self._shutdown() | |||
| class _BaseStreamDataLoaderIter: | |||
| def __init__(self, loader): | |||
| class _BaseStreamDataLoaderIter(PreLoader): | |||
| def __init__(self, loader, preload): | |||
| super().__init__(preload) | |||
| self.dataset = loader.dataset | |||
| self.sampler = loader.sampler | |||
| self.transform = loader.transform | |||
| @@ -388,12 +465,23 @@ class _BaseStreamDataLoaderIter: | |||
| return self | |||
| def __next__(self): | |||
| return self._get_next_batch() | |||
| if self.preload: | |||
| if self.pre_load_device_cache is None: | |||
| self._try_load_tensor(cached=False) # load in current | |||
| out = self._swap_out_cache() | |||
| self._try_load_tensor() # load in cached | |||
| return out | |||
| else: | |||
| return self._get_next_batch() | |||
| def _try_load_tensor(self, cached=True): | |||
| batch = self._get_next_batch() | |||
| self.pre_load_device_cache = self._load_tensor(batch, cached) | |||
| class _SerialStreamDataLoaderIter(_BaseStreamDataLoaderIter): | |||
| def __init__(self, loader): | |||
| super().__init__(loader) | |||
| def __init__(self, loader, preload): | |||
| super().__init__(loader, preload) | |||
| self.dataset_iter = iter(self.dataset) | |||
| self.idx = 0 | |||
| self.unused = [] | |||
| @@ -439,8 +527,8 @@ class _SerialStreamDataLoaderIter(_BaseStreamDataLoaderIter): | |||
| class _ParallelStreamDataLoaderIter(_BaseStreamDataLoaderIter): | |||
| __initialized = False | |||
| def __init__(self, loader): | |||
| super().__init__(loader) | |||
| def __init__(self, loader, preload): | |||
| super().__init__(loader, preload) | |||
| self.shutdown_flag = multiprocessing.Value("i", 0) | |||
| @@ -29,6 +29,19 @@ __all__ = [ | |||
| ] | |||
| class _stream_helper: | |||
| def __init__(self): | |||
| self.stream = 1 | |||
| def get_next(self): | |||
| out = self.stream | |||
| self.stream = self.stream + 1 | |||
| return out | |||
| _sh = _stream_helper() | |||
| def _valid_device(inp): | |||
| if isinstance(inp, str) and re.match("^([cxg]pu|rocm)(\d+|\d+:\d+|x)$", inp): | |||
| return True | |||
| @@ -12,7 +12,7 @@ from typing import List, Optional, Tuple | |||
| from mprop import mproperty | |||
| from ..device import set_default_device, what_is_xpu | |||
| from ..device import _sh, set_default_device, what_is_xpu | |||
| from ..random import seed | |||
| from .server import Client, Server | |||
| @@ -27,7 +27,6 @@ class StaticData: | |||
| proc_rank = None | |||
| device = None | |||
| backend = None | |||
| next_stream = None | |||
| device_type = None | |||
| machine_ranks = None | |||
| @@ -43,6 +42,8 @@ class Group: | |||
| Args: | |||
| proc_ranks: rank list of the group, the first one is root rank. | |||
| """ | |||
| def __init__(self, proc_ranks): | |||
| @@ -55,9 +56,7 @@ class Group: | |||
| def reset(self, proc_ranks): | |||
| self.check(proc_ranks) | |||
| self.proc_ranks = proc_ranks | |||
| self.stream = _sd.next_stream | |||
| _sd.next_stream += 1 | |||
| self.is_single_machine_cache = None | |||
| self.stream = _sh.get_next() | |||
| def check(self, proc_ranks): | |||
| assert _sd is not None, "please call init_process_group first" | |||
| @@ -160,7 +159,6 @@ def init_process_group( | |||
| _sd.proc_rank = rank | |||
| _sd.device = device | |||
| _sd.backend = backend | |||
| _sd.next_stream = 1 | |||
| _sd.device_type = device_type | |||
| WORLD.reset(list(range(world_size))) | |||
| @@ -0,0 +1,308 @@ | |||
| # -*- coding: utf-8 -*- | |||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| # | |||
| # Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, | |||
| # software distributed under the License is distributed on an | |||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| import gc | |||
| import os | |||
| import platform | |||
| import time | |||
| import numpy as np | |||
| import pytest | |||
| from megengine.data.collator import Collator | |||
| from megengine.data.dataloader import DataLoader | |||
| from megengine.data.dataset import ArrayDataset, StreamDataset | |||
| from megengine.data.sampler import RandomSampler, SequentialSampler, StreamSampler | |||
| from megengine.data.transform import ( | |||
| Compose, | |||
| Normalize, | |||
| PseudoTransform, | |||
| ToMode, | |||
| Transform, | |||
| ) | |||
| def init_dataset(): | |||
| sample_num = 100 | |||
| rand_data = np.random.randint(0, 255, size=(sample_num, 1, 32, 32), dtype=np.uint8) | |||
| label = np.random.randint(0, 10, size=(sample_num,), dtype=int) | |||
| dataset = ArrayDataset(rand_data, label) | |||
| return dataset | |||
| def test_dataloader_init(): | |||
| dataset = init_dataset() | |||
| with pytest.raises(ValueError): | |||
| dataloader = DataLoader(dataset, num_workers=2, divide=True) | |||
| with pytest.raises(ValueError): | |||
| dataloader = DataLoader(dataset, num_workers=-1) | |||
| with pytest.raises(ValueError): | |||
| dataloader = DataLoader(dataset, timeout=-1) | |||
| with pytest.raises(ValueError): | |||
| dataloader = DataLoader(dataset, num_workers=0, divide=True) | |||
| dataloader = DataLoader(dataset, preload=True) | |||
| assert isinstance(dataloader.sampler, SequentialSampler) | |||
| assert isinstance(dataloader.transform, PseudoTransform) | |||
| assert isinstance(dataloader.collator, Collator) | |||
| dataloader = DataLoader( | |||
| dataset, | |||
| sampler=RandomSampler(dataset, batch_size=6, drop_last=False), | |||
| preload=True, | |||
| ) | |||
| assert len(dataloader) == 17 | |||
| dataloader = DataLoader( | |||
| dataset, | |||
| sampler=RandomSampler(dataset, batch_size=6, drop_last=True), | |||
| preload=True, | |||
| ) | |||
| assert len(dataloader) == 16 | |||
| class MyStream(StreamDataset): | |||
| def __init__(self, number, batch=False, error_foramt=False, block=False): | |||
| self.number = number | |||
| self.batch = batch | |||
| self.error_format = error_foramt | |||
| self.block = block | |||
| def __iter__(self): | |||
| for cnt in range(self.number): | |||
| if self.block: | |||
| for _ in range(10): | |||
| time.sleep(1) | |||
| if self.batch: | |||
| data = np.random.randint(0, 256, (2, 2, 2, 3), dtype="uint8") | |||
| yield (True, (data, [cnt, cnt - self.number])) | |||
| else: | |||
| data = np.random.randint(0, 256, (2, 2, 3), dtype="uint8") | |||
| if self.error_format: | |||
| yield (data, cnt) | |||
| else: | |||
| yield (False, (data, cnt)) | |||
| raise StopIteration | |||
| @pytest.mark.parametrize("batch", [True, False]) | |||
| @pytest.mark.parametrize("num_workers", [0, 2]) | |||
| def test_stream_dataloader(batch, num_workers): | |||
| dataset = MyStream(100, batch=batch) | |||
| sampler = StreamSampler(batch_size=4) | |||
| dataloader = DataLoader( | |||
| dataset, | |||
| sampler, | |||
| Compose([Normalize(mean=(103, 116, 123), std=(57, 57, 58)), ToMode("CHW")]), | |||
| num_workers=num_workers, | |||
| preload=True, | |||
| ) | |||
| check_set = set() | |||
| for step, data in enumerate(dataloader): | |||
| if step == 10: | |||
| break | |||
| assert data[0]._tuple_shape == (4, 3, 2, 2) | |||
| assert data[1]._tuple_shape == (4,) | |||
| for i in data[1]: | |||
| assert i not in check_set | |||
| check_set.add(i) | |||
| def test_stream_dataloader_error(): | |||
| dataset = MyStream(100, error_foramt=True) | |||
| sampler = StreamSampler(batch_size=4) | |||
| dataloader = DataLoader(dataset, sampler, preload=True) | |||
| with pytest.raises(AssertionError, match=r".*tuple.*"): | |||
| data_iter = iter(dataloader) | |||
| next(data_iter) | |||
| @pytest.mark.parametrize("num_workers", [0, 2]) | |||
| def test_stream_dataloader_timeout(num_workers): | |||
| dataset = MyStream(100, False, block=True) | |||
| sampler = StreamSampler(batch_size=4) | |||
| dataloader = DataLoader( | |||
| dataset, sampler, num_workers=num_workers, timeout=2, preload=True | |||
| ) | |||
| with pytest.raises(RuntimeError, match=r".*timeout.*"): | |||
| data_iter = iter(dataloader) | |||
| next(data_iter) | |||
| def test_dataloader_serial(): | |||
| dataset = init_dataset() | |||
| dataloader = DataLoader( | |||
| dataset, | |||
| sampler=RandomSampler(dataset, batch_size=4, drop_last=False), | |||
| preload=True, | |||
| ) | |||
| for (data, label) in dataloader: | |||
| assert data._tuple_shape == (4, 1, 32, 32) | |||
| assert label._tuple_shape == (4,) | |||
| def test_dataloader_parallel(): | |||
| # set max shared memory to 100M | |||
| os.environ["MGE_PLASMA_MEMORY"] = "100000000" | |||
| dataset = init_dataset() | |||
| dataloader = DataLoader( | |||
| dataset, | |||
| sampler=RandomSampler(dataset, batch_size=4, drop_last=False), | |||
| num_workers=2, | |||
| divide=False, | |||
| preload=True, | |||
| ) | |||
| for (data, label) in dataloader: | |||
| assert data._tuple_shape == (4, 1, 32, 32) | |||
| assert label._tuple_shape == (4,) | |||
| dataloader = DataLoader( | |||
| dataset, | |||
| sampler=RandomSampler(dataset, batch_size=4, drop_last=False), | |||
| num_workers=2, | |||
| divide=True, | |||
| preload=True, | |||
| ) | |||
| for (data, label) in dataloader: | |||
| assert data._tuple_shape == (4, 1, 32, 32) | |||
| assert label._tuple_shape == (4,) | |||
| @pytest.mark.skipif( | |||
| platform.system() == "Windows", | |||
| reason="dataloader do not support parallel on windows", | |||
| ) | |||
| def test_dataloader_parallel_timeout(): | |||
| dataset = init_dataset() | |||
| class TimeoutTransform(Transform): | |||
| def __init__(self): | |||
| pass | |||
| def apply(self, input): | |||
| time.sleep(10) | |||
| return input | |||
| dataloader = DataLoader( | |||
| dataset, | |||
| sampler=RandomSampler(dataset, batch_size=4, drop_last=False), | |||
| transform=TimeoutTransform(), | |||
| num_workers=2, | |||
| timeout=2, | |||
| preload=True, | |||
| ) | |||
| with pytest.raises(RuntimeError, match=r".*timeout.*"): | |||
| data_iter = iter(dataloader) | |||
| batch_data = next(data_iter) | |||
| @pytest.mark.skipif( | |||
| platform.system() == "Windows", | |||
| reason="dataloader do not support parallel on windows", | |||
| ) | |||
| def test_dataloader_parallel_worker_exception(): | |||
| print("in target") | |||
| dataset = init_dataset() | |||
| class FakeErrorTransform(Transform): | |||
| def __init__(self): | |||
| pass | |||
| def apply(self, input): | |||
| y = x + 1 | |||
| return input | |||
| dataloader = DataLoader( | |||
| dataset, | |||
| sampler=RandomSampler(dataset, batch_size=4, drop_last=False), | |||
| transform=FakeErrorTransform(), | |||
| num_workers=2, | |||
| preload=True, | |||
| ) | |||
| with pytest.raises(RuntimeError, match=r"worker.*died"): | |||
| data_iter = iter(dataloader) | |||
| batch_data = next(data_iter) | |||
| def _multi_instances_parallel_dataloader_worker(): | |||
| dataset = init_dataset() | |||
| for divide_flag in [True, False]: | |||
| train_dataloader = DataLoader( | |||
| dataset, | |||
| sampler=RandomSampler(dataset, batch_size=4, drop_last=False), | |||
| num_workers=2, | |||
| divide=divide_flag, | |||
| preload=True, | |||
| ) | |||
| val_dataloader = DataLoader( | |||
| dataset, | |||
| sampler=RandomSampler(dataset, batch_size=10, drop_last=False), | |||
| num_workers=2, | |||
| divide=divide_flag, | |||
| preload=True, | |||
| ) | |||
| for idx, (data, label) in enumerate(train_dataloader): | |||
| assert data._tuple_shape == (4, 1, 32, 32) | |||
| assert label._tuple_shape == (4,) | |||
| if idx % 5 == 0: | |||
| for val_data, val_label in val_dataloader: | |||
| assert val_data._tuple_shape == (10, 1, 32, 32) | |||
| assert val_label._tuple_shape == (10,) | |||
| def test_dataloader_parallel_multi_instances(): | |||
| # set max shared memory to 100M | |||
| os.environ["MGE_PLASMA_MEMORY"] = "100000000" | |||
| _multi_instances_parallel_dataloader_worker() | |||
| @pytest.mark.isolated_distributed | |||
| def test_dataloader_parallel_multi_instances_multiprocessing(): | |||
| gc.collect() | |||
| # set max shared memory to 100M | |||
| os.environ["MGE_PLASMA_MEMORY"] = "100000000" | |||
| import multiprocessing as mp | |||
| # mp.set_start_method("spawn") | |||
| processes = [] | |||
| for i in range(4): | |||
| p = mp.Process(target=_multi_instances_parallel_dataloader_worker) | |||
| p.start() | |||
| processes.append(p) | |||
| for p in processes: | |||
| p.join() | |||
| assert p.exitcode == 0 | |||
| @pytest.mark.parametrize("num_workers", [0, 2]) | |||
| def test_timeout_event(num_workers): | |||
| def cb(): | |||
| return (True, (np.zeros(shape=(2, 2, 2, 3)), np.ones(shape=(2,)))) | |||
| dataset = MyStream(100, block=True) | |||
| sampler = StreamSampler(batch_size=4) | |||
| dataloader = DataLoader( | |||
| dataset, | |||
| sampler, | |||
| num_workers=num_workers, | |||
| timeout=2, | |||
| timeout_event=cb, | |||
| preload=True, | |||
| ) | |||
| for _, data in enumerate(dataloader): | |||
| np.testing.assert_equal(data[0], np.zeros(shape=(4, 2, 2, 3))) | |||
| np.testing.assert_equal(data[1], np.ones(shape=(4,))) | |||
| break | |||