Browse Source

perf(imperative/data): improve dataloader preformance

GitOrigin-RevId: 7d8d52aaeb
tags/v1.11.1
Megvii Engine Team 3 years ago
parent
commit
edc92ccfd6
4 changed files with 524 additions and 732 deletions
  1. +363
    -563
      imperative/python/megengine/data/dataloader.py
  2. +20
    -9
      imperative/python/megengine/data/sampler.py
  3. +106
    -73
      imperative/python/test/unit/data/test_dataloader.py
  4. +35
    -87
      imperative/python/test/unit/data/test_pre_dataloader.py

+ 363
- 563
imperative/python/megengine/data/dataloader.py
File diff suppressed because it is too large
View File


+ 20
- 9
imperative/python/megengine/data/sampler.py View File

@@ -2,6 +2,7 @@
import collections.abc import collections.abc
import math import math
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from itertools import count
from typing import Any, Generator, Iterator, List, Union from typing import Any, Generator, Iterator, List, Union


import numpy as np import numpy as np
@@ -126,13 +127,15 @@ class MapSampler(Sampler):
if self.world_size > 1: if self.world_size > 1:
indices = self.scatter(indices) indices = self.scatter(indices)


step, length = self.batch_size, len(indices)
batch_index = [indices[i : i + step] for i in range(0, length, step)]
batch = []
for idx in indices:
batch.append(idx)
if len(batch) == self.batch_size:
yield batch
batch = []


if self.drop_last and len(batch_index[-1]) < self.batch_size:
batch_index.pop()

return iter(batch_index)
if len(batch) > 0 and not self.drop_last:
yield batch




class StreamSampler(Sampler): class StreamSampler(Sampler):
@@ -151,10 +154,18 @@ class StreamSampler(Sampler):
self.batch_size = batch_size self.batch_size = batch_size


def __iter__(self): def __iter__(self):
return self
return self.batch()


def __next__(self):
return iter(range(self.batch_size))
def batch(self):
batch = []
for idx in self.sample():
batch.append(idx)
if len(batch) == self.batch_size:
yield batch
batch = []

def sample(self):
return count(start=0)




class SequentialSampler(MapSampler): class SequentialSampler(MapSampler):


+ 106
- 73
imperative/python/test/unit/data/test_dataloader.py View File

@@ -1,4 +1,12 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import math
import os import os
import platform import platform
import time import time
@@ -7,7 +15,7 @@ import numpy as np
import pytest import pytest


