GitOrigin-RevId: 6262561355
tags/v1.2.0
| @@ -19,7 +19,7 @@ import numpy as np | |||||
| from ..logger import get_logger | from ..logger import get_logger | ||||
| from ..random.rng import _random_seed_generator | from ..random.rng import _random_seed_generator | ||||
| from .collator import Collator | from .collator import Collator | ||||
| from .dataset import Dataset, MapDataset, StreamDataset | |||||
| from .dataset import Dataset, StreamDataset | |||||
| from .sampler import MapSampler, Sampler, SequentialSampler, StreamSampler | from .sampler import MapSampler, Sampler, SequentialSampler, StreamSampler | ||||
| from .transform import PseudoTransform, Transform | from .transform import PseudoTransform, Transform | ||||
| @@ -88,7 +88,15 @@ class DataLoader: | |||||
| self.divide = divide | self.divide = divide | ||||
| if isinstance(dataset, MapDataset): | |||||
| if isinstance(dataset, StreamDataset): | |||||
| self.sampler = sampler if sampler else StreamSampler(batch_size=1) | |||||
| assert isinstance( | |||||
| self.sampler, StreamSampler | |||||
| ), "types of dataset and sampler do not match" | |||||
| else: | |||||
| assert isinstance( | |||||
| dataset, Dataset | |||||
| ), "Can not recognize this kind of dataset: %s" % type(dataset) | |||||
| self.sampler = ( | self.sampler = ( | ||||
| sampler | sampler | ||||
| if sampler | if sampler | ||||
| @@ -97,15 +105,6 @@ class DataLoader: | |||||
| assert isinstance( | assert isinstance( | ||||
| self.sampler, MapSampler | self.sampler, MapSampler | ||||
| ), "types of dataset and sampler do not match" | ), "types of dataset and sampler do not match" | ||||
| elif isinstance(dataset, StreamDataset): | |||||
| self.sampler = sampler if sampler else StreamSampler(batch_size=1) | |||||
| assert isinstance( | |||||
| self.sampler, StreamSampler | |||||
| ), "types of dataset and sampler do not match" | |||||
| else: | |||||
| raise TypeError( | |||||
| "can not recognize this kind of dataset: %s" % type(dataset) | |||||
| ) | |||||
| if divide: | if divide: | ||||
| if self.sampler.batch_size <= self.num_workers: | if self.sampler.batch_size <= self.num_workers: | ||||
| @@ -140,15 +139,14 @@ class DataLoader: | |||||
| return _SerialStreamDataLoaderIter(self) | return _SerialStreamDataLoaderIter(self) | ||||
| else: | else: | ||||
| return _ParallelStreamDataLoaderIter(self) | return _ParallelStreamDataLoaderIter(self) | ||||
| elif isinstance(self.dataset, MapDataset): | |||||
| else: | |||||
| assert isinstance( | |||||
| self.dataset, Dataset | |||||
| ), "Can not recognize this kind of dataset: %s" % type(self.dataset) | |||||
| if not self.num_workers: | if not self.num_workers: | ||||
| return _SerialMapDataLoaderIter(self) | return _SerialMapDataLoaderIter(self) | ||||
| else: | else: | ||||
| return _ParallelMapDataLoaderIter(self) | return _ParallelMapDataLoaderIter(self) | ||||
| else: | |||||
| raise TypeError( | |||||
| "can not recognize this kind of dataset: %s" % type(self.dataset) | |||||
| ) | |||||
| def __len__(self): | def __len__(self): | ||||
| return len(self.sampler) | return len(self.sampler) | ||||
| @@ -6,5 +6,5 @@ | |||||
| # Unless required by applicable law or agreed to in writing, | # Unless required by applicable law or agreed to in writing, | ||||
| # software distributed under the License is distributed on an | # software distributed under the License is distributed on an | ||||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| from .meta_dataset import ArrayDataset, Dataset, MapDataset, StreamDataset | |||||
| from .meta_dataset import ArrayDataset, Dataset, StreamDataset | |||||
| from .vision import * | from .vision import * | ||||
| @@ -12,17 +12,7 @@ from typing import Tuple | |||||
| class Dataset(ABC): | class Dataset(ABC): | ||||
| r""" | r""" | ||||
| An abstract class for all Datasets. | |||||
| """ | |||||
| @abstractmethod | |||||
| def __init__(self): | |||||
| pass | |||||
| class MapDataset(Dataset): | |||||
| r""" | |||||
| An abstract class for map data. | |||||
| An abstract class for all datasets. | |||||
| __getitem__ and __len__ method are aditionally needed. | __getitem__ and __len__ method are aditionally needed. | ||||
| """ | """ | ||||
| @@ -53,8 +43,14 @@ class StreamDataset(Dataset): | |||||
| def __iter__(self): | def __iter__(self): | ||||
| pass | pass | ||||
| def __getitem__(self): | |||||
| raise AssertionError("can not get item from StreamDataset by index") | |||||
| def __len__(self): | |||||
| raise AssertionError("StreamDataset does not have length") | |||||
| class ArrayDataset(MapDataset): | |||||
| class ArrayDataset(Dataset): | |||||
| def __init__(self, *arrays): | def __init__(self, *arrays): | ||||
| r""" | r""" | ||||
| ArrayDataset is a dataset for numpy array data, one or more numpy arrays | ArrayDataset is a dataset for numpy array data, one or more numpy arrays | ||||
| @@ -9,10 +9,10 @@ | |||||
| import collections.abc | import collections.abc | ||||
| import os | import os | ||||
| from ..meta_dataset import MapDataset | |||||
| from ..meta_dataset import Dataset | |||||
| class VisionDataset(MapDataset): | |||||
| class VisionDataset(Dataset): | |||||
| _repr_indent = 4 | _repr_indent = 4 | ||||
| def __init__(self, root, *, order=None, supported_order=None): | def __init__(self, root, *, order=None, supported_order=None): | ||||
| @@ -12,14 +12,12 @@ import sys | |||||
| import numpy as np | import numpy as np | ||||
| import pytest | import pytest | ||||
| from megengine.data.dataset import ArrayDataset, Dataset, MapDataset, StreamDataset | |||||
| from megengine.data.dataset import ArrayDataset, Dataset, StreamDataset | |||||
| def test_abstract_cls(): | def test_abstract_cls(): | ||||
| with pytest.raises(TypeError): | with pytest.raises(TypeError): | ||||
| Dataset() | Dataset() | ||||
| with pytest.raises(TypeError): | |||||
| MapDataset() | |||||
| with pytest.raises(TypeError): | with pytest.raises(TypeError): | ||||
| StreamDataset() | StreamDataset() | ||||