You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_batch.py 7.1 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189
  1. import unittest
  2. import numpy as np
  3. import torch
  4. from fastNLP import Batch
  5. from fastNLP import DataSet
  6. from fastNLP import Instance
  7. from fastNLP import SequentialSampler
  8. def generate_fake_dataset(num_samples=1000):
  9. """
  10. 产生的DataSet包含以下的field {'1':[], '2':[], '3': [], '4':[]}
  11. :param num_samples: sample的数量
  12. :return:
  13. """
  14. max_len = 50
  15. min_len = 10
  16. num_features = 4
  17. data_dict = {}
  18. for i in range(num_features):
  19. data = []
  20. lengths = np.random.randint(min_len, max_len, size=(num_samples))
  21. for length in lengths:
  22. data.append(np.random.randint(100, size=length))
  23. data_dict[str(i)] = data
  24. dataset = DataSet(data_dict)
  25. for i in range(num_features):
  26. if np.random.randint(2) == 0:
  27. dataset.set_input(str(i))
  28. else:
  29. dataset.set_target(str(i))
  30. return dataset
  31. def construct_dataset(sentences):
  32. """Construct a data set from a list of sentences.
  33. :param sentences: list of list of str
  34. :return dataset: a DataSet object
  35. """
  36. dataset = DataSet()
  37. for sentence in sentences:
  38. instance = Instance()
  39. instance['raw_sentence'] = sentence
  40. dataset.append(instance)
  41. return dataset
  42. class TestCase1(unittest.TestCase):
  43. def test_simple(self):
  44. dataset = construct_dataset(
  45. [["FastNLP", "is", "the", "most", "beautiful", "tool", "in", "the", "world"] for _ in range(40)])
  46. dataset.set_target()
  47. batch = Batch(dataset, batch_size=4, sampler=SequentialSampler(), as_numpy=True)
  48. cnt = 0
  49. for _, _ in batch:
  50. cnt += 1
  51. self.assertEqual(cnt, 10)
  52. def test_dataset_batching(self):
  53. ds = DataSet({"x": [[1, 2, 3, 4]] * 40, "y": [[5, 6]] * 40})
  54. ds.set_input("x")
  55. ds.set_target("y")
  56. iter = Batch(ds, batch_size=4, sampler=SequentialSampler(), as_numpy=True)
  57. for x, y in iter:
  58. self.assertTrue(isinstance(x["x"], np.ndarray) and isinstance(y["y"], np.ndarray))
  59. self.assertEqual(len(x["x"]), 4)
  60. self.assertEqual(len(y["y"]), 4)
  61. self.assertListEqual(list(x["x"][-1]), [1, 2, 3, 4])
  62. self.assertListEqual(list(y["y"][-1]), [5, 6])
  63. def test_list_padding(self):
  64. ds = DataSet({"x": [[1], [1, 2], [1, 2, 3], [1, 2, 3, 4]] * 10,
  65. "y": [[4, 3, 2, 1], [3, 2, 1], [2, 1], [1]] * 10})
  66. ds.set_input("x")
  67. ds.set_target("y")
  68. iter = Batch(ds, batch_size=4, sampler=SequentialSampler(), as_numpy=True)
  69. for x, y in iter:
  70. self.assertEqual(x["x"].shape, (4, 4))
  71. self.assertEqual(y["y"].shape, (4, 4))
  72. def test_numpy_padding(self):
  73. ds = DataSet({"x": np.array([[1], [1, 2], [1, 2, 3], [1, 2, 3, 4]] * 10),
  74. "y": np.array([[4, 3, 2, 1], [3, 2, 1], [2, 1], [1]] * 10)})
  75. ds.set_input("x")
  76. ds.set_target("y")
  77. iter = Batch(ds, batch_size=4, sampler=SequentialSampler(), as_numpy=True)
  78. for x, y in iter:
  79. self.assertEqual(x["x"].shape, (4, 4))
  80. self.assertEqual(y["y"].shape, (4, 4))
  81. def test_list_to_tensor(self):
  82. ds = DataSet({"x": [[1], [1, 2], [1, 2, 3], [1, 2, 3, 4]] * 10,
  83. "y": [[4, 3, 2, 1], [3, 2, 1], [2, 1], [1]] * 10})
  84. ds.set_input("x")
  85. ds.set_target("y")
  86. iter = Batch(ds, batch_size=4, sampler=SequentialSampler(), as_numpy=False)
  87. for x, y in iter:
  88. self.assertTrue(isinstance(x["x"], torch.Tensor))
  89. self.assertEqual(tuple(x["x"].shape), (4, 4))
  90. self.assertTrue(isinstance(y["y"], torch.Tensor))
  91. self.assertEqual(tuple(y["y"].shape), (4, 4))
  92. def test_numpy_to_tensor(self):
  93. ds = DataSet({"x": np.array([[1], [1, 2], [1, 2, 3], [1, 2, 3, 4]] * 10),
  94. "y": np.array([[4, 3, 2, 1], [3, 2, 1], [2, 1], [1]] * 10)})
  95. ds.set_input("x")
  96. ds.set_target("y")
  97. iter = Batch(ds, batch_size=4, sampler=SequentialSampler(), as_numpy=False)
  98. for x, y in iter:
  99. self.assertTrue(isinstance(x["x"], torch.Tensor))
  100. self.assertEqual(tuple(x["x"].shape), (4, 4))
  101. self.assertTrue(isinstance(y["y"], torch.Tensor))
  102. self.assertEqual(tuple(y["y"].shape), (4, 4))
  103. def test_list_of_list_to_tensor(self):
  104. ds = DataSet([Instance(x=[1, 2], y=[3, 4]) for _ in range(2)] +
  105. [Instance(x=[1, 2, 3, 4], y=[3, 4, 5, 6]) for _ in range(2)])
  106. ds.set_input("x")
  107. ds.set_target("y")
  108. iter = Batch(ds, batch_size=4, sampler=SequentialSampler(), as_numpy=False)
  109. for x, y in iter:
  110. self.assertTrue(isinstance(x["x"], torch.Tensor))
  111. self.assertEqual(tuple(x["x"].shape), (4, 4))
  112. self.assertTrue(isinstance(y["y"], torch.Tensor))
  113. self.assertEqual(tuple(y["y"].shape), (4, 4))
  114. def test_list_of_numpy_to_tensor(self):
  115. ds = DataSet([Instance(x=np.array([1, 2]), y=np.array([3, 4])) for _ in range(2)] +
  116. [Instance(x=np.array([1, 2, 3, 4]), y=np.array([3, 4, 5, 6])) for _ in range(2)])
  117. ds.set_input("x")
  118. ds.set_target("y")
  119. iter = Batch(ds, batch_size=4, sampler=SequentialSampler(), as_numpy=False)
  120. for x, y in iter:
  121. print(x, y)
  122. def test_sequential_batch(self):
  123. batch_size = 32
  124. num_samples = 1000
  125. dataset = generate_fake_dataset(num_samples)
  126. batch = Batch(dataset, batch_size=batch_size, sampler=SequentialSampler())
  127. for batch_x, batch_y in batch:
  128. pass
  129. """
  130. def test_multi_workers_batch(self):
  131. batch_size = 32
  132. pause_seconds = 0.01
  133. num_samples = 1000
  134. dataset = generate_fake_dataset(num_samples)
  135. num_workers = 1
  136. batch = Batch(dataset, batch_size=batch_size, sampler=SequentialSampler(), num_workers=num_workers)
  137. for batch_x, batch_y in batch:
  138. time.sleep(pause_seconds)
  139. num_workers = 2
  140. batch = Batch(dataset, batch_size=batch_size, sampler=SequentialSampler(), num_workers=num_workers)
  141. end1 = time.time()
  142. for batch_x, batch_y in batch:
  143. time.sleep(pause_seconds)
  144. """
  145. """
  146. def test_pin_memory(self):
  147. batch_size = 32
  148. pause_seconds = 0.01
  149. num_samples = 1000
  150. dataset = generate_fake_dataset(num_samples)
  151. batch = Batch(dataset, batch_size=batch_size, sampler=SequentialSampler(), pin_memory=True)
  152. # 这里发生OOM
  153. # for batch_x, batch_y in batch:
  154. # time.sleep(pause_seconds)
  155. num_workers = 2
  156. batch = Batch(dataset, batch_size=batch_size, sampler=SequentialSampler(), num_workers=num_workers,
  157. pin_memory=True)
  158. # 这里发生OOM
  159. # for batch_x, batch_y in batch:
  160. # time.sleep(pause_seconds)
  161. """