You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_cache_map.py 54 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509151015111512151315141515151615171518151915201521152215231524152515261527152815291530153115321533153415351536153715381539154015411542154315441545154615471548154915501551155215531554155515561557155815591560156115621563156415651566156715681569157015711572157315741575157615771578157915801581158215831584158515861587158815891590159115921593159415951596159715981599160016011602160316041605160616071608160916101611161216131614161516161617161816191620162116221623162416251626162716281629163016311632163316341635163616371638163916401641164216431644164516461647164816491650165116521653165416551656165716581659166016611662166316641665166616671668166916701671167216731674167516761677167816791680168116821683168416851686168716881689169016911692169316941695169616971698169917001701170217031704170517061707170817091710171117121713
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. """
  16. Testing cache operator with mappable datasets
  17. """
  18. import os
  19. import pytest
  20. import numpy as np
  21. import mindspore.dataset as ds
  22. import mindspore.dataset.vision.c_transforms as c_vision
  23. from mindspore import log as logger
  24. from util import save_and_check_md5
  25. DATA_DIR = "../data/dataset/testImageNetData/train/"
  26. COCO_DATA_DIR = "../data/dataset/testCOCO/train/"
  27. COCO_ANNOTATION_FILE = "../data/dataset/testCOCO/annotations/train.json"
  28. NO_IMAGE_DIR = "../data/dataset/testRandomData/"
  29. MNIST_DATA_DIR = "../data/dataset/testMnistData/"
  30. CELEBA_DATA_DIR = "../data/dataset/testCelebAData/"
  31. VOC_DATA_DIR = "../data/dataset/testVOC2012/"
  32. MANIFEST_DATA_FILE = "../data/dataset/testManifestData/test.manifest"
  33. CIFAR10_DATA_DIR = "../data/dataset/testCifar10Data/"
  34. CIFAR100_DATA_DIR = "../data/dataset/testCifar100Data/"
  35. MIND_RECORD_DATA_DIR = "../data/mindrecord/testTwoImageData/twobytes.mindrecord"
  36. GENERATE_GOLDEN = False
  37. @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
  38. def test_cache_map_basic1():
  39. """
  40. Test mappable leaf with cache op right over the leaf
  41. Repeat
  42. |
  43. Map(decode)
  44. |
  45. Cache
  46. |
  47. ImageFolder
  48. """
  49. logger.info("Test cache map basic 1")
  50. if "SESSION_ID" in os.environ:
  51. session_id = int(os.environ['SESSION_ID'])
  52. else:
  53. session_id = 1
  54. some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True)
  55. # This DATA_DIR only has 2 images in it
  56. ds1 = ds.ImageFolderDataset(dataset_dir=DATA_DIR, cache=some_cache)
  57. decode_op = c_vision.Decode()
  58. ds1 = ds1.map(operations=decode_op, input_columns=["image"])
  59. ds1 = ds1.repeat(4)
  60. filename = "cache_map_01_result.npz"
  61. save_and_check_md5(ds1, filename, generate_golden=GENERATE_GOLDEN)
  62. logger.info("test_cache_map_basic1 Ended.\n")
  63. @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
  64. def test_cache_map_basic2():
  65. """
  66. Test mappable leaf with the cache op later in the tree above the map(decode)
  67. Repeat
  68. |
  69. Cache
  70. |
  71. Map(decode)
  72. |
  73. ImageFolder
  74. """
  75. logger.info("Test cache map basic 2")
  76. if "SESSION_ID" in os.environ:
  77. session_id = int(os.environ['SESSION_ID'])
  78. else:
  79. raise RuntimeError("Testcase requires SESSION_ID environment variable")
  80. some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True)
  81. # This DATA_DIR only has 2 images in it
  82. ds1 = ds.ImageFolderDataset(dataset_dir=DATA_DIR)
  83. decode_op = c_vision.Decode()
  84. ds1 = ds1.map(operations=decode_op, input_columns=["image"], cache=some_cache)
  85. ds1 = ds1.repeat(4)
  86. filename = "cache_map_02_result.npz"
  87. save_and_check_md5(ds1, filename, generate_golden=GENERATE_GOLDEN)
  88. logger.info("test_cache_map_basic2 Ended.\n")
  89. @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
  90. def test_cache_map_basic3():
  91. """
  92. Test a repeat under mappable cache
  93. Cache
  94. |
  95. Map(decode)
  96. |
  97. Repeat
  98. |
  99. ImageFolder
  100. """
  101. logger.info("Test cache basic 3")
  102. if "SESSION_ID" in os.environ:
  103. session_id = int(os.environ['SESSION_ID'])
  104. else:
  105. raise RuntimeError("Testcase requires SESSION_ID environment variable")
  106. some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True)
  107. # This DATA_DIR only has 2 images in it
  108. ds1 = ds.ImageFolderDataset(dataset_dir=DATA_DIR)
  109. decode_op = c_vision.Decode()
  110. ds1 = ds1.repeat(4)
  111. ds1 = ds1.map(operations=decode_op, input_columns=["image"], cache=some_cache)
  112. logger.info("ds1.dataset_size is ", ds1.get_dataset_size())
  113. num_iter = 0
  114. for _ in ds1.create_dict_iterator(num_epochs=1):
  115. logger.info("get data from dataset")
  116. num_iter += 1
  117. logger.info("Number of data in ds1: {} ".format(num_iter))
  118. assert num_iter == 8
  119. logger.info('test_cache_basic3 Ended.\n')
  120. @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
  121. def test_cache_map_basic4():
  122. """
  123. Test different rows result in core dump
  124. """
  125. logger.info("Test cache basic 4")
  126. if "SESSION_ID" in os.environ:
  127. session_id = int(os.environ['SESSION_ID'])
  128. else:
  129. raise RuntimeError("Testcase requires SESSION_ID environment variable")
  130. some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True)
  131. # This DATA_DIR only has 2 images in it
  132. ds1 = ds.ImageFolderDataset(dataset_dir=DATA_DIR, cache=some_cache)
  133. decode_op = c_vision.Decode()
  134. ds1 = ds1.repeat(4)
  135. ds1 = ds1.map(operations=decode_op, input_columns=["image"])
  136. logger.info("ds1.dataset_size is ", ds1.get_dataset_size())
  137. shape = ds1.output_shapes()
  138. logger.info(shape)
  139. num_iter = 0
  140. for _ in ds1.create_dict_iterator(num_epochs=1):
  141. logger.info("get data from dataset")
  142. num_iter += 1
  143. logger.info("Number of data in ds1: {} ".format(num_iter))
  144. assert num_iter == 8
  145. logger.info('test_cache_basic4 Ended.\n')
  146. @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
  147. def test_cache_map_basic5():
  148. """
  149. Test Map with non-deterministic TensorOps above cache
  150. repeat
  151. |
  152. Map(decode, randomCrop)
  153. |
  154. Cache
  155. |
  156. ImageFolder
  157. """
  158. logger.info("Test cache failure 5")
  159. if "SESSION_ID" in os.environ:
  160. session_id = int(os.environ['SESSION_ID'])
  161. else:
  162. raise RuntimeError("Testcase requires SESSION_ID environment variable")
  163. some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True)
  164. # This DATA_DIR only has 2 images in it
  165. data = ds.ImageFolderDataset(dataset_dir=DATA_DIR, cache=some_cache)
  166. random_crop_op = c_vision.RandomCrop([512, 512], [200, 200, 200, 200])
  167. decode_op = c_vision.Decode()
  168. data = data.map(input_columns=["image"], operations=decode_op)
  169. data = data.map(input_columns=["image"], operations=random_crop_op)
  170. data = data.repeat(4)
  171. num_iter = 0
  172. for _ in data.create_dict_iterator():
  173. num_iter += 1
  174. logger.info("Number of data in ds1: {} ".format(num_iter))
  175. assert num_iter == 8
  176. logger.info('test_cache_failure5 Ended.\n')
  177. @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
  178. def test_cache_map_basic6():
  179. """
  180. Test cache as root node
  181. cache
  182. |
  183. ImageFolder
  184. """
  185. logger.info("Test cache basic 6")
  186. if "SESSION_ID" in os.environ:
  187. session_id = int(os.environ['SESSION_ID'])
  188. else:
  189. raise RuntimeError("Testcase requires SESSION_ID environment variable")
  190. some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True)
  191. # This DATA_DIR only has 2 images in it
  192. ds1 = ds.ImageFolderDataset(dataset_dir=DATA_DIR, cache=some_cache)
  193. num_iter = 0
  194. for _ in ds1.create_dict_iterator(num_epochs=1):
  195. logger.info("get data from dataset")
  196. num_iter += 1
  197. logger.info("Number of data in ds1: {} ".format(num_iter))
  198. assert num_iter == 2
  199. logger.info('test_cache_basic6 Ended.\n')
  200. @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
  201. def test_cache_map_failure1():
  202. """
  203. Test nested cache (failure)
  204. Repeat
  205. |
  206. Cache
  207. |
  208. Map(decode)
  209. |
  210. Cache
  211. |
  212. ImageFolder
  213. """
  214. logger.info("Test cache failure 1")
  215. if "SESSION_ID" in os.environ:
  216. session_id = int(os.environ['SESSION_ID'])
  217. else:
  218. raise RuntimeError("Testcase requires SESSION_ID environment variable")
  219. some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True)
  220. # This DATA_DIR only has 2 images in it
  221. ds1 = ds.ImageFolderDataset(dataset_dir=DATA_DIR, cache=some_cache)
  222. decode_op = c_vision.Decode()
  223. ds1 = ds1.map(operations=decode_op, input_columns=["image"], cache=some_cache)
  224. ds1 = ds1.repeat(4)
  225. with pytest.raises(RuntimeError) as e:
  226. num_iter = 0
  227. for _ in ds1.create_dict_iterator(num_epochs=1):
  228. num_iter += 1
  229. assert "Nested cache operations is not supported!" in str(e.value)
  230. assert num_iter == 0
  231. logger.info('test_cache_failure1 Ended.\n')
  232. @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
  233. def test_cache_map_failure2():
  234. """
  235. Test zip under cache (failure)
  236. repeat
  237. |
  238. Cache
  239. |
  240. Map(decode)
  241. |
  242. Zip
  243. | |
  244. ImageFolder ImageFolder
  245. """
  246. logger.info("Test cache failure 2")
  247. if "SESSION_ID" in os.environ:
  248. session_id = int(os.environ['SESSION_ID'])
  249. else:
  250. raise RuntimeError("Testcase requires SESSION_ID environment variable")
  251. some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True)
  252. # This DATA_DIR only has 2 images in it
  253. ds1 = ds.ImageFolderDataset(dataset_dir=DATA_DIR)
  254. ds2 = ds.ImageFolderDataset(dataset_dir=DATA_DIR)
  255. dsz = ds.zip((ds1, ds2))
  256. decode_op = c_vision.Decode()
  257. dsz = dsz.map(input_columns=["image"], operations=decode_op, cache=some_cache)
  258. dsz = dsz.repeat(4)
  259. with pytest.raises(RuntimeError) as e:
  260. num_iter = 0
  261. for _ in dsz.create_dict_iterator():
  262. num_iter += 1
  263. assert "ZipOp is currently not supported as a descendant operator under a cache" in str(e.value)
  264. assert num_iter == 0
  265. logger.info('test_cache_failure2 Ended.\n')
  266. @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
  267. def test_cache_map_failure3():
  268. """
  269. Test batch under cache (failure)
  270. repeat
  271. |
  272. Cache
  273. |
  274. Map(resize)
  275. |
  276. Batch
  277. |
  278. ImageFolder
  279. """
  280. logger.info("Test cache failure 3")
  281. if "SESSION_ID" in os.environ:
  282. session_id = int(os.environ['SESSION_ID'])
  283. else:
  284. raise RuntimeError("Testcase requires SESSION_ID environment variable")
  285. some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True)
  286. # This DATA_DIR only has 2 images in it
  287. ds1 = ds.ImageFolderDataset(dataset_dir=DATA_DIR)
  288. ds1 = ds1.batch(2)
  289. resize_op = c_vision.Resize((224, 224))
  290. ds1 = ds1.map(input_columns=["image"], operations=resize_op, cache=some_cache)
  291. ds1 = ds1.repeat(4)
  292. with pytest.raises(RuntimeError) as e:
  293. num_iter = 0
  294. for _ in ds1.create_dict_iterator():
  295. num_iter += 1
  296. assert "Unexpected error. Expect positive row id: -1" in str(e.value)
  297. assert num_iter == 0
  298. logger.info('test_cache_failure3 Ended.\n')
  299. @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
  300. def test_cache_map_failure4():
  301. """
  302. Test filter under cache (failure)
  303. repeat
  304. |
  305. Cache
  306. |
  307. Map(decode)
  308. |
  309. Filter
  310. |
  311. ImageFolder
  312. """
  313. logger.info("Test cache failure 4")
  314. if "SESSION_ID" in os.environ:
  315. session_id = int(os.environ['SESSION_ID'])
  316. else:
  317. raise RuntimeError("Testcase requires SESSION_ID environment variable")
  318. some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True)
  319. # This DATA_DIR only has 2 images in it
  320. ds1 = ds.ImageFolderDataset(dataset_dir=DATA_DIR)
  321. ds1 = ds1.filter(predicate=lambda data: data < 11, input_columns=["label"])
  322. decode_op = c_vision.Decode()
  323. ds1 = ds1.map(input_columns=["image"], operations=decode_op, cache=some_cache)
  324. ds1 = ds1.repeat(4)
  325. with pytest.raises(RuntimeError) as e:
  326. num_iter = 0
  327. for _ in ds1.create_dict_iterator():
  328. num_iter += 1
  329. assert "FilterOp is currently not supported as a descendant operator under a cache" in str(e.value)
  330. assert num_iter == 0
  331. logger.info('test_cache_failure4 Ended.\n')
  332. @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
  333. def test_cache_map_failure5():
  334. """
  335. Test Map with non-deterministic TensorOps under cache (failure)
  336. repeat
  337. |
  338. Cache
  339. |
  340. Map(decode, randomCrop)
  341. |
  342. ImageFolder
  343. """
  344. logger.info("Test cache failure 5")
  345. if "SESSION_ID" in os.environ:
  346. session_id = int(os.environ['SESSION_ID'])
  347. else:
  348. raise RuntimeError("Testcase requires SESSION_ID environment variable")
  349. some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True)
  350. # This DATA_DIR only has 2 images in it
  351. data = ds.ImageFolderDataset(dataset_dir=DATA_DIR)
  352. random_crop_op = c_vision.RandomCrop([512, 512], [200, 200, 200, 200])
  353. decode_op = c_vision.Decode()
  354. data = data.map(input_columns=["image"], operations=decode_op)
  355. data = data.map(input_columns=["image"], operations=random_crop_op, cache=some_cache)
  356. data = data.repeat(4)
  357. with pytest.raises(RuntimeError) as e:
  358. num_iter = 0
  359. for _ in data.create_dict_iterator():
  360. num_iter += 1
  361. assert "MapOp with non-deterministic TensorOps is currently not supported as a descendant of cache" in str(e.value)
  362. assert num_iter == 0
  363. logger.info('test_cache_failure5 Ended.\n')
  364. @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
  365. def test_cache_map_failure6():
  366. """
  367. Test no-cache-supporting MindRecord leaf with Map under cache (failure)
  368. repeat
  369. |
  370. Cache
  371. |
  372. Map(resize)
  373. |
  374. MindRecord
  375. """
  376. logger.info("Test cache failure 6")
  377. if "SESSION_ID" in os.environ:
  378. session_id = int(os.environ['SESSION_ID'])
  379. else:
  380. raise RuntimeError("Testcase requires SESSION_ID environment variable")
  381. some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True)
  382. columns_list = ["id", "file_name", "label_name", "img_data", "label_data"]
  383. num_readers = 1
  384. # The dataset has 5 records
  385. data = ds.MindDataset(MIND_RECORD_DATA_DIR, columns_list, num_readers)
  386. resize_op = c_vision.Resize((224, 224))
  387. data = data.map(input_columns=["img_data"], operations=resize_op, cache=some_cache)
  388. data = data.repeat(4)
  389. with pytest.raises(RuntimeError) as e:
  390. num_iter = 0
  391. for _ in data.create_dict_iterator():
  392. num_iter += 1
  393. assert "There is currently no support for MindRecordOp under cache" in str(e.value)
  394. assert num_iter == 0
  395. logger.info('test_cache_failure6 Ended.\n')
  396. @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
  397. def test_cache_map_failure7():
  398. """
  399. Test no-cache-supporting Generator leaf with Map under cache (failure)
  400. repeat
  401. |
  402. Cache
  403. |
  404. Map(lambda x: x)
  405. |
  406. Generator
  407. """
  408. def generator_1d():
  409. for i in range(64):
  410. yield (np.array(i),)
  411. logger.info("Test cache failure 7")
  412. if "SESSION_ID" in os.environ:
  413. session_id = int(os.environ['SESSION_ID'])
  414. else:
  415. raise RuntimeError("Testcase requires SESSION_ID environment variable")
  416. some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True)
  417. data = ds.GeneratorDataset(generator_1d, ["data"])
  418. data = data.map((lambda x: x), ["data"], cache=some_cache)
  419. data = data.repeat(4)
  420. with pytest.raises(RuntimeError) as e:
  421. num_iter = 0
  422. for _ in data.create_dict_iterator():
  423. num_iter += 1
  424. assert "There is currently no support for GeneratorOp under cache" in str(e.value)
  425. assert num_iter == 0
  426. logger.info('test_cache_failure7 Ended.\n')
  427. @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
  428. def test_cache_map_parameter_check():
  429. """
  430. Test illegal parameters for DatasetCache
  431. """
  432. logger.info("Test cache map parameter check")
  433. with pytest.raises(ValueError) as info:
  434. ds.DatasetCache(session_id=-1, size=0, spilling=True)
  435. assert "Input is not within the required interval" in str(info.value)
  436. with pytest.raises(TypeError) as info:
  437. ds.DatasetCache(session_id="1", size=0, spilling=True)
  438. assert "Argument session_id with value 1 is not of type (<class 'int'>,)" in str(info.value)
  439. with pytest.raises(TypeError) as info:
  440. ds.DatasetCache(session_id=None, size=0, spilling=True)
  441. assert "Argument session_id with value None is not of type (<class 'int'>,)" in str(info.value)
  442. with pytest.raises(ValueError) as info:
  443. ds.DatasetCache(session_id=1, size=-1, spilling=True)
  444. assert "Input is not within the required interval" in str(info.value)
  445. with pytest.raises(TypeError) as info:
  446. ds.DatasetCache(session_id=1, size="1", spilling=True)
  447. assert "Argument size with value 1 is not of type (<class 'int'>,)" in str(info.value)
  448. with pytest.raises(TypeError) as info:
  449. ds.DatasetCache(session_id=1, size=None, spilling=True)
  450. assert "Argument size with value None is not of type (<class 'int'>,)" in str(info.value)
  451. with pytest.raises(TypeError) as info:
  452. ds.DatasetCache(session_id=1, size=0, spilling="illegal")
  453. assert "Argument spilling with value illegal is not of type (<class 'bool'>,)" in str(info.value)
  454. with pytest.raises(RuntimeError) as err:
  455. ds.DatasetCache(session_id=1, size=0, spilling=True, hostname="illegal")
  456. assert "Unexpected error. now cache client has to be on the same host with cache server" in str(err.value)
  457. with pytest.raises(RuntimeError) as err:
  458. ds.DatasetCache(session_id=1, size=0, spilling=True, hostname="127.0.0.2")
  459. assert "Unexpected error. now cache client has to be on the same host with cache server" in str(err.value)
  460. with pytest.raises(TypeError) as info:
  461. ds.DatasetCache(session_id=1, size=0, spilling=True, port="illegal")
  462. assert "incompatible constructor arguments" in str(info.value)
  463. with pytest.raises(TypeError) as info:
  464. ds.DatasetCache(session_id=1, size=0, spilling=True, port="50052")
  465. assert "incompatible constructor arguments" in str(info.value)
  466. with pytest.raises(RuntimeError) as err:
  467. ds.DatasetCache(session_id=1, size=0, spilling=True, port=0)
  468. assert "Unexpected error. port must be positive" in str(err.value)
  469. with pytest.raises(RuntimeError) as err:
  470. ds.DatasetCache(session_id=1, size=0, spilling=True, port=65536)
  471. assert "Unexpected error. illegal port number" in str(err.value)
  472. with pytest.raises(TypeError) as err:
  473. ds.ImageFolderDataset(dataset_dir=DATA_DIR, cache=True)
  474. assert "Argument cache with value True is not of type" in str(err.value)
  475. logger.info("test_cache_map_parameter_check Ended.\n")
  476. @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
  477. def test_cache_map_running_twice1():
  478. """
  479. Executing the same pipeline for twice (from python), with cache injected after map
  480. Repeat
  481. |
  482. Cache
  483. |
  484. Map(decode)
  485. |
  486. ImageFolder
  487. """
  488. logger.info("Test cache map running twice 1")
  489. if "SESSION_ID" in os.environ:
  490. session_id = int(os.environ['SESSION_ID'])
  491. else:
  492. raise RuntimeError("Testcase requires SESSION_ID environment variable")
  493. some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True)
  494. # This DATA_DIR only has 2 images in it
  495. ds1 = ds.ImageFolderDataset(dataset_dir=DATA_DIR)
  496. decode_op = c_vision.Decode()
  497. ds1 = ds1.map(input_columns=["image"], operations=decode_op, cache=some_cache)
  498. ds1 = ds1.repeat(4)
  499. num_iter = 0
  500. for _ in ds1.create_dict_iterator():
  501. num_iter += 1
  502. logger.info("Number of data in ds1: {} ".format(num_iter))
  503. assert num_iter == 8
  504. num_iter = 0
  505. for _ in ds1.create_dict_iterator():
  506. num_iter += 1
  507. logger.info("Number of data in ds1: {} ".format(num_iter))
  508. assert num_iter == 8
  509. logger.info("test_cache_map_running_twice1 Ended.\n")
  510. @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
  511. def test_cache_map_running_twice2():
  512. """
  513. Executing the same pipeline for twice (from shell), with cache injected after leaf
  514. Repeat
  515. |
  516. Map(decode)
  517. |
  518. Cache
  519. |
  520. ImageFolder
  521. """
  522. logger.info("Test cache map running twice 2")
  523. if "SESSION_ID" in os.environ:
  524. session_id = int(os.environ['SESSION_ID'])
  525. else:
  526. raise RuntimeError("Testcase requires SESSION_ID environment variable")
  527. some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True)
  528. # This DATA_DIR only has 2 images in it
  529. ds1 = ds.ImageFolderDataset(dataset_dir=DATA_DIR, cache=some_cache)
  530. decode_op = c_vision.Decode()
  531. ds1 = ds1.map(input_columns=["image"], operations=decode_op)
  532. ds1 = ds1.repeat(4)
  533. num_iter = 0
  534. for _ in ds1.create_dict_iterator():
  535. num_iter += 1
  536. logger.info("Number of data in ds1: {} ".format(num_iter))
  537. assert num_iter == 8
  538. logger.info("test_cache_map_running_twice2 Ended.\n")
  539. @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
  540. def test_cache_map_extra_small_size1():
  541. """
  542. Test running pipeline with cache of extra small size and spilling true
  543. Repeat
  544. |
  545. Map(decode)
  546. |
  547. Cache
  548. |
  549. ImageFolder
  550. """
  551. logger.info("Test cache map extra small size 1")
  552. if "SESSION_ID" in os.environ:
  553. session_id = int(os.environ['SESSION_ID'])
  554. else:
  555. raise RuntimeError("Testcase requires SESSION_ID environment variable")
  556. some_cache = ds.DatasetCache(session_id=session_id, size=1, spilling=True)
  557. # This DATA_DIR only has 2 images in it
  558. ds1 = ds.ImageFolderDataset(dataset_dir=DATA_DIR, cache=some_cache)
  559. decode_op = c_vision.Decode()
  560. ds1 = ds1.map(input_columns=["image"], operations=decode_op)
  561. ds1 = ds1.repeat(4)
  562. num_iter = 0
  563. for _ in ds1.create_dict_iterator():
  564. num_iter += 1
  565. logger.info("Number of data in ds1: {} ".format(num_iter))
  566. assert num_iter == 8
  567. logger.info("test_cache_map_extra_small_size1 Ended.\n")
  568. @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
  569. def test_cache_map_extra_small_size2():
  570. """
  571. Test running pipeline with cache of extra small size and spilling false
  572. Repeat
  573. |
  574. Cache
  575. |
  576. Map(decode)
  577. |
  578. ImageFolder
  579. """
  580. logger.info("Test cache map extra small size 2")
  581. if "SESSION_ID" in os.environ:
  582. session_id = int(os.environ['SESSION_ID'])
  583. else:
  584. raise RuntimeError("Testcase requires SESSION_ID environment variable")
  585. some_cache = ds.DatasetCache(session_id=session_id, size=1, spilling=False)
  586. # This DATA_DIR only has 2 images in it
  587. ds1 = ds.ImageFolderDataset(dataset_dir=DATA_DIR)
  588. decode_op = c_vision.Decode()
  589. ds1 = ds1.map(input_columns=["image"], operations=decode_op, cache=some_cache)
  590. ds1 = ds1.repeat(4)
  591. num_iter = 0
  592. for _ in ds1.create_dict_iterator():
  593. num_iter += 1
  594. logger.info("Number of data in ds1: {} ".format(num_iter))
  595. assert num_iter == 8
  596. logger.info("test_cache_map_extra_small_size2 Ended.\n")
  597. @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
  598. def test_cache_map_no_image():
  599. """
  600. Test cache with no dataset existing in the path
  601. Repeat
  602. |
  603. Map(decode)
  604. |
  605. Cache
  606. |
  607. ImageFolder
  608. """
  609. logger.info("Test cache map no image")
  610. if "SESSION_ID" in os.environ:
  611. session_id = int(os.environ['SESSION_ID'])
  612. else:
  613. raise RuntimeError("Testcase requires SESSION_ID environment variable")
  614. some_cache = ds.DatasetCache(session_id=session_id, size=1, spilling=False)
  615. # This DATA_DIR only has 2 images in it
  616. ds1 = ds.ImageFolderDataset(dataset_dir=NO_IMAGE_DIR, cache=some_cache)
  617. decode_op = c_vision.Decode()
  618. ds1 = ds1.map(input_columns=["image"], operations=decode_op)
  619. ds1 = ds1.repeat(4)
  620. with pytest.raises(RuntimeError):
  621. num_iter = 0
  622. for _ in ds1.create_dict_iterator():
  623. num_iter += 1
  624. assert num_iter == 0
  625. logger.info("test_cache_map_no_image Ended.\n")
  626. @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
  627. def test_cache_map_parallel_pipeline1(shard):
  628. """
  629. Test running two parallel pipelines (sharing cache) with cache injected after leaf op
  630. Repeat
  631. |
  632. Map(decode)
  633. |
  634. Cache
  635. |
  636. ImageFolder
  637. """
  638. logger.info("Test cache map parallel pipeline 1")
  639. if "SESSION_ID" in os.environ:
  640. session_id = int(os.environ['SESSION_ID'])
  641. else:
  642. raise RuntimeError("Testcase requires SESSION_ID environment variable")
  643. some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True)
  644. # This DATA_DIR only has 2 images in it
  645. ds1 = ds.ImageFolderDataset(dataset_dir=DATA_DIR, num_shards=2, shard_id=int(shard), cache=some_cache)
  646. decode_op = c_vision.Decode()
  647. ds1 = ds1.map(input_columns=["image"], operations=decode_op)
  648. ds1 = ds1.repeat(4)
  649. num_iter = 0
  650. for _ in ds1.create_dict_iterator():
  651. num_iter += 1
  652. logger.info("Number of data in ds1: {} ".format(num_iter))
  653. assert num_iter == 4
  654. logger.info("test_cache_map_parallel_pipeline1 Ended.\n")
  655. @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
  656. def test_cache_map_parallel_pipeline2(shard):
  657. """
  658. Test running two parallel pipelines (sharing cache) with cache injected after map op
  659. Repeat
  660. |
  661. Cache
  662. |
  663. Map(decode)
  664. |
  665. ImageFolder
  666. """
  667. logger.info("Test cache map parallel pipeline 2")
  668. if "SESSION_ID" in os.environ:
  669. session_id = int(os.environ['SESSION_ID'])
  670. else:
  671. raise RuntimeError("Testcase requires SESSION_ID environment variable")
  672. some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True)
  673. # This DATA_DIR only has 2 images in it
  674. ds1 = ds.ImageFolderDataset(dataset_dir=DATA_DIR, num_shards=2, shard_id=int(shard))
  675. decode_op = c_vision.Decode()
  676. ds1 = ds1.map(input_columns=["image"], operations=decode_op, cache=some_cache)
  677. ds1 = ds1.repeat(4)
  678. num_iter = 0
  679. for _ in ds1.create_dict_iterator():
  680. num_iter += 1
  681. logger.info("Number of data in ds1: {} ".format(num_iter))
  682. assert num_iter == 4
  683. logger.info("test_cache_map_parallel_pipeline2 Ended.\n")
  684. @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
  685. def test_cache_map_parallel_workers():
  686. """
  687. Test cache with num_parallel_workers > 1 set for map op and leaf op
  688. Repeat
  689. |
  690. cache
  691. |
  692. Map(decode)
  693. |
  694. ImageFolder
  695. """
  696. logger.info("Test cache map parallel workers")
  697. if "SESSION_ID" in os.environ:
  698. session_id = int(os.environ['SESSION_ID'])
  699. else:
  700. raise RuntimeError("Testcase requires SESSION_ID environment variable")
  701. some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True)
  702. # This DATA_DIR only has 2 images in it
  703. ds1 = ds.ImageFolderDataset(dataset_dir=DATA_DIR, num_parallel_workers=4)
  704. decode_op = c_vision.Decode()
  705. ds1 = ds1.map(input_columns=["image"], operations=decode_op, num_parallel_workers=4, cache=some_cache)
  706. ds1 = ds1.repeat(4)
  707. num_iter = 0
  708. for _ in ds1.create_dict_iterator():
  709. num_iter += 1
  710. logger.info("Number of data in ds1: {} ".format(num_iter))
  711. assert num_iter == 8
  712. logger.info("test_cache_map_parallel_workers Ended.\n")
  713. @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
  714. def test_cache_map_server_workers_1():
  715. """
  716. start cache server with --workers 1 and then test cache function
  717. Repeat
  718. |
  719. cache
  720. |
  721. Map(decode)
  722. |
  723. ImageFolder
  724. """
  725. logger.info("Test cache map server workers 1")
  726. if "SESSION_ID" in os.environ:
  727. session_id = int(os.environ['SESSION_ID'])
  728. else:
  729. raise RuntimeError("Testcase requires SESSION_ID environment variable")
  730. some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True)
  731. # This DATA_DIR only has 2 images in it
  732. ds1 = ds.ImageFolderDataset(dataset_dir=DATA_DIR)
  733. decode_op = c_vision.Decode()
  734. ds1 = ds1.map(input_columns=["image"], operations=decode_op, cache=some_cache)
  735. ds1 = ds1.repeat(4)
  736. num_iter = 0
  737. for _ in ds1.create_dict_iterator():
  738. num_iter += 1
  739. logger.info("Number of data in ds1: {} ".format(num_iter))
  740. assert num_iter == 8
  741. logger.info("test_cache_map_server_workers_1 Ended.\n")
  742. @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
  743. def test_cache_map_server_workers_100():
  744. """
  745. start cache server with --workers 100 and then test cache function
  746. Repeat
  747. |
  748. Map(decode)
  749. |
  750. cache
  751. |
  752. ImageFolder
  753. """
  754. logger.info("Test cache map server workers 100")
  755. if "SESSION_ID" in os.environ:
  756. session_id = int(os.environ['SESSION_ID'])
  757. else:
  758. raise RuntimeError("Testcase requires SESSION_ID environment variable")
  759. some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True)
  760. # This DATA_DIR only has 2 images in it
  761. ds1 = ds.ImageFolderDataset(dataset_dir=DATA_DIR, cache=some_cache)
  762. decode_op = c_vision.Decode()
  763. ds1 = ds1.map(input_columns=["image"], operations=decode_op)
  764. ds1 = ds1.repeat(4)
  765. num_iter = 0
  766. for _ in ds1.create_dict_iterator():
  767. num_iter += 1
  768. logger.info("Number of data in ds1: {} ".format(num_iter))
  769. assert num_iter == 8
  770. logger.info("test_cache_map_server_workers_100 Ended.\n")
  771. @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
  772. def test_cache_map_num_connections_1():
  773. """
  774. Test setting num_connections=1 in DatasetCache
  775. Repeat
  776. |
  777. cache
  778. |
  779. Map(decode)
  780. |
  781. ImageFolder
  782. """
  783. logger.info("Test cache map num_connections 1")
  784. if "SESSION_ID" in os.environ:
  785. session_id = int(os.environ['SESSION_ID'])
  786. else:
  787. raise RuntimeError("Testcase requires SESSION_ID environment variable")
  788. some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True, num_connections=1)
  789. # This DATA_DIR only has 2 images in it
  790. ds1 = ds.ImageFolderDataset(dataset_dir=DATA_DIR)
  791. decode_op = c_vision.Decode()
  792. ds1 = ds1.map(input_columns=["image"], operations=decode_op, cache=some_cache)
  793. ds1 = ds1.repeat(4)
  794. num_iter = 0
  795. for _ in ds1.create_dict_iterator():
  796. num_iter += 1
  797. logger.info("Number of data in ds1: {} ".format(num_iter))
  798. assert num_iter == 8
  799. logger.info("test_cache_map_num_connections_1 Ended.\n")
  800. @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
  801. def test_cache_map_num_connections_100():
  802. """
  803. Test setting num_connections=100 in DatasetCache
  804. Repeat
  805. |
  806. Map(decode)
  807. |
  808. cache
  809. |
  810. ImageFolder
  811. """
  812. logger.info("Test cache map num_connections 100")
  813. if "SESSION_ID" in os.environ:
  814. session_id = int(os.environ['SESSION_ID'])
  815. else:
  816. raise RuntimeError("Testcase requires SESSION_ID environment variable")
  817. some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True, num_connections=100)
  818. # This DATA_DIR only has 2 images in it
  819. ds1 = ds.ImageFolderDataset(dataset_dir=DATA_DIR, cache=some_cache)
  820. decode_op = c_vision.Decode()
  821. ds1 = ds1.map(input_columns=["image"], operations=decode_op)
  822. ds1 = ds1.repeat(4)
  823. num_iter = 0
  824. for _ in ds1.create_dict_iterator():
  825. num_iter += 1
  826. logger.info("Number of data in ds1: {} ".format(num_iter))
  827. assert num_iter == 8
  828. logger.info("test_cache_map_num_connections_100 Ended.\n")
  829. @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
  830. def test_cache_map_prefetch_size_1():
  831. """
  832. Test setting prefetch_size=1 in DatasetCache
  833. Repeat
  834. |
  835. cache
  836. |
  837. Map(decode)
  838. |
  839. ImageFolder
  840. """
  841. logger.info("Test cache map prefetch_size 1")
  842. if "SESSION_ID" in os.environ:
  843. session_id = int(os.environ['SESSION_ID'])
  844. else:
  845. raise RuntimeError("Testcase requires SESSION_ID environment variable")
  846. some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True, prefetch_size=1)
  847. # This DATA_DIR only has 2 images in it
  848. ds1 = ds.ImageFolderDataset(dataset_dir=DATA_DIR)
  849. decode_op = c_vision.Decode()
  850. ds1 = ds1.map(input_columns=["image"], operations=decode_op, cache=some_cache)
  851. ds1 = ds1.repeat(4)
  852. num_iter = 0
  853. for _ in ds1.create_dict_iterator():
  854. num_iter += 1
  855. logger.info("Number of data in ds1: {} ".format(num_iter))
  856. assert num_iter == 8
  857. logger.info("test_cache_map_prefetch_size_1 Ended.\n")
  858. @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
  859. def test_cache_map_prefetch_size_100():
  860. """
  861. Test setting prefetch_size=100 in DatasetCache
  862. Repeat
  863. |
  864. Map(decode)
  865. |
  866. cache
  867. |
  868. ImageFolder
  869. """
  870. logger.info("Test cache map prefetch_size 100")
  871. if "SESSION_ID" in os.environ:
  872. session_id = int(os.environ['SESSION_ID'])
  873. else:
  874. raise RuntimeError("Testcase requires SESSION_ID environment variable")
  875. some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True, prefetch_size=100)
  876. # This DATA_DIR only has 2 images in it
  877. ds1 = ds.ImageFolderDataset(dataset_dir=DATA_DIR, cache=some_cache)
  878. decode_op = c_vision.Decode()
  879. ds1 = ds1.map(input_columns=["image"], operations=decode_op)
  880. ds1 = ds1.repeat(4)
  881. num_iter = 0
  882. for _ in ds1.create_dict_iterator():
  883. num_iter += 1
  884. logger.info("Number of data in ds1: {} ".format(num_iter))
  885. assert num_iter == 8
  886. logger.info("test_cache_map_prefetch_size_100 Ended.\n")
  887. @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
  888. def test_cache_map_to_device():
  889. """
  890. Test cache with to_device
  891. DeviceQueue
  892. |
  893. EpochCtrl
  894. |
  895. Repeat
  896. |
  897. Map(decode)
  898. |
  899. cache
  900. |
  901. ImageFolder
  902. """
  903. logger.info("Test cache map to_device")
  904. if "SESSION_ID" in os.environ:
  905. session_id = int(os.environ['SESSION_ID'])
  906. else:
  907. raise RuntimeError("Testcase requires SESSION_ID environment variable")
  908. some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True)
  909. # This DATA_DIR only has 2 images in it
  910. ds1 = ds.ImageFolderDataset(dataset_dir=DATA_DIR)
  911. decode_op = c_vision.Decode()
  912. ds1 = ds1.map(input_columns=["image"], operations=decode_op, cache=some_cache)
  913. ds1 = ds1.repeat(4)
  914. ds1 = ds1.to_device()
  915. ds1.send()
  916. logger.info("test_cache_map_to_device Ended.\n")
  917. @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
  918. def test_cache_map_epoch_ctrl1():
  919. """
  920. Test using two-loops method to run several epochs
  921. Map(decode)
  922. |
  923. cache
  924. |
  925. ImageFolder
  926. """
  927. logger.info("Test cache map epoch ctrl1")
  928. if "SESSION_ID" in os.environ:
  929. session_id = int(os.environ['SESSION_ID'])
  930. else:
  931. raise RuntimeError("Testcase requires SESSION_ID environment variable")
  932. some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True)
  933. # This DATA_DIR only has 2 images in it
  934. ds1 = ds.ImageFolderDataset(dataset_dir=DATA_DIR, cache=some_cache)
  935. decode_op = c_vision.Decode()
  936. ds1 = ds1.map(input_columns=["image"], operations=decode_op)
  937. num_epoch = 5
  938. iter1 = ds1.create_dict_iterator(num_epochs=num_epoch)
  939. epoch_count = 0
  940. for _ in range(num_epoch):
  941. row_count = 0
  942. for _ in iter1:
  943. row_count += 1
  944. logger.info("Number of data in ds1: {} ".format(row_count))
  945. assert row_count == 2
  946. epoch_count += 1
  947. assert epoch_count == num_epoch
  948. logger.info("test_cache_map_epoch_ctrl1 Ended.\n")
  949. @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
  950. def test_cache_map_epoch_ctrl2():
  951. """
  952. Test using two-loops method with infinite epochs
  953. cache
  954. |
  955. Map(decode)
  956. |
  957. ImageFolder
  958. """
  959. logger.info("Test cache map epoch ctrl2")
  960. if "SESSION_ID" in os.environ:
  961. session_id = int(os.environ['SESSION_ID'])
  962. else:
  963. raise RuntimeError("Testcase requires SESSION_ID environment variable")
  964. some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True)
  965. # This DATA_DIR only has 2 images in it
  966. ds1 = ds.ImageFolderDataset(dataset_dir=DATA_DIR)
  967. decode_op = c_vision.Decode()
  968. ds1 = ds1.map(input_columns=["image"], operations=decode_op, cache=some_cache)
  969. num_epoch = 5
  970. # iter1 will always assume there is a next epoch and never shutdown
  971. iter1 = ds1.create_dict_iterator()
  972. epoch_count = 0
  973. for _ in range(num_epoch):
  974. row_count = 0
  975. for _ in iter1:
  976. row_count += 1
  977. logger.info("Number of data in ds1: {} ".format(row_count))
  978. assert row_count == 2
  979. epoch_count += 1
  980. assert epoch_count == num_epoch
  981. # manually stop the iterator
  982. iter1.stop()
  983. logger.info("test_cache_map_epoch_ctrl2 Ended.\n")
  984. @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
  985. def test_cache_map_epoch_ctrl3():
  986. """
  987. Test using two-loops method with infinite epochs over repeat
  988. repeat
  989. |
  990. Map(decode)
  991. |
  992. cache
  993. |
  994. ImageFolder
  995. """
  996. logger.info("Test cache map epoch ctrl3")
  997. if "SESSION_ID" in os.environ:
  998. session_id = int(os.environ['SESSION_ID'])
  999. else:
  1000. raise RuntimeError("Testcase requires SESSION_ID environment variable")
  1001. some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True)
  1002. # This DATA_DIR only has 2 images in it
  1003. ds1 = ds.ImageFolderDataset(dataset_dir=DATA_DIR, cache=some_cache)
  1004. decode_op = c_vision.Decode()
  1005. ds1 = ds1.map(input_columns=["image"], operations=decode_op)
  1006. ds1 = ds1.repeat(2)
  1007. num_epoch = 5
  1008. # iter1 will always assume there is a next epoch and never shutdown
  1009. iter1 = ds1.create_dict_iterator()
  1010. epoch_count = 0
  1011. for _ in range(num_epoch):
  1012. row_count = 0
  1013. for _ in iter1:
  1014. row_count += 1
  1015. logger.info("Number of data in ds1: {} ".format(row_count))
  1016. assert row_count == 4
  1017. epoch_count += 1
  1018. assert epoch_count == num_epoch
  1019. # reply on garbage collector to destroy iter1
  1020. logger.info("test_cache_map_epoch_ctrl3 Ended.\n")
  1021. @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
  1022. def test_cache_map_coco1():
  1023. """
  1024. Test mappable coco leaf with cache op right over the leaf
  1025. cache
  1026. |
  1027. Coco
  1028. """
  1029. logger.info("Test cache map coco1")
  1030. if "SESSION_ID" in os.environ:
  1031. session_id = int(os.environ['SESSION_ID'])
  1032. else:
  1033. raise RuntimeError("Testcase requires SESSION_ID environment variable")
  1034. some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True)
  1035. # This dataset has 6 records
  1036. ds1 = ds.CocoDataset(COCO_DATA_DIR, annotation_file=COCO_ANNOTATION_FILE, task="Detection", decode=True,
  1037. cache=some_cache)
  1038. num_epoch = 4
  1039. iter1 = ds1.create_dict_iterator(num_epochs=num_epoch)
  1040. epoch_count = 0
  1041. for _ in range(num_epoch):
  1042. assert sum([1 for _ in iter1]) == 6
  1043. epoch_count += 1
  1044. assert epoch_count == num_epoch
  1045. logger.info("test_cache_map_coco1 Ended.\n")
  1046. @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
  1047. def test_cache_map_coco2():
  1048. """
  1049. Test mappable coco leaf with the cache op later in the tree above the map(resize)
  1050. cache
  1051. |
  1052. Map(resize)
  1053. |
  1054. Coco
  1055. """
  1056. logger.info("Test cache map coco2")
  1057. if "SESSION_ID" in os.environ:
  1058. session_id = int(os.environ['SESSION_ID'])
  1059. else:
  1060. raise RuntimeError("Testcase requires SESSION_ID environment variable")
  1061. some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True)
  1062. # This dataset has 6 records
  1063. ds1 = ds.CocoDataset(COCO_DATA_DIR, annotation_file=COCO_ANNOTATION_FILE, task="Detection", decode=True)
  1064. resize_op = c_vision.Resize((224, 224))
  1065. ds1 = ds1.map(input_columns=["image"], operations=resize_op, cache=some_cache)
  1066. num_epoch = 4
  1067. iter1 = ds1.create_dict_iterator(num_epochs=num_epoch)
  1068. epoch_count = 0
  1069. for _ in range(num_epoch):
  1070. assert sum([1 for _ in iter1]) == 6
  1071. epoch_count += 1
  1072. assert epoch_count == num_epoch
  1073. logger.info("test_cache_map_coco2 Ended.\n")
  1074. @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
  1075. def test_cache_map_mnist1():
  1076. """
  1077. Test mappable mnist leaf with cache op right over the leaf
  1078. cache
  1079. |
  1080. Mnist
  1081. """
  1082. logger.info("Test cache map mnist1")
  1083. if "SESSION_ID" in os.environ:
  1084. session_id = int(os.environ['SESSION_ID'])
  1085. else:
  1086. raise RuntimeError("Testcase requires SESSION_ID environment variable")
  1087. some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True)
  1088. ds1 = ds.MnistDataset(MNIST_DATA_DIR, num_samples=10, cache=some_cache)
  1089. num_epoch = 4
  1090. iter1 = ds1.create_dict_iterator(num_epochs=num_epoch)
  1091. epoch_count = 0
  1092. for _ in range(num_epoch):
  1093. assert sum([1 for _ in iter1]) == 10
  1094. epoch_count += 1
  1095. assert epoch_count == num_epoch
  1096. logger.info("test_cache_map_mnist1 Ended.\n")
  1097. @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
  1098. def test_cache_map_mnist2():
  1099. """
  1100. Test mappable mnist leaf with the cache op later in the tree above the map(resize)
  1101. cache
  1102. |
  1103. Map(resize)
  1104. |
  1105. Mnist
  1106. """
  1107. logger.info("Test cache map mnist2")
  1108. if "SESSION_ID" in os.environ:
  1109. session_id = int(os.environ['SESSION_ID'])
  1110. else:
  1111. raise RuntimeError("Testcase requires SESSION_ID environment variable")
  1112. some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True)
  1113. ds1 = ds.MnistDataset(MNIST_DATA_DIR, num_samples=10)
  1114. resize_op = c_vision.Resize((224, 224))
  1115. ds1 = ds1.map(input_columns=["image"], operations=resize_op, cache=some_cache)
  1116. num_epoch = 4
  1117. iter1 = ds1.create_dict_iterator(num_epochs=num_epoch)
  1118. epoch_count = 0
  1119. for _ in range(num_epoch):
  1120. assert sum([1 for _ in iter1]) == 10
  1121. epoch_count += 1
  1122. assert epoch_count == num_epoch
  1123. logger.info("test_cache_map_mnist2 Ended.\n")
  1124. @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
  1125. def test_cache_map_celeba1():
  1126. """
  1127. Test mappable celeba leaf with cache op right over the leaf
  1128. cache
  1129. |
  1130. CelebA
  1131. """
  1132. logger.info("Test cache map celeba1")
  1133. if "SESSION_ID" in os.environ:
  1134. session_id = int(os.environ['SESSION_ID'])
  1135. else:
  1136. raise RuntimeError("Testcase requires SESSION_ID environment variable")
  1137. some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True)
  1138. # This dataset has 4 records
  1139. ds1 = ds.CelebADataset(CELEBA_DATA_DIR, shuffle=False, decode=True, cache=some_cache)
  1140. num_epoch = 4
  1141. iter1 = ds1.create_dict_iterator(num_epochs=num_epoch)
  1142. epoch_count = 0
  1143. for _ in range(num_epoch):
  1144. assert sum([1 for _ in iter1]) == 4
  1145. epoch_count += 1
  1146. assert epoch_count == num_epoch
  1147. logger.info("test_cache_map_celeba1 Ended.\n")
  1148. @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
  1149. def test_cache_map_celeba2():
  1150. """
  1151. Test mappable celeba leaf with the cache op later in the tree above the map(resize)
  1152. cache
  1153. |
  1154. Map(resize)
  1155. |
  1156. CelebA
  1157. """
  1158. logger.info("Test cache map celeba2")
  1159. if "SESSION_ID" in os.environ:
  1160. session_id = int(os.environ['SESSION_ID'])
  1161. else:
  1162. raise RuntimeError("Testcase requires SESSION_ID environment variable")
  1163. some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True)
  1164. # This dataset has 4 records
  1165. ds1 = ds.CelebADataset(CELEBA_DATA_DIR, shuffle=False, decode=True)
  1166. resize_op = c_vision.Resize((224, 224))
  1167. ds1 = ds1.map(input_columns=["image"], operations=resize_op, cache=some_cache)
  1168. num_epoch = 4
  1169. iter1 = ds1.create_dict_iterator(num_epochs=num_epoch)
  1170. epoch_count = 0
  1171. for _ in range(num_epoch):
  1172. assert sum([1 for _ in iter1]) == 4
  1173. epoch_count += 1
  1174. assert epoch_count == num_epoch
  1175. logger.info("test_cache_map_celeba2 Ended.\n")
  1176. @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
  1177. def test_cache_map_manifest1():
  1178. """
  1179. Test mappable manifest leaf with cache op right over the leaf
  1180. cache
  1181. |
  1182. Manifest
  1183. """
  1184. logger.info("Test cache map manifest1")
  1185. if "SESSION_ID" in os.environ:
  1186. session_id = int(os.environ['SESSION_ID'])
  1187. else:
  1188. raise RuntimeError("Testcase requires SESSION_ID environment variable")
  1189. some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True)
  1190. # This dataset has 4 records
  1191. ds1 = ds.ManifestDataset(MANIFEST_DATA_FILE, decode=True, cache=some_cache)
  1192. num_epoch = 4
  1193. iter1 = ds1.create_dict_iterator(num_epochs=num_epoch)
  1194. epoch_count = 0
  1195. for _ in range(num_epoch):
  1196. assert sum([1 for _ in iter1]) == 4
  1197. epoch_count += 1
  1198. assert epoch_count == num_epoch
  1199. logger.info("test_cache_map_manifest1 Ended.\n")
  1200. @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
  1201. def test_cache_map_manifest2():
  1202. """
  1203. Test mappable manifest leaf with the cache op later in the tree above the map(resize)
  1204. cache
  1205. |
  1206. Map(resize)
  1207. |
  1208. Manifest
  1209. """
  1210. logger.info("Test cache map manifest2")
  1211. if "SESSION_ID" in os.environ:
  1212. session_id = int(os.environ['SESSION_ID'])
  1213. else:
  1214. raise RuntimeError("Testcase requires SESSION_ID environment variable")
  1215. some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True)
  1216. # This dataset has 4 records
  1217. ds1 = ds.ManifestDataset(MANIFEST_DATA_FILE, decode=True)
  1218. resize_op = c_vision.Resize((224, 224))
  1219. ds1 = ds1.map(input_columns=["image"], operations=resize_op, cache=some_cache)
  1220. num_epoch = 4
  1221. iter1 = ds1.create_dict_iterator(num_epochs=num_epoch)
  1222. epoch_count = 0
  1223. for _ in range(num_epoch):
  1224. assert sum([1 for _ in iter1]) == 4
  1225. epoch_count += 1
  1226. assert epoch_count == num_epoch
  1227. logger.info("test_cache_map_manifest2 Ended.\n")
  1228. @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
  1229. def test_cache_map_cifar1():
  1230. """
  1231. Test mappable cifar10 leaf with cache op right over the leaf
  1232. cache
  1233. |
  1234. Cifar10
  1235. """
  1236. logger.info("Test cache map cifar1")
  1237. if "SESSION_ID" in os.environ:
  1238. session_id = int(os.environ['SESSION_ID'])
  1239. else:
  1240. raise RuntimeError("Testcase requires SESSION_ID environment variable")
  1241. some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True)
  1242. ds1 = ds.Cifar10Dataset(CIFAR10_DATA_DIR, num_samples=10, cache=some_cache)
  1243. num_epoch = 4
  1244. iter1 = ds1.create_dict_iterator(num_epochs=num_epoch)
  1245. epoch_count = 0
  1246. for _ in range(num_epoch):
  1247. assert sum([1 for _ in iter1]) == 10
  1248. epoch_count += 1
  1249. assert epoch_count == num_epoch
  1250. logger.info("test_cache_map_cifar1 Ended.\n")
  1251. @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
  1252. def test_cache_map_cifar2():
  1253. """
  1254. Test mappable cifar100 leaf with the cache op later in the tree above the map(resize)
  1255. cache
  1256. |
  1257. Map(resize)
  1258. |
  1259. Cifar100
  1260. """
  1261. logger.info("Test cache map cifar2")
  1262. if "SESSION_ID" in os.environ:
  1263. session_id = int(os.environ['SESSION_ID'])
  1264. else:
  1265. raise RuntimeError("Testcase requires SESSION_ID environment variable")
  1266. some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True)
  1267. ds1 = ds.Cifar100Dataset(CIFAR100_DATA_DIR, num_samples=10)
  1268. resize_op = c_vision.Resize((224, 224))
  1269. ds1 = ds1.map(input_columns=["image"], operations=resize_op, cache=some_cache)
  1270. num_epoch = 4
  1271. iter1 = ds1.create_dict_iterator(num_epochs=num_epoch)
  1272. epoch_count = 0
  1273. for _ in range(num_epoch):
  1274. assert sum([1 for _ in iter1]) == 10
  1275. epoch_count += 1
  1276. assert epoch_count == num_epoch
  1277. logger.info("test_cache_map_cifar2 Ended.\n")
  1278. @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
  1279. def test_cache_map_voc1():
  1280. """
  1281. Test mappable voc leaf with cache op right over the leaf
  1282. cache
  1283. |
  1284. VOC
  1285. """
  1286. logger.info("Test cache map voc1")
  1287. if "SESSION_ID" in os.environ:
  1288. session_id = int(os.environ['SESSION_ID'])
  1289. else:
  1290. raise RuntimeError("Testcase requires SESSION_ID environment variable")
  1291. some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True)
  1292. # This dataset has 9 records
  1293. ds1 = ds.VOCDataset(VOC_DATA_DIR, task="Detection", usage="train", shuffle=False, decode=True, cache=some_cache)
  1294. num_epoch = 4
  1295. iter1 = ds1.create_dict_iterator(num_epochs=num_epoch)
  1296. epoch_count = 0
  1297. for _ in range(num_epoch):
  1298. assert sum([1 for _ in iter1]) == 9
  1299. epoch_count += 1
  1300. assert epoch_count == num_epoch
  1301. logger.info("test_cache_map_voc1 Ended.\n")
  1302. @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
  1303. def test_cache_map_voc2():
  1304. """
  1305. Test mappable voc leaf with the cache op later in the tree above the map(resize)
  1306. cache
  1307. |
  1308. Map(resize)
  1309. |
  1310. VOC
  1311. """
  1312. logger.info("Test cache map voc2")
  1313. if "SESSION_ID" in os.environ:
  1314. session_id = int(os.environ['SESSION_ID'])
  1315. else:
  1316. raise RuntimeError("Testcase requires SESSION_ID environment variable")
  1317. some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True)
  1318. # This dataset has 9 records
  1319. ds1 = ds.VOCDataset(VOC_DATA_DIR, task="Detection", usage="train", shuffle=False, decode=True)
  1320. resize_op = c_vision.Resize((224, 224))
  1321. ds1 = ds1.map(input_columns=["image"], operations=resize_op, cache=some_cache)
  1322. num_epoch = 4
  1323. iter1 = ds1.create_dict_iterator(num_epochs=num_epoch)
  1324. epoch_count = 0
  1325. for _ in range(num_epoch):
  1326. assert sum([1 for _ in iter1]) == 9
  1327. epoch_count += 1
  1328. assert epoch_count == num_epoch
  1329. logger.info("test_cache_map_voc2 Ended.\n")
  1330. if __name__ == '__main__':
  1331. test_cache_map_basic1()
  1332. test_cache_map_basic2()
  1333. test_cache_map_basic3()
  1334. test_cache_map_basic4()
  1335. test_cache_map_failure1()
  1336. test_cache_map_failure2()
  1337. test_cache_map_failure3()
  1338. test_cache_map_failure4()