You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_fieldarray.py 1.8 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354
  1. import unittest
  2. import numpy as np
  3. from fastNLP.core.fieldarray import FieldArray
  4. class TestFieldArray(unittest.TestCase):
  5. def test(self):
  6. fa = FieldArray("x", [1, 2, 3, 4, 5], is_input=True)
  7. self.assertEqual(len(fa), 5)
  8. fa.append(6)
  9. self.assertEqual(len(fa), 6)
  10. self.assertEqual(fa[-1], 6)
  11. self.assertEqual(fa[0], 1)
  12. fa[-1] = 60
  13. self.assertEqual(fa[-1], 60)
  14. self.assertEqual(fa.get(0), 1)
  15. self.assertTrue(isinstance(fa.get([0, 1, 2]), np.ndarray))
  16. self.assertListEqual(list(fa.get([0, 1, 2])), [1, 2, 3])
  17. def test_type_conversion(self):
  18. fa = FieldArray("x", [1.2, 2.2, 3, 4, 5], is_input=True)
  19. self.assertEqual(fa.pytype, float)
  20. self.assertEqual(fa.dtype, np.float64)
  21. fa = FieldArray("x", [1, 2, 3, 4, 5], is_input=True)
  22. fa.append(1.3333)
  23. self.assertEqual(fa.pytype, float)
  24. self.assertEqual(fa.dtype, np.float64)
  25. fa = FieldArray("y", [1.1, 2.2, 3.3, 4.4, 5.5], is_input=False)
  26. fa.append(10)
  27. self.assertEqual(fa.pytype, float)
  28. self.assertEqual(fa.dtype, np.float64)
  29. fa = FieldArray("y", ["a", "b", "c", "d"], is_input=False)
  30. fa.append("e")
  31. self.assertEqual(fa.dtype, np.str)
  32. self.assertEqual(fa.pytype, str)
  33. def test_support_np_array(self):
  34. fa = FieldArray("y", [np.array([1.1, 2.2, 3.3, 4.4, 5.5])], is_input=False)
  35. self.assertEqual(fa.dtype, np.ndarray)
  36. fa.append(np.array([1.1, 2.2, 3.3, 4.4, 5.5]))
  37. self.assertEqual(fa.pytype, np.ndarray)
  38. def test_nested_list(self):
  39. fa = FieldArray("y", [[1.1, 2.2, 3.3, 4.4, 5.5], [1.1, 2.2, 3.3, 4.4, 5.5]], is_input=False)
  40. self.assertEqual(fa.pytype, float)
  41. self.assertEqual(fa.dtype, np.float64)