You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

dataloader.py 2.6 kB

7 years ago
7 years ago
7 years ago
7 years ago
7 years ago
7 years ago
7 years ago
7 years ago
12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182
  1. import pickle
  2. import random
  3. import torch
  4. from torch.autograd import Variable
  5. def float_wrapper(x, requires_grad=True, using_cuda=True):
  6. """
  7. transform float type list to pytorch variable
  8. """
  9. if using_cuda==True:
  10. return Variable(torch.FloatTensor(x).cuda(), requires_grad=requires_grad)
  11. else:
  12. return Variable(torch.FloatTensor(x), requires_grad=requires_grad)
  13. def long_wrapper(x, requires_grad=True, using_cuda=True):
  14. """
  15. transform long type list to pytorch variable
  16. """
  17. if using_cuda==True:
  18. return Variable(torch.LongTensor(x).cuda(), requires_grad=requires_grad)
  19. else:
  20. return Variable(torch.LongTensor(x), requires_grad=requires_grad)
  21. def pad(X, using_cuda):
  22. """
  23. zero-pad sequnces to same length then pack them together
  24. """
  25. maxlen = max([x.size(0) for x in X])
  26. Y = []
  27. for x in X:
  28. padlen = maxlen - x.size(0)
  29. if padlen > 0:
  30. if using_cuda:
  31. paddings = Variable(torch.zeros(padlen).long()).cuda()
  32. else:
  33. paddings = Variable(torch.zeros(padlen).long())
  34. x_ = torch.cat((x, paddings), 0)
  35. Y.append(x_)
  36. else:
  37. Y.append(x)
  38. return torch.stack(Y)
  39. class DataLoader(object):
  40. """
  41. load data with form {"feature", "class"}
  42. Args:
  43. fdir : data file address
  44. batch_size : batch_size
  45. shuffle : if True, shuffle dataset every epoch
  46. using_cuda : if True, return tensors on GPU
  47. """
  48. def __init__(self, fdir, batch_size, shuffle=True, using_cuda=True):
  49. with open(fdir, "rb") as f:
  50. self.data = pickle.load(f)
  51. self.batch_size = batch_size
  52. self.num = len(self.data)
  53. self.count = 0
  54. self.iters = int(self.num / batch_size)
  55. self.shuffle = shuffle
  56. self.using_cuda = using_cuda
  57. def __iter__(self):
  58. return self
  59. def __next__(self):
  60. if self.count == self.iters:
  61. self.count = 0
  62. if self.shuffle:
  63. random.shuffle(self.data)
  64. raise StopIteration()
  65. else:
  66. batch = self.data[self.count * self.batch_size : (self.count + 1) * self.batch_size]
  67. self.count += 1
  68. X = [long_wrapper(x["sent"], using_cuda=self.using_cuda, requires_grad=False) for x in batch]
  69. X = pad(X, self.using_cuda)
  70. y = long_wrapper([x["class"] for x in batch], using_cuda=self.using_cuda, requires_grad=False)
  71. return {"feature" : X, "class" : y}