from megengine.data.collator import Collator from megengine.data.collator import Collator
from megengine.data.dataloader import DataLoader
from megengine.data.dataloader import DataLoader, get_worker_info
from megengine.data.dataset import ArrayDataset, StreamDataset from megengine.data.dataset import ArrayDataset, StreamDataset
from megengine.data.sampler import RandomSampler, SequentialSampler, StreamSampler from megengine.data.sampler import RandomSampler, SequentialSampler, StreamSampler
from megengine.data.transform import ( from megengine.data.transform import (
@@ -29,14 +37,10 @@ def init_dataset():


def test_dataloader_init(): def test_dataloader_init():
dataset = init_dataset() dataset = init_dataset()
with pytest.raises(ValueError):
dataloader = DataLoader(dataset, num_workers=2, divide=True)
with pytest.raises(ValueError): with pytest.raises(ValueError):
dataloader = DataLoader(dataset, num_workers=-1) dataloader = DataLoader(dataset, num_workers=-1)
with pytest.raises(ValueError): with pytest.raises(ValueError):
dataloader = DataLoader(dataset, timeout=-1) dataloader = DataLoader(dataset, timeout=-1)
with pytest.raises(ValueError):
dataloader = DataLoader(dataset, num_workers=0, divide=True)


dataloader = DataLoader(dataset) dataloader = DataLoader(dataset)
assert isinstance(dataloader.sampler, SequentialSampler) assert isinstance(dataloader.sampler, SequentialSampler)
@@ -54,10 +58,8 @@ def test_dataloader_init():




class MyStream(StreamDataset): class MyStream(StreamDataset):
def __init__(self, number, batch=False, error_foramt=False, block=False):
def __init__(self, number, block=False):
self.number = number self.number = number
self.batch = batch
self.error_format = error_foramt
self.block = block self.block = block


def __iter__(self): def __iter__(self):
@@ -65,22 +67,14 @@ class MyStream(StreamDataset):
if self.block: if self.block:
for _ in range(10): for _ in range(10):
time.sleep(1) time.sleep(1)
if self.batch:
data = np.random.randint(0, 256, (2, 2, 2, 3), dtype="uint8")
yield (True, (data, [cnt, cnt - self.number]))
else:
data = np.random.randint(0, 256, (2, 2, 3), dtype="uint8")
if self.error_format:
yield (data, cnt)
else:
yield (False, (data, cnt))
data = np.random.randint(0, 256, (2, 2, 3), dtype="uint8")
yield (data, cnt)
raise StopIteration raise StopIteration




@pytest.mark.parametrize("batch", [True, False])
@pytest.mark.parametrize("num_workers", [0, 2]) @pytest.mark.parametrize("num_workers", [0, 2])
def test_stream_dataloader(batch, num_workers):
dataset = MyStream(100, batch=batch)
def test_stream_dataloader(num_workers):
dataset = MyStream(100)
sampler = StreamSampler(batch_size=4) sampler = StreamSampler(batch_size=4)
dataloader = DataLoader( dataloader = DataLoader(
dataset, dataset,
@@ -90,7 +84,6 @@ def test_stream_dataloader(batch, num_workers):
) )


check_set = set() check_set = set()

for step, data in enumerate(dataloader): for step, data in enumerate(dataloader):
if step == 10: if step == 10:
break break
@@ -101,18 +94,9 @@ def test_stream_dataloader(batch, num_workers):
check_set.add(i) check_set.add(i)




def test_stream_dataloader_error():
dataset = MyStream(100, error_foramt=True)
sampler = StreamSampler(batch_size=4)
dataloader = DataLoader(dataset, sampler)
with pytest.raises(AssertionError, match=r".*tuple.*"):
data_iter = iter(dataloader)
next(data_iter)


@pytest.mark.parametrize("num_workers", [0, 2]) @pytest.mark.parametrize("num_workers", [0, 2])
def test_stream_dataloader_timeout(num_workers): def test_stream_dataloader_timeout(num_workers):
dataset = MyStream(100, False, block=True)
dataset = MyStream(100, block=True)
sampler = StreamSampler(batch_size=4) sampler = StreamSampler(batch_size=4)


dataloader = DataLoader(dataset, sampler, num_workers=num_workers, timeout=2) dataloader = DataLoader(dataset, sampler, num_workers=num_workers, timeout=2)
@@ -140,17 +124,6 @@ def test_dataloader_parallel():
dataset, dataset,
sampler=RandomSampler(dataset, batch_size=4, drop_last=False), sampler=RandomSampler(dataset, batch_size=4, drop_last=False),
num_workers=2, num_workers=2,
divide=False,
)
for (data, label) in dataloader:
assert data.shape == (4, 1, 32, 32)
assert label.shape == (4,)

dataloader = DataLoader(
dataset,
sampler=RandomSampler(dataset, batch_size=4, drop_last=False),
num_workers=2,
divide=True,
) )
for (data, label) in dataloader: for (data, label) in dataloader:
assert data.shape == (4, 1, 32, 32) assert data.shape == (4, 1, 32, 32)
@@ -205,7 +178,7 @@ def test_dataloader_parallel_worker_exception():
transform=FakeErrorTransform(), transform=FakeErrorTransform(),
num_workers=2, num_workers=2,
) )
with pytest.raises(RuntimeError, match=r"worker.*died"):
with pytest.raises(RuntimeError, match=r"exited unexpectedly"):
data_iter = iter(dataloader) data_iter = iter(dataloader)
batch_data = next(data_iter) batch_data = next(data_iter)


@@ -213,26 +186,23 @@ def test_dataloader_parallel_worker_exception():
def _multi_instances_parallel_dataloader_worker(): def _multi_instances_parallel_dataloader_worker():
dataset = init_dataset() dataset = init_dataset()


