You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_cache_map.py 70 kB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516151715181519152015211522152315241525152615271528152915301531153215331534153515361537153815391540154115421543154415451546154715481549155015511552155315541555155615571558155915601561156215631564156515661567156815691570157115721573157415751576157715781579158015811582158315841585158615871588158915901591159215931594159515961597159815991600160116021603160416051606160716081609161016111612161316141615161616171618161916201621162216231624162516261627162816291630163116321633163416351636163716381639164016411642164316441645164616471648164916501651165216531654165516561657165816591660166116621663166416651666166716681669167016711672167316741675167616771678167916801681168216831684168516861687168816891690169116921693169416951696169716981699170017011702170317041705170617071708170917101711171217131714171517161717171817191720172117221723172417251726172717281729173017311732173317341735173617371738173917401741174217431744174517461747174817491750175117521753175417551756175717581759176017611762176317641765176617671768176917701771177217731774177517761777177817791780178117821783178417851786178717881789179017911792179317941795179617971798179918001801180218031804180518061807180818091810181118121813181418151816181718181819182018211822182318241825182618271828182918301831183218331834183518361837183818391840184118421843184418451846184718481849185018511852185318541855185618571858185918601861186218631864186518661867186818691870187118721873187418751876187718781879188018811882188318841885188618871888188918901891189218931894189518961897189818991900190119021903190419051906190719081909191019111912191319141915191619171918191919201921192219231924192519261927192819291930193119321933193419351936193719381939194019411942194319441945194619471948194919501951195219531954195519561957195819591960196119621963196419651966196719681969197019711972197319741975197619771978197919801981198219831984198519861987198819891990199119921993199419951996199719981999200020012002200320042005200620072008200920102011201220132014201520162017201820192020202120222023202420252026202720282029203020312032203320342035203620372038203920402041204220432044204520462047204820492050205120522053205420552056205720582059206020612062206320642065206620672068206920702071207220732074207520762077207820792080208120822083208420852086208720882089209020912092209320942095209620972098209921002101210221032104210521062107210821092110211121122113211421152116211721182119212021212122212321242125212621272128212921302131213221332134213521362137213821392140214121422143214421452146214721482149215021512152215321542155215621572158215921602161216221632164216521662167216821692170217121722173217421752176217721782179218021812182218321842185218621872188218921902191219221932194219521962197219821992200220122022203220422052206220722082209221022112212221322142215221622172218221922202221222222232224222522262227222822292230223122322233223422352236223722382239224022412242224322442245224622472248
  1. # Copyright 2020-2021 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. """
  16. Testing cache operator with mappable datasets
  17. """
  18. import os
  19. import pytest
  20. import numpy as np
  21. import mindspore.dataset as ds
  22. import mindspore.dataset.vision.c_transforms as c_vision
  23. import mindspore.dataset.vision.py_transforms as py_vision
  24. from mindspore import log as logger
  25. from util import save_and_check_md5
  26. DATA_DIR = "../data/dataset/testImageNetData/train/"
  27. COCO_DATA_DIR = "../data/dataset/testCOCO/train/"
  28. COCO_ANNOTATION_FILE = "../data/dataset/testCOCO/annotations/train.json"
  29. NO_IMAGE_DIR = "../data/dataset/testRandomData/"
  30. MNIST_DATA_DIR = "../data/dataset/testMnistData/"
  31. CELEBA_DATA_DIR = "../data/dataset/testCelebAData/"
  32. VOC_DATA_DIR = "../data/dataset/testVOC2012/"
  33. MANIFEST_DATA_FILE = "../data/dataset/testManifestData/test.manifest"
  34. CIFAR10_DATA_DIR = "../data/dataset/testCifar10Data/"
  35. CIFAR100_DATA_DIR = "../data/dataset/testCifar100Data/"
  36. MIND_RECORD_DATA_DIR = "../data/mindrecord/testTwoImageData/twobytes.mindrecord"
  37. GENERATE_GOLDEN = False
  38. @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
  39. def test_cache_map_basic1():
  40. """
  41. Test mappable leaf with cache op right over the leaf
  42. Repeat
  43. |
  44. Map(decode)
  45. |
  46. Cache
  47. |
  48. ImageFolder
  49. """
  50. logger.info("Test cache map basic 1")
  51. if "SESSION_ID" in os.environ:
  52. session_id = int(os.environ['SESSION_ID'])
  53. else:
  54. session_id = 1
  55. some_cache = ds.DatasetCache(session_id=session_id, size=0)
  56. # This DATA_DIR only has 2 images in it
  57. ds1 = ds.ImageFolderDataset(dataset_dir=DATA_DIR, cache=some_cache)
  58. decode_op = c_vision.Decode()
  59. ds1 = ds1.map(operations=decode_op, input_columns=["image"])
  60. ds1 = ds1.repeat(4)
  61. filename = "cache_map_01_result.npz"
  62. save_and_check_md5(ds1, filename, generate_golden=GENERATE_GOLDEN)
  63. logger.info("test_cache_map_basic1 Ended.\n")
  64. @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
  65. def test_cache_map_basic2():
  66. """
  67. Test mappable leaf with the cache op later in the tree above the map(decode)
  68. Repeat
  69. |
  70. Cache
  71. |
  72. Map(decode)
  73. |
  74. ImageFolder
  75. """
  76. logger.info("Test cache map basic 2")
  77. if "SESSION_ID" in os.environ:
  78. session_id = int(os.environ['SESSION_ID'])
  79. else:
  80. raise RuntimeError("Testcase requires SESSION_ID environment variable")
  81. some_cache = ds.DatasetCache(session_id=session_id, size=0)
  82. # This DATA_DIR only has 2 images in it
  83. ds1 = ds.ImageFolderDataset(dataset_dir=DATA_DIR)
  84. decode_op = c_vision.Decode()
  85. ds1 = ds1.map(operations=decode_op, input_columns=["image"], cache=some_cache)
  86. ds1 = ds1.repeat(4)
  87. filename = "cache_map_02_result.npz"
  88. save_and_check_md5(ds1, filename, generate_golden=GENERATE_GOLDEN)
  89. logger.info("test_cache_map_basic2 Ended.\n")
  90. @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
  91. def test_cache_map_basic3():
  92. """
  93. Test different rows result in core dump
  94. """
  95. logger.info("Test cache basic 3")
  96. if "SESSION_ID" in os.environ:
  97. session_id = int(os.environ['SESSION_ID'])
  98. else:
  99. raise RuntimeError("Testcase requires SESSION_ID environment variable")
  100. some_cache = ds.DatasetCache(session_id=session_id, size=0)
  101. # This DATA_DIR only has 2 images in it
  102. ds1 = ds.ImageFolderDataset(dataset_dir=DATA_DIR, cache=some_cache)
  103. decode_op = c_vision.Decode()
  104. ds1 = ds1.repeat(4)
  105. ds1 = ds1.map(operations=decode_op, input_columns=["image"])
  106. logger.info("ds1.dataset_size is ", ds1.get_dataset_size())
  107. shape = ds1.output_shapes()
  108. logger.info(shape)
  109. num_iter = 0
  110. for _ in ds1.create_dict_iterator(num_epochs=1):
  111. logger.info("get data from dataset")
  112. num_iter += 1
  113. logger.info("Number of data in ds1: {} ".format(num_iter))
  114. assert num_iter == 8
  115. logger.info('test_cache_basic3 Ended.\n')
  116. @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
  117. def test_cache_map_basic4():
  118. """
  119. Test Map containing random operation above cache
  120. repeat
  121. |
  122. Map(decode, randomCrop)
  123. |
  124. Cache
  125. |
  126. ImageFolder
  127. """
  128. logger.info("Test cache basic 4")
  129. if "SESSION_ID" in os.environ:
  130. session_id = int(os.environ['SESSION_ID'])
  131. else:
  132. raise RuntimeError("Testcase requires SESSION_ID environment variable")
  133. some_cache = ds.DatasetCache(session_id=session_id, size=0)
  134. # This DATA_DIR only has 2 images in it
  135. data = ds.ImageFolderDataset(dataset_dir=DATA_DIR, cache=some_cache)
  136. random_crop_op = c_vision.RandomCrop([512, 512], [200, 200, 200, 200])
  137. decode_op = c_vision.Decode()
  138. data = data.map(input_columns=["image"], operations=decode_op)
  139. data = data.map(input_columns=["image"], operations=random_crop_op)
  140. data = data.repeat(4)
  141. num_iter = 0
  142. for _ in data.create_dict_iterator():
  143. num_iter += 1
  144. logger.info("Number of data in ds1: {} ".format(num_iter))
  145. assert num_iter == 8
  146. logger.info('test_cache_basic4 Ended.\n')
  147. @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
  148. def test_cache_map_basic5():
  149. """
  150. Test cache as root node
  151. cache
  152. |
  153. ImageFolder
  154. """
  155. logger.info("Test cache basic 5")
  156. if "SESSION_ID" in os.environ:
  157. session_id = int(os.environ['SESSION_ID'])
  158. else:
  159. raise RuntimeError("Testcase requires SESSION_ID environment variable")
  160. some_cache = ds.DatasetCache(session_id=session_id, size=0)
  161. # This DATA_DIR only has 2 images in it
  162. ds1 = ds.ImageFolderDataset(dataset_dir=DATA_DIR, cache=some_cache)
  163. num_iter = 0
  164. for _ in ds1.create_dict_iterator(num_epochs=1):
  165. logger.info("get data from dataset")
  166. num_iter += 1
  167. logger.info("Number of data in ds1: {} ".format(num_iter))
  168. assert num_iter == 2
  169. logger.info('test_cache_basic5 Ended.\n')
  170. @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
  171. def test_cache_map_failure1():
  172. """
  173. Test nested cache (failure)
  174. Repeat
  175. |
  176. Cache
  177. |
  178. Map(decode)
  179. |
  180. Cache
  181. |
  182. Coco
  183. """
  184. logger.info("Test cache failure 1")
  185. if "SESSION_ID" in os.environ:
  186. session_id = int(os.environ['SESSION_ID'])
  187. else:
  188. raise RuntimeError("Testcase requires SESSION_ID environment variable")
  189. some_cache = ds.DatasetCache(session_id=session_id, size=0)
  190. # This DATA_DIR has 6 images in it
  191. ds1 = ds.CocoDataset(COCO_DATA_DIR, annotation_file=COCO_ANNOTATION_FILE, task="Detection", decode=True,
  192. cache=some_cache)
  193. decode_op = c_vision.Decode()
  194. ds1 = ds1.map(operations=decode_op, input_columns=["image"], cache=some_cache)
  195. ds1 = ds1.repeat(4)
  196. with pytest.raises(RuntimeError) as e:
  197. ds1.get_batch_size()
  198. assert "Nested cache operations" in str(e.value)
  199. with pytest.raises(RuntimeError) as e:
  200. num_iter = 0
  201. for _ in ds1.create_dict_iterator(num_epochs=1):
  202. num_iter += 1
  203. assert "Nested cache operations" in str(e.value)
  204. assert num_iter == 0
  205. logger.info('test_cache_failure1 Ended.\n')
  206. @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
  207. def test_cache_map_failure2():
  208. """
  209. Test zip under cache (failure)
  210. repeat
  211. |
  212. Cache
  213. |
  214. Map(decode)
  215. |
  216. Zip
  217. | |
  218. ImageFolder ImageFolder
  219. """
  220. logger.info("Test cache failure 2")
  221. if "SESSION_ID" in os.environ:
  222. session_id = int(os.environ['SESSION_ID'])
  223. else:
  224. raise RuntimeError("Testcase requires SESSION_ID environment variable")
  225. some_cache = ds.DatasetCache(session_id=session_id, size=0)
  226. # This DATA_DIR only has 2 images in it
  227. ds1 = ds.ImageFolderDataset(dataset_dir=DATA_DIR)
  228. ds2 = ds.ImageFolderDataset(dataset_dir=DATA_DIR)
  229. dsz = ds.zip((ds1, ds2))
  230. decode_op = c_vision.Decode()
  231. dsz = dsz.map(input_columns=["image"], operations=decode_op, cache=some_cache)
  232. dsz = dsz.repeat(4)
  233. with pytest.raises(RuntimeError) as e:
  234. num_iter = 0
  235. for _ in dsz.create_dict_iterator():
  236. num_iter += 1
  237. assert "ZipNode is not supported as a descendant operator under a cache" in str(e.value)
  238. assert num_iter == 0
  239. logger.info('test_cache_failure2 Ended.\n')
  240. @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
  241. def test_cache_map_failure3():
  242. """
  243. Test batch under cache (failure)
  244. repeat
  245. |
  246. Cache
  247. |
  248. Map(resize)
  249. |
  250. Batch
  251. |
  252. Mnist
  253. """
  254. logger.info("Test cache failure 3")
  255. if "SESSION_ID" in os.environ:
  256. session_id = int(os.environ['SESSION_ID'])
  257. else:
  258. raise RuntimeError("Testcase requires SESSION_ID environment variable")
  259. some_cache = ds.DatasetCache(session_id=session_id, size=0)
  260. ds1 = ds.MnistDataset(MNIST_DATA_DIR, num_samples=10)
  261. ds1 = ds1.batch(2)
  262. resize_op = c_vision.Resize((224, 224))
  263. ds1 = ds1.map(input_columns=["image"], operations=resize_op, cache=some_cache)
  264. ds1 = ds1.repeat(4)
  265. with pytest.raises(RuntimeError) as e:
  266. num_iter = 0
  267. for _ in ds1.create_dict_iterator():
  268. num_iter += 1
  269. assert "BatchNode is not supported as a descendant operator under a cache" in str(e.value)
  270. assert num_iter == 0
  271. logger.info('test_cache_failure3 Ended.\n')
  272. @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
  273. def test_cache_map_failure4():
  274. """
  275. Test filter under cache (failure)
  276. repeat
  277. |
  278. Cache
  279. |
  280. Map(decode)
  281. |
  282. Filter
  283. |
  284. CelebA
  285. """
  286. logger.info("Test cache failure 4")
  287. if "SESSION_ID" in os.environ:
  288. session_id = int(os.environ['SESSION_ID'])
  289. else:
  290. raise RuntimeError("Testcase requires SESSION_ID environment variable")
  291. some_cache = ds.DatasetCache(session_id=session_id, size=0)
  292. # This dataset has 4 records
  293. ds1 = ds.CelebADataset(CELEBA_DATA_DIR, shuffle=False, decode=True)
  294. ds1 = ds1.filter(predicate=lambda data: data < 11, input_columns=["label"])
  295. decode_op = c_vision.Decode()
  296. ds1 = ds1.map(input_columns=["image"], operations=decode_op, cache=some_cache)
  297. ds1 = ds1.repeat(4)
  298. with pytest.raises(RuntimeError) as e:
  299. num_iter = 0
  300. for _ in ds1.create_dict_iterator():
  301. num_iter += 1
  302. assert "FilterNode is not supported as a descendant operator under a cache" in str(e.value)
  303. assert num_iter == 0
  304. logger.info('test_cache_failure4 Ended.\n')
  305. @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
  306. def test_cache_map_failure5():
  307. """
  308. Test Map containing random operation under cache (failure)
  309. repeat
  310. |
  311. Cache
  312. |
  313. Map(decode, randomCrop)
  314. |
  315. Manifest
  316. """
  317. logger.info("Test cache failure 5")
  318. if "SESSION_ID" in os.environ:
  319. session_id = int(os.environ['SESSION_ID'])
  320. else:
  321. raise RuntimeError("Testcase requires SESSION_ID environment variable")
  322. some_cache = ds.DatasetCache(session_id=session_id, size=0)
  323. # This dataset has 4 records
  324. data = ds.ManifestDataset(MANIFEST_DATA_FILE, decode=True)
  325. random_crop_op = c_vision.RandomCrop([512, 512], [200, 200, 200, 200])
  326. decode_op = c_vision.Decode()
  327. data = data.map(input_columns=["image"], operations=decode_op)
  328. data = data.map(input_columns=["image"], operations=random_crop_op, cache=some_cache)
  329. data = data.repeat(4)
  330. with pytest.raises(RuntimeError) as e:
  331. num_iter = 0
  332. for _ in data.create_dict_iterator():
  333. num_iter += 1
  334. assert "MapNode containing random operation is not supported as a descendant of cache" in str(e.value)
  335. assert num_iter == 0
  336. logger.info('test_cache_failure5 Ended.\n')
  337. @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
  338. def test_cache_map_failure7():
  339. """
  340. Test no-cache-supporting Generator leaf with Map under cache (failure)
  341. repeat
  342. |
  343. Cache
  344. |
  345. Map(lambda x: x)
  346. |
  347. Generator
  348. """
  349. def generator_1d():
  350. for i in range(64):
  351. yield (np.array(i),)
  352. logger.info("Test cache failure 7")
  353. if "SESSION_ID" in os.environ:
  354. session_id = int(os.environ['SESSION_ID'])
  355. else:
  356. raise RuntimeError("Testcase requires SESSION_ID environment variable")
  357. some_cache = ds.DatasetCache(session_id=session_id, size=0)
  358. data = ds.GeneratorDataset(generator_1d, ["data"])
  359. data = data.map(py_vision.not_random(lambda x: x), ["data"], cache=some_cache)
  360. data = data.repeat(4)
  361. with pytest.raises(RuntimeError) as e:
  362. num_iter = 0
  363. for _ in data.create_dict_iterator():
  364. num_iter += 1
  365. assert "There is currently no support for GeneratorOp under cache" in str(e.value)
  366. assert num_iter == 0
  367. logger.info('test_cache_failure7 Ended.\n')
  368. @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
  369. def test_cache_map_failure8():
  370. """
  371. Test a repeat under mappable cache (failure)
  372. Cache
  373. |
  374. Map(decode)
  375. |
  376. Repeat
  377. |
  378. Cifar10
  379. """
  380. logger.info("Test cache failure 8")
  381. if "SESSION_ID" in os.environ:
  382. session_id = int(os.environ['SESSION_ID'])
  383. else:
  384. raise RuntimeError("Testcase requires SESSION_ID environment variable")
  385. some_cache = ds.DatasetCache(session_id=session_id, size=0)
  386. ds1 = ds.Cifar10Dataset(CIFAR10_DATA_DIR, num_samples=10)
  387. decode_op = c_vision.Decode()
  388. ds1 = ds1.repeat(4)
  389. ds1 = ds1.map(operations=decode_op, input_columns=["image"], cache=some_cache)
  390. with pytest.raises(RuntimeError) as e:
  391. num_iter = 0
  392. for _ in ds1.create_dict_iterator(num_epochs=1):
  393. num_iter += 1
  394. assert "A cache over a RepeatNode of a mappable dataset is not supported" in str(e.value)
  395. assert num_iter == 0
  396. logger.info('test_cache_failure8 Ended.\n')
  397. @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
  398. def test_cache_map_failure9():
  399. """
  400. Test take under cache (failure)
  401. repeat
  402. |
  403. Cache
  404. |
  405. Map(decode)
  406. |
  407. Take
  408. |
  409. Cifar100
  410. """
  411. logger.info("Test cache failure 9")
  412. if "SESSION_ID" in os.environ:
  413. session_id = int(os.environ['SESSION_ID'])
  414. else:
  415. raise RuntimeError("Testcase requires SESSION_ID environment variable")
  416. some_cache = ds.DatasetCache(session_id=session_id, size=0)
  417. ds1 = ds.Cifar100Dataset(CIFAR100_DATA_DIR, num_samples=10)
  418. ds1 = ds1.take(2)
  419. decode_op = c_vision.Decode()
  420. ds1 = ds1.map(input_columns=["image"], operations=decode_op, cache=some_cache)
  421. ds1 = ds1.repeat(4)
  422. with pytest.raises(RuntimeError) as e:
  423. num_iter = 0
  424. for _ in ds1.create_dict_iterator():
  425. num_iter += 1
  426. assert "TakeNode (possibly from Split) is not supported as a descendant operator under a cache" in str(e.value)
  427. assert num_iter == 0
  428. logger.info('test_cache_failure9 Ended.\n')
  429. @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
  430. def test_cache_map_failure10():
  431. """
  432. Test skip under cache (failure)
  433. repeat
  434. |
  435. Cache
  436. |
  437. Map(decode)
  438. |
  439. Skip
  440. |
  441. VOC
  442. """
  443. logger.info("Test cache failure 10")
  444. if "SESSION_ID" in os.environ:
  445. session_id = int(os.environ['SESSION_ID'])
  446. else:
  447. raise RuntimeError("Testcase requires SESSION_ID environment variable")
  448. some_cache = ds.DatasetCache(session_id=session_id, size=0)
  449. # This dataset has 9 records
  450. ds1 = ds.VOCDataset(VOC_DATA_DIR, task="Detection", usage="train", shuffle=False, decode=True)
  451. ds1 = ds1.skip(1)
  452. decode_op = c_vision.Decode()
  453. ds1 = ds1.map(input_columns=["image"], operations=decode_op, cache=some_cache)
  454. ds1 = ds1.repeat(4)
  455. with pytest.raises(RuntimeError) as e:
  456. num_iter = 0
  457. for _ in ds1.create_dict_iterator():
  458. num_iter += 1
  459. assert "SkipNode is not supported as a descendant operator under a cache" in str(e.value)
  460. assert num_iter == 0
  461. logger.info('test_cache_failure10 Ended.\n')
  462. @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
  463. def test_cache_map_failure11():
  464. """
  465. Test set spilling=true when cache server is started without spilling support (failure)
  466. Cache(spilling=true)
  467. |
  468. ImageFolder
  469. """
  470. logger.info("Test cache failure 11")
  471. if "SESSION_ID" in os.environ:
  472. session_id = int(os.environ['SESSION_ID'])
  473. else:
  474. raise RuntimeError("Testcase requires SESSION_ID environment variable")
  475. some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True)
  476. # This DATA_DIR only has 2 images in it
  477. ds1 = ds.ImageFolderDataset(dataset_dir=DATA_DIR, cache=some_cache)
  478. with pytest.raises(RuntimeError) as e:
  479. num_iter = 0
  480. for _ in ds1.create_dict_iterator():
  481. num_iter += 1
  482. assert "Unexpected error. Server is not set up with spill support" in str(e.value)
  483. assert num_iter == 0
  484. logger.info('test_cache_failure11 Ended.\n')
  485. @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
  486. def test_cache_map_split1():
  487. """
  488. Test split (after a non-source node) under cache (failure).
  489. Split after a non-source node is implemented with TakeOp/SkipOp, hence the failure.
  490. repeat
  491. |
  492. Cache
  493. |
  494. Map(resize)
  495. |
  496. Split
  497. |
  498. Map(decode)
  499. |
  500. ImageFolder
  501. """
  502. logger.info("Test cache split 1")
  503. if "SESSION_ID" in os.environ:
  504. session_id = int(os.environ['SESSION_ID'])
  505. else:
  506. raise RuntimeError("Testcase requires SESSION_ID environment variable")
  507. some_cache = ds.DatasetCache(session_id=session_id, size=0)
  508. # This DATA_DIR only has 2 images in it
  509. ds1 = ds.ImageFolderDataset(dataset_dir=DATA_DIR)
  510. decode_op = c_vision.Decode()
  511. ds1 = ds1.map(input_columns=["image"], operations=decode_op)
  512. ds1, ds2 = ds1.split([0.5, 0.5])
  513. resize_op = c_vision.Resize((224, 224))
  514. ds1 = ds1.map(input_columns=["image"], operations=resize_op, cache=some_cache)
  515. ds2 = ds2.map(input_columns=["image"], operations=resize_op, cache=some_cache)
  516. ds1 = ds1.repeat(4)
  517. ds2 = ds2.repeat(4)
  518. with pytest.raises(RuntimeError) as e:
  519. num_iter = 0
  520. for _ in ds1.create_dict_iterator():
  521. num_iter += 1
  522. assert "TakeNode (possibly from Split) is not supported as a descendant operator under a cache" in str(e.value)
  523. with pytest.raises(RuntimeError) as e:
  524. num_iter = 0
  525. for _ in ds2.create_dict_iterator():
  526. num_iter += 1
  527. assert "TakeNode (possibly from Split) is not supported as a descendant operator under a cache" in str(e.value)
  528. logger.info('test_cache_split1 Ended.\n')
  529. @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
  530. def test_cache_map_split2():
  531. """
  532. Test split (after a source node) under cache (ok).
  533. Split after a source node is implemented with subset sampler, hence ok.
  534. repeat
  535. |
  536. Cache
  537. |
  538. Map(resize)
  539. |
  540. Split
  541. |
  542. VOCDataset
  543. """
  544. logger.info("Test cache split 2")
  545. if "SESSION_ID" in os.environ:
  546. session_id = int(os.environ['SESSION_ID'])
  547. else:
  548. raise RuntimeError("Testcase requires SESSION_ID environment variable")
  549. some_cache = ds.DatasetCache(session_id=session_id, size=0)
  550. # This dataset has 9 records
  551. ds1 = ds.VOCDataset(VOC_DATA_DIR, task="Detection", usage="train", shuffle=False, decode=True)
  552. ds1, ds2 = ds1.split([0.3, 0.7])
  553. resize_op = c_vision.Resize((224, 224))
  554. ds1 = ds1.map(input_columns=["image"], operations=resize_op, cache=some_cache)
  555. ds2 = ds2.map(input_columns=["image"], operations=resize_op, cache=some_cache)
  556. ds1 = ds1.repeat(4)
  557. ds2 = ds2.repeat(4)
  558. num_iter = 0
  559. for _ in ds1.create_dict_iterator():
  560. num_iter += 1
  561. assert num_iter == 12
  562. num_iter = 0
  563. for _ in ds2.create_dict_iterator():
  564. num_iter += 1
  565. assert num_iter == 24
  566. logger.info('test_cache_split2 Ended.\n')
  567. @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
  568. def test_cache_map_parameter_check():
  569. """
  570. Test illegal parameters for DatasetCache
  571. """
  572. logger.info("Test cache map parameter check")
  573. with pytest.raises(ValueError) as info:
  574. ds.DatasetCache(session_id=-1, size=0)
  575. assert "Input is not within the required interval" in str(info.value)
  576. with pytest.raises(TypeError) as info:
  577. ds.DatasetCache(session_id="1", size=0)
  578. assert "Argument session_id with value 1 is not of type" in str(info.value)
  579. with pytest.raises(TypeError) as info:
  580. ds.DatasetCache(session_id=None, size=0)
  581. assert "Argument session_id with value None is not of type" in str(info.value)
  582. with pytest.raises(ValueError) as info:
  583. ds.DatasetCache(session_id=1, size=-1)
  584. assert "Input size must be greater than 0" in str(info.value)
  585. with pytest.raises(TypeError) as info:
  586. ds.DatasetCache(session_id=1, size="1")
  587. assert "Argument size with value 1 is not of type" in str(info.value)
  588. with pytest.raises(TypeError) as info:
  589. ds.DatasetCache(session_id=1, size=None)
  590. assert "Argument size with value None is not of type" in str(info.value)
  591. with pytest.raises(TypeError) as info:
  592. ds.DatasetCache(session_id=1, size=0, spilling="illegal")
  593. assert "Argument spilling with value illegal is not of type" in str(info.value)
  594. with pytest.raises(TypeError) as err:
  595. ds.DatasetCache(session_id=1, size=0, hostname=50052)
  596. assert "Argument hostname with value 50052 is not of type" in str(err.value)
  597. with pytest.raises(RuntimeError) as err:
  598. ds.DatasetCache(session_id=1, size=0, hostname="illegal")
  599. assert "now cache client has to be on the same host with cache server" in str(err.value)
  600. with pytest.raises(RuntimeError) as err:
  601. ds.DatasetCache(session_id=1, size=0, hostname="127.0.0.2")
  602. assert "now cache client has to be on the same host with cache server" in str(err.value)
  603. with pytest.raises(TypeError) as info:
  604. ds.DatasetCache(session_id=1, size=0, port="illegal")
  605. assert "Argument port with value illegal is not of type" in str(info.value)
  606. with pytest.raises(TypeError) as info:
  607. ds.DatasetCache(session_id=1, size=0, port="50052")
  608. assert "Argument port with value 50052 is not of type" in str(info.value)
  609. with pytest.raises(ValueError) as err:
  610. ds.DatasetCache(session_id=1, size=0, port=0)
  611. assert "Input port is not within the required interval of [1025, 65535]" in str(err.value)
  612. with pytest.raises(ValueError) as err:
  613. ds.DatasetCache(session_id=1, size=0, port=65536)
  614. assert "Input port is not within the required interval of [1025, 65535]" in str(err.value)
  615. with pytest.raises(TypeError) as err:
  616. ds.ImageFolderDataset(dataset_dir=DATA_DIR, cache=True)
  617. assert "Argument cache with value True is not of type" in str(err.value)
  618. logger.info("test_cache_map_parameter_check Ended.\n")
  619. @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
  620. def test_cache_map_running_twice1():
  621. """
  622. Executing the same pipeline for twice (from python), with cache injected after map
  623. Repeat
  624. |
  625. Cache
  626. |
  627. Map(decode)
  628. |
  629. ImageFolder
  630. """
  631. logger.info("Test cache map running twice 1")
  632. if "SESSION_ID" in os.environ:
  633. session_id = int(os.environ['SESSION_ID'])
  634. else:
  635. raise RuntimeError("Testcase requires SESSION_ID environment variable")
  636. some_cache = ds.DatasetCache(session_id=session_id, size=0)
  637. # This DATA_DIR only has 2 images in it
  638. ds1 = ds.ImageFolderDataset(dataset_dir=DATA_DIR)
  639. decode_op = c_vision.Decode()
  640. ds1 = ds1.map(input_columns=["image"], operations=decode_op, cache=some_cache)
  641. ds1 = ds1.repeat(4)
  642. num_iter = 0
  643. for _ in ds1.create_dict_iterator():
  644. num_iter += 1
  645. logger.info("Number of data in ds1: {} ".format(num_iter))
  646. assert num_iter == 8
  647. num_iter = 0
  648. for _ in ds1.create_dict_iterator():
  649. num_iter += 1
  650. logger.info("Number of data in ds1: {} ".format(num_iter))
  651. assert num_iter == 8
  652. logger.info("test_cache_map_running_twice1 Ended.\n")
  653. @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
  654. def test_cache_map_running_twice2():
  655. """
  656. Executing the same pipeline for twice (from shell), with cache injected after leaf
  657. Repeat
  658. |
  659. Map(decode)
  660. |
  661. Cache
  662. |
  663. ImageFolder
  664. """
  665. logger.info("Test cache map running twice 2")
  666. if "SESSION_ID" in os.environ:
  667. session_id = int(os.environ['SESSION_ID'])
  668. else:
  669. raise RuntimeError("Testcase requires SESSION_ID environment variable")
  670. some_cache = ds.DatasetCache(session_id=session_id, size=0)
  671. # This DATA_DIR only has 2 images in it
  672. ds1 = ds.ImageFolderDataset(dataset_dir=DATA_DIR, cache=some_cache)
  673. decode_op = c_vision.Decode()
  674. ds1 = ds1.map(input_columns=["image"], operations=decode_op)
  675. ds1 = ds1.repeat(4)
  676. num_iter = 0
  677. for _ in ds1.create_dict_iterator():
  678. num_iter += 1
  679. logger.info("Number of data in ds1: {} ".format(num_iter))
  680. assert num_iter == 8
  681. logger.info("test_cache_map_running_twice2 Ended.\n")
  682. @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
  683. def test_cache_map_extra_small_size1():
  684. """
  685. Test running pipeline with cache of extra small size and spilling true
  686. Repeat
  687. |
  688. Map(decode)
  689. |
  690. Cache
  691. |
  692. ImageFolder
  693. """
  694. logger.info("Test cache map extra small size 1")
  695. if "SESSION_ID" in os.environ:
  696. session_id = int(os.environ['SESSION_ID'])
  697. else:
  698. raise RuntimeError("Testcase requires SESSION_ID environment variable")
  699. some_cache = ds.DatasetCache(session_id=session_id, size=1, spilling=True)
  700. # This DATA_DIR only has 2 images in it
  701. ds1 = ds.ImageFolderDataset(dataset_dir=DATA_DIR, cache=some_cache)
  702. decode_op = c_vision.Decode()
  703. ds1 = ds1.map(input_columns=["image"], operations=decode_op)
  704. ds1 = ds1.repeat(4)
  705. num_iter = 0
  706. for _ in ds1.create_dict_iterator():
  707. num_iter += 1
  708. logger.info("Number of data in ds1: {} ".format(num_iter))
  709. assert num_iter == 8
  710. logger.info("test_cache_map_extra_small_size1 Ended.\n")
  711. @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
  712. def test_cache_map_extra_small_size2():
  713. """
  714. Test running pipeline with cache of extra small size and spilling false
  715. Repeat
  716. |
  717. Cache
  718. |
  719. Map(decode)
  720. |
  721. ImageFolder
  722. """
  723. logger.info("Test cache map extra small size 2")
  724. if "SESSION_ID" in os.environ:
  725. session_id = int(os.environ['SESSION_ID'])
  726. else:
  727. raise RuntimeError("Testcase requires SESSION_ID environment variable")
  728. some_cache = ds.DatasetCache(session_id=session_id, size=1, spilling=False)
  729. # This DATA_DIR only has 2 images in it
  730. ds1 = ds.ImageFolderDataset(dataset_dir=DATA_DIR)
  731. decode_op = c_vision.Decode()
  732. ds1 = ds1.map(input_columns=["image"], operations=decode_op, cache=some_cache)
  733. ds1 = ds1.repeat(4)
  734. num_iter = 0
  735. for _ in ds1.create_dict_iterator():
  736. num_iter += 1
  737. logger.info("Number of data in ds1: {} ".format(num_iter))
  738. assert num_iter == 8
  739. logger.info("test_cache_map_extra_small_size2 Ended.\n")
  740. @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
  741. def test_cache_map_no_image():
  742. """
  743. Test cache with no dataset existing in the path
  744. Repeat
  745. |
  746. Map(decode)
  747. |
  748. Cache
  749. |
  750. ImageFolder
  751. """
  752. logger.info("Test cache map no image")
  753. if "SESSION_ID" in os.environ:
  754. session_id = int(os.environ['SESSION_ID'])
  755. else:
  756. raise RuntimeError("Testcase requires SESSION_ID environment variable")
  757. some_cache = ds.DatasetCache(session_id=session_id, size=1, spilling=False)
  758. # This DATA_DIR only has 2 images in it
  759. ds1 = ds.ImageFolderDataset(dataset_dir=NO_IMAGE_DIR, cache=some_cache)
  760. decode_op = c_vision.Decode()
  761. ds1 = ds1.map(input_columns=["image"], operations=decode_op)
  762. ds1 = ds1.repeat(4)
  763. with pytest.raises(RuntimeError):
  764. num_iter = 0
  765. for _ in ds1.create_dict_iterator():
  766. num_iter += 1
  767. assert num_iter == 0
  768. logger.info("test_cache_map_no_image Ended.\n")
  769. @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
  770. def test_cache_map_parallel_pipeline1(shard):
  771. """
  772. Test running two parallel pipelines (sharing cache) with cache injected after leaf op
  773. Repeat
  774. |
  775. Map(decode)
  776. |
  777. Cache
  778. |
  779. ImageFolder
  780. """
  781. logger.info("Test cache map parallel pipeline 1")
  782. if "SESSION_ID" in os.environ:
  783. session_id = int(os.environ['SESSION_ID'])
  784. else:
  785. raise RuntimeError("Testcase requires SESSION_ID environment variable")
  786. some_cache = ds.DatasetCache(session_id=session_id, size=0)
  787. # This DATA_DIR only has 2 images in it
  788. ds1 = ds.ImageFolderDataset(dataset_dir=DATA_DIR, num_shards=2, shard_id=int(shard), cache=some_cache)
  789. decode_op = c_vision.Decode()
  790. ds1 = ds1.map(input_columns=["image"], operations=decode_op)
  791. ds1 = ds1.repeat(4)
  792. num_iter = 0
  793. for _ in ds1.create_dict_iterator():
  794. num_iter += 1
  795. logger.info("Number of data in ds1: {} ".format(num_iter))
  796. assert num_iter == 4
  797. logger.info("test_cache_map_parallel_pipeline1 Ended.\n")
  798. @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
  799. def test_cache_map_parallel_pipeline2(shard):
  800. """
  801. Test running two parallel pipelines (sharing cache) with cache injected after map op
  802. Repeat
  803. |
  804. Cache
  805. |
  806. Map(decode)
  807. |
  808. ImageFolder
  809. """
  810. logger.info("Test cache map parallel pipeline 2")
  811. if "SESSION_ID" in os.environ:
  812. session_id = int(os.environ['SESSION_ID'])
  813. else:
  814. raise RuntimeError("Testcase requires SESSION_ID environment variable")
  815. some_cache = ds.DatasetCache(session_id=session_id, size=0)
  816. # This DATA_DIR only has 2 images in it
  817. ds1 = ds.ImageFolderDataset(dataset_dir=DATA_DIR, num_shards=2, shard_id=int(shard))
  818. decode_op = c_vision.Decode()
  819. ds1 = ds1.map(input_columns=["image"], operations=decode_op, cache=some_cache)
  820. ds1 = ds1.repeat(4)
  821. num_iter = 0
  822. for _ in ds1.create_dict_iterator():
  823. num_iter += 1
  824. logger.info("Number of data in ds1: {} ".format(num_iter))
  825. assert num_iter == 4
  826. logger.info("test_cache_map_parallel_pipeline2 Ended.\n")
  827. @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
  828. def test_cache_map_parallel_workers():
  829. """
  830. Test cache with num_parallel_workers > 1 set for map op and leaf op
  831. Repeat
  832. |
  833. cache
  834. |
  835. Map(decode)
  836. |
  837. ImageFolder
  838. """
  839. logger.info("Test cache map parallel workers")
  840. if "SESSION_ID" in os.environ:
  841. session_id = int(os.environ['SESSION_ID'])
  842. else:
  843. raise RuntimeError("Testcase requires SESSION_ID environment variable")
  844. some_cache = ds.DatasetCache(session_id=session_id, size=0)
  845. # This DATA_DIR only has 2 images in it
  846. ds1 = ds.ImageFolderDataset(dataset_dir=DATA_DIR, num_parallel_workers=4)
  847. decode_op = c_vision.Decode()
  848. ds1 = ds1.map(input_columns=["image"], operations=decode_op, num_parallel_workers=4, cache=some_cache)
  849. ds1 = ds1.repeat(4)
  850. num_iter = 0
  851. for _ in ds1.create_dict_iterator():
  852. num_iter += 1
  853. logger.info("Number of data in ds1: {} ".format(num_iter))
  854. assert num_iter == 8
  855. logger.info("test_cache_map_parallel_workers Ended.\n")
  856. @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
  857. def test_cache_map_server_workers_1():
  858. """
  859. start cache server with --workers 1 and then test cache function
  860. Repeat
  861. |
  862. cache
  863. |
  864. Map(decode)
  865. |
  866. ImageFolder
  867. """
  868. logger.info("Test cache map server workers 1")
  869. if "SESSION_ID" in os.environ:
  870. session_id = int(os.environ['SESSION_ID'])
  871. else:
  872. raise RuntimeError("Testcase requires SESSION_ID environment variable")
  873. some_cache = ds.DatasetCache(session_id=session_id, size=0)
  874. # This DATA_DIR only has 2 images in it
  875. ds1 = ds.ImageFolderDataset(dataset_dir=DATA_DIR)
  876. decode_op = c_vision.Decode()
  877. ds1 = ds1.map(input_columns=["image"], operations=decode_op, cache=some_cache)
  878. ds1 = ds1.repeat(4)
  879. num_iter = 0
  880. for _ in ds1.create_dict_iterator():
  881. num_iter += 1
  882. logger.info("Number of data in ds1: {} ".format(num_iter))
  883. assert num_iter == 8
  884. logger.info("test_cache_map_server_workers_1 Ended.\n")
  885. @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
  886. def test_cache_map_server_workers_100():
  887. """
  888. start cache server with --workers 100 and then test cache function
  889. Repeat
  890. |
  891. Map(decode)
  892. |
  893. cache
  894. |
  895. ImageFolder
  896. """
  897. logger.info("Test cache map server workers 100")
  898. if "SESSION_ID" in os.environ:
  899. session_id = int(os.environ['SESSION_ID'])
  900. else:
  901. raise RuntimeError("Testcase requires SESSION_ID environment variable")
  902. some_cache = ds.DatasetCache(session_id=session_id, size=0)
  903. # This DATA_DIR only has 2 images in it
  904. ds1 = ds.ImageFolderDataset(dataset_dir=DATA_DIR, cache=some_cache)
  905. decode_op = c_vision.Decode()
  906. ds1 = ds1.map(input_columns=["image"], operations=decode_op)
  907. ds1 = ds1.repeat(4)
  908. num_iter = 0
  909. for _ in ds1.create_dict_iterator():
  910. num_iter += 1
  911. logger.info("Number of data in ds1: {} ".format(num_iter))
  912. assert num_iter == 8
  913. logger.info("test_cache_map_server_workers_100 Ended.\n")
  914. @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
  915. def test_cache_map_num_connections_1():
  916. """
  917. Test setting num_connections=1 in DatasetCache
  918. Repeat
  919. |
  920. cache
  921. |
  922. Map(decode)
  923. |
  924. ImageFolder
  925. """
  926. logger.info("Test cache map num_connections 1")
  927. if "SESSION_ID" in os.environ:
  928. session_id = int(os.environ['SESSION_ID'])
  929. else:
  930. raise RuntimeError("Testcase requires SESSION_ID environment variable")
  931. some_cache = ds.DatasetCache(session_id=session_id, size=0, num_connections=1)
  932. # This DATA_DIR only has 2 images in it
  933. ds1 = ds.ImageFolderDataset(dataset_dir=DATA_DIR)
  934. decode_op = c_vision.Decode()
  935. ds1 = ds1.map(input_columns=["image"], operations=decode_op, cache=some_cache)
  936. ds1 = ds1.repeat(4)
  937. num_iter = 0
  938. for _ in ds1.create_dict_iterator():
  939. num_iter += 1
  940. logger.info("Number of data in ds1: {} ".format(num_iter))
  941. assert num_iter == 8
  942. logger.info("test_cache_map_num_connections_1 Ended.\n")
  943. @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
  944. def test_cache_map_num_connections_100():
  945. """
  946. Test setting num_connections=100 in DatasetCache
  947. Repeat
  948. |
  949. Map(decode)
  950. |
  951. cache
  952. |
  953. ImageFolder
  954. """
  955. logger.info("Test cache map num_connections 100")
  956. if "SESSION_ID" in os.environ:
  957. session_id = int(os.environ['SESSION_ID'])
  958. else:
  959. raise RuntimeError("Testcase requires SESSION_ID environment variable")
  960. some_cache = ds.DatasetCache(session_id=session_id, size=0, num_connections=100)
  961. # This DATA_DIR only has 2 images in it
  962. ds1 = ds.ImageFolderDataset(dataset_dir=DATA_DIR, cache=some_cache)
  963. decode_op = c_vision.Decode()
  964. ds1 = ds1.map(input_columns=["image"], operations=decode_op)
  965. ds1 = ds1.repeat(4)
  966. num_iter = 0
  967. for _ in ds1.create_dict_iterator():
  968. num_iter += 1
  969. logger.info("Number of data in ds1: {} ".format(num_iter))
  970. assert num_iter == 8
  971. logger.info("test_cache_map_num_connections_100 Ended.\n")
  972. @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
  973. def test_cache_map_prefetch_size_1():
  974. """
  975. Test setting prefetch_size=1 in DatasetCache
  976. Repeat
  977. |
  978. cache
  979. |
  980. Map(decode)
  981. |
  982. ImageFolder
  983. """
  984. logger.info("Test cache map prefetch_size 1")
  985. if "SESSION_ID" in os.environ:
  986. session_id = int(os.environ['SESSION_ID'])
  987. else:
  988. raise RuntimeError("Testcase requires SESSION_ID environment variable")
  989. some_cache = ds.DatasetCache(session_id=session_id, size=0, prefetch_size=1)
  990. # This DATA_DIR only has 2 images in it
  991. ds1 = ds.ImageFolderDataset(dataset_dir=DATA_DIR)
  992. decode_op = c_vision.Decode()
  993. ds1 = ds1.map(input_columns=["image"], operations=decode_op, cache=some_cache)
  994. ds1 = ds1.repeat(4)
  995. num_iter = 0
  996. for _ in ds1.create_dict_iterator():
  997. num_iter += 1
  998. logger.info("Number of data in ds1: {} ".format(num_iter))
  999. assert num_iter == 8
  1000. logger.info("test_cache_map_prefetch_size_1 Ended.\n")
  1001. @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
  1002. def test_cache_map_prefetch_size_100():
  1003. """
  1004. Test setting prefetch_size=100 in DatasetCache
  1005. Repeat
  1006. |
  1007. Map(decode)
  1008. |
  1009. cache
  1010. |
  1011. ImageFolder
  1012. """
  1013. logger.info("Test cache map prefetch_size 100")
  1014. if "SESSION_ID" in os.environ:
  1015. session_id = int(os.environ['SESSION_ID'])
  1016. else:
  1017. raise RuntimeError("Testcase requires SESSION_ID environment variable")
  1018. some_cache = ds.DatasetCache(session_id=session_id, size=0, prefetch_size=100)
  1019. # This DATA_DIR only has 2 images in it
  1020. ds1 = ds.ImageFolderDataset(dataset_dir=DATA_DIR, cache=some_cache)
  1021. decode_op = c_vision.Decode()
  1022. ds1 = ds1.map(input_columns=["image"], operations=decode_op)
  1023. ds1 = ds1.repeat(4)
  1024. num_iter = 0
  1025. for _ in ds1.create_dict_iterator():
  1026. num_iter += 1
  1027. logger.info("Number of data in ds1: {} ".format(num_iter))
  1028. assert num_iter == 8
  1029. logger.info("test_cache_map_prefetch_size_100 Ended.\n")
  1030. @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
  1031. def test_cache_map_to_device():
  1032. """
  1033. Test cache with to_device
  1034. DeviceQueue
  1035. |
  1036. EpochCtrl
  1037. |
  1038. Repeat
  1039. |
  1040. Map(decode)
  1041. |
  1042. cache
  1043. |
  1044. ImageFolder
  1045. """
  1046. logger.info("Test cache map to_device")
  1047. if "SESSION_ID" in os.environ:
  1048. session_id = int(os.environ['SESSION_ID'])
  1049. else:
  1050. raise RuntimeError("Testcase requires SESSION_ID environment variable")
  1051. some_cache = ds.DatasetCache(session_id=session_id, size=0)
  1052. # This DATA_DIR only has 2 images in it
  1053. ds1 = ds.ImageFolderDataset(dataset_dir=DATA_DIR)
  1054. decode_op = c_vision.Decode()
  1055. ds1 = ds1.map(input_columns=["image"], operations=decode_op, cache=some_cache)
  1056. ds1 = ds1.repeat(4)
  1057. ds1 = ds1.to_device()
  1058. ds1.send()
  1059. logger.info("test_cache_map_to_device Ended.\n")
  1060. @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
  1061. def test_cache_map_epoch_ctrl1():
  1062. """
  1063. Test using two-loops method to run several epochs
  1064. Map(decode)
  1065. |
  1066. cache
  1067. |
  1068. ImageFolder
  1069. """
  1070. logger.info("Test cache map epoch ctrl1")
  1071. if "SESSION_ID" in os.environ:
  1072. session_id = int(os.environ['SESSION_ID'])
  1073. else:
  1074. raise RuntimeError("Testcase requires SESSION_ID environment variable")
  1075. some_cache = ds.DatasetCache(session_id=session_id, size=0)
  1076. # This DATA_DIR only has 2 images in it
  1077. ds1 = ds.ImageFolderDataset(dataset_dir=DATA_DIR, cache=some_cache)
  1078. decode_op = c_vision.Decode()
  1079. ds1 = ds1.map(input_columns=["image"], operations=decode_op)
  1080. num_epoch = 5
  1081. iter1 = ds1.create_dict_iterator(num_epochs=num_epoch)
  1082. epoch_count = 0
  1083. for _ in range(num_epoch):
  1084. row_count = 0
  1085. for _ in iter1:
  1086. row_count += 1
  1087. logger.info("Number of data in ds1: {} ".format(row_count))
  1088. assert row_count == 2
  1089. epoch_count += 1
  1090. assert epoch_count == num_epoch
  1091. logger.info("test_cache_map_epoch_ctrl1 Ended.\n")
  1092. @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
  1093. def test_cache_map_epoch_ctrl2():
  1094. """
  1095. Test using two-loops method with infinite epochs
  1096. cache
  1097. |
  1098. Map(decode)
  1099. |
  1100. ImageFolder
  1101. """
  1102. logger.info("Test cache map epoch ctrl2")
  1103. if "SESSION_ID" in os.environ:
  1104. session_id = int(os.environ['SESSION_ID'])
  1105. else:
  1106. raise RuntimeError("Testcase requires SESSION_ID environment variable")
  1107. some_cache = ds.DatasetCache(session_id=session_id, size=0)
  1108. # This DATA_DIR only has 2 images in it
  1109. ds1 = ds.ImageFolderDataset(dataset_dir=DATA_DIR)
  1110. decode_op = c_vision.Decode()
  1111. ds1 = ds1.map(input_columns=["image"], operations=decode_op, cache=some_cache)
  1112. num_epoch = 5
  1113. # iter1 will always assume there is a next epoch and never shutdown
  1114. iter1 = ds1.create_dict_iterator()
  1115. epoch_count = 0
  1116. for _ in range(num_epoch):
  1117. row_count = 0
  1118. for _ in iter1:
  1119. row_count += 1
  1120. logger.info("Number of data in ds1: {} ".format(row_count))
  1121. assert row_count == 2
  1122. epoch_count += 1
  1123. assert epoch_count == num_epoch
  1124. # manually stop the iterator
  1125. iter1.stop()
  1126. logger.info("test_cache_map_epoch_ctrl2 Ended.\n")
  1127. @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
  1128. def test_cache_map_epoch_ctrl3():
  1129. """
  1130. Test using two-loops method with infinite epochs over repeat
  1131. repeat
  1132. |
  1133. Map(decode)
  1134. |
  1135. cache
  1136. |
  1137. ImageFolder
  1138. """
  1139. logger.info("Test cache map epoch ctrl3")
  1140. if "SESSION_ID" in os.environ:
  1141. session_id = int(os.environ['SESSION_ID'])
  1142. else:
  1143. raise RuntimeError("Testcase requires SESSION_ID environment variable")
  1144. some_cache = ds.DatasetCache(session_id=session_id, size=0)
  1145. # This DATA_DIR only has 2 images in it
  1146. ds1 = ds.ImageFolderDataset(dataset_dir=DATA_DIR, cache=some_cache)
  1147. decode_op = c_vision.Decode()
  1148. ds1 = ds1.map(input_columns=["image"], operations=decode_op)
  1149. ds1 = ds1.repeat(2)
  1150. num_epoch = 5
  1151. # iter1 will always assume there is a next epoch and never shutdown
  1152. iter1 = ds1.create_dict_iterator()
  1153. epoch_count = 0
  1154. for _ in range(num_epoch):
  1155. row_count = 0
  1156. for _ in iter1:
  1157. row_count += 1
  1158. logger.info("Number of data in ds1: {} ".format(row_count))
  1159. assert row_count == 4
  1160. epoch_count += 1
  1161. assert epoch_count == num_epoch
  1162. # reply on garbage collector to destroy iter1
  1163. logger.info("test_cache_map_epoch_ctrl3 Ended.\n")
  1164. @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
  1165. def test_cache_map_coco1():
  1166. """
  1167. Test mappable coco leaf with cache op right over the leaf
  1168. cache
  1169. |
  1170. Coco
  1171. """
  1172. logger.info("Test cache map coco1")
  1173. if "SESSION_ID" in os.environ:
  1174. session_id = int(os.environ['SESSION_ID'])
  1175. else:
  1176. raise RuntimeError("Testcase requires SESSION_ID environment variable")
  1177. some_cache = ds.DatasetCache(session_id=session_id, size=0)
  1178. # This dataset has 6 records
  1179. ds1 = ds.CocoDataset(COCO_DATA_DIR, annotation_file=COCO_ANNOTATION_FILE, task="Detection", decode=True,
  1180. cache=some_cache)
  1181. num_epoch = 4
  1182. iter1 = ds1.create_dict_iterator(num_epochs=num_epoch)
  1183. epoch_count = 0
  1184. for _ in range(num_epoch):
  1185. assert sum([1 for _ in iter1]) == 6
  1186. epoch_count += 1
  1187. assert epoch_count == num_epoch
  1188. logger.info("test_cache_map_coco1 Ended.\n")
  1189. @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
  1190. def test_cache_map_coco2():
  1191. """
  1192. Test mappable coco leaf with the cache op later in the tree above the map(resize)
  1193. cache
  1194. |
  1195. Map(resize)
  1196. |
  1197. Coco
  1198. """
  1199. logger.info("Test cache map coco2")
  1200. if "SESSION_ID" in os.environ:
  1201. session_id = int(os.environ['SESSION_ID'])
  1202. else:
  1203. raise RuntimeError("Testcase requires SESSION_ID environment variable")
  1204. some_cache = ds.DatasetCache(session_id=session_id, size=0)
  1205. # This dataset has 6 records
  1206. ds1 = ds.CocoDataset(COCO_DATA_DIR, annotation_file=COCO_ANNOTATION_FILE, task="Detection", decode=True)
  1207. resize_op = c_vision.Resize((224, 224))
  1208. ds1 = ds1.map(input_columns=["image"], operations=resize_op, cache=some_cache)
  1209. num_epoch = 4
  1210. iter1 = ds1.create_dict_iterator(num_epochs=num_epoch)
  1211. epoch_count = 0
  1212. for _ in range(num_epoch):
  1213. assert sum([1 for _ in iter1]) == 6
  1214. epoch_count += 1
  1215. assert epoch_count == num_epoch
  1216. logger.info("test_cache_map_coco2 Ended.\n")
  1217. @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
  1218. def test_cache_map_mnist1():
  1219. """
  1220. Test mappable mnist leaf with cache op right over the leaf
  1221. cache
  1222. |
  1223. Mnist
  1224. """
  1225. logger.info("Test cache map mnist1")
  1226. if "SESSION_ID" in os.environ:
  1227. session_id = int(os.environ['SESSION_ID'])
  1228. else:
  1229. raise RuntimeError("Testcase requires SESSION_ID environment variable")
  1230. some_cache = ds.DatasetCache(session_id=session_id, size=0)
  1231. ds1 = ds.MnistDataset(MNIST_DATA_DIR, num_samples=10, cache=some_cache)
  1232. num_epoch = 4
  1233. iter1 = ds1.create_dict_iterator(num_epochs=num_epoch)
  1234. epoch_count = 0
  1235. for _ in range(num_epoch):
  1236. assert sum([1 for _ in iter1]) == 10
  1237. epoch_count += 1
  1238. assert epoch_count == num_epoch
  1239. logger.info("test_cache_map_mnist1 Ended.\n")
  1240. @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
  1241. def test_cache_map_mnist2():
  1242. """
  1243. Test mappable mnist leaf with the cache op later in the tree above the map(resize)
  1244. cache
  1245. |
  1246. Map(resize)
  1247. |
  1248. Mnist
  1249. """
  1250. logger.info("Test cache map mnist2")
  1251. if "SESSION_ID" in os.environ:
  1252. session_id = int(os.environ['SESSION_ID'])
  1253. else:
  1254. raise RuntimeError("Testcase requires SESSION_ID environment variable")
  1255. some_cache = ds.DatasetCache(session_id=session_id, size=0)
  1256. ds1 = ds.MnistDataset(MNIST_DATA_DIR, num_samples=10)
  1257. resize_op = c_vision.Resize((224, 224))
  1258. ds1 = ds1.map(input_columns=["image"], operations=resize_op, cache=some_cache)
  1259. num_epoch = 4
  1260. iter1 = ds1.create_dict_iterator(num_epochs=num_epoch)
  1261. epoch_count = 0
  1262. for _ in range(num_epoch):
  1263. assert sum([1 for _ in iter1]) == 10
  1264. epoch_count += 1
  1265. assert epoch_count == num_epoch
  1266. logger.info("test_cache_map_mnist2 Ended.\n")
  1267. @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
  1268. def test_cache_map_celeba1():
  1269. """
  1270. Test mappable celeba leaf with cache op right over the leaf
  1271. cache
  1272. |
  1273. CelebA
  1274. """
  1275. logger.info("Test cache map celeba1")
  1276. if "SESSION_ID" in os.environ:
  1277. session_id = int(os.environ['SESSION_ID'])
  1278. else:
  1279. raise RuntimeError("Testcase requires SESSION_ID environment variable")
  1280. some_cache = ds.DatasetCache(session_id=session_id, size=0)
  1281. # This dataset has 4 records
  1282. ds1 = ds.CelebADataset(CELEBA_DATA_DIR, shuffle=False, decode=True, cache=some_cache)
  1283. num_epoch = 4
  1284. iter1 = ds1.create_dict_iterator(num_epochs=num_epoch)
  1285. epoch_count = 0
  1286. for _ in range(num_epoch):
  1287. assert sum([1 for _ in iter1]) == 4
  1288. epoch_count += 1
  1289. assert epoch_count == num_epoch
  1290. logger.info("test_cache_map_celeba1 Ended.\n")
  1291. @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
  1292. def test_cache_map_celeba2():
  1293. """
  1294. Test mappable celeba leaf with the cache op later in the tree above the map(resize)
  1295. cache
  1296. |
  1297. Map(resize)
  1298. |
  1299. CelebA
  1300. """
  1301. logger.info("Test cache map celeba2")
  1302. if "SESSION_ID" in os.environ:
  1303. session_id = int(os.environ['SESSION_ID'])
  1304. else:
  1305. raise RuntimeError("Testcase requires SESSION_ID environment variable")
  1306. some_cache = ds.DatasetCache(session_id=session_id, size=0)
  1307. # This dataset has 4 records
  1308. ds1 = ds.CelebADataset(CELEBA_DATA_DIR, shuffle=False, decode=True)
  1309. resize_op = c_vision.Resize((224, 224))
  1310. ds1 = ds1.map(input_columns=["image"], operations=resize_op, cache=some_cache)
  1311. num_epoch = 4
  1312. iter1 = ds1.create_dict_iterator(num_epochs=num_epoch)
  1313. epoch_count = 0
  1314. for _ in range(num_epoch):
  1315. assert sum([1 for _ in iter1]) == 4
  1316. epoch_count += 1
  1317. assert epoch_count == num_epoch
  1318. logger.info("test_cache_map_celeba2 Ended.\n")
  1319. @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
  1320. def test_cache_map_manifest1():
  1321. """
  1322. Test mappable manifest leaf with cache op right over the leaf
  1323. cache
  1324. |
  1325. Manifest
  1326. """
  1327. logger.info("Test cache map manifest1")
  1328. if "SESSION_ID" in os.environ:
  1329. session_id = int(os.environ['SESSION_ID'])
  1330. else:
  1331. raise RuntimeError("Testcase requires SESSION_ID environment variable")
  1332. some_cache = ds.DatasetCache(session_id=session_id, size=0)
  1333. # This dataset has 4 records
  1334. ds1 = ds.ManifestDataset(MANIFEST_DATA_FILE, decode=True, cache=some_cache)
  1335. num_epoch = 4
  1336. iter1 = ds1.create_dict_iterator(num_epochs=num_epoch)
  1337. epoch_count = 0
  1338. for _ in range(num_epoch):
  1339. assert sum([1 for _ in iter1]) == 4
  1340. epoch_count += 1
  1341. assert epoch_count == num_epoch
  1342. logger.info("test_cache_map_manifest1 Ended.\n")
  1343. @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
  1344. def test_cache_map_manifest2():
  1345. """
  1346. Test mappable manifest leaf with the cache op later in the tree above the map(resize)
  1347. cache
  1348. |
  1349. Map(resize)
  1350. |
  1351. Manifest
  1352. """
  1353. logger.info("Test cache map manifest2")
  1354. if "SESSION_ID" in os.environ:
  1355. session_id = int(os.environ['SESSION_ID'])
  1356. else:
  1357. raise RuntimeError("Testcase requires SESSION_ID environment variable")
  1358. some_cache = ds.DatasetCache(session_id=session_id, size=0)
  1359. # This dataset has 4 records
  1360. ds1 = ds.ManifestDataset(MANIFEST_DATA_FILE, decode=True)
  1361. resize_op = c_vision.Resize((224, 224))
  1362. ds1 = ds1.map(input_columns=["image"], operations=resize_op, cache=some_cache)
  1363. num_epoch = 4
  1364. iter1 = ds1.create_dict_iterator(num_epochs=num_epoch)
  1365. epoch_count = 0
  1366. for _ in range(num_epoch):
  1367. assert sum([1 for _ in iter1]) == 4
  1368. epoch_count += 1
  1369. assert epoch_count == num_epoch
  1370. logger.info("test_cache_map_manifest2 Ended.\n")
  1371. @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
  1372. def test_cache_map_cifar1():
  1373. """
  1374. Test mappable cifar10 leaf with cache op right over the leaf
  1375. cache
  1376. |
  1377. Cifar10
  1378. """
  1379. logger.info("Test cache map cifar1")
  1380. if "SESSION_ID" in os.environ:
  1381. session_id = int(os.environ['SESSION_ID'])
  1382. else:
  1383. raise RuntimeError("Testcase requires SESSION_ID environment variable")
  1384. some_cache = ds.DatasetCache(session_id=session_id, size=0)
  1385. ds1 = ds.Cifar10Dataset(CIFAR10_DATA_DIR, num_samples=10, cache=some_cache)
  1386. num_epoch = 4
  1387. iter1 = ds1.create_dict_iterator(num_epochs=num_epoch)
  1388. epoch_count = 0
  1389. for _ in range(num_epoch):
  1390. assert sum([1 for _ in iter1]) == 10
  1391. epoch_count += 1
  1392. assert epoch_count == num_epoch
  1393. logger.info("test_cache_map_cifar1 Ended.\n")
  1394. @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
  1395. def test_cache_map_cifar2():
  1396. """
  1397. Test mappable cifar100 leaf with the cache op later in the tree above the map(resize)
  1398. cache
  1399. |
  1400. Map(resize)
  1401. |
  1402. Cifar100
  1403. """
  1404. logger.info("Test cache map cifar2")
  1405. if "SESSION_ID" in os.environ:
  1406. session_id = int(os.environ['SESSION_ID'])
  1407. else:
  1408. raise RuntimeError("Testcase requires SESSION_ID environment variable")
  1409. some_cache = ds.DatasetCache(session_id=session_id, size=0)
  1410. ds1 = ds.Cifar100Dataset(CIFAR100_DATA_DIR, num_samples=10)
  1411. resize_op = c_vision.Resize((224, 224))
  1412. ds1 = ds1.map(input_columns=["image"], operations=resize_op, cache=some_cache)
  1413. num_epoch = 4
  1414. iter1 = ds1.create_dict_iterator(num_epochs=num_epoch)
  1415. epoch_count = 0
  1416. for _ in range(num_epoch):
  1417. assert sum([1 for _ in iter1]) == 10
  1418. epoch_count += 1
  1419. assert epoch_count == num_epoch
  1420. logger.info("test_cache_map_cifar2 Ended.\n")
  1421. @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
  1422. def test_cache_map_cifar3():
  1423. """
  1424. Test mappable cifar10 leaf with the cache op later in the tree above the map(resize)
  1425. In this case, we set a extra-small size for cache (size=1) and there are 10000 rows in the dataset.
  1426. cache
  1427. |
  1428. Cifar10
  1429. """
  1430. logger.info("Test cache map cifar3")
  1431. if "SESSION_ID" in os.environ:
  1432. session_id = int(os.environ['SESSION_ID'])
  1433. else:
  1434. raise RuntimeError("Testcase requires SESSION_ID environment variable")
  1435. some_cache = ds.DatasetCache(session_id=session_id, size=1)
  1436. ds1 = ds.Cifar10Dataset(CIFAR10_DATA_DIR, cache=some_cache)
  1437. num_epoch = 2
  1438. iter1 = ds1.create_dict_iterator(num_epochs=num_epoch)
  1439. epoch_count = 0
  1440. for _ in range(num_epoch):
  1441. assert sum([1 for _ in iter1]) == 10000
  1442. epoch_count += 1
  1443. assert epoch_count == num_epoch
  1444. logger.info("test_cache_map_cifar3 Ended.\n")
  1445. @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
  1446. def test_cache_map_cifar4():
  1447. """
  1448. Test mappable cifar10 leaf with cache op right over the leaf, and shuffle op over the cache op
  1449. shuffle
  1450. |
  1451. cache
  1452. |
  1453. Cifar10
  1454. """
  1455. logger.info("Test cache map cifar4")
  1456. if "SESSION_ID" in os.environ:
  1457. session_id = int(os.environ['SESSION_ID'])
  1458. else:
  1459. raise RuntimeError("Testcase requires SESSION_ID environment variable")
  1460. some_cache = ds.DatasetCache(session_id=session_id, size=0)
  1461. ds1 = ds.Cifar10Dataset(CIFAR10_DATA_DIR, num_samples=10, cache=some_cache)
  1462. ds1 = ds1.shuffle(10)
  1463. num_epoch = 1
  1464. iter1 = ds1.create_dict_iterator(num_epochs=num_epoch)
  1465. epoch_count = 0
  1466. for _ in range(num_epoch):
  1467. assert sum([1 for _ in iter1]) == 10
  1468. epoch_count += 1
  1469. assert epoch_count == num_epoch
  1470. logger.info("test_cache_map_cifar4 Ended.\n")
  1471. @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
  1472. def test_cache_map_voc1():
  1473. """
  1474. Test mappable voc leaf with cache op right over the leaf
  1475. cache
  1476. |
  1477. VOC
  1478. """
  1479. logger.info("Test cache map voc1")
  1480. if "SESSION_ID" in os.environ:
  1481. session_id = int(os.environ['SESSION_ID'])
  1482. else:
  1483. raise RuntimeError("Testcase requires SESSION_ID environment variable")
  1484. some_cache = ds.DatasetCache(session_id=session_id, size=0)
  1485. # This dataset has 9 records
  1486. ds1 = ds.VOCDataset(VOC_DATA_DIR, task="Detection", usage="train", shuffle=False, decode=True, cache=some_cache)
  1487. num_epoch = 4
  1488. iter1 = ds1.create_dict_iterator(num_epochs=num_epoch)
  1489. epoch_count = 0
  1490. for _ in range(num_epoch):
  1491. assert sum([1 for _ in iter1]) == 9
  1492. epoch_count += 1
  1493. assert epoch_count == num_epoch
  1494. logger.info("test_cache_map_voc1 Ended.\n")
  1495. @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
  1496. def test_cache_map_voc2():
  1497. """
  1498. Test mappable voc leaf with the cache op later in the tree above the map(resize)
  1499. cache
  1500. |
  1501. Map(resize)
  1502. |
  1503. VOC
  1504. """
  1505. logger.info("Test cache map voc2")
  1506. if "SESSION_ID" in os.environ:
  1507. session_id = int(os.environ['SESSION_ID'])
  1508. else:
  1509. raise RuntimeError("Testcase requires SESSION_ID environment variable")
  1510. some_cache = ds.DatasetCache(session_id=session_id, size=0)
  1511. # This dataset has 9 records
  1512. ds1 = ds.VOCDataset(VOC_DATA_DIR, task="Detection", usage="train", shuffle=False, decode=True)
  1513. resize_op = c_vision.Resize((224, 224))
  1514. ds1 = ds1.map(input_columns=["image"], operations=resize_op, cache=some_cache)
  1515. num_epoch = 4
  1516. iter1 = ds1.create_dict_iterator(num_epochs=num_epoch)
  1517. epoch_count = 0
  1518. for _ in range(num_epoch):
  1519. assert sum([1 for _ in iter1]) == 9
  1520. epoch_count += 1
  1521. assert epoch_count == num_epoch
  1522. logger.info("test_cache_map_voc2 Ended.\n")
  1523. class ReverseSampler(ds.Sampler):
  1524. def __iter__(self):
  1525. for i in range(self.dataset_size - 1, -1, -1):
  1526. yield i
  1527. @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
  1528. def test_cache_map_mindrecord1():
  1529. """
  1530. Test mappable mindrecord leaf with cache op right over the leaf
  1531. cache
  1532. |
  1533. MindRecord
  1534. """
  1535. logger.info("Test cache map mindrecord1")
  1536. if "SESSION_ID" in os.environ:
  1537. session_id = int(os.environ['SESSION_ID'])
  1538. else:
  1539. session_id = 1
  1540. some_cache = ds.DatasetCache(session_id=session_id, size=0)
  1541. # This dataset has 5 records
  1542. columns_list = ["id", "file_name", "label_name", "img_data", "label_data"]
  1543. ds1 = ds.MindDataset(MIND_RECORD_DATA_DIR, columns_list, cache=some_cache)
  1544. num_epoch = 4
  1545. iter1 = ds1.create_dict_iterator(num_epochs=num_epoch, output_numpy=True)
  1546. epoch_count = 0
  1547. for _ in range(num_epoch):
  1548. assert sum([1 for _ in iter1]) == 5
  1549. epoch_count += 1
  1550. assert epoch_count == num_epoch
  1551. logger.info("test_cache_map_mindrecord1 Ended.\n")
  1552. @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
  1553. def test_cache_map_mindrecord2():
  1554. """
  1555. Test mappable mindrecord leaf with the cache op later in the tree above the map(decode)
  1556. cache
  1557. |
  1558. Map(decode)
  1559. |
  1560. MindRecord
  1561. """
  1562. logger.info("Test cache map mindrecord2")
  1563. if "SESSION_ID" in os.environ:
  1564. session_id = int(os.environ['SESSION_ID'])
  1565. else:
  1566. session_id = 1
  1567. some_cache = ds.DatasetCache(session_id=session_id, size=0)
  1568. # This dataset has 5 records
  1569. columns_list = ["id", "file_name", "label_name", "img_data", "label_data"]
  1570. ds1 = ds.MindDataset(MIND_RECORD_DATA_DIR, columns_list)
  1571. decode_op = c_vision.Decode()
  1572. ds1 = ds1.map(input_columns=["img_data"], operations=decode_op, cache=some_cache)
  1573. num_epoch = 4
  1574. iter1 = ds1.create_dict_iterator(num_epochs=num_epoch, output_numpy=True)
  1575. epoch_count = 0
  1576. for _ in range(num_epoch):
  1577. assert sum([1 for _ in iter1]) == 5
  1578. epoch_count += 1
  1579. assert epoch_count == num_epoch
  1580. logger.info("test_cache_map_mindrecord2 Ended.\n")
  1581. @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
  1582. def test_cache_map_python_sampler1():
  1583. """
  1584. Test using a python sampler, and cache after leaf
  1585. Repeat
  1586. |
  1587. Map(decode)
  1588. |
  1589. cache
  1590. |
  1591. ImageFolder
  1592. """
  1593. logger.info("Test cache map python sampler1")
  1594. if "SESSION_ID" in os.environ:
  1595. session_id = int(os.environ['SESSION_ID'])
  1596. else:
  1597. raise RuntimeError("Testcase requires SESSION_ID environment variable")
  1598. some_cache = ds.DatasetCache(session_id=session_id, size=0)
  1599. # This DATA_DIR only has 2 images in it
  1600. ds1 = ds.ImageFolderDataset(dataset_dir=DATA_DIR, sampler=ReverseSampler(), cache=some_cache)
  1601. decode_op = c_vision.Decode()
  1602. ds1 = ds1.map(input_columns=["image"], operations=decode_op)
  1603. ds1 = ds1.repeat(4)
  1604. num_iter = 0
  1605. for _ in ds1.create_dict_iterator():
  1606. num_iter += 1
  1607. logger.info("Number of data in ds1: {} ".format(num_iter))
  1608. assert num_iter == 8
  1609. logger.info("test_cache_map_python_sampler1 Ended.\n")
  1610. @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
  1611. def test_cache_map_python_sampler2():
  1612. """
  1613. Test using a python sampler, and cache after map
  1614. Repeat
  1615. |
  1616. cache
  1617. |
  1618. Map(decode)
  1619. |
  1620. ImageFolder
  1621. """
  1622. logger.info("Test cache map python sampler2")
  1623. if "SESSION_ID" in os.environ:
  1624. session_id = int(os.environ['SESSION_ID'])
  1625. else:
  1626. raise RuntimeError("Testcase requires SESSION_ID environment variable")
  1627. some_cache = ds.DatasetCache(session_id=session_id, size=0)
  1628. # This DATA_DIR only has 2 images in it
  1629. ds1 = ds.ImageFolderDataset(dataset_dir=DATA_DIR, sampler=ReverseSampler())
  1630. decode_op = c_vision.Decode()
  1631. ds1 = ds1.map(input_columns=["image"], operations=decode_op, cache=some_cache)
  1632. ds1 = ds1.repeat(4)
  1633. num_iter = 0
  1634. for _ in ds1.create_dict_iterator():
  1635. num_iter += 1
  1636. logger.info("Number of data in ds1: {} ".format(num_iter))
  1637. assert num_iter == 8
  1638. logger.info("test_cache_map_python_sampler2 Ended.\n")
  1639. @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
  1640. def test_cache_map_nested_repeat():
  1641. """
  1642. Test cache on pipeline with nested repeat ops
  1643. Repeat
  1644. |
  1645. Map(decode)
  1646. |
  1647. Repeat
  1648. |
  1649. Cache
  1650. |
  1651. ImageFolder
  1652. """
  1653. logger.info("Test cache map nested repeat")
  1654. if "SESSION_ID" in os.environ:
  1655. session_id = int(os.environ['SESSION_ID'])
  1656. else:
  1657. raise RuntimeError("Testcase requires SESSION_ID environment variable")
  1658. some_cache = ds.DatasetCache(session_id=session_id, size=0)
  1659. # This DATA_DIR only has 2 images in it
  1660. ds1 = ds.ImageFolderDataset(dataset_dir=DATA_DIR, cache=some_cache)
  1661. decode_op = c_vision.Decode()
  1662. ds1 = ds1.repeat(4)
  1663. ds1 = ds1.map(operations=decode_op, input_columns=["image"])
  1664. ds1 = ds1.repeat(2)
  1665. num_iter = 0
  1666. for _ in ds1.create_dict_iterator(num_epochs=1):
  1667. logger.info("get data from dataset")
  1668. num_iter += 1
  1669. logger.info("Number of data in ds1: {} ".format(num_iter))
  1670. assert num_iter == 16
  1671. logger.info('test_cache_map_nested_repeat Ended.\n')
  1672. @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
  1673. def test_cache_map_interrupt_and_rerun():
  1674. """
  1675. Test interrupt a running pipeline and then re-use the same cache to run another pipeline
  1676. cache
  1677. |
  1678. Cifar10
  1679. """
  1680. logger.info("Test cache map interrupt and rerun")
  1681. if "SESSION_ID" in os.environ:
  1682. session_id = int(os.environ['SESSION_ID'])
  1683. else:
  1684. raise RuntimeError("Testcase requires SESSION_ID environment variable")
  1685. some_cache = ds.DatasetCache(session_id=session_id, size=0)
  1686. ds1 = ds.Cifar10Dataset(CIFAR10_DATA_DIR, cache=some_cache)
  1687. iter1 = ds1.create_dict_iterator()
  1688. num_iter = 0
  1689. with pytest.raises(AttributeError) as e:
  1690. for _ in iter1:
  1691. num_iter += 1
  1692. if num_iter == 10:
  1693. iter1.stop()
  1694. assert "'DictIterator' object has no attribute '_runtime_context'" in str(e.value)
  1695. num_epoch = 2
  1696. iter2 = ds1.create_dict_iterator(num_epochs=num_epoch)
  1697. epoch_count = 0
  1698. for _ in range(num_epoch):
  1699. num_iter = 0
  1700. for _ in iter2:
  1701. num_iter += 1
  1702. logger.info("Number of data in ds1: {} ".format(num_iter))
  1703. assert num_iter == 10000
  1704. epoch_count += 1
  1705. cache_stat = some_cache.GetStat()
  1706. assert cache_stat.num_mem_cached == 10000
  1707. logger.info("test_cache_map_interrupt_and_rerun Ended.\n")
  1708. if __name__ == '__main__':
  1709. # This is just a list of tests, don't try to run these tests with 'python test_cache_map.py'
  1710. # since cache server is required to be brought up first
  1711. test_cache_map_basic1()
  1712. test_cache_map_basic2()
  1713. test_cache_map_basic3()
  1714. test_cache_map_basic4()
  1715. test_cache_map_basic5()
  1716. test_cache_map_failure1()
  1717. test_cache_map_failure2()
  1718. test_cache_map_failure3()
  1719. test_cache_map_failure4()
  1720. test_cache_map_failure5()
  1721. test_cache_map_failure7()
  1722. test_cache_map_failure8()
  1723. test_cache_map_failure9()
  1724. test_cache_map_failure10()
  1725. test_cache_map_failure11()
  1726. test_cache_map_split1()
  1727. test_cache_map_split2()
  1728. test_cache_map_parameter_check()
  1729. test_cache_map_running_twice1()
  1730. test_cache_map_running_twice2()
  1731. test_cache_map_extra_small_size1()
  1732. test_cache_map_extra_small_size2()
  1733. test_cache_map_no_image()
  1734. test_cache_map_parallel_pipeline1(shard=0)
  1735. test_cache_map_parallel_pipeline2(shard=1)
  1736. test_cache_map_parallel_workers()
  1737. test_cache_map_server_workers_1()
  1738. test_cache_map_server_workers_100()
  1739. test_cache_map_num_connections_1()
  1740. test_cache_map_num_connections_100()
  1741. test_cache_map_prefetch_size_1()
  1742. test_cache_map_prefetch_size_100()
  1743. test_cache_map_to_device()
  1744. test_cache_map_epoch_ctrl1()
  1745. test_cache_map_epoch_ctrl2()
  1746. test_cache_map_epoch_ctrl3()
  1747. test_cache_map_coco1()
  1748. test_cache_map_coco2()
  1749. test_cache_map_mnist1()
  1750. test_cache_map_mnist2()
  1751. test_cache_map_celeba1()
  1752. test_cache_map_celeba2()
  1753. test_cache_map_manifest1()
  1754. test_cache_map_manifest2()
  1755. test_cache_map_cifar1()
  1756. test_cache_map_cifar2()
  1757. test_cache_map_cifar3()
  1758. test_cache_map_cifar4()
  1759. test_cache_map_voc1()
  1760. test_cache_map_voc2()
  1761. test_cache_map_mindrecord1()
  1762. test_cache_map_mindrecord2()
  1763. test_cache_map_python_sampler1()
  1764. test_cache_map_python_sampler2()
  1765. test_cache_map_nested_repeat()