from __future__ import absolute_import import os import numpy as np import multiprocessing as mp from . import ndarray from .gpu_ops.Node import Op # Multi-Process not useful now, since we don't have memory to CPU bottleneck class Dataloader(object): def __init__(self, raw_data, batch_size, name='default', func=None, drop_last=True): self.func = func if func else lambda x: x self.raw_data = np.array(self.func(raw_data), np.float32) self.batch_size = batch_size self.drop_last = drop_last self.name = str(name) def init_states(self, rank=None, nrank=None): if rank is not None: cur_size = self.raw_data.shape[0] // nrank start = cur_size * rank ending = start + cur_size self.raw_data = self.raw_data[start:ending] self.samples_num = len(self.raw_data) self.queue_size = 3 # if use prefetch, needs 3; if only current batch, needs 2 self.batch_size = min(int(self.batch_size), self.samples_num // self.queue_size) assert self.batch_size > 0, 'Batch size %d invalid.' % self.batch_size self.batch_num = int(np.ceil(self.samples_num / self.batch_size)) if not self.drop_last else \ self.samples_num // self.batch_size self.shape = tuple([self.batch_size] + list(self.raw_data.shape[1:])) self.seq = np.arange(self.samples_num) self.index = 0 self.arrs = [] self.arr_map = {} # prefetch to fill up the queue for i in range(self.queue_size): next_index = self.index + self.batch_size self.arrs.append(ndarray.array( self.raw_data[self.seq[self.index:next_index]], ctx=ndarray.cpu(0))) self.index = next_index self.arr_map[i] = i self.max_key = self.queue_size - 1 # in case the last batch's shape is different, pre-allocate an array if not self.drop_last: res_num = self.samples_num % self.batch_size if res_num > 0: self.arrs.append(ndarray.empty( tuple([res_num] + list(self.shape[1:])), ctx=ndarray.cpu(0))) self.rest = self.queue_size self.batch_index = 0 def _get_arr(self, batchind): # get specific batch # if the batch to be fetched is the newest one, replace the oldest with new batch assert batchind in self.arr_map res = self.arrs[self.arr_map[batchind]] if batchind == self.max_key: self.max_key = (self.max_key + 1) % self.samples_num min_key = (self.max_key - self.queue_size) % self.samples_num if self.index >= self.samples_num or (self.drop_last and self.index + self.batch_size > self.samples_num): self.index = 0 next_index = self.index + self.batch_size if next_index <= self.samples_num: temp_ind = self.arr_map.pop(min_key) if temp_ind == self.queue_size and not self.drop_last: temp_ind = self.rest self.rest = self.queue_size self.arr_map[self.max_key] = temp_ind self.arrs[temp_ind][:] = self.raw_data[self.seq[self.index:next_index]] else: assert not self.drop_last self.arrs[-1][:] = self.raw_data[self.seq[self.index:next_index]] self.rest = self.arr_map.pop(min_key) self.arr_map[self.max_key] = self.queue_size self.index = next_index return res def get_arr(self): # step forward in this function res = self._get_arr(self.batch_index) self.last_batch_size = res.shape[0] self.batch_index = (self.batch_index + 1) % self.samples_num return res def get_next_arr(self): res = self._get_arr(self.batch_index) return res def get_cur_shape(self): return tuple(self.arrs[self.arr_map[self.batch_index]].shape) class GNNDataLoaderOp(Op): graph = None nxt_graph = None def __init__(self, handler, ctx=ndarray.cpu(0)): super().__init__(DataloaderOp, [], ctx) self.on_gpu = True self.on_cpu = False self.handler = handler self.name = "GNNDataloaderOp" self.desc = self.name def get_batch_num(self, name): return None def get_arr(self, name): return self.handler(self.graph) def get_next_arr(self, name): return self.handler(self.nxt_graph) def get_cur_shape(self, name): return self.handler(self.graph).shape def gradient(self, output_grad): return None def infer_shape(self, input_shapes): raise NotImplementedError @classmethod def step(cls, graph): cls.graph = cls.nxt_graph cls.nxt_graph = graph class DataloaderOp(Op): def __init__(self, dataloaders): super().__init__(DataloaderOp, [], ndarray.cpu(0)) self.on_gpu = False self.on_cpu = True self.dataloaders = { dl.name: dl for dl in dataloaders } self.name = "DataloaderOp%d(%s)" % ( self.id, '_'.join(self.dataloaders.keys())) self.desc = self.name def get_batch_num(self, name): return self.dataloaders[name].batch_num def get_arr(self, name): return self.dataloaders[name].get_arr() def get_next_arr(self, name): return self.dataloaders[name].get_next_arr() def get_cur_shape(self, name): return self.dataloaders[name].get_cur_shape() def gradient(self, output_grad): return None def infer_shape(self, input_shapes): # actually this function can never be called raise NotImplementedError def forward_hook(self, config): pass def backward_hook(self, config): for d in self.dataloaders.values(): if config.context_launch: d.init_states(config.rank, config.nrank) else: d.init_states() def dataloader_op(dataloaders): ''' dataloaders: list of dataloaders ''' temp_dataloaders = [] for dl in dataloaders: if isinstance(dl, Dataloader): temp_dataloaders.append(dl) elif isinstance(dl, list): temp_dataloaders.append(Dataloader(*dl)) elif isinstance(dl, dict): temp_dataloaders.append(Dataloader(**dl)) else: assert False, 'Dataloader parameter invalid.' return DataloaderOp(temp_dataloaders)