You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_dataset.py 19 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516
  1. import os
  2. import unittest
  3. import numpy as np
  4. from fastNLP.core.dataset import DataSet, FieldArray, Instance, ApplyResultException
  5. class TestDataSetInit(unittest.TestCase):
  6. """初始化DataSet的办法有以下几种:
  7. 1) 用dict:
  8. 1.1) 二维list DataSet({"x": [[1, 2], [3, 4]]})
  9. 1.2) 二维array DataSet({"x": np.array([[1, 2], [3, 4]])})
  10. 1.3) 三维list DataSet({"x": [[[1, 2], [3, 4]], [[1, 2], [3, 4]]]})
  11. 2) 用list of Instance:
  12. 2.1) 一维list DataSet([Instance(x=[1, 2, 3, 4])])
  13. 2.2) 一维array DataSet([Instance(x=np.array([1, 2, 3, 4]))])
  14. 2.3) 二维list DataSet([Instance(x=[[1, 2], [3, 4]])])
  15. 2.4) 二维array DataSet([Instance(x=np.array([[1, 2], [3, 4]]))])
  16. 只接受纯list或者最外层ndarray
  17. """
  18. def test_init_v1(self):
  19. # 一维list
  20. ds = DataSet([Instance(x=[1, 2, 3, 4], y=[5, 6])] * 40)
  21. self.assertTrue("x" in ds.field_arrays and "y" in ds.field_arrays)
  22. self.assertEqual(ds.field_arrays["x"].content, [[1, 2, 3, 4], ] * 40)
  23. self.assertEqual(ds.field_arrays["y"].content, [[5, 6], ] * 40)
  24. def test_init_v2(self):
  25. # 用dict
  26. ds = DataSet({"x": [[1, 2, 3, 4]] * 40, "y": [[5, 6]] * 40})
  27. self.assertTrue("x" in ds.field_arrays and "y" in ds.field_arrays)
  28. self.assertEqual(ds.field_arrays["x"].content, [[1, 2, 3, 4], ] * 40)
  29. self.assertEqual(ds.field_arrays["y"].content, [[5, 6], ] * 40)
  30. def test_init_assert(self):
  31. with self.assertRaises(AssertionError):
  32. _ = DataSet({"x": [[1, 2, 3, 4]] * 40, "y": [[5, 6]] * 100})
  33. with self.assertRaises(AssertionError):
  34. _ = DataSet([[1, 2, 3, 4]] * 10)
  35. with self.assertRaises(ValueError):
  36. _ = DataSet(0.00001)
  37. class TestDataSetMethods(unittest.TestCase):
  38. def test_append(self):
  39. dd = DataSet()
  40. for _ in range(3):
  41. dd.append(Instance(x=[1, 2, 3, 4], y=[5, 6]))
  42. self.assertEqual(len(dd), 3)
  43. self.assertEqual(dd.field_arrays["x"].content, [[1, 2, 3, 4]] * 3)
  44. self.assertEqual(dd.field_arrays["y"].content, [[5, 6]] * 3)
  45. def test_add_field(self):
  46. dd = DataSet()
  47. dd.add_field("x", [[1, 2, 3]] * 10)
  48. dd.add_field("y", [[1, 2, 3, 4]] * 10)
  49. dd.add_field("z", [[5, 6]] * 10)
  50. self.assertEqual(len(dd), 10)
  51. self.assertEqual(dd.field_arrays["x"].content, [[1, 2, 3]] * 10)
  52. self.assertEqual(dd.field_arrays["y"].content, [[1, 2, 3, 4]] * 10)
  53. self.assertEqual(dd.field_arrays["z"].content, [[5, 6]] * 10)
  54. with self.assertRaises(RuntimeError):
  55. dd.add_field("??", [[1, 2]] * 40)
  56. def test_delete_field(self):
  57. dd = DataSet()
  58. dd.add_field("x", [[1, 2, 3]] * 10)
  59. dd.add_field("y", [[1, 2, 3, 4]] * 10)
  60. dd.delete_field("x")
  61. self.assertFalse("x" in dd.field_arrays)
  62. self.assertTrue("y" in dd.field_arrays)
  63. def test_delete_instance(self):
  64. dd = DataSet()
  65. old_length = 2
  66. dd.add_field("x", [[1, 2, 3]] * old_length)
  67. dd.add_field("y", [[1, 2, 3, 4]] * old_length)
  68. dd.delete_instance(0)
  69. self.assertEqual(len(dd), old_length - 1)
  70. dd.delete_instance(0)
  71. self.assertEqual(len(dd), old_length - 2)
  72. def test_getitem(self):
  73. ds = DataSet({"x": [[1, 2, 3, 4]] * 40, "y": [[5, 6]] * 40})
  74. ins_1, ins_0 = ds[0], ds[1]
  75. self.assertTrue(isinstance(ins_1, Instance) and isinstance(ins_0, Instance))
  76. self.assertEqual(ins_1["x"], [1, 2, 3, 4])
  77. self.assertEqual(ins_1["y"], [5, 6])
  78. self.assertEqual(ins_0["x"], [1, 2, 3, 4])
  79. self.assertEqual(ins_0["y"], [5, 6])
  80. sub_ds = ds[:10]
  81. self.assertTrue(isinstance(sub_ds, DataSet))
  82. self.assertEqual(len(sub_ds), 10)
  83. sub_ds_1 = ds[[10, 0, 2, 3]]
  84. self.assertTrue(isinstance(sub_ds_1, DataSet))
  85. self.assertEqual(len(sub_ds_1), 4)
  86. field_array = ds['x']
  87. self.assertTrue(isinstance(field_array, FieldArray))
  88. self.assertEqual(len(field_array), 40)
  89. def test_get_item_error(self):
  90. with self.assertRaises(RuntimeError):
  91. ds = DataSet({"x": [[1, 2, 3, 4]] * 10, "y": [[5, 6]] * 10})
  92. _ = ds[40:]
  93. with self.assertRaises(KeyError):
  94. ds = DataSet({"x": [[1, 2, 3, 4]] * 10, "y": [[5, 6]] * 10})
  95. _ = ds["kom"]
  96. def test_len_(self):
  97. ds = DataSet({"x": [[1, 2, 3, 4]] * 40, "y": [[5, 6]] * 40})
  98. self.assertEqual(len(ds), 40)
  99. ds = DataSet()
  100. self.assertEqual(len(ds), 0)
  101. def test_add_fieldarray(self):
  102. ds = DataSet({"x": [[1, 2, 3, 4]] * 40, "y": [[5, 6]] * 40})
  103. ds.add_fieldarray('z', FieldArray('z', [[7, 8]]*40))
  104. self.assertEqual(ds['z'].content, [[7, 8]]*40)
  105. with self.assertRaises(RuntimeError):
  106. ds.add_fieldarray('z', FieldArray('z', [[7, 8]]*10))
  107. with self.assertRaises(TypeError):
  108. ds.add_fieldarray('z', [1, 2, 4])
  109. def test_copy_field(self):
  110. ds = DataSet({"x": [[1, 2, 3, 4]] * 40, "y": [[5, 6]] * 40})
  111. ds.copy_field('x', 'z')
  112. self.assertEqual(ds['x'].content, ds['z'].content)
  113. def test_has_field(self):
  114. ds = DataSet({"x": [[1, 2, 3, 4]] * 40, "y": [[5, 6]] * 40})
  115. self.assertTrue(ds.has_field('x'))
  116. self.assertFalse(ds.has_field('z'))
  117. def test_get_field(self):
  118. ds = DataSet({"x": [[1, 2, 3, 4]] * 40, "y": [[5, 6]] * 40})
  119. with self.assertRaises(KeyError):
  120. ds.get_field('z')
  121. x_array = ds.get_field('x')
  122. self.assertEqual(x_array.content, [[1, 2, 3, 4]] * 40)
  123. def test_get_all_fields(self):
  124. ds = DataSet({"x": [[1, 2, 3, 4]] * 40, "y": [[5, 6]] * 40})
  125. field_arrays = ds.get_all_fields()
  126. self.assertEqual(field_arrays["x"], [[1, 2, 3, 4]] * 40)
  127. self.assertEqual(field_arrays['y'], [[5, 6]] * 40)
  128. def test_get_field_names(self):
  129. ds = DataSet({"x": [[1, 2, 3, 4]] * 40, "y": [[5, 6]] * 40})
  130. field_names = ds.get_field_names()
  131. self.assertTrue('x' in field_names)
  132. self.assertTrue('y' in field_names)
  133. def test_apply(self):
  134. ds = DataSet({"x": [[1, 2, 3, 4]] * 4000, "y": [[5, 6]] * 4000})
  135. ds.apply(lambda ins: ins["x"][::-1], new_field_name="rx", progress_desc='rx')
  136. self.assertTrue("rx" in ds.field_arrays)
  137. self.assertEqual(ds.field_arrays["rx"].content[0], [4, 3, 2, 1])
  138. ds.apply(lambda ins: len(ins["y"]), new_field_name="y", show_progress_bar=False)
  139. self.assertEqual(ds.field_arrays["y"].content[0], 2)
  140. res = ds.apply(lambda ins: len(ins["x"]), num_proc=0, progress_desc="len")
  141. self.assertTrue(isinstance(res, list) and len(res) > 0)
  142. self.assertTrue(res[0], 4)
  143. ds.apply(lambda ins: (len(ins["x"]), "hahaha"), new_field_name="k")
  144. # expect no exception raised
  145. def test_apply_progress_bar(self):
  146. import time
  147. ds = DataSet({"x": [[1, 2, 3, 4]] * 400, "y": [[5, 6]] * 400})
  148. def do_nothing(ins):
  149. time.sleep(0.01)
  150. ds.apply(do_nothing, show_progress_bar=True, num_proc=0)
  151. ds.apply_field(do_nothing, field_name='x', show_progress_bar=True)
  152. def test_apply_cannot_modify_instance(self):
  153. ds = DataSet({"x": [[1, 2, 3, 4]] * 40, "y": [[5, 6]] * 40})
  154. def modify_inplace(instance):
  155. instance['words'] = 1
  156. ds.apply(modify_inplace)
  157. # with self.assertRaises(TypeError):
  158. # ds.apply(modify_inplace)
  159. def test_apply_more(self):
  160. T = DataSet({"a": [1, 2, 3], "b": [2, 4, 5]})
  161. func_1 = lambda x: {"c": x["a"] * 2, "d": x["a"] ** 2}
  162. func_2 = lambda x: {"c": x * 3, "d": x ** 3}
  163. def func_err_1(x):
  164. if x["a"] == 1:
  165. return {"e": x["a"] * 2, "f": x["a"] ** 2}
  166. else:
  167. return {"e": x["a"] * 2}
  168. def func_err_2(x):
  169. if x == 1:
  170. return {"e": x * 2, "f": x ** 2}
  171. else:
  172. return {"e": x * 2}
  173. T.apply_more(func_1)
  174. # print(T['c'][0, 1, 2])
  175. self.assertEqual(list(T["c"].content), [2, 4, 6])
  176. self.assertEqual(list(T["d"].content), [1, 4, 9])
  177. res = T.apply_field_more(func_2, "a", modify_fields=False)
  178. self.assertEqual(list(T["c"].content), [2, 4, 6])
  179. self.assertEqual(list(T["d"].content), [1, 4, 9])
  180. self.assertEqual(list(res["c"]), [3, 6, 9])
  181. self.assertEqual(list(res["d"]), [1, 8, 27])
  182. with self.assertRaises(ApplyResultException) as e:
  183. T.apply_more(func_err_1)
  184. print(e)
  185. with self.assertRaises(ApplyResultException) as e:
  186. T.apply_field_more(func_err_2, "a")
  187. print(e)
  188. def test_drop(self):
  189. ds = DataSet({"x": [[1, 2, 3, 4]] * 40, "y": [[5, 6], [7, 8, 9, 0]] * 20})
  190. ds.drop(lambda ins: len(ins["y"]) < 3, inplace=True)
  191. self.assertEqual(len(ds), 20)
  192. def test_contains(self):
  193. ds = DataSet({"x": [[1, 2, 3, 4]] * 40, "y": [[5, 6]] * 40})
  194. self.assertTrue("x" in ds)
  195. self.assertTrue("y" in ds)
  196. self.assertFalse("z" in ds)
  197. def test_rename_field(self):
  198. ds = DataSet({"x": [[1, 2, 3, 4]] * 10, "y": [[5, 6]] * 10})
  199. ds.rename_field("x", "xx")
  200. self.assertTrue("xx" in ds)
  201. self.assertFalse("x" in ds)
  202. with self.assertRaises(KeyError):
  203. ds.rename_field("yyy", "oo")
  204. def test_split(self):
  205. ds = DataSet({"x": [[1, 2, 3, 4]] * 10, "y": [[5, 6]] * 10})
  206. d1, d2 = ds.split(0.1)
  207. self.assertEqual(len(d1), len(ds)*0.9)
  208. self.assertEqual(len(d2), len(ds)*0.1)
  209. def test_add_field_v2(self):
  210. ds = DataSet({"x": [3, 4]})
  211. ds.add_field('y', [['hello', 'world'], ['this', 'is', 'a', 'test']])
  212. # ds.apply(lambda x:[x['x']]*3, new_field_name='y')
  213. print(ds)
  214. def test_save_load(self):
  215. ds = DataSet({"x": [[1, 2, 3, 4]] * 10, "y": [[5, 6]] * 10})
  216. ds.save("./my_ds.pkl")
  217. self.assertTrue(os.path.exists("./my_ds.pkl"))
  218. ds_1 = DataSet.load("./my_ds.pkl")
  219. os.remove("my_ds.pkl")
  220. def test_add_null(self):
  221. ds = DataSet()
  222. with self.assertRaises(RuntimeError) as RE:
  223. ds.add_field('test', [])
  224. def test_concat(self):
  225. """
  226. 测试两个dataset能否正确concat
  227. """
  228. ds1 = DataSet({"x": [[1, 2, 3, 4] for _ in range(10)], "y": [[5, 6] for _ in range(10)]})
  229. ds2 = DataSet({"x": [[4, 3, 2, 1] for _ in range(10)], "y": [[6, 5] for _ in range(10)]})
  230. ds3 = ds1.concat(ds2)
  231. self.assertEqual(len(ds3), 20)
  232. self.assertListEqual(ds1[9]['x'], [1, 2, 3, 4])
  233. self.assertListEqual(ds1[10]['x'], [4, 3, 2, 1])
  234. ds2[0]['x'][0] = 100
  235. self.assertEqual(ds3[10]['x'][0], 4) # 不改变copy后的field了
  236. ds3[10]['x'][0] = -100
  237. self.assertEqual(ds2[0]['x'][0], 100) # 不改变copy前的field了
  238. # 测试inplace
  239. ds1 = DataSet({"x": [[1, 2, 3, 4] for i in range(10)], "y": [[5, 6] for i in range(10)]})
  240. ds2 = DataSet({"x": [[4, 3, 2, 1] for i in range(10)], "y": [[6, 5] for i in range(10)]})
  241. ds3 = ds1.concat(ds2, inplace=True)
  242. ds2[0]['x'][0] = 100
  243. self.assertEqual(ds3[10]['x'][0], 4) # 不改变copy后的field了
  244. ds3[10]['x'][0] = -100
  245. self.assertEqual(ds2[0]['x'][0], 100) # 不改变copy前的field了
  246. ds3[0]['x'][0] = 100
  247. self.assertEqual(ds1[0]['x'][0], 100) # 改变copy前的field了
  248. # 测试mapping
  249. ds1 = DataSet({"x": [[1, 2, 3, 4] for i in range(10)], "y": [[5, 6] for i in range(10)]})
  250. ds2 = DataSet({"X": [[4, 3, 2, 1] for i in range(10)], "Y": [[6, 5] for i in range(10)]})
  251. ds3 = ds1.concat(ds2, field_mapping={'X': 'x', 'Y': 'y'})
  252. self.assertEqual(len(ds3), 20)
  253. # 测试忽略掉多余的
  254. ds1 = DataSet({"x": [[1, 2, 3, 4] for i in range(10)], "y": [[5, 6] for i in range(10)]})
  255. ds2 = DataSet({"X": [[4, 3, 2, 1] for i in range(10)], "Y": [[6, 5] for i in range(10)], 'Z': [0] * 10})
  256. ds3 = ds1.concat(ds2, field_mapping={'X': 'x', 'Y': 'y'})
  257. # 测试报错
  258. ds1 = DataSet({"x": [[1, 2, 3, 4] for i in range(10)], "y": [[5, 6] for i in range(10)]})
  259. ds2 = DataSet({"X": [[4, 3, 2, 1] for i in range(10)]})
  260. with self.assertRaises(RuntimeError):
  261. ds3 = ds1.concat(ds2, field_mapping={'X': 'x'})
  262. def test_instance_field_disappear_bug(self):
  263. data = DataSet({'raw_chars': [[0, 1], [2]], 'target': [0, 1]})
  264. data.copy_field(field_name='raw_chars', new_field_name='chars')
  265. _data = data[:1]
  266. for field_name in ['raw_chars', 'target', 'chars']:
  267. self.assertTrue(_data.has_field(field_name))
  268. def test_from_pandas(self):
  269. import pandas as pd
  270. df = pd.DataFrame({'x': [1, 2, 3], 'y': [4, 5, 6]})
  271. ds = DataSet.from_pandas(df)
  272. print(ds)
  273. self.assertEqual(ds['x'].content, [1, 2, 3])
  274. self.assertEqual(ds['y'].content, [4, 5, 6])
  275. def test_to_pandas(self):
  276. ds = DataSet({'x': [1, 2, 3], 'y': [4, 5, 6]})
  277. df = ds.to_pandas()
  278. def test_to_csv(self):
  279. ds = DataSet({'x': [1, 2, 3], 'y': [4, 5, 6]})
  280. ds.to_csv("1.csv")
  281. self.assertTrue(os.path.exists("1.csv"))
  282. os.remove("1.csv")
  283. def test_add_collate_fn(self):
  284. ds = DataSet({'x': [1, 2, 3], 'y': [4, 5, 6]})
  285. def collate_fn(item):
  286. return item
  287. ds.add_collate_fn(collate_fn)
  288. self.assertEqual(len(ds.collate_fns.collators), 2)
  289. def test_get_collator(self):
  290. from typing import Callable
  291. ds = DataSet({'x': [1, 2, 3], 'y': [4, 5, 6]})
  292. collate_fn = ds.get_collator()
  293. self.assertEqual(isinstance(collate_fn, Callable), True)
  294. def test_add_seq_len(self):
  295. ds = DataSet({'x': [[1, 2], [2, 3 , 4], [3]], 'y': [4, 5, 6]})
  296. ds.add_seq_len('x')
  297. print(ds)
  298. def test_set_target(self):
  299. ds = DataSet({'x': [[1, 2], [2, 3 , 4], [3]], 'y': [4, 5, 6]})
  300. ds.set_target('x')
  301. class TestFieldArrayInit(unittest.TestCase):
  302. """
  303. 1) 如果DataSet使用dict初始化,那么在add_field中会构造FieldArray:
  304. 1.1) 二维list DataSet({"x": [[1, 2], [3, 4]]})
  305. 1.2) 二维array DataSet({"x": np.array([[1, 2], [3, 4]])})
  306. 1.3) 三维list DataSet({"x": [[[1, 2], [3, 4]], [[1, 2], [3, 4]]]})
  307. 2) 如果DataSet使用list of Instance 初始化,那么在append中会先对第一个样本初始化FieldArray;
  308. 然后后面的样本使用FieldArray.append进行添加。
  309. 2.1) 一维list DataSet([Instance(x=[1, 2, 3, 4])])
  310. 2.2) 一维array DataSet([Instance(x=np.array([1, 2, 3, 4]))])
  311. 2.3) 二维list DataSet([Instance(x=[[1, 2], [3, 4]])])
  312. 2.4) 二维array DataSet([Instance(x=np.array([[1, 2], [3, 4]]))])
  313. """
  314. def test_init_v1(self):
  315. # 二维list
  316. fa = FieldArray("x", [[1, 2], [3, 4]] * 5)
  317. def test_init_v2(self):
  318. # 二维array
  319. fa = FieldArray("x", np.array([[1, 2], [3, 4]] * 5))
  320. def test_init_v3(self):
  321. # 三维list
  322. fa = FieldArray("x", [[[1, 2], [3, 4]], [[1, 2], [3, 4]]])
  323. def test_init_v4(self):
  324. # 一维list
  325. val = [1, 2, 3, 4]
  326. fa = FieldArray("x", [val])
  327. fa.append(val)
  328. def test_init_v5(self):
  329. # 一维array
  330. val = np.array([1, 2, 3, 4])
  331. fa = FieldArray("x", [val])
  332. fa.append(val)
  333. def test_init_v6(self):
  334. # 二维array
  335. val = [[1, 2], [3, 4]]
  336. fa = FieldArray("x", [val])
  337. fa.append(val)
  338. def test_init_v7(self):
  339. # list of array
  340. fa = FieldArray("x", [np.array([[1, 2], [3, 4]]), np.array([[1, 2], [3, 4]])])
  341. def test_init_v8(self):
  342. # 二维list
  343. val = np.array([[1, 2], [3, 4]])
  344. fa = FieldArray("x", [val])
  345. fa.append(val)
  346. class TestFieldArray(unittest.TestCase):
  347. def test_main(self):
  348. fa = FieldArray("x", [1, 2, 3, 4, 5])
  349. self.assertEqual(len(fa), 5)
  350. fa.append(6)
  351. self.assertEqual(len(fa), 6)
  352. self.assertEqual(fa[-1], 6)
  353. self.assertEqual(fa[0], 1)
  354. fa[-1] = 60
  355. self.assertEqual(fa[-1], 60)
  356. self.assertEqual(fa.get(0), 1)
  357. self.assertTrue(isinstance(fa.get([0, 1, 2]), np.ndarray))
  358. self.assertListEqual(list(fa.get([0, 1, 2])), [1, 2, 3])
  359. def test_getitem_v1(self):
  360. fa = FieldArray("y", [[1.1, 2.2, 3.3, 4.4, 5.5], [1.0, 2.0, 3.0, 4.0, 5.0]])
  361. self.assertEqual(fa[0], [1.1, 2.2, 3.3, 4.4, 5.5])
  362. ans = fa[[0, 1]]
  363. self.assertTrue(isinstance(ans, np.ndarray))
  364. self.assertTrue(isinstance(ans[0], np.ndarray))
  365. self.assertEqual(ans[0].tolist(), [1.1, 2.2, 3.3, 4.4, 5.5])
  366. self.assertEqual(ans[1].tolist(), [1, 2, 3, 4, 5])
  367. self.assertEqual(ans.dtype, np.float64)
  368. def test_getitem_v2(self):
  369. x = np.random.rand(10, 5)
  370. fa = FieldArray("my_field", x)
  371. indices = [0, 1, 3, 4, 6]
  372. for a, b in zip(fa[indices], x[indices]):
  373. self.assertListEqual(a.tolist(), b.tolist())
  374. def test_append(self):
  375. fa = FieldArray("y", [[1.1, 2.2, 3.3, 4.4, 5.5], [1.0, 2.0, 3.0, 4.0, 5.0]])
  376. fa.append([1.2, 2.3, 3.4, 4.5, 5.6])
  377. self.assertEqual(len(fa), 3)
  378. self.assertEqual(fa[2], [1.2, 2.3, 3.4, 4.5, 5.6])
  379. def test_pop(self):
  380. fa = FieldArray("y", [[1.1, 2.2, 3.3, 4.4, 5.5], [1.0, 2.0, 3.0, 4.0, 5.0]])
  381. fa.pop(0)
  382. self.assertEqual(len(fa), 1)
  383. self.assertEqual(fa[0], [1.0, 2.0, 3.0, 4.0, 5.0])
  384. fa[0] = [1.1, 2.2, 3.3, 4.4, 5.5]
  385. self.assertEqual(fa[0], [1.1, 2.2, 3.3, 4.4, 5.5])
  386. class TestCase(unittest.TestCase):
  387. def test_init(self):
  388. fields = {"x": [1, 2, 3], "y": [4, 5, 6]}
  389. ins = Instance(x=[1, 2, 3], y=[4, 5, 6])
  390. self.assertTrue(isinstance(ins.fields, dict))
  391. self.assertEqual(ins.fields, fields)
  392. ins = Instance(**fields)
  393. self.assertEqual(ins.fields, fields)
  394. def test_add_field(self):
  395. fields = {"x": [1, 2, 3], "y": [4, 5, 6]}
  396. ins = Instance(**fields)
  397. ins.add_field("z", [1, 1, 1])
  398. fields.update({"z": [1, 1, 1]})
  399. self.assertEqual(ins.fields, fields)
  400. def test_get_item(self):
  401. fields = {"x": [1, 2, 3], "y": [4, 5, 6], "z": [1, 1, 1]}
  402. ins = Instance(**fields)
  403. self.assertEqual(ins["x"], [1, 2, 3])
  404. self.assertEqual(ins["y"], [4, 5, 6])
  405. self.assertEqual(ins["z"], [1, 1, 1])
  406. def test_repr(self):
  407. fields = {"x": [1, 2, 3], "y": [4, 5, 6], "z": [1, 1, 1]}
  408. ins = Instance(**fields)
  409. # simple print, that is enough.
  410. print(ins)