You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_batch.py 7.1 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189
  1. import time
  2. import unittest
  3. import numpy as np
  4. import torch
  5. from fastNLP.core.batch import Batch
  6. from fastNLP.core.dataset import DataSet
  7. from fastNLP.core.instance import Instance
  8. from fastNLP.core.sampler import SequentialSampler
  9. def generate_fake_dataset(num_samples=1000):
  10. """
  11. 产生的DataSet包含以下的field {'1':[], '2':[], '3': [], '4':[]}
  12. :param num_samples: sample的数量
  13. :return:
  14. """
  15. max_len = 50
  16. min_len = 10
  17. num_features = 4
  18. data_dict = {}
  19. for i in range(num_features):
  20. data = []
  21. lengths = np.random.randint(min_len, max_len, size=(num_samples))
  22. for length in lengths:
  23. data.append(np.random.randint(100, size=length))
  24. data_dict[str(i)] = data
  25. dataset = DataSet(data_dict)
  26. for i in range(num_features):
  27. if np.random.randint(2) == 0:
  28. dataset.set_input(str(i))
  29. else:
  30. dataset.set_target(str(i))
  31. return dataset
  32. def construct_dataset(sentences):
  33. """Construct a data set from a list of sentences.
  34. :param sentences: list of list of str
  35. :return dataset: a DataSet object
  36. """
  37. dataset = DataSet()
  38. for sentence in sentences:
  39. instance = Instance()
  40. instance['raw_sentence'] = sentence
  41. dataset.append(instance)
  42. return dataset
  43. class TestCase1(unittest.TestCase):
  44. def test_simple(self):
  45. dataset = construct_dataset(
  46. [["FastNLP", "is", "the", "most", "beautiful", "tool", "in", "the", "world"] for _ in range(40)])
  47. dataset.set_target()
  48. batch = Batch(dataset, batch_size=4, sampler=SequentialSampler(), as_numpy=True)
  49. cnt = 0
  50. for _, _ in batch:
  51. cnt += 1
  52. self.assertEqual(cnt, 10)
  53. def test_dataset_batching(self):
  54. ds = DataSet({"x": [[1, 2, 3, 4]] * 40, "y": [[5, 6]] * 40})
  55. ds.set_input("x")
  56. ds.set_target("y")
  57. iter = Batch(ds, batch_size=4, sampler=SequentialSampler(), as_numpy=True)
  58. for x, y in iter:
  59. self.assertTrue(isinstance(x["x"], np.ndarray) and isinstance(y["y"], np.ndarray))
  60. self.assertEqual(len(x["x"]), 4)
  61. self.assertEqual(len(y["y"]), 4)
  62. self.assertListEqual(list(x["x"][-1]), [1, 2, 3, 4])
  63. self.assertListEqual(list(y["y"][-1]), [5, 6])
  64. def test_list_padding(self):
  65. ds = DataSet({"x": [[1], [1, 2], [1, 2, 3], [1, 2, 3, 4]] * 10,
  66. "y": [[4, 3, 2, 1], [3, 2, 1], [2, 1], [1]] * 10})
  67. ds.set_input("x")
  68. ds.set_target("y")
  69. iter = Batch(ds, batch_size=4, sampler=SequentialSampler(), as_numpy=True)
  70. for x, y in iter:
  71. self.assertEqual(x["x"].shape, (4, 4))
  72. self.assertEqual(y["y"].shape, (4, 4))
  73. def test_numpy_padding(self):
  74. ds = DataSet({"x": np.array([[1], [1, 2], [1, 2, 3], [1, 2, 3, 4]] * 10),
  75. "y": np.array([[4, 3, 2, 1], [3, 2, 1], [2, 1], [1]] * 10)})
  76. ds.set_input("x")
  77. ds.set_target("y")
  78. iter = Batch(ds, batch_size=4, sampler=SequentialSampler(), as_numpy=True)
  79. for x, y in iter:
  80. self.assertEqual(x["x"].shape, (4, 4))
  81. self.assertEqual(y["y"].shape, (4, 4))
  82. def test_list_to_tensor(self):
  83. ds = DataSet({"x": [[1], [1, 2], [1, 2, 3], [1, 2, 3, 4]] * 10,
  84. "y": [[4, 3, 2, 1], [3, 2, 1], [2, 1], [1]] * 10})
  85. ds.set_input("x")
  86. ds.set_target("y")
  87. iter = Batch(ds, batch_size=4, sampler=SequentialSampler(), as_numpy=False)
  88. for x, y in iter:
  89. self.assertTrue(isinstance(x["x"], torch.Tensor))
  90. self.assertEqual(tuple(x["x"].shape), (4, 4))
  91. self.assertTrue(isinstance(y["y"], torch.Tensor))
  92. self.assertEqual(tuple(y["y"].shape), (4, 4))
  93. def test_numpy_to_tensor(self):
  94. ds = DataSet({"x": np.array([[1], [1, 2], [1, 2, 3], [1, 2, 3, 4]] * 10),
  95. "y": np.array([[4, 3, 2, 1], [3, 2, 1], [2, 1], [1]] * 10)})
  96. ds.set_input("x")
  97. ds.set_target("y")
  98. iter = Batch(ds, batch_size=4, sampler=SequentialSampler(), as_numpy=False)
  99. for x, y in iter:
  100. self.assertTrue(isinstance(x["x"], torch.Tensor))
  101. self.assertEqual(tuple(x["x"].shape), (4, 4))
  102. self.assertTrue(isinstance(y["y"], torch.Tensor))
  103. self.assertEqual(tuple(y["y"].shape), (4, 4))
  104. def test_list_of_list_to_tensor(self):
  105. ds = DataSet([Instance(x=[1, 2], y=[3, 4]) for _ in range(2)] +
  106. [Instance(x=[1, 2, 3, 4], y=[3, 4, 5, 6]) for _ in range(2)])
  107. ds.set_input("x")
  108. ds.set_target("y")
  109. iter = Batch(ds, batch_size=4, sampler=SequentialSampler(), as_numpy=False)
  110. for x, y in iter:
  111. self.assertTrue(isinstance(x["x"], torch.Tensor))
  112. self.assertEqual(tuple(x["x"].shape), (4, 4))
  113. self.assertTrue(isinstance(y["y"], torch.Tensor))
  114. self.assertEqual(tuple(y["y"].shape), (4, 4))
  115. def test_list_of_numpy_to_tensor(self):
  116. ds = DataSet([Instance(x=np.array([1, 2]), y=np.array([3, 4])) for _ in range(2)] +
  117. [Instance(x=np.array([1, 2, 3, 4]), y=np.array([3, 4, 5, 6])) for _ in range(2)])
  118. ds.set_input("x")
  119. ds.set_target("y")
  120. iter = Batch(ds, batch_size=4, sampler=SequentialSampler(), as_numpy=False)
  121. for x, y in iter:
  122. print(x, y)
  123. def test_sequential_batch(self):
  124. batch_size = 32
  125. pause_seconds = 0.01
  126. num_samples = 1000
  127. dataset = generate_fake_dataset(num_samples)
  128. batch = Batch(dataset, batch_size=batch_size, sampler=SequentialSampler())
  129. for batch_x, batch_y in batch:
  130. time.sleep(pause_seconds)
  131. """
  132. def test_multi_workers_batch(self):
  133. batch_size = 32
  134. pause_seconds = 0.01
  135. num_samples = 1000
  136. dataset = generate_fake_dataset(num_samples)
  137. num_workers = 1
  138. batch = Batch(dataset, batch_size=batch_size, sampler=SequentialSampler(), num_workers=num_workers)
  139. for batch_x, batch_y in batch:
  140. time.sleep(pause_seconds)
  141. num_workers = 2
  142. batch = Batch(dataset, batch_size=batch_size, sampler=SequentialSampler(), num_workers=num_workers)
  143. end1 = time.time()
  144. for batch_x, batch_y in batch:
  145. time.sleep(pause_seconds)
  146. """
  147. """
  148. def test_pin_memory(self):
  149. batch_size = 32
  150. pause_seconds = 0.01
  151. num_samples = 1000
  152. dataset = generate_fake_dataset(num_samples)
  153. batch = Batch(dataset, batch_size=batch_size, sampler=SequentialSampler(), pin_memory=True)
  154. # 这里发生OOM
  155. # for batch_x, batch_y in batch:
  156. # time.sleep(pause_seconds)
  157. num_workers = 2
  158. batch = Batch(dataset, batch_size=batch_size, sampler=SequentialSampler(), num_workers=num_workers,
  159. pin_memory=True)
  160. # 这里发生OOM
  161. # for batch_x, batch_y in batch:
  162. # time.sleep(pause_seconds)
  163. """