You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_dataset.py 6.6 kB

7 years ago
7 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170
  1. import os
  2. import unittest
  3. from fastNLP.core.dataset import DataSet
  4. from fastNLP.core.instance import Instance
  5. class TestDataSet(unittest.TestCase):
  6. def test_init_v1(self):
  7. ds = DataSet([Instance(x=[1, 2, 3, 4], y=[5, 6])] * 40)
  8. self.assertTrue("x" in ds.field_arrays and "y" in ds.field_arrays)
  9. self.assertEqual(ds.field_arrays["x"].content, [[1, 2, 3, 4], ] * 40)
  10. self.assertEqual(ds.field_arrays["y"].content, [[5, 6], ] * 40)
  11. def test_init_v2(self):
  12. ds = DataSet({"x": [[1, 2, 3, 4]] * 40, "y": [[5, 6]] * 40})
  13. self.assertTrue("x" in ds.field_arrays and "y" in ds.field_arrays)
  14. self.assertEqual(ds.field_arrays["x"].content, [[1, 2, 3, 4], ] * 40)
  15. self.assertEqual(ds.field_arrays["y"].content, [[5, 6], ] * 40)
  16. def test_init_assert(self):
  17. with self.assertRaises(AssertionError):
  18. _ = DataSet({"x": [[1, 2, 3, 4]] * 40, "y": [[5, 6]] * 100})
  19. with self.assertRaises(AssertionError):
  20. _ = DataSet([[1, 2, 3, 4]] * 10)
  21. with self.assertRaises(ValueError):
  22. _ = DataSet(0.00001)
  23. def test_append(self):
  24. dd = DataSet()
  25. for _ in range(3):
  26. dd.append(Instance(x=[1, 2, 3, 4], y=[5, 6]))
  27. self.assertEqual(len(dd), 3)
  28. self.assertEqual(dd.field_arrays["x"].content, [[1, 2, 3, 4]] * 3)
  29. self.assertEqual(dd.field_arrays["y"].content, [[5, 6]] * 3)
  30. def test_add_append(self):
  31. dd = DataSet()
  32. dd.add_field("x", [[1, 2, 3]] * 10)
  33. dd.add_field("y", [[1, 2, 3, 4]] * 10)
  34. dd.add_field("z", [[5, 6]] * 10)
  35. self.assertEqual(len(dd), 10)
  36. self.assertEqual(dd.field_arrays["x"].content, [[1, 2, 3]] * 10)
  37. self.assertEqual(dd.field_arrays["y"].content, [[1, 2, 3, 4]] * 10)
  38. self.assertEqual(dd.field_arrays["z"].content, [[5, 6]] * 10)
  39. with self.assertRaises(RuntimeError):
  40. dd.add_field("??", [[1, 2]] * 40)
  41. def test_delete_field(self):
  42. dd = DataSet()
  43. dd.add_field("x", [[1, 2, 3]] * 10)
  44. dd.add_field("y", [[1, 2, 3, 4]] * 10)
  45. dd.delete_field("x")
  46. self.assertFalse("x" in dd.field_arrays)
  47. self.assertTrue("y" in dd.field_arrays)
  48. def test_getitem(self):
  49. ds = DataSet({"x": [[1, 2, 3, 4]] * 40, "y": [[5, 6]] * 40})
  50. ins_1, ins_0 = ds[0], ds[1]
  51. self.assertTrue(isinstance(ins_1, Instance) and isinstance(ins_0, Instance))
  52. self.assertEqual(ins_1["x"], [1, 2, 3, 4])
  53. self.assertEqual(ins_1["y"], [5, 6])
  54. self.assertEqual(ins_0["x"], [1, 2, 3, 4])
  55. self.assertEqual(ins_0["y"], [5, 6])
  56. sub_ds = ds[:10]
  57. self.assertTrue(isinstance(sub_ds, DataSet))
  58. self.assertEqual(len(sub_ds), 10)
  59. def test_get_item_error(self):
  60. with self.assertRaises(RuntimeError):
  61. ds = DataSet({"x": [[1, 2, 3, 4]] * 10, "y": [[5, 6]] * 10})
  62. _ = ds[40:]
  63. with self.assertRaises(KeyError):
  64. ds = DataSet({"x": [[1, 2, 3, 4]] * 10, "y": [[5, 6]] * 10})
  65. _ = ds["kom"]
  66. def test_len_(self):
  67. ds = DataSet({"x": [[1, 2, 3, 4]] * 40, "y": [[5, 6]] * 40})
  68. self.assertEqual(len(ds), 40)
  69. ds = DataSet()
  70. self.assertEqual(len(ds), 0)
  71. def test_apply(self):
  72. ds = DataSet({"x": [[1, 2, 3, 4]] * 40, "y": [[5, 6]] * 40})
  73. ds.apply(lambda ins: ins["x"][::-1], new_field_name="rx")
  74. self.assertTrue("rx" in ds.field_arrays)
  75. self.assertEqual(ds.field_arrays["rx"].content[0], [4, 3, 2, 1])
  76. ds.apply(lambda ins: len(ins["y"]), new_field_name="y")
  77. self.assertEqual(ds.field_arrays["y"].content[0], 2)
  78. res = ds.apply(lambda ins: len(ins["x"]))
  79. self.assertTrue(isinstance(res, list) and len(res) > 0)
  80. self.assertTrue(res[0], 4)
  81. def test_drop(self):
  82. ds = DataSet({"x": [[1, 2, 3, 4]] * 40, "y": [[5, 6], [7, 8, 9, 0]] * 20})
  83. ds.drop(lambda ins: len(ins["y"]) < 3)
  84. self.assertEqual(len(ds), 20)
  85. def test_contains(self):
  86. ds = DataSet({"x": [[1, 2, 3, 4]] * 40, "y": [[5, 6]] * 40})
  87. self.assertTrue("x" in ds)
  88. self.assertTrue("y" in ds)
  89. self.assertFalse("z" in ds)
  90. def test_rename_field(self):
  91. ds = DataSet({"x": [[1, 2, 3, 4]] * 10, "y": [[5, 6]] * 10})
  92. ds.rename_field("x", "xx")
  93. self.assertTrue("xx" in ds)
  94. self.assertFalse("x" in ds)
  95. with self.assertRaises(KeyError):
  96. ds.rename_field("yyy", "oo")
  97. def test_input_target(self):
  98. ds = DataSet({"x": [[1, 2, 3, 4]] * 10, "y": [[5, 6]] * 10})
  99. ds.set_input("x")
  100. ds.set_target("y")
  101. self.assertTrue(ds.field_arrays["x"].is_input)
  102. self.assertTrue(ds.field_arrays["y"].is_target)
  103. with self.assertRaises(KeyError):
  104. ds.set_input("xxx")
  105. with self.assertRaises(KeyError):
  106. ds.set_input("yyy")
  107. def test_get_input_name(self):
  108. ds = DataSet({"x": [[1, 2, 3, 4]] * 10, "y": [[5, 6]] * 10})
  109. self.assertEqual(ds.get_input_name(), [_ for _ in ds.field_arrays if ds.field_arrays[_].is_input])
  110. def test_get_target_name(self):
  111. ds = DataSet({"x": [[1, 2, 3, 4]] * 10, "y": [[5, 6]] * 10})
  112. self.assertEqual(ds.get_target_name(), [_ for _ in ds.field_arrays if ds.field_arrays[_].is_target])
  113. def test_apply2(self):
  114. def split_sent(ins):
  115. return ins['raw_sentence'].split()
  116. dataset = DataSet.read_csv('test/data_for_tests/tutorial_sample_dataset.csv', headers=('raw_sentence', 'label'),
  117. sep='\t')
  118. dataset.drop(lambda x: len(x['raw_sentence'].split()) == 0)
  119. dataset.apply(split_sent, new_field_name='words', is_input=True)
  120. # print(dataset)
  121. def test_add_field(self):
  122. ds = DataSet({"x": [3, 4]})
  123. ds.add_field('y', [['hello', 'world'], ['this', 'is', 'a', 'test']], is_input=True, is_target=True)
  124. # ds.apply(lambda x:[x['x']]*3, is_input=True, is_target=True, new_field_name='y')
  125. print(ds)
  126. def test_save_load(self):
  127. ds = DataSet({"x": [[1, 2, 3, 4]] * 10, "y": [[5, 6]] * 10})
  128. ds.save("./my_ds.pkl")
  129. self.assertTrue(os.path.exists("./my_ds.pkl"))
  130. ds_1 = DataSet.load("./my_ds.pkl")
  131. os.remove("my_ds.pkl")
  132. class TestDataSetIter(unittest.TestCase):
  133. def test__repr__(self):
  134. ds = DataSet({"x": [[1, 2, 3, 4]] * 10, "y": [[5, 6]] * 10})
  135. for iter in ds:
  136. self.assertEqual(iter.__repr__(), "{'x': [1, 2, 3, 4],\n'y': [5, 6]}")