You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_var_batch_map.py 23 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467
  1. # Copyright 2019-2022 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. import numpy as np
  16. import mindspore.dataset as ds
  17. from mindspore import log as logger
  18. def test_batch_corner_cases():
  19. def gen(num):
  20. for i in range(num):
  21. yield (np.array([i]),)
  22. def test_repeat_batch(gen_num, repeats, batch_size, drop, res):
  23. data1 = ds.GeneratorDataset((lambda: gen(gen_num)), ["num"]).repeat(repeats).batch(batch_size, drop)
  24. for item in data1.create_dict_iterator(num_epochs=1, output_numpy=True):
  25. res.append(item["num"])
  26. def test_batch_repeat(gen_num, repeats, batch_size, drop, res):
  27. data1 = ds.GeneratorDataset((lambda: gen(gen_num)), ["num"]).batch(batch_size, drop).repeat(repeats)
  28. for item in data1.create_dict_iterator(num_epochs=1, output_numpy=True):
  29. res.append(item["num"])
  30. tst1, tst2, tst3, tst4 = [], [], [], []
  31. # case 1 & 2, where batch_size is greater than the entire epoch, with drop equals to both val
  32. test_repeat_batch(gen_num=2, repeats=4, batch_size=7, drop=False, res=tst1)
  33. np.testing.assert_array_equal(np.array([[0], [1], [0], [1], [0], [1], [0]]), tst1[0], "\nATTENTION BATCH FAILED\n")
  34. np.testing.assert_array_equal(np.array([[1]]), tst1[1], "\nATTENTION TEST BATCH FAILED\n")
  35. assert len(tst1) == 2, "\nATTENTION TEST BATCH FAILED\n"
  36. test_repeat_batch(gen_num=2, repeats=4, batch_size=5, drop=True, res=tst2)
  37. np.testing.assert_array_equal(np.array([[0], [1], [0], [1], [0]]), tst2[0], "\nATTENTION BATCH FAILED\n")
  38. assert len(tst2) == 1, "\nATTENTION TEST BATCH FAILED\n"
  39. # case 3 & 4, batch before repeat with different drop
  40. test_batch_repeat(gen_num=5, repeats=2, batch_size=4, drop=True, res=tst3)
  41. np.testing.assert_array_equal(np.array([[0], [1], [2], [3]]), tst3[0], "\nATTENTION BATCH FAILED\n")
  42. np.testing.assert_array_equal(tst3[0], tst3[1], "\nATTENTION BATCH FAILED\n")
  43. assert len(tst3) == 2, "\nATTENTION BATCH FAILED\n"
  44. test_batch_repeat(gen_num=5, repeats=2, batch_size=4, drop=False, res=tst4)
  45. np.testing.assert_array_equal(np.array([[0], [1], [2], [3]]), tst4[0], "\nATTENTION BATCH FAILED\n")
  46. np.testing.assert_array_equal(tst4[0], tst4[2], "\nATTENTION BATCH FAILED\n")
  47. np.testing.assert_array_equal(tst4[1], np.array([[4]]), "\nATTENTION BATCH FAILED\n")
  48. np.testing.assert_array_equal(tst4[1], tst4[3], "\nATTENTION BATCH FAILED\n")
  49. assert len(tst4) == 4, "\nATTENTION BATCH FAILED\n"
  50. def test_variable_size_batch():
  51. """
  52. Feature: Batch
  53. Description: Test batch variations with repeat and with/without per_batch_map.
  54. Each sub-test is tested with same parameters except that
  55. - the second test uses per_batch_map which passes each row a pyfunc and makes a deep copy of the row
  56. - the third test (if it exists) uses per_batch_map and python multiprocessing
  57. Expectation: Results are the same, independent of per_batch_map or python_multiprocessing settings
  58. """
  59. def check_res(arr1, arr2):
  60. for ind, _ in enumerate(arr1):
  61. if not np.array_equal(arr1[ind], np.array(arr2[ind])):
  62. return False
  63. return len(arr1) == len(arr2)
  64. def gen(num):
  65. for i in range(num):
  66. yield (np.array([i]),)
  67. def add_one_by_batch_num(batchInfo):
  68. return batchInfo.get_batch_num() + 1
  69. def add_one_by_epoch(batchInfo):
  70. return batchInfo.get_epoch_num() + 1
  71. def simple_copy(colList, batchInfo):
  72. _ = batchInfo
  73. return ([np.copy(arr) for arr in colList],)
  74. def test_repeat_batch(gen_num, r, drop, func, res):
  75. data1 = ds.GeneratorDataset((lambda: gen(gen_num)), ["num"]).repeat(r).batch(batch_size=func,
  76. drop_remainder=drop)
  77. for item in data1.create_dict_iterator(num_epochs=1, output_numpy=True):
  78. res.append(item["num"])
  79. # same as test_repeat_batch except each row is passed through via a map which makes a copy of each element
  80. def test_repeat_batch_with_copy_map(gen_num, r, drop, func):
  81. res = []
  82. data1 = ds.GeneratorDataset((lambda: gen(gen_num)), ["num"]).repeat(r) \
  83. .batch(batch_size=func, drop_remainder=drop, input_columns=["num"], per_batch_map=simple_copy)
  84. for item in data1.create_dict_iterator(num_epochs=1, output_numpy=True):
  85. res.append(item["num"])
  86. return res
  87. def test_batch_repeat(gen_num, r, drop, func, res):
  88. data1 = ds.GeneratorDataset((lambda: gen(gen_num)), ["num"]).batch(batch_size=func, drop_remainder=drop).repeat(
  89. r)
  90. for item in data1.create_dict_iterator(num_epochs=1, output_numpy=True):
  91. res.append(item["num"])
  92. # same as test_batch_repeat except each row is passed through via a map which makes a copy of each element
  93. def test_batch_repeat_with_copy_map(gen_num, r, drop, func):
  94. res = []
  95. data1 = ds.GeneratorDataset((lambda: gen(gen_num)), ["num"]) \
  96. .batch(batch_size=func, drop_remainder=drop, input_columns=["num"], per_batch_map=simple_copy).repeat(r)
  97. for item in data1.create_dict_iterator(num_epochs=1, output_numpy=True):
  98. res.append(item["num"])
  99. return res
  100. # same as test_batch_repeat_with_copy_map except with python multiprocessing enabled
  101. def test_batch_repeat_with_copy_map_multiproc(gen_num, r, drop, func, num_workers, my_maxrowsize):
  102. # Reduce memory required by disabling the shared memory optimization
  103. mem_original = ds.config.get_enable_shared_mem()
  104. ds.config.set_enable_shared_mem(False)
  105. res = []
  106. data1 = ds.GeneratorDataset((lambda: gen(gen_num)), ["num"], num_parallel_workers=num_workers,
  107. python_multiprocessing=True, max_rowsize=my_maxrowsize) \
  108. .batch(batch_size=func, drop_remainder=drop, input_columns=["num"], per_batch_map=simple_copy,
  109. num_parallel_workers=num_workers, python_multiprocessing=True,
  110. max_rowsize=my_maxrowsize).repeat(r)
  111. for item in data1.create_dict_iterator(num_epochs=1, output_numpy=True):
  112. res.append(item["num"])
  113. ds.config.set_enable_shared_mem(mem_original)
  114. return res
  115. tst1, tst2, tst3, tst4, tst5, tst6, tst7 = [], [], [], [], [], [], []
  116. # no repeat, simple var size, based on batch_num
  117. test_repeat_batch(7, 1, True, add_one_by_batch_num, tst1)
  118. assert check_res(tst1, [[[0]], [[1], [2]], [[3], [4], [5]]]), "\nATTENTION VAR BATCH FAILED\n"
  119. assert check_res(tst1, test_repeat_batch_with_copy_map(7, 1, True, add_one_by_batch_num)), "\nMAP FAILED\n"
  120. test_repeat_batch(9, 1, False, add_one_by_batch_num, tst2)
  121. assert check_res(tst2, [[[0]], [[1], [2]], [[3], [4], [5]], [[6], [7], [8]]]), "\nATTENTION VAR BATCH FAILED\n"
  122. assert check_res(tst2, test_repeat_batch_with_copy_map(9, 1, False, add_one_by_batch_num)), "\nMAP FAILED\n"
  123. # batch after repeat, cross epoch batch
  124. test_repeat_batch(7, 2, False, add_one_by_batch_num, tst3)
  125. assert check_res(tst3, [[[0]], [[1], [2]], [[3], [4], [5]], [[6], [0], [1], [2]],
  126. [[3], [4], [5], [6]]]), "\nATTENTION VAR BATCH FAILED\n"
  127. assert check_res(tst3, test_repeat_batch_with_copy_map(7, 2, False, add_one_by_batch_num)), "\nMAP FAILED\n"
  128. # repeat after batch, no cross epoch batch, remainder dropped
  129. test_batch_repeat(9, 7, True, add_one_by_batch_num, tst4)
  130. assert check_res(tst4, [[[0]], [[1], [2]], [[3], [4], [5]]] * 7), "\nATTENTION VAR BATCH FAILED\n"
  131. assert check_res(tst4, test_batch_repeat_with_copy_map(9, 7, True, add_one_by_batch_num)), "\nAMAP FAILED\n"
  132. # repeat after batch, no cross epoch batch, remainder kept
  133. test_batch_repeat(9, 3, False, add_one_by_batch_num, tst5)
  134. assert check_res(tst5, [[[0]], [[1], [2]], [[3], [4], [5]], [[6], [7], [8]]] * 3), "\nATTENTION VAR BATCH FAILED\n"
  135. assert check_res(tst5, test_batch_repeat_with_copy_map(9, 3, False, add_one_by_batch_num)), "\nMAP FAILED\n"
  136. # batch_size based on epoch number, drop
  137. test_batch_repeat(4, 4, True, add_one_by_epoch, tst6)
  138. assert check_res(tst6, [[[0]], [[1]], [[2]], [[3]], [[0], [1]], [[2], [3]], [[0], [1], [2]],
  139. [[0], [1], [2], [3]]]), "\nATTENTION VAR BATCH FAILED\n"
  140. assert check_res(tst6, test_batch_repeat_with_copy_map(4, 4, True, add_one_by_epoch)), "\nMAP FAILED\n"
  141. # batch_size based on epoch number, no drop
  142. test_batch_repeat(4, 4, False, add_one_by_epoch, tst7)
  143. assert check_res(tst7, [[[0]], [[1]], [[2]], [[3]], [[0], [1]], [[2], [3]], [[0], [1], [2]], [[3]],
  144. [[0], [1], [2], [3]]]), "\nATTENTION VAR BATCH FAILED\n" + str(tst7)
  145. assert check_res(tst7, test_batch_repeat_with_copy_map(4, 4, False, add_one_by_epoch)), "\nMAP FAILED\n"
  146. assert check_res(tst7, test_batch_repeat_with_copy_map_multiproc(
  147. 4, 4, False, add_one_by_epoch, 4, 1)), "\nMULTIPROC1 MAP FAILED\n"
  148. assert check_res(tst7, test_batch_repeat_with_copy_map_multiproc(
  149. 4, 4, False, add_one_by_epoch, 2, 2)), "\nMULTIPROC2 MAP FAILED\n"
  150. def test_basic_batch_map():
  151. def check_res(arr1, arr2):
  152. for ind, _ in enumerate(arr1):
  153. if not np.array_equal(arr1[ind], np.array(arr2[ind])):
  154. return False
  155. return len(arr1) == len(arr2)
  156. def gen(num):
  157. for i in range(num):
  158. yield (np.array([i]),)
  159. def invert_sign_per_epoch(colList, batchInfo):
  160. return ([np.copy(((-1) ** batchInfo.get_epoch_num()) * arr) for arr in colList],)
  161. def invert_sign_per_batch(colList, batchInfo):
  162. return ([np.copy(((-1) ** batchInfo.get_batch_num()) * arr) for arr in colList],)
  163. def batch_map_config(num, r, batch_size, func, res):
  164. data1 = ds.GeneratorDataset((lambda: gen(num)), ["num"]) \
  165. .batch(batch_size=batch_size, input_columns=["num"], per_batch_map=func).repeat(r)
  166. for item in data1.create_dict_iterator(num_epochs=1, output_numpy=True):
  167. res.append(item["num"])
  168. tst1, tst2, = [], []
  169. batch_map_config(4, 2, 2, invert_sign_per_epoch, tst1)
  170. assert check_res(tst1, [[[0], [1]], [[2], [3]], [[0], [-1]], [[-2], [-3]]]), "\nATTENTION MAP BATCH FAILED\n" + str(
  171. tst1)
  172. # each batch, the sign of a row is changed, test map is corrected performed according to its batch_num
  173. batch_map_config(4, 2, 2, invert_sign_per_batch, tst2)
  174. assert check_res(tst2,
  175. [[[0], [1]], [[-2], [-3]], [[0], [1]], [[-2], [-3]]]), "\nATTENTION MAP BATCH FAILED\n" + str(tst2)
  176. def test_batch_multi_col_map():
  177. def check_res(arr1, arr2):
  178. for ind, _ in enumerate(arr1):
  179. if not np.array_equal(arr1[ind], np.array(arr2[ind])):
  180. return False
  181. return len(arr1) == len(arr2)
  182. def gen(num):
  183. for i in range(num):
  184. yield (np.array([i]), np.array([i ** 2]))
  185. def col1_col2_add_num(col1, col2, batchInfo):
  186. _ = batchInfo
  187. return ([[np.copy(arr + 100) for arr in col1],
  188. [np.copy(arr + 300) for arr in col2]])
  189. def invert_sign_per_batch(colList, batchInfo):
  190. return ([np.copy(((-1) ** batchInfo.get_batch_num()) * arr) for arr in colList],)
  191. def invert_sign_per_batch_multi_col(col1, col2, batchInfo):
  192. return ([np.copy(((-1) ** batchInfo.get_batch_num()) * arr) for arr in col1],
  193. [np.copy(((-1) ** batchInfo.get_batch_num()) * arr) for arr in col2])
  194. def batch_map_config(num, r, batch_size, func, col_names, res):
  195. data1 = ds.GeneratorDataset((lambda: gen(num)), ["num", "num_square"]) \
  196. .batch(batch_size=batch_size, input_columns=col_names, per_batch_map=func).repeat(r)
  197. for item in data1.create_dict_iterator(num_epochs=1, output_numpy=True):
  198. res.append(np.array([item["num"], item["num_square"]]))
  199. tst1, tst2, tst3, tst4 = [], [], [], []
  200. batch_map_config(4, 2, 2, invert_sign_per_batch, ["num_square"], tst1)
  201. assert check_res(tst1, [[[[0], [1]], [[0], [1]]], [[[2], [3]], [[-4], [-9]]], [[[0], [1]], [[0], [1]]],
  202. [[[2], [3]], [[-4], [-9]]]]), "\nATTENTION MAP BATCH FAILED\n" + str(tst1)
  203. batch_map_config(4, 2, 2, invert_sign_per_batch_multi_col, ["num", "num_square"], tst2)
  204. assert check_res(tst2, [[[[0], [1]], [[0], [1]]], [[[-2], [-3]], [[-4], [-9]]], [[[0], [1]], [[0], [1]]],
  205. [[[-2], [-3]], [[-4], [-9]]]]), "\nATTENTION MAP BATCH FAILED\n" + str(tst2)
  206. # the two tests below verify the order of the map.
  207. # num_square column adds 100, num column adds 300.
  208. batch_map_config(4, 3, 2, col1_col2_add_num, ["num_square", "num"], tst3)
  209. assert check_res(tst3, [[[[300], [301]], [[100], [101]]],
  210. [[[302], [303]], [[104], [109]]]] * 3), "\nATTENTION MAP BATCH FAILED\n" + str(tst3)
  211. # num column adds 100, num_square column adds 300.
  212. batch_map_config(4, 3, 2, col1_col2_add_num, ["num", "num_square"], tst4)
  213. assert check_res(tst4, [[[[100], [101]], [[300], [301]]],
  214. [[[102], [103]], [[304], [309]]]] * 3), "\nATTENTION MAP BATCH FAILED\n" + str(tst4)
  215. def test_var_batch_multi_col_map():
  216. def check_res(arr1, arr2):
  217. for ind, _ in enumerate(arr1):
  218. if not np.array_equal(arr1[ind], np.array(arr2[ind])):
  219. return False
  220. return len(arr1) == len(arr2)
  221. # gen 3 columns
  222. # first column: 0, 3, 6, 9 ... ...
  223. # second column:1, 4, 7, 10 ... ...
  224. # third column: 2, 5, 8, 11 ... ...
  225. def gen_3_cols(num):
  226. for i in range(num):
  227. yield (np.array([i * 3]), np.array([i * 3 + 1]), np.array([i * 3 + 2]))
  228. # first epoch batch_size per batch: 1, 2 ,3 ... ...
  229. # second epoch batch_size per batch: 2, 4, 6 ... ...
  230. # third epoch batch_size per batch: 3, 6 ,9 ... ...
  231. def batch_func(batchInfo):
  232. return (batchInfo.get_batch_num() + 1) * (batchInfo.get_epoch_num() + 1)
  233. # multiply first col by batch_num, multiply second col by -batch_num
  234. def map_func(col1, col2, batchInfo):
  235. return ([np.copy((1 + batchInfo.get_batch_num()) * arr) for arr in col1],
  236. [np.copy(-(1 + batchInfo.get_batch_num()) * arr) for arr in col2])
  237. def batch_map_config(num, r, fbatch, fmap, col_names, res):
  238. data1 = ds.GeneratorDataset((lambda: gen_3_cols(num)), ["col1", "col2", "col3"]) \
  239. .batch(batch_size=fbatch, input_columns=col_names, per_batch_map=fmap).repeat(r)
  240. for item in data1.create_dict_iterator(num_epochs=1, output_numpy=True):
  241. res.append(np.array([item["col1"], item["col2"], item["col3"]]))
  242. tst1 = []
  243. tst1_res = [[[[0]], [[-1]], [[2]]], [[[6], [12]], [[-8], [-14]], [[5], [8]]],
  244. [[[27], [36], [45]], [[-30], [-39], [-48]], [[11], [14], [17]]],
  245. [[[72], [84], [96], [108]], [[-76], [-88], [-100], [-112]], [[20], [23], [26], [29]]]]
  246. batch_map_config(10, 1, batch_func, map_func, ["col1", "col2"], tst1)
  247. assert check_res(tst1, tst1_res), "test_var_batch_multi_col_map FAILED"
  248. def test_var_batch_var_resize():
  249. # fake resize image according to its batch number, if it's 5-th batch, resize to (5^2, 5^2) = (25, 25)
  250. def np_psedo_resize(col, batchInfo):
  251. s = (batchInfo.get_batch_num() + 1) ** 2
  252. return ([np.copy(c[0:s, 0:s, :]) for c in col],)
  253. def add_one(batchInfo):
  254. return batchInfo.get_batch_num() + 1
  255. data1 = ds.ImageFolderDataset("../data/dataset/testPK/data/", num_parallel_workers=4, decode=True)
  256. data1 = data1.batch(batch_size=add_one, drop_remainder=True, input_columns=["image"], per_batch_map=np_psedo_resize)
  257. # i-th batch has shape [i, i^2, i^2, 3]
  258. i = 1
  259. for item in data1.create_dict_iterator(num_epochs=1, output_numpy=True):
  260. assert item["image"].shape == (i, i ** 2, i ** 2, 3), "\ntest_var_batch_var_resize FAILED\n"
  261. i += 1
  262. def test_exception():
  263. def gen(num):
  264. for i in range(num):
  265. yield (np.array([i]),)
  266. def bad_batch_size(batchInfo):
  267. raise StopIteration
  268. # return batchInfo.get_batch_num()
  269. def bad_map_func(col, batchInfo):
  270. raise StopIteration
  271. # return (col,)
  272. data1 = ds.GeneratorDataset((lambda: gen(100)), ["num"]).batch(bad_batch_size)
  273. try:
  274. for _ in data1.create_dict_iterator(num_epochs=1):
  275. pass
  276. assert False
  277. except RuntimeError:
  278. pass
  279. data2 = ds.GeneratorDataset((lambda: gen(100)), ["num"]).batch(4, input_columns=["num"], per_batch_map=bad_map_func)
  280. try:
  281. for _ in data2.create_dict_iterator(num_epochs=1):
  282. pass
  283. assert False
  284. except RuntimeError:
  285. pass
  286. def test_multi_col_map():
  287. def gen_2_cols(num):
  288. for i in range(1, 1 + num):
  289. yield (np.array([i]), np.array([i ** 2]))
  290. def split_col(col, batchInfo):
  291. return ([np.copy(arr) for arr in col], [np.copy(-arr) for arr in col])
  292. def merge_col(col1, col2, batchInfo):
  293. merged = []
  294. for k, v in enumerate(col1):
  295. merged.append(np.array(v + col2[k]))
  296. return (merged,)
  297. def swap_col(col1, col2, batchInfo):
  298. return ([np.copy(a) for a in col2], [np.copy(b) for b in col1])
  299. def batch_map_config(num, s, f, in_nms, out_nms, col_order=None):
  300. try:
  301. dst = ds.GeneratorDataset((lambda: gen_2_cols(num)), ["col1", "col2"])
  302. dst = dst.batch(batch_size=s, input_columns=in_nms, output_columns=out_nms, per_batch_map=f,
  303. column_order=col_order)
  304. res = []
  305. for row in dst.create_dict_iterator(num_epochs=1, output_numpy=True):
  306. res.append(row)
  307. return res
  308. except (ValueError, RuntimeError, TypeError) as e:
  309. return str(e)
  310. # split 1 col into 2 cols
  311. res = batch_map_config(2, 2, split_col, ["col2"], ["col_x", "col_y"])[0]
  312. assert np.array_equal(res["col1"], [[1], [2]])
  313. assert np.array_equal(res["col_x"], [[1], [4]]) and np.array_equal(res["col_y"], [[-1], [-4]])
  314. # merge 2 cols into 1 col
  315. res = batch_map_config(4, 4, merge_col, ["col1", "col2"], ["merged"])[0]
  316. assert np.array_equal(res["merged"], [[2], [6], [12], [20]])
  317. # swap once
  318. res = batch_map_config(3, 3, swap_col, ["col1", "col2"], ["col1", "col2"])[0]
  319. assert np.array_equal(res["col1"], [[1], [4], [9]]) and np.array_equal(res["col2"], [[1], [2], [3]])
  320. # swap twice
  321. res = batch_map_config(3, 3, swap_col, ["col1", "col2"], ["col2", "col1"])[0]
  322. assert np.array_equal(res["col2"], [[1], [4], [9]]) and np.array_equal(res["col1"], [[1], [2], [3]])
  323. # test project after map
  324. res = batch_map_config(2, 2, split_col, ["col2"], ["col_x", "col_y"], ["col_x", "col_y", "col1"])[0]
  325. assert list(res.keys()) == ["col_x", "col_y", "col1"]
  326. # test the insertion order is maintained
  327. res = batch_map_config(2, 2, split_col, ["col2"], ["col_x", "col_y"], ["col1", "col_x", "col_y"])[0]
  328. assert list(res.keys()) == ["col1", "col_x", "col_y"]
  329. # test exceptions
  330. assert "output_columns with value 233 is not of type" in batch_map_config(2, 2, split_col, ["col2"], 233)
  331. assert "column_order with value 233 is not of type" in batch_map_config(2, 2, split_col, ["col2"], ["col1"], 233)
  332. assert "columns that are not involved in 'per_batch_map' should not be in output_columns" \
  333. in batch_map_config(2, 2, split_col, ["col2"], ["col1"])
  334. assert "the number of columns returned in 'per_batch_map' function should be 3" \
  335. in batch_map_config(2, 2, split_col, ["col2"], ["col3", "col4", "col5"])
  336. assert "'col-1' of 'input_columns' doesn't exist" \
  337. in batch_map_config(2, 2, split_col, ["col-1"], ["col_x", "col_y"])
  338. def test_exceptions_2():
  339. def gen(num):
  340. for i in range(num):
  341. yield (np.array([i]),)
  342. def simple_copy(col_list, batch_info):
  343. return ([np.copy(arr) for arr in col_list],)
  344. def concat_copy(col_list, batch_info):
  345. # this will duplicate the number of rows returned, which would be wrong!
  346. return ([np.copy(arr) for arr in col_list] * 2,)
  347. def shrink_copy(col_list, batch_info):
  348. # this will duplicate the number of rows returned, which would be wrong!
  349. return ([np.copy(arr) for arr in col_list][0:int(len(col_list) / 2)],)
  350. def test_exceptions_config(gen_num, batch_size, in_cols, per_batch_map):
  351. data1 = ds.GeneratorDataset((lambda: gen(gen_num)), ["num"]).batch(batch_size, input_columns=in_cols,
  352. per_batch_map=per_batch_map)
  353. try:
  354. for _ in data1.create_dict_iterator(num_epochs=1, output_numpy=True):
  355. pass
  356. return "success"
  357. except RuntimeError as e:
  358. return str(e)
  359. # test exception where column name is incorrect
  360. assert "'num1' of 'input_columns' doesn't exist" in test_exceptions_config(4, 2, ["num1"], simple_copy)
  361. assert "expects: 2 rows returned from 'per_batch_map', got: 4" in test_exceptions_config(4, 2, ["num"], concat_copy)
  362. assert "expects: 4 rows returned from 'per_batch_map', got: 2" in test_exceptions_config(4, 4, ["num"], shrink_copy)
  363. if __name__ == '__main__':
  364. logger.info("Running test_var_batch_map.py test_batch_corner_cases() function")
  365. test_batch_corner_cases()
  366. logger.info("Running test_var_batch_map.py test_variable_size_batch() function")
  367. test_variable_size_batch()
  368. logger.info("Running test_var_batch_map.py test_basic_batch_map() function")
  369. test_basic_batch_map()
  370. logger.info("Running test_var_batch_map.py test_batch_multi_col_map() function")
  371. test_batch_multi_col_map()
  372. logger.info("Running test_var_batch_map.py tesgit t_var_batch_multi_col_map() function")
  373. test_var_batch_multi_col_map()
  374. logger.info("Running test_var_batch_map.py test_var_batch_var_resize() function")
  375. test_var_batch_var_resize()
  376. logger.info("Running test_var_batch_map.py test_exception() function")
  377. test_exception()
  378. logger.info("Running test_var_batch_map.py test_multi_col_map() function")
  379. test_multi_col_map()
  380. logger.info("Running test_var_batch_map.py test_exceptions_2() function")
  381. test_exceptions_2()