You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_dataset.py 9.2 kB

7 years ago
7 years ago
7 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230
  1. import os
  2. import unittest
  3. from fastNLP.core.dataset import DataSet
  4. from fastNLP.core.fieldarray import FieldArray
  5. from fastNLP.core.instance import Instance
  6. class TestDataSetInit(unittest.TestCase):
  7. """初始化DataSet的办法有以下几种:
  8. 1) 用dict:
  9. 1.1) 二维list DataSet({"x": [[1, 2], [3, 4]]})
  10. 1.2) 二维array DataSet({"x": np.array([[1, 2], [3, 4]])})
  11. 1.3) 三维list DataSet({"x": [[[1, 2], [3, 4]], [[1, 2], [3, 4]]]})
  12. 2) 用list of Instance:
  13. 2.1) 一维list DataSet([Instance(x=[1, 2, 3, 4])])
  14. 2.2) 一维array DataSet([Instance(x=np.array([1, 2, 3, 4]))])
  15. 2.3) 二维list DataSet([Instance(x=[[1, 2], [3, 4]])])
  16. 2.4) 二维array DataSet([Instance(x=np.array([[1, 2], [3, 4]]))])
  17. 只接受纯list或者最外层ndarray
  18. """
  19. def test_init_v1(self):
  20. # 一维list
  21. ds = DataSet([Instance(x=[1, 2, 3, 4], y=[5, 6])] * 40)
  22. self.assertTrue("x" in ds.field_arrays and "y" in ds.field_arrays)
  23. self.assertEqual(ds.field_arrays["x"].content, [[1, 2, 3, 4], ] * 40)
  24. self.assertEqual(ds.field_arrays["y"].content, [[5, 6], ] * 40)
  25. def test_init_v2(self):
  26. # 用dict
  27. ds = DataSet({"x": [[1, 2, 3, 4]] * 40, "y": [[5, 6]] * 40})
  28. self.assertTrue("x" in ds.field_arrays and "y" in ds.field_arrays)
  29. self.assertEqual(ds.field_arrays["x"].content, [[1, 2, 3, 4], ] * 40)
  30. self.assertEqual(ds.field_arrays["y"].content, [[5, 6], ] * 40)
  31. def test_init_assert(self):
  32. with self.assertRaises(AssertionError):
  33. _ = DataSet({"x": [[1, 2, 3, 4]] * 40, "y": [[5, 6]] * 100})
  34. with self.assertRaises(AssertionError):
  35. _ = DataSet([[1, 2, 3, 4]] * 10)
  36. with self.assertRaises(ValueError):
  37. _ = DataSet(0.00001)
  38. class TestDataSetMethods(unittest.TestCase):
  39. def test_append(self):
  40. dd = DataSet()
  41. for _ in range(3):
  42. dd.append(Instance(x=[1, 2, 3, 4], y=[5, 6]))
  43. self.assertEqual(len(dd), 3)
  44. self.assertEqual(dd.field_arrays["x"].content, [[1, 2, 3, 4]] * 3)
  45. self.assertEqual(dd.field_arrays["y"].content, [[5, 6]] * 3)
  46. def test_add_field(self):
  47. dd = DataSet()
  48. dd.add_field("x", [[1, 2, 3]] * 10)
  49. dd.add_field("y", [[1, 2, 3, 4]] * 10)
  50. dd.add_field("z", [[5, 6]] * 10)
  51. self.assertEqual(len(dd), 10)
  52. self.assertEqual(dd.field_arrays["x"].content, [[1, 2, 3]] * 10)
  53. self.assertEqual(dd.field_arrays["y"].content, [[1, 2, 3, 4]] * 10)
  54. self.assertEqual(dd.field_arrays["z"].content, [[5, 6]] * 10)
  55. with self.assertRaises(RuntimeError):
  56. dd.add_field("??", [[1, 2]] * 40)
  57. def test_add_field_ignore_type(self):
  58. dd = DataSet()
  59. dd.add_field("x", [(1, "1"), (2, "2"), (3, "3"), (4, "4")], ignore_type=True, is_target=True)
  60. dd.add_field("y", [{1, "1"}, {2, "2"}, {3, "3"}, {4, "4"}], ignore_type=True, is_target=True)
  61. def test_delete_field(self):
  62. dd = DataSet()
  63. dd.add_field("x", [[1, 2, 3]] * 10)
  64. dd.add_field("y", [[1, 2, 3, 4]] * 10)
  65. dd.delete_field("x")
  66. self.assertFalse("x" in dd.field_arrays)
  67. self.assertTrue("y" in dd.field_arrays)
  68. def test_getitem(self):
  69. ds = DataSet({"x": [[1, 2, 3, 4]] * 40, "y": [[5, 6]] * 40})
  70. ins_1, ins_0 = ds[0], ds[1]
  71. self.assertTrue(isinstance(ins_1, Instance) and isinstance(ins_0, Instance))
  72. self.assertEqual(ins_1["x"], [1, 2, 3, 4])
  73. self.assertEqual(ins_1["y"], [5, 6])
  74. self.assertEqual(ins_0["x"], [1, 2, 3, 4])
  75. self.assertEqual(ins_0["y"], [5, 6])
  76. sub_ds = ds[:10]
  77. self.assertTrue(isinstance(sub_ds, DataSet))
  78. self.assertEqual(len(sub_ds), 10)
  79. def test_get_item_error(self):
  80. with self.assertRaises(RuntimeError):
  81. ds = DataSet({"x": [[1, 2, 3, 4]] * 10, "y": [[5, 6]] * 10})
  82. _ = ds[40:]
  83. with self.assertRaises(KeyError):
  84. ds = DataSet({"x": [[1, 2, 3, 4]] * 10, "y": [[5, 6]] * 10})
  85. _ = ds["kom"]
  86. def test_len_(self):
  87. ds = DataSet({"x": [[1, 2, 3, 4]] * 40, "y": [[5, 6]] * 40})
  88. self.assertEqual(len(ds), 40)
  89. ds = DataSet()
  90. self.assertEqual(len(ds), 0)
  91. def test_apply(self):
  92. ds = DataSet({"x": [[1, 2, 3, 4]] * 40, "y": [[5, 6]] * 40})
  93. ds.apply(lambda ins: ins["x"][::-1], new_field_name="rx")
  94. self.assertTrue("rx" in ds.field_arrays)
  95. self.assertEqual(ds.field_arrays["rx"].content[0], [4, 3, 2, 1])
  96. ds.apply(lambda ins: len(ins["y"]), new_field_name="y")
  97. self.assertEqual(ds.field_arrays["y"].content[0], 2)
  98. res = ds.apply(lambda ins: len(ins["x"]))
  99. self.assertTrue(isinstance(res, list) and len(res) > 0)
  100. self.assertTrue(res[0], 4)
  101. ds.apply(lambda ins: (len(ins["x"]), "hahaha"), new_field_name="k", ignore_type=True)
  102. # expect no exception raised
  103. def test_drop(self):
  104. ds = DataSet({"x": [[1, 2, 3, 4]] * 40, "y": [[5, 6], [7, 8, 9, 0]] * 20})
  105. ds.drop(lambda ins: len(ins["y"]) < 3)
  106. self.assertEqual(len(ds), 20)
  107. def test_contains(self):
  108. ds = DataSet({"x": [[1, 2, 3, 4]] * 40, "y": [[5, 6]] * 40})
  109. self.assertTrue("x" in ds)
  110. self.assertTrue("y" in ds)
  111. self.assertFalse("z" in ds)
  112. def test_rename_field(self):
  113. ds = DataSet({"x": [[1, 2, 3, 4]] * 10, "y": [[5, 6]] * 10})
  114. ds.rename_field("x", "xx")
  115. self.assertTrue("xx" in ds)
  116. self.assertFalse("x" in ds)
  117. with self.assertRaises(KeyError):
  118. ds.rename_field("yyy", "oo")
  119. def test_input_target(self):
  120. ds = DataSet({"x": [[1, 2, 3, 4]] * 10, "y": [[5, 6]] * 10})
  121. ds.set_input("x")
  122. ds.set_target("y")
  123. self.assertTrue(ds.field_arrays["x"].is_input)
  124. self.assertTrue(ds.field_arrays["y"].is_target)
  125. with self.assertRaises(KeyError):
  126. ds.set_input("xxx")
  127. with self.assertRaises(KeyError):
  128. ds.set_input("yyy")
  129. def test_get_input_name(self):
  130. ds = DataSet({"x": [[1, 2, 3, 4]] * 10, "y": [[5, 6]] * 10})
  131. self.assertEqual(ds.get_input_name(), [_ for _ in ds.field_arrays if ds.field_arrays[_].is_input])
  132. def test_get_target_name(self):
  133. ds = DataSet({"x": [[1, 2, 3, 4]] * 10, "y": [[5, 6]] * 10})
  134. self.assertEqual(ds.get_target_name(), [_ for _ in ds.field_arrays if ds.field_arrays[_].is_target])
  135. def test_apply2(self):
  136. def split_sent(ins):
  137. return ins['raw_sentence'].split()
  138. dataset = DataSet.read_csv('test/data_for_tests/tutorial_sample_dataset.csv', headers=('raw_sentence', 'label'),
  139. sep='\t')
  140. dataset.drop(lambda x: len(x['raw_sentence'].split()) == 0)
  141. dataset.apply(split_sent, new_field_name='words', is_input=True)
  142. # print(dataset)
  143. def test_add_field_v2(self):
  144. ds = DataSet({"x": [3, 4]})
  145. ds.add_field('y', [['hello', 'world'], ['this', 'is', 'a', 'test']], is_input=True, is_target=True)
  146. # ds.apply(lambda x:[x['x']]*3, is_input=True, is_target=True, new_field_name='y')
  147. print(ds)
  148. def test_save_load(self):
  149. ds = DataSet({"x": [[1, 2, 3, 4]] * 10, "y": [[5, 6]] * 10})
  150. ds.save("./my_ds.pkl")
  151. self.assertTrue(os.path.exists("./my_ds.pkl"))
  152. ds_1 = DataSet.load("./my_ds.pkl")
  153. os.remove("my_ds.pkl")
  154. def test_get_all_fields(self):
  155. ds = DataSet({"x": [[1, 2, 3, 4]] * 10, "y": [[5, 6]] * 10})
  156. ans = ds.get_all_fields()
  157. self.assertEqual(ans["x"].content, [[1, 2, 3, 4]] * 10)
  158. self.assertEqual(ans["y"].content, [[5, 6]] * 10)
  159. def test_get_field(self):
  160. ds = DataSet({"x": [[1, 2, 3, 4]] * 10, "y": [[5, 6]] * 10})
  161. ans = ds.get_field("x")
  162. self.assertTrue(isinstance(ans, FieldArray))
  163. self.assertEqual(ans.content, [[1, 2, 3, 4]] * 10)
  164. ans = ds.get_field("y")
  165. self.assertTrue(isinstance(ans, FieldArray))
  166. self.assertEqual(ans.content, [[5, 6]] * 10)
  167. def test_reader(self):
  168. # 跑通即可
  169. ds = DataSet().read_naive("test/data_for_tests/tutorial_sample_dataset.csv")
  170. self.assertTrue(isinstance(ds, DataSet))
  171. self.assertTrue(len(ds) > 0)
  172. ds = DataSet().read_rawdata("test/data_for_tests/people_daily_raw.txt")
  173. self.assertTrue(isinstance(ds, DataSet))
  174. self.assertTrue(len(ds) > 0)
  175. ds = DataSet().read_pos("test/data_for_tests/people.txt")
  176. self.assertTrue(isinstance(ds, DataSet))
  177. self.assertTrue(len(ds) > 0)
  178. def test_add_null(self):
  179. # TODO test failed because 'fastNLP\core\fieldarray.py:143: RuntimeError'
  180. ds = DataSet()
  181. ds.add_field('test', [])
  182. ds.set_target('test')
  183. class TestDataSetIter(unittest.TestCase):
  184. def test__repr__(self):
  185. ds = DataSet({"x": [[1, 2, 3, 4]] * 10, "y": [[5, 6]] * 10})
  186. for iter in ds:
  187. self.assertEqual(iter.__repr__(), "{'x': [1, 2, 3, 4] type=list,\n'y': [5, 6] type=list}")