You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

dataloader.py 6.6 kB

4 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190
  1. from __future__ import absolute_import
  2. import os
  3. import numpy as np
  4. import multiprocessing as mp
  5. from . import ndarray
  6. from .gpu_ops.Node import Op
  7. # Multi-Process not useful now, since we don't have memory to CPU bottleneck
  8. class Dataloader(object):
  9. def __init__(self, raw_data, batch_size, name='default', func=None, drop_last=True):
  10. self.func = func if func else lambda x: x
  11. self.raw_data = np.array(self.func(raw_data), np.float32)
  12. self.batch_size = batch_size
  13. self.drop_last = drop_last
  14. self.name = str(name)
  15. def init_states(self, rank=None, nrank=None):
  16. if rank is not None:
  17. cur_size = self.raw_data.shape[0] // nrank
  18. start = cur_size * rank
  19. ending = start + cur_size
  20. self.raw_data = self.raw_data[start:ending]
  21. self.samples_num = len(self.raw_data)
  22. self.queue_size = 3 # if use prefetch, needs 3; if only current batch, needs 2
  23. self.batch_size = min(int(self.batch_size),
  24. self.samples_num // self.queue_size)
  25. assert self.batch_size > 0, 'Batch size %d invalid.' % self.batch_size
  26. self.batch_num = int(np.ceil(self.samples_num / self.batch_size)) if not self.drop_last else \
  27. self.samples_num // self.batch_size
  28. self.shape = tuple([self.batch_size] + list(self.raw_data.shape[1:]))
  29. self.seq = np.arange(self.samples_num)
  30. self.index = 0
  31. self.arrs = []
  32. self.arr_map = {}
  33. # prefetch to fill up the queue
  34. for i in range(self.queue_size):
  35. next_index = self.index + self.batch_size
  36. self.arrs.append(ndarray.array(
  37. self.raw_data[self.seq[self.index:next_index]], ctx=ndarray.cpu(0)))
  38. self.index = next_index
  39. self.arr_map[i] = i
  40. self.max_key = self.queue_size - 1
  41. # in case the last batch's shape is different, pre-allocate an array
  42. if not self.drop_last:
  43. res_num = self.samples_num % self.batch_size
  44. if res_num > 0:
  45. self.arrs.append(ndarray.empty(
  46. tuple([res_num] + list(self.shape[1:])), ctx=ndarray.cpu(0)))
  47. self.rest = self.queue_size
  48. self.batch_index = 0
  49. def _get_arr(self, batchind):
  50. # get specific batch
  51. # if the batch to be fetched is the newest one, replace the oldest with new batch
  52. assert batchind in self.arr_map
  53. res = self.arrs[self.arr_map[batchind]]
  54. if batchind == self.max_key:
  55. self.max_key = (self.max_key + 1) % self.samples_num
  56. min_key = (self.max_key - self.queue_size) % self.samples_num
  57. if self.index >= self.samples_num or (self.drop_last and self.index + self.batch_size > self.samples_num):
  58. self.index = 0
  59. next_index = self.index + self.batch_size
  60. if next_index <= self.samples_num:
  61. temp_ind = self.arr_map.pop(min_key)
  62. if temp_ind == self.queue_size and not self.drop_last:
  63. temp_ind = self.rest
  64. self.rest = self.queue_size
  65. self.arr_map[self.max_key] = temp_ind
  66. self.arrs[temp_ind][:] = self.raw_data[self.seq[self.index:next_index]]
  67. else:
  68. assert not self.drop_last
  69. self.arrs[-1][:] = self.raw_data[self.seq[self.index:next_index]]
  70. self.rest = self.arr_map.pop(min_key)
  71. self.arr_map[self.max_key] = self.queue_size
  72. self.index = next_index
  73. return res
  74. def get_arr(self):
  75. # step forward in this function
  76. res = self._get_arr(self.batch_index)
  77. self.last_batch_size = res.shape[0]
  78. self.batch_index = (self.batch_index + 1) % self.samples_num
  79. return res
  80. def get_next_arr(self):
  81. res = self._get_arr(self.batch_index)
  82. return res
  83. def get_cur_shape(self):
  84. return tuple(self.arrs[self.arr_map[self.batch_index]].shape)
  85. class GNNDataLoaderOp(Op):
  86. graph = None
  87. nxt_graph = None
  88. def __init__(self, handler, ctx=ndarray.cpu(0)):
  89. super().__init__(DataloaderOp, [], ctx)
  90. self.on_gpu = True
  91. self.on_cpu = False
  92. self.handler = handler
  93. self.name = "GNNDataloaderOp"
  94. self.desc = self.name
  95. def get_batch_num(self, name):
  96. return None
  97. def get_arr(self, name):
  98. return self.handler(self.graph)
  99. def get_next_arr(self, name):
  100. return self.handler(self.nxt_graph)
  101. def get_cur_shape(self, name):
  102. return self.handler(self.graph).shape
  103. def gradient(self, output_grad):
  104. return None
  105. def infer_shape(self, input_shapes):
  106. raise NotImplementedError
  107. @classmethod
  108. def step(cls, graph):
  109. cls.graph = cls.nxt_graph
  110. cls.nxt_graph = graph
  111. class DataloaderOp(Op):
  112. def __init__(self, dataloaders):
  113. super().__init__(DataloaderOp, [], ndarray.cpu(0))
  114. self.on_gpu = False
  115. self.on_cpu = True
  116. self.dataloaders = {
  117. dl.name: dl for dl in dataloaders
  118. }
  119. self.name = "DataloaderOp%d(%s)" % (
  120. self.id, '_'.join(self.dataloaders.keys()))
  121. self.desc = self.name
  122. def get_batch_num(self, name):
  123. return self.dataloaders[name].batch_num
  124. def get_arr(self, name):
  125. return self.dataloaders[name].get_arr()
  126. def get_next_arr(self, name):
  127. return self.dataloaders[name].get_next_arr()
  128. def get_cur_shape(self, name):
  129. return self.dataloaders[name].get_cur_shape()
  130. def gradient(self, output_grad):
  131. return None
  132. def infer_shape(self, input_shapes):
  133. # actually this function can never be called
  134. raise NotImplementedError
  135. def forward_hook(self, config):
  136. pass
  137. def backward_hook(self, config):
  138. for d in self.dataloaders.values():
  139. if config.context_launch:
  140. d.init_states(config.rank, config.nrank)
  141. else:
  142. d.init_states()
  143. def dataloader_op(dataloaders):
  144. '''
  145. dataloaders: list of dataloaders
  146. '''
  147. temp_dataloaders = []
  148. for dl in dataloaders:
  149. if isinstance(dl, Dataloader):
  150. temp_dataloaders.append(dl)
  151. elif isinstance(dl, list):
  152. temp_dataloaders.append(Dataloader(*dl))
  153. elif isinstance(dl, dict):
  154. temp_dataloaders.append(Dataloader(**dl))
  155. else:
  156. assert False, 'Dataloader parameter invalid.'
  157. return DataloaderOp(temp_dataloaders)