You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_batch.py 15 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489
  1. # Copyright 2019 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. from util import save_and_check
  16. import mindspore.dataset as ds
  17. from mindspore import log as logger
  18. # Note: Number of rows in test.data dataset: 12
  19. DATA_DIR = ["../data/dataset/testTFTestAllTypes/test.data"]
  20. GENERATE_GOLDEN = False
  21. def test_batch_01():
  22. """
  23. Test batch: batch_size>1, drop_remainder=True, no remainder exists
  24. """
  25. logger.info("test_batch_01")
  26. # define parameters
  27. batch_size = 2
  28. drop_remainder = True
  29. parameters = {"params": {'batch_size': batch_size,
  30. 'drop_remainder': drop_remainder}}
  31. # apply dataset operations
  32. data1 = ds.TFRecordDataset(DATA_DIR, shuffle=ds.Shuffle.FILES)
  33. data1 = data1.batch(batch_size, drop_remainder)
  34. filename = "batch_01_result.npz"
  35. save_and_check(data1, parameters, filename, generate_golden=GENERATE_GOLDEN)
  36. def test_batch_02():
  37. """
  38. Test batch: batch_size>1, drop_remainder=True, remainder exists
  39. """
  40. logger.info("test_batch_02")
  41. # define parameters
  42. batch_size = 5
  43. drop_remainder = True
  44. parameters = {"params": {'batch_size': batch_size,
  45. 'drop_remainder': drop_remainder}}
  46. # apply dataset operations
  47. data1 = ds.TFRecordDataset(DATA_DIR, shuffle=ds.Shuffle.FILES)
  48. data1 = data1.batch(batch_size, drop_remainder=drop_remainder)
  49. filename = "batch_02_result.npz"
  50. save_and_check(data1, parameters, filename, generate_golden=GENERATE_GOLDEN)
  51. def test_batch_03():
  52. """
  53. Test batch: batch_size>1, drop_remainder=False, no remainder exists
  54. """
  55. logger.info("test_batch_03")
  56. # define parameters
  57. batch_size = 3
  58. drop_remainder = False
  59. parameters = {"params": {'batch_size': batch_size,
  60. 'drop_remainder': drop_remainder}}
  61. # apply dataset operations
  62. data1 = ds.TFRecordDataset(DATA_DIR, shuffle=ds.Shuffle.FILES)
  63. data1 = data1.batch(batch_size=batch_size, drop_remainder=drop_remainder)
  64. filename = "batch_03_result.npz"
  65. save_and_check(data1, parameters, filename, generate_golden=GENERATE_GOLDEN)
  66. def test_batch_04():
  67. """
  68. Test batch: batch_size>1, drop_remainder=False, remainder exists
  69. """
  70. logger.info("test_batch_04")
  71. # define parameters
  72. batch_size = 7
  73. drop_remainder = False
  74. parameters = {"params": {'batch_size': batch_size,
  75. 'drop_remainder': drop_remainder}}
  76. # apply dataset operations
  77. data1 = ds.TFRecordDataset(DATA_DIR, shuffle=ds.Shuffle.FILES)
  78. data1 = data1.batch(batch_size, drop_remainder)
  79. filename = "batch_04_result.npz"
  80. save_and_check(data1, parameters, filename, generate_golden=GENERATE_GOLDEN)
  81. def test_batch_05():
  82. """
  83. Test batch: batch_size=1 (minimum valid size), drop_remainder default
  84. """
  85. logger.info("test_batch_05")
  86. # define parameters
  87. batch_size = 1
  88. parameters = {"params": {'batch_size': batch_size}}
  89. # apply dataset operations
  90. data1 = ds.TFRecordDataset(DATA_DIR, shuffle=ds.Shuffle.FILES)
  91. data1 = data1.batch(batch_size)
  92. filename = "batch_05_result.npz"
  93. save_and_check(data1, parameters, filename, generate_golden=GENERATE_GOLDEN)
  94. def test_batch_06():
  95. """
  96. Test batch: batch_size = number-of-rows-in-dataset, drop_remainder=True, reorder params
  97. """
  98. logger.info("test_batch_06")
  99. # define parameters
  100. batch_size = 12
  101. drop_remainder = False
  102. parameters = {"params": {'batch_size': batch_size,
  103. 'drop_remainder': drop_remainder}}
  104. # apply dataset operations
  105. data1 = ds.TFRecordDataset(DATA_DIR, shuffle=ds.Shuffle.FILES)
  106. data1 = data1.batch(drop_remainder=drop_remainder, batch_size=batch_size)
  107. filename = "batch_06_result.npz"
  108. save_and_check(data1, parameters, filename, generate_golden=GENERATE_GOLDEN)
  109. def test_batch_07():
  110. """
  111. Test batch: num_parallel_workers>1, drop_remainder=False, reorder params
  112. """
  113. logger.info("test_batch_07")
  114. # define parameters
  115. batch_size = 4
  116. drop_remainder = False
  117. num_parallel_workers = 2
  118. parameters = {"params": {'batch_size': batch_size,
  119. 'drop_remainder': drop_remainder,
  120. 'num_parallel_workers': num_parallel_workers}}
  121. # apply dataset operations
  122. data1 = ds.TFRecordDataset(DATA_DIR, shuffle=ds.Shuffle.FILES)
  123. data1 = data1.batch(num_parallel_workers=num_parallel_workers, drop_remainder=drop_remainder,
  124. batch_size=batch_size)
  125. filename = "batch_07_result.npz"
  126. save_and_check(data1, parameters, filename, generate_golden=GENERATE_GOLDEN)
  127. def test_batch_08():
  128. """
  129. Test batch: num_parallel_workers=1, drop_remainder default
  130. """
  131. logger.info("test_batch_08")
  132. # define parameters
  133. batch_size = 6
  134. num_parallel_workers = 1
  135. parameters = {"params": {'batch_size': batch_size,
  136. 'num_parallel_workers': num_parallel_workers}}
  137. # apply dataset operations
  138. data1 = ds.TFRecordDataset(DATA_DIR, shuffle=ds.Shuffle.FILES)
  139. data1 = data1.batch(batch_size, num_parallel_workers=num_parallel_workers)
  140. filename = "batch_08_result.npz"
  141. save_and_check(data1, parameters, filename, generate_golden=GENERATE_GOLDEN)
  142. def test_batch_09():
  143. """
  144. Test batch: batch_size > number-of-rows-in-dataset, drop_remainder=False
  145. """
  146. logger.info("test_batch_09")
  147. # define parameters
  148. batch_size = 13
  149. drop_remainder = False
  150. parameters = {"params": {'batch_size': batch_size,
  151. 'drop_remainder': drop_remainder}}
  152. # apply dataset operations
  153. data1 = ds.TFRecordDataset(DATA_DIR, shuffle=ds.Shuffle.FILES)
  154. data1 = data1.batch(batch_size, drop_remainder=drop_remainder)
  155. filename = "batch_09_result.npz"
  156. save_and_check(data1, parameters, filename, generate_golden=GENERATE_GOLDEN)
  157. def test_batch_10():
  158. """
  159. Test batch: batch_size > number-of-rows-in-dataset, drop_remainder=True
  160. """
  161. logger.info("test_batch_10")
  162. # define parameters
  163. batch_size = 99
  164. drop_remainder = True
  165. parameters = {"params": {'batch_size': batch_size,
  166. 'drop_remainder': drop_remainder}}
  167. # apply dataset operations
  168. data1 = ds.TFRecordDataset(DATA_DIR, shuffle=ds.Shuffle.FILES)
  169. data1 = data1.batch(batch_size, drop_remainder=drop_remainder)
  170. filename = "batch_10_result.npz"
  171. save_and_check(data1, parameters, filename, generate_golden=GENERATE_GOLDEN)
  172. def test_batch_11():
  173. """
  174. Test batch: batch_size=1 and dataset-size=1
  175. """
  176. logger.info("test_batch_11")
  177. # define parameters
  178. batch_size = 1
  179. parameters = {"params": {'batch_size': batch_size}}
  180. # apply dataset operations
  181. # Use schema file with 1 row
  182. schema_file = "../data/dataset/testTFTestAllTypes/datasetSchema1Row.json"
  183. data1 = ds.TFRecordDataset(DATA_DIR, schema_file)
  184. data1 = data1.batch(batch_size)
  185. filename = "batch_11_result.npz"
  186. save_and_check(data1, parameters, filename, generate_golden=GENERATE_GOLDEN)
  187. def test_batch_exception_01():
  188. """
  189. Test batch exception: num_parallel_workers=0
  190. """
  191. logger.info("test_batch_exception_01")
  192. # apply dataset operations
  193. data1 = ds.TFRecordDataset(DATA_DIR, shuffle=ds.Shuffle.FILES)
  194. try:
  195. data1 = data1.batch(batch_size=2, drop_remainder=True, num_parallel_workers=0)
  196. sum([1 for _ in data1])
  197. except BaseException as e:
  198. logger.info("Got an exception in DE: {}".format(str(e)))
  199. assert "num_parallel_workers" in str(e)
  200. def test_batch_exception_02():
  201. """
  202. Test batch exception: num_parallel_workers<0
  203. """
  204. logger.info("test_batch_exception_02")
  205. # apply dataset operations
  206. data1 = ds.TFRecordDataset(DATA_DIR, shuffle=ds.Shuffle.FILES)
  207. try:
  208. data1 = data1.batch(3, drop_remainder=True, num_parallel_workers=-1)
  209. sum([1 for _ in data1])
  210. except BaseException as e:
  211. logger.info("Got an exception in DE: {}".format(str(e)))
  212. assert "num_parallel_workers" in str(e)
  213. def test_batch_exception_03():
  214. """
  215. Test batch exception: batch_size=0
  216. """
  217. logger.info("test_batch_exception_03")
  218. # apply dataset operations
  219. data1 = ds.TFRecordDataset(DATA_DIR, shuffle=ds.Shuffle.FILES)
  220. try:
  221. data1 = data1.batch(batch_size=0)
  222. sum([1 for _ in data1])
  223. except BaseException as e:
  224. logger.info("Got an exception in DE: {}".format(str(e)))
  225. assert "batch_size" in str(e)
  226. def test_batch_exception_04():
  227. """
  228. Test batch exception: batch_size<0
  229. """
  230. logger.info("test_batch_exception_04")
  231. # apply dataset operations
  232. data1 = ds.TFRecordDataset(DATA_DIR, shuffle=ds.Shuffle.FILES)
  233. try:
  234. data1 = data1.batch(batch_size=-1)
  235. sum([1 for _ in data1])
  236. except BaseException as e:
  237. logger.info("Got an exception in DE: {}".format(str(e)))
  238. assert "batch_size" in str(e)
  239. def test_batch_exception_05():
  240. """
  241. Test batch exception: batch_size wrong type, boolean value False
  242. """
  243. logger.info("test_batch_exception_05")
  244. # apply dataset operations
  245. data1 = ds.TFRecordDataset(DATA_DIR, shuffle=ds.Shuffle.FILES)
  246. try:
  247. data1 = data1.batch(batch_size=False)
  248. sum([1 for _ in data1])
  249. except BaseException as e:
  250. logger.info("Got an exception in DE: {}".format(str(e)))
  251. assert "batch_size" in str(e)
  252. def skip_test_batch_exception_06():
  253. """
  254. Test batch exception: batch_size wrong type, boolean value True
  255. """
  256. logger.info("test_batch_exception_06")
  257. # apply dataset operations
  258. data1 = ds.TFRecordDataset(DATA_DIR, shuffle=ds.Shuffle.FILES)
  259. try:
  260. data1 = data1.batch(batch_size=True)
  261. sum([1 for _ in data1])
  262. except BaseException as e:
  263. logger.info("Got an exception in DE: {}".format(str(e)))
  264. assert "batch_size" in str(e)
  265. def test_batch_exception_07():
  266. """
  267. Test batch exception: drop_remainder wrong type
  268. """
  269. logger.info("test_batch_exception_07")
  270. # apply dataset operations
  271. data1 = ds.TFRecordDataset(DATA_DIR, shuffle=ds.Shuffle.FILES)
  272. try:
  273. data1 = data1.batch(3, drop_remainder=0)
  274. sum([1 for _ in data1])
  275. except BaseException as e:
  276. logger.info("Got an exception in DE: {}".format(str(e)))
  277. assert "drop_remainder" in str(e)
  278. def test_batch_exception_08():
  279. """
  280. Test batch exception: num_parallel_workers wrong type
  281. """
  282. logger.info("test_batch_exception_08")
  283. # apply dataset operations
  284. data1 = ds.TFRecordDataset(DATA_DIR, shuffle=ds.Shuffle.FILES)
  285. try:
  286. data1 = data1.batch(3, drop_remainder=True, num_parallel_workers=False)
  287. sum([1 for _ in data1])
  288. except BaseException as e:
  289. logger.info("Got an exception in DE: {}".format(str(e)))
  290. assert "num_parallel_workers" in str(e)
  291. def test_batch_exception_09():
  292. """
  293. Test batch exception: Missing mandatory batch_size
  294. """
  295. logger.info("test_batch_exception_09")
  296. # apply dataset operations
  297. data1 = ds.TFRecordDataset(DATA_DIR, shuffle=ds.Shuffle.FILES)
  298. try:
  299. data1 = data1.batch(drop_remainder=True, num_parallel_workers=4)
  300. sum([1 for _ in data1])
  301. except BaseException as e:
  302. logger.info("Got an exception in DE: {}".format(str(e)))
  303. assert "batch_size" in str(e)
  304. def test_batch_exception_10():
  305. """
  306. Test batch exception: num_parallel_workers>>1
  307. """
  308. logger.info("test_batch_exception_10")
  309. # apply dataset operations
  310. data1 = ds.TFRecordDataset(DATA_DIR, shuffle=ds.Shuffle.FILES)
  311. try:
  312. data1 = data1.batch(batch_size=4, num_parallel_workers=8192)
  313. sum([1 for _ in data1])
  314. except BaseException as e:
  315. logger.info("Got an exception in DE: {}".format(str(e)))
  316. assert "num_parallel_workers" in str(e)
  317. def test_batch_exception_11():
  318. """
  319. Test batch exception: wrong input order, num_parallel_workers wrongly used as drop_remainder
  320. """
  321. logger.info("test_batch_exception_11")
  322. # define parameters
  323. batch_size = 6
  324. num_parallel_workers = 1
  325. # apply dataset operations
  326. data1 = ds.TFRecordDataset(DATA_DIR)
  327. try:
  328. data1 = data1.batch(batch_size, num_parallel_workers)
  329. sum([1 for _ in data1])
  330. except BaseException as e:
  331. logger.info("Got an exception in DE: {}".format(str(e)))
  332. assert "drop_remainder" in str(e)
  333. def test_batch_exception_12():
  334. """
  335. Test batch exception: wrong input order, drop_remainder wrongly used as batch_size
  336. """
  337. logger.info("test_batch_exception_12")
  338. # define parameters
  339. batch_size = 1
  340. drop_remainder = True
  341. # apply dataset operations
  342. data1 = ds.TFRecordDataset(DATA_DIR)
  343. try:
  344. data1 = data1.batch(drop_remainder, batch_size=batch_size)
  345. sum([1 for _ in data1])
  346. except BaseException as e:
  347. logger.info("Got an exception in DE: {}".format(str(e)))
  348. assert "batch_size" in str(e)
  349. def test_batch_exception_13():
  350. """
  351. Test batch exception: invalid input parameter
  352. """
  353. logger.info("test_batch_exception_13")
  354. # define parameters
  355. batch_size = 4
  356. # apply dataset operations
  357. data1 = ds.TFRecordDataset(DATA_DIR)
  358. try:
  359. data1 = data1.batch(batch_size, shard_id=1)
  360. sum([1 for _ in data1])
  361. except BaseException as e:
  362. logger.info("Got an exception in DE: {}".format(str(e)))
  363. assert "shard_id" in str(e)
  364. if __name__ == '__main__':
  365. test_batch_01()
  366. test_batch_02()
  367. test_batch_03()
  368. test_batch_04()
  369. test_batch_05()
  370. test_batch_06()
  371. test_batch_07()
  372. test_batch_08()
  373. test_batch_09()
  374. test_batch_10()
  375. test_batch_11()
  376. test_batch_exception_01()
  377. test_batch_exception_02()
  378. test_batch_exception_03()
  379. test_batch_exception_04()
  380. test_batch_exception_05()
  381. skip_test_batch_exception_06()
  382. test_batch_exception_07()
  383. test_batch_exception_08()
  384. test_batch_exception_09()
  385. test_batch_exception_10()
  386. test_batch_exception_11()
  387. test_batch_exception_12()
  388. test_batch_exception_13()
  389. logger.info('\n')