for divide_flag in [True, False]:
train_dataloader = DataLoader(
dataset,
sampler=RandomSampler(dataset, batch_size=4, drop_last=False),
num_workers=2,
divide=divide_flag,
)
val_dataloader = DataLoader(
dataset,
sampler=RandomSampler(dataset, batch_size=10, drop_last=False),
num_workers=2,
divide=divide_flag,
)
for idx, (data, label) in enumerate(train_dataloader):
assert data.shape == (4, 1, 32, 32)
assert label.shape == (4,)
if idx % 5 == 0:
for val_data, val_label in val_dataloader:
assert val_data.shape == (10, 1, 32, 32)
assert val_label.shape == (10,)
train_dataloader = DataLoader(
dataset,
sampler=RandomSampler(dataset, batch_size=4, drop_last=False),
num_workers=2,
)
val_dataloader = DataLoader(
dataset,
sampler=RandomSampler(dataset, batch_size=10, drop_last=False),
num_workers=2,
)
for idx, (data, label) in enumerate(train_dataloader):
assert data.shape == (4, 1, 32, 32)
assert label.shape == (4,)
if idx % 5 == 0:
for val_data, val_label in val_dataloader:
assert val_data.shape == (10, 1, 32, 32)
assert val_label.shape == (10,)




def test_dataloader_parallel_multi_instances(): def test_dataloader_parallel_multi_instances():
@@ -261,18 +231,81 @@ def test_dataloader_parallel_multi_instances_multiprocessing():
assert p.exitcode == 0 assert p.exitcode == 0




@pytest.mark.parametrize("num_workers", [0, 2])
def test_timeout_event(num_workers):
def cb():
return (True, (np.zeros(shape=(2, 2, 2, 3)), np.ones(shape=(2,))))
def partition(ls, size):
return [ls[i : i + size] for i in range(0, len(ls), size)]


dataset = MyStream(100, block=True)

class MyPreStream(StreamDataset):
def __init__(self, number, block=False):
self.number = [i for i in range(number)]
self.block = block
self.data = []
for i in range(100):
self.data.append(np.random.randint(0, 256, (2, 2, 3), dtype="uint8"))

def __iter__(self):
worker_info = get_worker_info()
per_worker = int(math.ceil((len(self.data)) / float(worker_info.worker)))
pre_data = iter(partition(self.data, per_worker)[worker_info.idx])
pre_cnt = partition(self.number, per_worker)[worker_info.idx]
for cnt in pre_cnt:
if self.block:
for _ in range(10):
time.sleep(1)
yield (next(pre_data), cnt)
raise StopIteration


@pytest.mark.skipif(
platform.system() == "Windows",
reason="dataloader do not support parallel on windows",
)
def test_prestream_dataloader_multiprocessing():
dataset = MyPreStream(100)
sampler = StreamSampler(batch_size=4) sampler = StreamSampler(batch_size=4)
dataloader = DataLoader(
dataset,
sampler,
Compose([Normalize(mean=(103, 116, 123), std=(57, 57, 58)), ToMode("CHW")]),
num_workers=2,
parallel_stream=True,
)

check_set = set()

for step, data in enumerate(dataloader):
if step == 10:
break
assert data[0].shape == (4, 3, 2, 2)
assert data[1].shape == (4,)
for i in data[1]:
assert i not in check_set
check_set.add(i)


@pytest.mark.skipif(
platform.system() == "Windows",
reason="dataloader do not support parallel on windows",
)
def test_predataloader_parallel_worker_exception():
dataset = MyPreStream(100)

class FakeErrorTransform(Transform):
def __init__(self):
pass

def apply(self, input):
raise RuntimeError("test raise error")
return input


dataloader = DataLoader( dataloader = DataLoader(
dataset, sampler, num_workers=num_workers, timeout=2, timeout_event=cb
dataset,
sampler=StreamSampler(batch_size=4),
transform=FakeErrorTransform(),
num_workers=2,
parallel_stream=True,
) )
for _, data in enumerate(dataloader):
np.testing.assert_equal(data[0], np.zeros(shape=(4, 2, 2, 3)))
np.testing.assert_equal(data[1], np.ones(shape=(4,)))
break
with pytest.raises(RuntimeError, match=r"exited unexpectedly"):
data_iter = iter(dataloader)
batch_data = next(data_iter)
print(batch_data.shape)

