#! /usr/bin/python # -*- coding: utf-8 -*- import numpy as np import paddle from paddle.io import Dataset, BatchSampler, DataLoader, IterableDataset __all__ = [ 'Concat', 'FromGenerator', 'FromSlices', 'Map', # 'Shuffle', # 'Batch', 'Dataloader', ] def to_list(value): if value is None: return value if isinstance(value, (list, tuple)): return list(value) return [value] class FromGenerator(Dataset): def __init__(self, generator): if not callable(generator): raise TypeError("'generator' must be callable") self.generator = generator() self.datas = [] self.labels = [] for data, label in self.generator: self.datas.append(data) self.labels.append(label) def __getitem__(self, idx): x = self.datas[idx] y = self.labels[idx] return x, y def __len__(self): return self.datas.shape[0] class FromSlices(Dataset): def __init__(self, datas, transform = None): self.datas = datas[0] self.labels = datas[1] self.transform = transform if len(self.datas) != len(self.labels): raise ValueError('Datas and labels not have same shape of the 1st dimension.') def __getitem__(self, idx): data = paddle.to_tensor(self.datas[idx], dtype='float32') label = paddle.to_tensor(self.labels[idx], dtype='int64') if self.transform is not None: data = self.transform(data) return data, label def __len__(self): return len(self.datas) class Concat(IterableDataset): def __init__(self, datasets): self.datasets = list(datasets) assert len(self.datasets) > 0, "input datasets shoule not be empty" for i, dataset in enumerate(self.datasets): assert isinstance(dataset, IterableDataset), \ "ChainDataset only support paddle.io.IterableDataset" def __iter__(self): for dataset in self.datasets: for sample in dataset: yield sample class Map(Dataset): def __init__(self, dataset, transform): self.isDataset = False self.transform = transform if isinstance(dataset, Dataset): self.isDataset = True self.dataset = dataset elif isinstance(dataset, list) or isinstance(dataset, tuple): self.datas = dataset[0] self.labels = dataset[1] else: raise TypeError( " 'dataset' should be subclass instance of paddle.io.Dataset " "or a [data, label] list/tulpe, not a {}".format(type(dataset)) ) def __getitem__(self, idx): if self.isDataset: x = self.dataset[idx][0] if not isinstance(x, np.ndarray): x = np.asarray(x) x = self.transform(x) y = self.dataset[idx][1] else: x = self.datas[idx] if not isinstance(x, np.ndarray): x = np.asarray(x) x = self.transform(x) y = self.labels[idx] return x, y def __len__(self): if self.isDataset: return len(self.dataset[0]) else: return len(self.datas) def Dataloader(dataset, batch_size=None, shuffle=False, drop_last=False, prefetch=0, shuffle_buffer_size=0): return DataLoader(dataset, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last)