You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_dataset.py 3.0 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172
  1. import unittest
  2. from fastNLP.core.dataset import DataSet
  3. from fastNLP.core.instance import Instance
  4. class TestDataSet(unittest.TestCase):
  5. def test_init_v1(self):
  6. ds = DataSet([Instance(x=[1, 2, 3, 4], y=[5, 6])] * 40)
  7. self.assertTrue("x" in ds.field_arrays and "y" in ds.field_arrays)
  8. self.assertEqual(ds.field_arrays["x"].content, [[1, 2, 3, 4], ] * 40)
  9. self.assertEqual(ds.field_arrays["y"].content, [[5, 6], ] * 40)
  10. def test_init_v2(self):
  11. ds = DataSet({"x": [[1, 2, 3, 4]] * 40, "y": [[5, 6]] * 40})
  12. self.assertTrue("x" in ds.field_arrays and "y" in ds.field_arrays)
  13. self.assertEqual(ds.field_arrays["x"].content, [[1, 2, 3, 4], ] * 40)
  14. self.assertEqual(ds.field_arrays["y"].content, [[5, 6], ] * 40)
  15. def test_init_assert(self):
  16. with self.assertRaises(AssertionError):
  17. _ = DataSet({"x": [[1, 2, 3, 4]] * 40, "y": [[5, 6]] * 100})
  18. with self.assertRaises(AssertionError):
  19. _ = DataSet([[1, 2, 3, 4]] * 10)
  20. with self.assertRaises(ValueError):
  21. _ = DataSet(0.00001)
  22. def test_append(self):
  23. dd = DataSet()
  24. for _ in range(3):
  25. dd.append(Instance(x=[1, 2, 3, 4], y=[5, 6]))
  26. self.assertEqual(len(dd), 3)
  27. self.assertEqual(dd.field_arrays["x"].content, [[1, 2, 3, 4]] * 3)
  28. self.assertEqual(dd.field_arrays["y"].content, [[5, 6]] * 3)
  29. def test_add_append(self):
  30. dd = DataSet()
  31. dd.add_field("x", [[1, 2, 3]] * 10)
  32. dd.add_field("y", [[1, 2, 3, 4]] * 10)
  33. dd.add_field("z", [[5, 6]] * 10)
  34. self.assertEqual(len(dd), 10)
  35. self.assertEqual(dd.field_arrays["x"].content, [[1, 2, 3]] * 10)
  36. self.assertEqual(dd.field_arrays["y"].content, [[1, 2, 3, 4]] * 10)
  37. self.assertEqual(dd.field_arrays["z"].content, [[5, 6]] * 10)
  38. def test_delete_field(self):
  39. dd = DataSet()
  40. dd.add_field("x", [[1, 2, 3]] * 10)
  41. dd.add_field("y", [[1, 2, 3, 4]] * 10)
  42. dd.delete_field("x")
  43. self.assertFalse("x" in dd.field_arrays)
  44. self.assertTrue("y" in dd.field_arrays)
  45. def test_getitem(self):
  46. ds = DataSet({"x": [[1, 2, 3, 4]] * 40, "y": [[5, 6]] * 40})
  47. ins_1, ins_0 = ds[0], ds[1]
  48. self.assertTrue(isinstance(ins_1, Instance) and isinstance(ins_0, Instance))
  49. self.assertEqual(ins_1["x"], [1, 2, 3, 4])
  50. self.assertEqual(ins_1["y"], [5, 6])
  51. self.assertEqual(ins_0["x"], [1, 2, 3, 4])
  52. self.assertEqual(ins_0["y"], [5, 6])
  53. sub_ds = ds[:10]
  54. self.assertTrue(isinstance(sub_ds, DataSet))
  55. self.assertEqual(len(sub_ds), 10)
  56. def test_apply(self):
  57. ds = DataSet({"x": [[1, 2, 3, 4]] * 40, "y": [[5, 6]] * 40})
  58. ds.apply(lambda ins: ins["x"][::-1], new_field_name="rx")
  59. self.assertTrue("rx" in ds.field_arrays)
  60. self.assertEqual(ds.field_arrays["rx"].content[0], [4, 3, 2, 1])