+ 35
- 87
imperative/python/test/unit/data/test_pre_dataloader.py View File

@@ -1,5 +1,13 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import gc import gc
import math
import os import os
import platform import platform
import time import time
@@ -8,7 +16,7 @@ import numpy as np
import pytest import pytest


from megengine.data.collator import Collator from megengine.data.collator import Collator
from megengine.data.dataloader import DataLoader
from megengine.data.dataloader import DataLoader, get_worker_info
from megengine.data.dataset import ArrayDataset, StreamDataset from megengine.data.dataset import ArrayDataset, StreamDataset
from megengine.data.sampler import RandomSampler, SequentialSampler, StreamSampler from megengine.data.sampler import RandomSampler, SequentialSampler, StreamSampler
from megengine.data.transform import ( from megengine.data.transform import (
@@ -30,14 +38,10 @@ def init_dataset():


def test_dataloader_init(): def test_dataloader_init():
dataset = init_dataset() dataset = init_dataset()
with pytest.raises(ValueError):
dataloader = DataLoader(dataset, num_workers=2, divide=True)
with pytest.raises(ValueError): with pytest.raises(ValueError):
dataloader = DataLoader(dataset, num_workers=-1) dataloader = DataLoader(dataset, num_workers=-1)
with pytest.raises(ValueError): with pytest.raises(ValueError):
dataloader = DataLoader(dataset, timeout=-1) dataloader = DataLoader(dataset, timeout=-1)
with pytest.raises(ValueError):
dataloader = DataLoader(dataset, num_workers=0, divide=True)


dataloader = DataLoader(dataset, preload=True) dataloader = DataLoader(dataset, preload=True)
assert isinstance(dataloader.sampler, SequentialSampler) assert isinstance(dataloader.sampler, SequentialSampler)
@@ -59,10 +63,8 @@ def test_dataloader_init():




class MyStream(StreamDataset): class MyStream(StreamDataset):
def __init__(self, number, batch=False, error_foramt=False, block=False):
def __init__(self, number, block=False):
self.number = number self.number = number
self.batch = batch
self.error_format = error_foramt
self.block = block self.block = block


def __iter__(self): def __iter__(self):
@@ -70,22 +72,14 @@ class MyStream(StreamDataset):
if self.block: if self.block:
for _ in range(10): for _ in range(10):
time.sleep(1) time.sleep(1)
if self.batch:
data = np.random.randint(0, 256, (2, 2, 2, 3), dtype="uint8")
yield (True, (data, [cnt, cnt - self.number]))
else:
data = np.random.randint(0, 256, (2, 2, 3), dtype="uint8")
if self.error_format:
yield (data, cnt)
else:
yield (False, (data, cnt))
data = np.random.randint(0, 256, (2, 2, 3), dtype="uint8")
yield (data, cnt)
raise StopIteration raise StopIteration




@pytest.mark.parametrize("batch", [True, False])
@pytest.mark.parametrize("num_workers", [0, 2]) @pytest.mark.parametrize("num_workers", [0, 2])
def test_stream_dataloader(batch, num_workers):
dataset = MyStream(100, batch=batch)
def test_stream_dataloader(num_workers):
dataset = MyStream(100)
sampler = StreamSampler(batch_size=4) sampler = StreamSampler(batch_size=4)
dataloader = DataLoader( dataloader = DataLoader(
dataset, dataset,
@@ -107,18 +101,9 @@ def test_stream_dataloader(batch, num_workers):
check_set.add(i) check_set.add(i)




def test_stream_dataloader_error():
dataset = MyStream(100, error_foramt=True)
sampler = StreamSampler(batch_size=4)
dataloader = DataLoader(dataset, sampler, preload=True)
with pytest.raises(AssertionError, match=r".*tuple.*"):
data_iter = iter(dataloader)
next(data_iter)


@pytest.mark.parametrize("num_workers", [0, 2]) @pytest.mark.parametrize("num_workers", [0, 2])
def test_stream_dataloader_timeout(num_workers): def test_stream_dataloader_timeout(num_workers):
dataset = MyStream(100, False, block=True)
dataset = MyStream(100, block=True)
sampler = StreamSampler(batch_size=4) sampler = StreamSampler(batch_size=4)


dataloader = DataLoader( dataloader = DataLoader(
@@ -150,18 +135,6 @@ def test_dataloader_parallel():
dataset, dataset,
sampler=RandomSampler(dataset, batch_size=4, drop_last=False), sampler=RandomSampler(dataset, batch_size=4, drop_last=False),
num_workers=2, num_workers=2,
divide=False,
preload=True,
)
for (data, label) in dataloader:
assert data._tuple_shape == (4, 1, 32, 32)
assert label._tuple_shape == (4,)

dataloader = DataLoader(
dataset,
sampler=RandomSampler(dataset, batch_size=4, drop_last=False),
num_workers=2,
divide=True,
preload=True, preload=True,
) )
for (data, label) in dataloader: for (data, label) in dataloader:
@@ -219,7 +192,7 @@ def test_dataloader_parallel_worker_exception():
num_workers=2, num_workers=2,
preload=True, preload=True,
) )
with pytest.raises(RuntimeError, match=r"worker.*died"):
with pytest.raises(RuntimeError, match=r"exited unexpectedly"):
data_iter = iter(dataloader) data_iter = iter(dataloader)
batch_data = next(data_iter) batch_data = next(data_iter)


@@ -227,28 +200,25 @@ def test_dataloader_parallel_worker_exception():
def _multi_instances_parallel_dataloader_worker(): def _multi_instances_parallel_dataloader_worker():
dataset = init_dataset() dataset = init_dataset()


for divide_flag in [True, False]:
train_dataloader = DataLoader(
dataset,
sampler=RandomSampler(dataset, batch_size=4, drop_last=False),
num_workers=2,
divide=divide_flag,
preload=True,
)
val_dataloader = DataLoader(
dataset,
sampler=RandomSampler(dataset, batch_size=10, drop_last=False),
num_workers=2,
divide=divide_flag,
preload=True,
)
for idx, (data, label) in enumerate(train_dataloader):
assert data._tuple_shape == (4, 1, 32, 32)
assert label._tuple_shape == (4,)
if idx % 5 == 0:
for val_data, val_label in val_dataloader:
assert val_data._tuple_shape == (10, 1, 32, 32)
assert val_label._tuple_shape == (10,)
train_dataloader = DataLoader(
dataset,
sampler=RandomSampler(dataset, batch_size=4, drop_last=False),
num_workers=2,
preload=True,
)
val_dataloader = DataLoader(
dataset,
sampler=RandomSampler(dataset, batch_size=10, drop_last=False),
num_workers=2,
preload=True,
)
for idx, (data, label) in enumerate(train_dataloader):
assert data._tuple_shape == (4, 1, 32, 32)
assert label._tuple_shape == (4,)
if idx % 5 == 0:
for val_data, val_label in val_dataloader:
assert val_data._tuple_shape == (10, 1, 32, 32)
assert val_label._tuple_shape == (10,)




def test_dataloader_parallel_multi_instances(): def test_dataloader_parallel_multi_instances():
@@ -276,25 +246,3 @@ def test_dataloader_parallel_multi_instances_multiprocessing():
for p in processes: for p in processes:
p.join() p.join()
assert p.exitcode == 0 assert p.exitcode == 0


@pytest.mark.parametrize("num_workers", [0, 2])
def test_timeout_event(num_workers):
def cb():
return (True, (np.zeros(shape=(2, 2, 2, 3)), np.ones(shape=(2,))))

dataset = MyStream(100, block=True)
sampler = StreamSampler(batch_size=4)

dataloader = DataLoader(
dataset,
sampler,
num_workers=num_workers,
timeout=2,
timeout_event=cb,
preload=True,
)
for _, data in enumerate(dataloader):
np.testing.assert_equal(data[0], np.zeros(shape=(4, 2, 2, 3)))
np.testing.assert_equal(data[1], np.ones(shape=(4,)))
break

Loading…
Cancel
Save