You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_batch.py 15 kB

5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500
  1. # Copyright 2019 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. import mindspore.dataset as ds
  16. from mindspore import log as logger
  17. from util import save_and_check
  18. # Note: Number of rows in test.data dataset: 12
  19. DATA_DIR = ["../data/dataset/testTFTestAllTypes/test.data"]
  20. GENERATE_GOLDEN = False
  21. def test_batch_01():
  22. """
  23. Test batch: batch_size>1, drop_remainder=True, no remainder exists
  24. """
  25. logger.info("test_batch_01")
  26. # define parameters
  27. batch_size = 2
  28. drop_remainder = True
  29. parameters = {"params": {'batch_size': batch_size,
  30. 'drop_remainder': drop_remainder}}
  31. # apply dataset operations
  32. data1 = ds.TFRecordDataset(DATA_DIR, shuffle=ds.Shuffle.FILES)
  33. data1 = data1.batch(batch_size, drop_remainder)
  34. assert sum([1 for _ in data1]) == 6
  35. filename = "batch_01_result.npz"
  36. save_and_check(data1, parameters, filename, generate_golden=GENERATE_GOLDEN)
  37. def test_batch_02():
  38. """
  39. Test batch: batch_size>1, drop_remainder=True, remainder exists
  40. """
  41. logger.info("test_batch_02")
  42. # define parameters
  43. batch_size = 5
  44. drop_remainder = True
  45. parameters = {"params": {'batch_size': batch_size,
  46. 'drop_remainder': drop_remainder}}
  47. # apply dataset operations
  48. data1 = ds.TFRecordDataset(DATA_DIR, shuffle=ds.Shuffle.FILES)
  49. data1 = data1.batch(batch_size, drop_remainder=drop_remainder)
  50. assert sum([1 for _ in data1]) == 2
  51. filename = "batch_02_result.npz"
  52. save_and_check(data1, parameters, filename, generate_golden=GENERATE_GOLDEN)
  53. def test_batch_03():
  54. """
  55. Test batch: batch_size>1, drop_remainder=False, no remainder exists
  56. """
  57. logger.info("test_batch_03")
  58. # define parameters
  59. batch_size = 3
  60. drop_remainder = False
  61. parameters = {"params": {'batch_size': batch_size,
  62. 'drop_remainder': drop_remainder}}
  63. # apply dataset operations
  64. data1 = ds.TFRecordDataset(DATA_DIR, shuffle=ds.Shuffle.FILES)
  65. data1 = data1.batch(batch_size=batch_size, drop_remainder=drop_remainder)
  66. assert sum([1 for _ in data1]) == 4
  67. filename = "batch_03_result.npz"
  68. save_and_check(data1, parameters, filename, generate_golden=GENERATE_GOLDEN)
  69. def test_batch_04():
  70. """
  71. Test batch: batch_size>1, drop_remainder=False, remainder exists
  72. """
  73. logger.info("test_batch_04")
  74. # define parameters
  75. batch_size = 7
  76. drop_remainder = False
  77. parameters = {"params": {'batch_size': batch_size,
  78. 'drop_remainder': drop_remainder}}
  79. # apply dataset operations
  80. data1 = ds.TFRecordDataset(DATA_DIR, shuffle=ds.Shuffle.FILES)
  81. data1 = data1.batch(batch_size, drop_remainder)
  82. assert sum([1 for _ in data1]) == 2
  83. filename = "batch_04_result.npz"
  84. save_and_check(data1, parameters, filename, generate_golden=GENERATE_GOLDEN)
  85. def test_batch_05():
  86. """
  87. Test batch: batch_size=1 (minimum valid size), drop_remainder default
  88. """
  89. logger.info("test_batch_05")
  90. # define parameters
  91. batch_size = 1
  92. parameters = {"params": {'batch_size': batch_size}}
  93. # apply dataset operations
  94. data1 = ds.TFRecordDataset(DATA_DIR, shuffle=ds.Shuffle.FILES)
  95. data1 = data1.batch(batch_size)
  96. assert sum([1 for _ in data1]) == 12
  97. filename = "batch_05_result.npz"
  98. save_and_check(data1, parameters, filename, generate_golden=GENERATE_GOLDEN)
  99. def test_batch_06():
  100. """
  101. Test batch: batch_size = number-of-rows-in-dataset, drop_remainder=True, reorder params
  102. """
  103. logger.info("test_batch_06")
  104. # define parameters
  105. batch_size = 12
  106. drop_remainder = False
  107. parameters = {"params": {'batch_size': batch_size,
  108. 'drop_remainder': drop_remainder}}
  109. # apply dataset operations
  110. data1 = ds.TFRecordDataset(DATA_DIR, shuffle=ds.Shuffle.FILES)
  111. data1 = data1.batch(drop_remainder=drop_remainder, batch_size=batch_size)
  112. assert sum([1 for _ in data1]) == 1
  113. filename = "batch_06_result.npz"
  114. save_and_check(data1, parameters, filename, generate_golden=GENERATE_GOLDEN)
  115. def test_batch_07():
  116. """
  117. Test batch: num_parallel_workers>1, drop_remainder=False, reorder params
  118. """
  119. logger.info("test_batch_07")
  120. # define parameters
  121. batch_size = 4
  122. drop_remainder = False
  123. num_parallel_workers = 2
  124. parameters = {"params": {'batch_size': batch_size,
  125. 'drop_remainder': drop_remainder,
  126. 'num_parallel_workers': num_parallel_workers}}
  127. # apply dataset operations
  128. data1 = ds.TFRecordDataset(DATA_DIR, shuffle=ds.Shuffle.FILES)
  129. data1 = data1.batch(num_parallel_workers=num_parallel_workers, drop_remainder=drop_remainder,
  130. batch_size=batch_size)
  131. assert sum([1 for _ in data1]) == 3
  132. filename = "batch_07_result.npz"
  133. save_and_check(data1, parameters, filename, generate_golden=GENERATE_GOLDEN)
  134. def test_batch_08():
  135. """
  136. Test batch: num_parallel_workers=1, drop_remainder default
  137. """
  138. logger.info("test_batch_08")
  139. # define parameters
  140. batch_size = 6
  141. num_parallel_workers = 1
  142. parameters = {"params": {'batch_size': batch_size,
  143. 'num_parallel_workers': num_parallel_workers}}
  144. # apply dataset operations
  145. data1 = ds.TFRecordDataset(DATA_DIR, shuffle=ds.Shuffle.FILES)
  146. data1 = data1.batch(batch_size, num_parallel_workers=num_parallel_workers)
  147. assert sum([1 for _ in data1]) == 2
  148. filename = "batch_08_result.npz"
  149. save_and_check(data1, parameters, filename, generate_golden=GENERATE_GOLDEN)
  150. def test_batch_09():
  151. """
  152. Test batch: batch_size > number-of-rows-in-dataset, drop_remainder=False
  153. """
  154. logger.info("test_batch_09")
  155. # define parameters
  156. batch_size = 13
  157. drop_remainder = False
  158. parameters = {"params": {'batch_size': batch_size,
  159. 'drop_remainder': drop_remainder}}
  160. # apply dataset operations
  161. data1 = ds.TFRecordDataset(DATA_DIR, shuffle=ds.Shuffle.FILES)
  162. data1 = data1.batch(batch_size, drop_remainder=drop_remainder)
  163. assert sum([1 for _ in data1]) == 1
  164. filename = "batch_09_result.npz"
  165. save_and_check(data1, parameters, filename, generate_golden=GENERATE_GOLDEN)
  166. def test_batch_10():
  167. """
  168. Test batch: batch_size > number-of-rows-in-dataset, drop_remainder=True
  169. """
  170. logger.info("test_batch_10")
  171. # define parameters
  172. batch_size = 99
  173. drop_remainder = True
  174. parameters = {"params": {'batch_size': batch_size,
  175. 'drop_remainder': drop_remainder}}
  176. # apply dataset operations
  177. data1 = ds.TFRecordDataset(DATA_DIR, shuffle=ds.Shuffle.FILES)
  178. data1 = data1.batch(batch_size, drop_remainder=drop_remainder)
  179. assert sum([1 for _ in data1]) == 0
  180. filename = "batch_10_result.npz"
  181. save_and_check(data1, parameters, filename, generate_golden=GENERATE_GOLDEN)
  182. def test_batch_11():
  183. """
  184. Test batch: batch_size=1 and dataset-size=1
  185. """
  186. logger.info("test_batch_11")
  187. # define parameters
  188. batch_size = 1
  189. parameters = {"params": {'batch_size': batch_size}}
  190. # apply dataset operations
  191. # Use schema file with 1 row
  192. schema_file = "../data/dataset/testTFTestAllTypes/datasetSchema1Row.json"
  193. data1 = ds.TFRecordDataset(DATA_DIR, schema_file)
  194. data1 = data1.batch(batch_size)
  195. assert sum([1 for _ in data1]) == 1
  196. filename = "batch_11_result.npz"
  197. save_and_check(data1, parameters, filename, generate_golden=GENERATE_GOLDEN)
  198. def test_batch_12():
  199. """
  200. Test batch: batch_size boolean value True, treated as valid value 1
  201. """
  202. logger.info("test_batch_12")
  203. # define parameters
  204. batch_size = True
  205. parameters = {"params": {'batch_size': batch_size}}
  206. # apply dataset operations
  207. data1 = ds.TFRecordDataset(DATA_DIR, shuffle=ds.Shuffle.FILES)
  208. data1 = data1.batch(batch_size=batch_size)
  209. assert sum([1 for _ in data1]) == 12
  210. filename = "batch_12_result.npz"
  211. save_and_check(data1, parameters, filename, generate_golden=GENERATE_GOLDEN)
  212. def test_batch_exception_01():
  213. """
  214. Test batch exception: num_parallel_workers=0
  215. """
  216. logger.info("test_batch_exception_01")
  217. # apply dataset operations
  218. data1 = ds.TFRecordDataset(DATA_DIR, shuffle=ds.Shuffle.FILES)
  219. try:
  220. data1 = data1.batch(batch_size=2, drop_remainder=True, num_parallel_workers=0)
  221. sum([1 for _ in data1])
  222. except Exception as e:
  223. logger.info("Got an exception in DE: {}".format(str(e)))
  224. assert "num_parallel_workers" in str(e)
  225. def test_batch_exception_02():
  226. """
  227. Test batch exception: num_parallel_workers<0
  228. """
  229. logger.info("test_batch_exception_02")
  230. # apply dataset operations
  231. data1 = ds.TFRecordDataset(DATA_DIR, shuffle=ds.Shuffle.FILES)
  232. try:
  233. data1 = data1.batch(3, drop_remainder=True, num_parallel_workers=-1)
  234. sum([1 for _ in data1])
  235. except Exception as e:
  236. logger.info("Got an exception in DE: {}".format(str(e)))
  237. assert "num_parallel_workers" in str(e)
  238. def test_batch_exception_03():
  239. """
  240. Test batch exception: batch_size=0
  241. """
  242. logger.info("test_batch_exception_03")
  243. # apply dataset operations
  244. data1 = ds.TFRecordDataset(DATA_DIR, shuffle=ds.Shuffle.FILES)
  245. try:
  246. data1 = data1.batch(batch_size=0)
  247. sum([1 for _ in data1])
  248. except Exception as e:
  249. logger.info("Got an exception in DE: {}".format(str(e)))
  250. assert "batch_size" in str(e)
  251. def test_batch_exception_04():
  252. """
  253. Test batch exception: batch_size<0
  254. """
  255. logger.info("test_batch_exception_04")
  256. # apply dataset operations
  257. data1 = ds.TFRecordDataset(DATA_DIR, shuffle=ds.Shuffle.FILES)
  258. try:
  259. data1 = data1.batch(batch_size=-1)
  260. sum([1 for _ in data1])
  261. except Exception as e:
  262. logger.info("Got an exception in DE: {}".format(str(e)))
  263. assert "batch_size" in str(e)
  264. def test_batch_exception_05():
  265. """
  266. Test batch exception: batch_size boolean value False, treated as invalid value 0
  267. """
  268. logger.info("test_batch_exception_05")
  269. # apply dataset operations
  270. data1 = ds.TFRecordDataset(DATA_DIR, shuffle=ds.Shuffle.FILES)
  271. try:
  272. data1 = data1.batch(batch_size=False)
  273. sum([1 for _ in data1])
  274. except Exception as e:
  275. logger.info("Got an exception in DE: {}".format(str(e)))
  276. assert "batch_size" in str(e)
  277. def test_batch_exception_07():
  278. """
  279. Test batch exception: drop_remainder wrong type
  280. """
  281. logger.info("test_batch_exception_07")
  282. # apply dataset operations
  283. data1 = ds.TFRecordDataset(DATA_DIR, shuffle=ds.Shuffle.FILES)
  284. try:
  285. data1 = data1.batch(3, drop_remainder=0)
  286. sum([1 for _ in data1])
  287. except Exception as e:
  288. logger.info("Got an exception in DE: {}".format(str(e)))
  289. assert "drop_remainder" in str(e)
  290. def test_batch_exception_08():
  291. """
  292. Test batch exception: num_parallel_workers wrong type
  293. """
  294. logger.info("test_batch_exception_08")
  295. # apply dataset operations
  296. data1 = ds.TFRecordDataset(DATA_DIR, shuffle=ds.Shuffle.FILES)
  297. try:
  298. data1 = data1.batch(3, drop_remainder=True, num_parallel_workers=False)
  299. sum([1 for _ in data1])
  300. except Exception as e:
  301. logger.info("Got an exception in DE: {}".format(str(e)))
  302. assert "num_parallel_workers" in str(e)
  303. def test_batch_exception_09():
  304. """
  305. Test batch exception: Missing mandatory batch_size
  306. """
  307. logger.info("test_batch_exception_09")
  308. # apply dataset operations
  309. data1 = ds.TFRecordDataset(DATA_DIR, shuffle=ds.Shuffle.FILES)
  310. try:
  311. data1 = data1.batch(drop_remainder=True, num_parallel_workers=4)
  312. sum([1 for _ in data1])
  313. except Exception as e:
  314. logger.info("Got an exception in DE: {}".format(str(e)))
  315. assert "batch_size" in str(e)
  316. def test_batch_exception_10():
  317. """
  318. Test batch exception: num_parallel_workers>>1
  319. """
  320. logger.info("test_batch_exception_10")
  321. # apply dataset operations
  322. data1 = ds.TFRecordDataset(DATA_DIR, shuffle=ds.Shuffle.FILES)
  323. try:
  324. data1 = data1.batch(batch_size=4, num_parallel_workers=8192)
  325. sum([1 for _ in data1])
  326. except Exception as e:
  327. logger.info("Got an exception in DE: {}".format(str(e)))
  328. assert "num_parallel_workers" in str(e)
  329. def test_batch_exception_11():
  330. """
  331. Test batch exception: wrong input order, num_parallel_workers wrongly used as drop_remainder
  332. """
  333. logger.info("test_batch_exception_11")
  334. # define parameters
  335. batch_size = 6
  336. num_parallel_workers = 1
  337. # apply dataset operations
  338. data1 = ds.TFRecordDataset(DATA_DIR)
  339. try:
  340. data1 = data1.batch(batch_size, num_parallel_workers)
  341. sum([1 for _ in data1])
  342. except Exception as e:
  343. logger.info("Got an exception in DE: {}".format(str(e)))
  344. assert "drop_remainder" in str(e)
  345. def test_batch_exception_12():
  346. """
  347. Test batch exception: wrong input order, drop_remainder wrongly used as batch_size
  348. """
  349. logger.info("test_batch_exception_12")
  350. # define parameters
  351. batch_size = 1
  352. drop_remainder = True
  353. # apply dataset operations
  354. data1 = ds.TFRecordDataset(DATA_DIR)
  355. try:
  356. data1 = data1.batch(drop_remainder, batch_size)
  357. sum([1 for _ in data1])
  358. except Exception as e:
  359. logger.info("Got an exception in DE: {}".format(str(e)))
  360. assert "drop_remainder" in str(e)
  361. def test_batch_exception_13():
  362. """
  363. Test batch exception: invalid input parameter
  364. """
  365. logger.info("test_batch_exception_13")
  366. # define parameters
  367. batch_size = 4
  368. # apply dataset operations
  369. data1 = ds.TFRecordDataset(DATA_DIR)
  370. try:
  371. data1 = data1.batch(batch_size, shard_id=1)
  372. sum([1 for _ in data1])
  373. except Exception as e:
  374. logger.info("Got an exception in DE: {}".format(str(e)))
  375. assert "shard_id" in str(e)
  376. if __name__ == '__main__':
  377. test_batch_01()
  378. test_batch_02()
  379. test_batch_03()
  380. test_batch_04()
  381. test_batch_05()
  382. test_batch_06()
  383. test_batch_07()
  384. test_batch_08()
  385. test_batch_09()
  386. test_batch_10()
  387. test_batch_11()
  388. test_batch_12()
  389. test_batch_exception_01()
  390. test_batch_exception_02()
  391. test_batch_exception_03()
  392. test_batch_exception_04()
  393. test_batch_exception_05()
  394. test_batch_exception_07()
  395. test_batch_exception_08()
  396. test_batch_exception_09()
  397. test_batch_exception_10()
  398. test_batch_exception_11()
  399. test_batch_exception_12()
  400. test_batch_exception_13()
  401. logger.info('\n')