You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_cache_map.py 57 kB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538153915401541154215431544154515461547154815491550155115521553155415551556155715581559156015611562156315641565156615671568156915701571157215731574157515761577157815791580158115821583158415851586158715881589159015911592159315941595159615971598159916001601160216031604160516061607160816091610161116121613161416151616161716181619162016211622162316241625162616271628162916301631163216331634163516361637163816391640164116421643164416451646164716481649165016511652165316541655165616571658165916601661166216631664166516661667166816691670167116721673167416751676167716781679168016811682168316841685168616871688168916901691169216931694169516961697169816991700170117021703170417051706170717081709171017111712171317141715171617171718171917201721172217231724172517261727172817291730173117321733173417351736173717381739174017411742174317441745174617471748174917501751175217531754175517561757175817591760176117621763176417651766176717681769177017711772177317741775177617771778177917801781178217831784178517861787178817891790179117921793179417951796179717981799180018011802180318041805180618071808180918101811181218131814181518161817181818191820182118221823182418251826182718281829183018311832
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. """
  16. Testing cache operator with mappable datasets
  17. """
  18. import os
  19. import pytest
  20. import numpy as np
  21. import mindspore.dataset as ds
  22. import mindspore.dataset.vision.c_transforms as c_vision
  23. from mindspore import log as logger
  24. from util import save_and_check_md5
  25. DATA_DIR = "../data/dataset/testImageNetData/train/"
  26. COCO_DATA_DIR = "../data/dataset/testCOCO/train/"
  27. COCO_ANNOTATION_FILE = "../data/dataset/testCOCO/annotations/train.json"
  28. NO_IMAGE_DIR = "../data/dataset/testRandomData/"
  29. MNIST_DATA_DIR = "../data/dataset/testMnistData/"
  30. CELEBA_DATA_DIR = "../data/dataset/testCelebAData/"
  31. VOC_DATA_DIR = "../data/dataset/testVOC2012/"
  32. MANIFEST_DATA_FILE = "../data/dataset/testManifestData/test.manifest"
  33. CIFAR10_DATA_DIR = "../data/dataset/testCifar10Data/"
  34. CIFAR100_DATA_DIR = "../data/dataset/testCifar100Data/"
  35. MIND_RECORD_DATA_DIR = "../data/mindrecord/testTwoImageData/twobytes.mindrecord"
  36. GENERATE_GOLDEN = False
  37. @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
  38. def test_cache_map_basic1():
  39. """
  40. Test mappable leaf with cache op right over the leaf
  41. Repeat
  42. |
  43. Map(decode)
  44. |
  45. Cache
  46. |
  47. ImageFolder
  48. """
  49. logger.info("Test cache map basic 1")
  50. if "SESSION_ID" in os.environ:
  51. session_id = int(os.environ['SESSION_ID'])
  52. else:
  53. session_id = 1
  54. some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True)
  55. # This DATA_DIR only has 2 images in it
  56. ds1 = ds.ImageFolderDataset(dataset_dir=DATA_DIR, cache=some_cache)
  57. decode_op = c_vision.Decode()
  58. ds1 = ds1.map(operations=decode_op, input_columns=["image"])
  59. ds1 = ds1.repeat(4)
  60. filename = "cache_map_01_result.npz"
  61. save_and_check_md5(ds1, filename, generate_golden=GENERATE_GOLDEN)
  62. logger.info("test_cache_map_basic1 Ended.\n")
  63. @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
  64. def test_cache_map_basic2():
  65. """
  66. Test mappable leaf with the cache op later in the tree above the map(decode)
  67. Repeat
  68. |
  69. Cache
  70. |
  71. Map(decode)
  72. |
  73. ImageFolder
  74. """
  75. logger.info("Test cache map basic 2")
  76. if "SESSION_ID" in os.environ:
  77. session_id = int(os.environ['SESSION_ID'])
  78. else:
  79. raise RuntimeError("Testcase requires SESSION_ID environment variable")
  80. some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True)
  81. # This DATA_DIR only has 2 images in it
  82. ds1 = ds.ImageFolderDataset(dataset_dir=DATA_DIR)
  83. decode_op = c_vision.Decode()
  84. ds1 = ds1.map(operations=decode_op, input_columns=["image"], cache=some_cache)
  85. ds1 = ds1.repeat(4)
  86. filename = "cache_map_02_result.npz"
  87. save_and_check_md5(ds1, filename, generate_golden=GENERATE_GOLDEN)
  88. logger.info("test_cache_map_basic2 Ended.\n")
  89. @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
  90. def test_cache_map_basic3():
  91. """
  92. Test different rows result in core dump
  93. """
  94. logger.info("Test cache basic 3")
  95. if "SESSION_ID" in os.environ:
  96. session_id = int(os.environ['SESSION_ID'])
  97. else:
  98. raise RuntimeError("Testcase requires SESSION_ID environment variable")
  99. some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True)
  100. # This DATA_DIR only has 2 images in it
  101. ds1 = ds.ImageFolderDataset(dataset_dir=DATA_DIR, cache=some_cache)
  102. decode_op = c_vision.Decode()
  103. ds1 = ds1.repeat(4)
  104. ds1 = ds1.map(operations=decode_op, input_columns=["image"])
  105. logger.info("ds1.dataset_size is ", ds1.get_dataset_size())
  106. shape = ds1.output_shapes()
  107. logger.info(shape)
  108. num_iter = 0
  109. for _ in ds1.create_dict_iterator(num_epochs=1):
  110. logger.info("get data from dataset")
  111. num_iter += 1
  112. logger.info("Number of data in ds1: {} ".format(num_iter))
  113. assert num_iter == 8
  114. logger.info('test_cache_basic3 Ended.\n')
  115. @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
  116. def test_cache_map_basic4():
  117. """
  118. Test Map with non-deterministic TensorOps above cache
  119. repeat
  120. |
  121. Map(decode, randomCrop)
  122. |
  123. Cache
  124. |
  125. ImageFolder
  126. """
  127. logger.info("Test cache basic 4")
  128. if "SESSION_ID" in os.environ:
  129. session_id = int(os.environ['SESSION_ID'])
  130. else:
  131. raise RuntimeError("Testcase requires SESSION_ID environment variable")
  132. some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True)
  133. # This DATA_DIR only has 2 images in it
  134. data = ds.ImageFolderDataset(dataset_dir=DATA_DIR, cache=some_cache)
  135. random_crop_op = c_vision.RandomCrop([512, 512], [200, 200, 200, 200])
  136. decode_op = c_vision.Decode()
  137. data = data.map(input_columns=["image"], operations=decode_op)
  138. data = data.map(input_columns=["image"], operations=random_crop_op)
  139. data = data.repeat(4)
  140. num_iter = 0
  141. for _ in data.create_dict_iterator():
  142. num_iter += 1
  143. logger.info("Number of data in ds1: {} ".format(num_iter))
  144. assert num_iter == 8
  145. logger.info('test_cache_basic4 Ended.\n')
  146. @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
  147. def test_cache_map_basic5():
  148. """
  149. Test cache as root node
  150. cache
  151. |
  152. ImageFolder
  153. """
  154. logger.info("Test cache basic 5")
  155. if "SESSION_ID" in os.environ:
  156. session_id = int(os.environ['SESSION_ID'])
  157. else:
  158. raise RuntimeError("Testcase requires SESSION_ID environment variable")
  159. some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True)
  160. # This DATA_DIR only has 2 images in it
  161. ds1 = ds.ImageFolderDataset(dataset_dir=DATA_DIR, cache=some_cache)
  162. num_iter = 0
  163. for _ in ds1.create_dict_iterator(num_epochs=1):
  164. logger.info("get data from dataset")
  165. num_iter += 1
  166. logger.info("Number of data in ds1: {} ".format(num_iter))
  167. assert num_iter == 2
  168. logger.info('test_cache_basic5 Ended.\n')
  169. @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
  170. def test_cache_map_failure1():
  171. """
  172. Test nested cache (failure)
  173. Repeat
  174. |
  175. Cache
  176. |
  177. Map(decode)
  178. |
  179. Cache
  180. |
  181. ImageFolder
  182. """
  183. logger.info("Test cache failure 1")
  184. if "SESSION_ID" in os.environ:
  185. session_id = int(os.environ['SESSION_ID'])
  186. else:
  187. raise RuntimeError("Testcase requires SESSION_ID environment variable")
  188. some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True)
  189. # This DATA_DIR only has 2 images in it
  190. ds1 = ds.ImageFolderDataset(dataset_dir=DATA_DIR, cache=some_cache)
  191. decode_op = c_vision.Decode()
  192. ds1 = ds1.map(operations=decode_op, input_columns=["image"], cache=some_cache)
  193. ds1 = ds1.repeat(4)
  194. with pytest.raises(RuntimeError) as e:
  195. num_iter = 0
  196. for _ in ds1.create_dict_iterator(num_epochs=1):
  197. num_iter += 1
  198. assert "Nested cache operations is not supported!" in str(e.value)
  199. assert num_iter == 0
  200. logger.info('test_cache_failure1 Ended.\n')
  201. @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
  202. def test_cache_map_failure2():
  203. """
  204. Test zip under cache (failure)
  205. repeat
  206. |
  207. Cache
  208. |
  209. Map(decode)
  210. |
  211. Zip
  212. | |
  213. ImageFolder ImageFolder
  214. """
  215. logger.info("Test cache failure 2")
  216. if "SESSION_ID" in os.environ:
  217. session_id = int(os.environ['SESSION_ID'])
  218. else:
  219. raise RuntimeError("Testcase requires SESSION_ID environment variable")
  220. some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True)
  221. # This DATA_DIR only has 2 images in it
  222. ds1 = ds.ImageFolderDataset(dataset_dir=DATA_DIR)
  223. ds2 = ds.ImageFolderDataset(dataset_dir=DATA_DIR)
  224. dsz = ds.zip((ds1, ds2))
  225. decode_op = c_vision.Decode()
  226. dsz = dsz.map(input_columns=["image"], operations=decode_op, cache=some_cache)
  227. dsz = dsz.repeat(4)
  228. with pytest.raises(RuntimeError) as e:
  229. num_iter = 0
  230. for _ in dsz.create_dict_iterator():
  231. num_iter += 1
  232. assert "ZipOp is currently not supported as a descendant operator under a cache" in str(e.value)
  233. assert num_iter == 0
  234. logger.info('test_cache_failure2 Ended.\n')
  235. @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
  236. def test_cache_map_failure3():
  237. """
  238. Test batch under cache (failure)
  239. repeat
  240. |
  241. Cache
  242. |
  243. Map(resize)
  244. |
  245. Batch
  246. |
  247. ImageFolder
  248. """
  249. logger.info("Test cache failure 3")
  250. if "SESSION_ID" in os.environ:
  251. session_id = int(os.environ['SESSION_ID'])
  252. else:
  253. raise RuntimeError("Testcase requires SESSION_ID environment variable")
  254. some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True)
  255. # This DATA_DIR only has 2 images in it
  256. ds1 = ds.ImageFolderDataset(dataset_dir=DATA_DIR)
  257. ds1 = ds1.batch(2)
  258. resize_op = c_vision.Resize((224, 224))
  259. ds1 = ds1.map(input_columns=["image"], operations=resize_op, cache=some_cache)
  260. ds1 = ds1.repeat(4)
  261. with pytest.raises(RuntimeError) as e:
  262. num_iter = 0
  263. for _ in ds1.create_dict_iterator():
  264. num_iter += 1
  265. assert "Unexpected error. Expect positive row id: -1" in str(e.value)
  266. assert num_iter == 0
  267. logger.info('test_cache_failure3 Ended.\n')
  268. @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
  269. def test_cache_map_failure4():
  270. """
  271. Test filter under cache (failure)
  272. repeat
  273. |
  274. Cache
  275. |
  276. Map(decode)
  277. |
  278. Filter
  279. |
  280. ImageFolder
  281. """
  282. logger.info("Test cache failure 4")
  283. if "SESSION_ID" in os.environ:
  284. session_id = int(os.environ['SESSION_ID'])
  285. else:
  286. raise RuntimeError("Testcase requires SESSION_ID environment variable")
  287. some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True)
  288. # This DATA_DIR only has 2 images in it
  289. ds1 = ds.ImageFolderDataset(dataset_dir=DATA_DIR)
  290. ds1 = ds1.filter(predicate=lambda data: data < 11, input_columns=["label"])
  291. decode_op = c_vision.Decode()
  292. ds1 = ds1.map(input_columns=["image"], operations=decode_op, cache=some_cache)
  293. ds1 = ds1.repeat(4)
  294. with pytest.raises(RuntimeError) as e:
  295. num_iter = 0
  296. for _ in ds1.create_dict_iterator():
  297. num_iter += 1
  298. assert "FilterOp is currently not supported as a descendant operator under a cache" in str(e.value)
  299. assert num_iter == 0
  300. logger.info('test_cache_failure4 Ended.\n')
  301. @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
  302. def test_cache_map_failure5():
  303. """
  304. Test Map with non-deterministic TensorOps under cache (failure)
  305. repeat
  306. |
  307. Cache
  308. |
  309. Map(decode, randomCrop)
  310. |
  311. ImageFolder
  312. """
  313. logger.info("Test cache failure 5")
  314. if "SESSION_ID" in os.environ:
  315. session_id = int(os.environ['SESSION_ID'])
  316. else:
  317. raise RuntimeError("Testcase requires SESSION_ID environment variable")
  318. some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True)
  319. # This DATA_DIR only has 2 images in it
  320. data = ds.ImageFolderDataset(dataset_dir=DATA_DIR)
  321. random_crop_op = c_vision.RandomCrop([512, 512], [200, 200, 200, 200])
  322. decode_op = c_vision.Decode()
  323. data = data.map(input_columns=["image"], operations=decode_op)
  324. data = data.map(input_columns=["image"], operations=random_crop_op, cache=some_cache)
  325. data = data.repeat(4)
  326. with pytest.raises(RuntimeError) as e:
  327. num_iter = 0
  328. for _ in data.create_dict_iterator():
  329. num_iter += 1
  330. assert "MapOp with non-deterministic TensorOps is currently not supported as a descendant of cache" in str(e.value)
  331. assert num_iter == 0
  332. logger.info('test_cache_failure5 Ended.\n')
  333. @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
  334. def test_cache_map_failure6():
  335. """
  336. Test no-cache-supporting MindRecord leaf with Map under cache (failure)
  337. repeat
  338. |
  339. Cache
  340. |
  341. Map(resize)
  342. |
  343. MindRecord
  344. """
  345. logger.info("Test cache failure 6")
  346. if "SESSION_ID" in os.environ:
  347. session_id = int(os.environ['SESSION_ID'])
  348. else:
  349. raise RuntimeError("Testcase requires SESSION_ID environment variable")
  350. some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True)
  351. columns_list = ["id", "file_name", "label_name", "img_data", "label_data"]
  352. num_readers = 1
  353. # The dataset has 5 records
  354. data = ds.MindDataset(MIND_RECORD_DATA_DIR, columns_list, num_readers)
  355. resize_op = c_vision.Resize((224, 224))
  356. data = data.map(input_columns=["img_data"], operations=resize_op, cache=some_cache)
  357. data = data.repeat(4)
  358. with pytest.raises(RuntimeError) as e:
  359. num_iter = 0
  360. for _ in data.create_dict_iterator():
  361. num_iter += 1
  362. assert "There is currently no support for MindRecordOp under cache" in str(e.value)
  363. assert num_iter == 0
  364. logger.info('test_cache_failure6 Ended.\n')
  365. @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
  366. def test_cache_map_failure7():
  367. """
  368. Test no-cache-supporting Generator leaf with Map under cache (failure)
  369. repeat
  370. |
  371. Cache
  372. |
  373. Map(lambda x: x)
  374. |
  375. Generator
  376. """
  377. def generator_1d():
  378. for i in range(64):
  379. yield (np.array(i),)
  380. logger.info("Test cache failure 7")
  381. if "SESSION_ID" in os.environ:
  382. session_id = int(os.environ['SESSION_ID'])
  383. else:
  384. raise RuntimeError("Testcase requires SESSION_ID environment variable")
  385. some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True)
  386. data = ds.GeneratorDataset(generator_1d, ["data"])
  387. data = data.map((lambda x: x), ["data"], cache=some_cache)
  388. data = data.repeat(4)
  389. with pytest.raises(RuntimeError) as e:
  390. num_iter = 0
  391. for _ in data.create_dict_iterator():
  392. num_iter += 1
  393. assert "There is currently no support for GeneratorOp under cache" in str(e.value)
  394. assert num_iter == 0
  395. logger.info('test_cache_failure7 Ended.\n')
  396. @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
  397. def test_cache_map_failure8():
  398. """
  399. Test a repeat under mappable cache (failure)
  400. Cache
  401. |
  402. Map(decode)
  403. |
  404. Repeat
  405. |
  406. ImageFolder
  407. """
  408. logger.info("Test cache failure 8")
  409. if "SESSION_ID" in os.environ:
  410. session_id = int(os.environ['SESSION_ID'])
  411. else:
  412. raise RuntimeError("Testcase requires SESSION_ID environment variable")
  413. some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True)
  414. # This DATA_DIR only has 2 images in it
  415. ds1 = ds.ImageFolderDataset(dataset_dir=DATA_DIR)
  416. decode_op = c_vision.Decode()
  417. ds1 = ds1.repeat(4)
  418. ds1 = ds1.map(operations=decode_op, input_columns=["image"], cache=some_cache)
  419. with pytest.raises(RuntimeError) as e:
  420. num_iter = 0
  421. for _ in ds1.create_dict_iterator(num_epochs=1):
  422. num_iter += 1
  423. assert "Repeat is not supported as a descendant operator under a mappable cache" in str(e.value)
  424. assert num_iter == 0
  425. logger.info('test_cache_failure8 Ended.\n')
  426. @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
  427. def test_cache_map_parameter_check():
  428. """
  429. Test illegal parameters for DatasetCache
  430. """
  431. logger.info("Test cache map parameter check")
  432. with pytest.raises(ValueError) as info:
  433. ds.DatasetCache(session_id=-1, size=0, spilling=True)
  434. assert "Input is not within the required interval" in str(info.value)
  435. with pytest.raises(TypeError) as info:
  436. ds.DatasetCache(session_id="1", size=0, spilling=True)
  437. assert "Argument session_id with value 1 is not of type (<class 'int'>,)" in str(info.value)
  438. with pytest.raises(TypeError) as info:
  439. ds.DatasetCache(session_id=None, size=0, spilling=True)
  440. assert "Argument session_id with value None is not of type (<class 'int'>,)" in str(info.value)
  441. with pytest.raises(ValueError) as info:
  442. ds.DatasetCache(session_id=1, size=-1, spilling=True)
  443. assert "Input is not within the required interval" in str(info.value)
  444. with pytest.raises(TypeError) as info:
  445. ds.DatasetCache(session_id=1, size="1", spilling=True)
  446. assert "Argument size with value 1 is not of type (<class 'int'>,)" in str(info.value)
  447. with pytest.raises(TypeError) as info:
  448. ds.DatasetCache(session_id=1, size=None, spilling=True)
  449. assert "Argument size with value None is not of type (<class 'int'>,)" in str(info.value)
  450. with pytest.raises(TypeError) as info:
  451. ds.DatasetCache(session_id=1, size=0, spilling="illegal")
  452. assert "Argument spilling with value illegal is not of type (<class 'bool'>,)" in str(info.value)
  453. with pytest.raises(RuntimeError) as err:
  454. ds.DatasetCache(session_id=1, size=0, spilling=True, hostname="illegal")
  455. assert "Unexpected error. now cache client has to be on the same host with cache server" in str(err.value)
  456. with pytest.raises(RuntimeError) as err:
  457. ds.DatasetCache(session_id=1, size=0, spilling=True, hostname="127.0.0.2")
  458. assert "Unexpected error. now cache client has to be on the same host with cache server" in str(err.value)
  459. with pytest.raises(TypeError) as info:
  460. ds.DatasetCache(session_id=1, size=0, spilling=True, port="illegal")
  461. assert "incompatible constructor arguments" in str(info.value)
  462. with pytest.raises(TypeError) as info:
  463. ds.DatasetCache(session_id=1, size=0, spilling=True, port="50052")
  464. assert "incompatible constructor arguments" in str(info.value)
  465. with pytest.raises(RuntimeError) as err:
  466. ds.DatasetCache(session_id=1, size=0, spilling=True, port=0)
  467. assert "Unexpected error. port must be positive" in str(err.value)
  468. with pytest.raises(RuntimeError) as err:
  469. ds.DatasetCache(session_id=1, size=0, spilling=True, port=65536)
  470. assert "Unexpected error. illegal port number" in str(err.value)
  471. with pytest.raises(TypeError) as err:
  472. ds.ImageFolderDataset(dataset_dir=DATA_DIR, cache=True)
  473. assert "Argument cache with value True is not of type" in str(err.value)
  474. logger.info("test_cache_map_parameter_check Ended.\n")
  475. @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
  476. def test_cache_map_running_twice1():
  477. """
  478. Executing the same pipeline for twice (from python), with cache injected after map
  479. Repeat
  480. |
  481. Cache
  482. |
  483. Map(decode)
  484. |
  485. ImageFolder
  486. """
  487. logger.info("Test cache map running twice 1")
  488. if "SESSION_ID" in os.environ:
  489. session_id = int(os.environ['SESSION_ID'])
  490. else:
  491. raise RuntimeError("Testcase requires SESSION_ID environment variable")
  492. some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True)
  493. # This DATA_DIR only has 2 images in it
  494. ds1 = ds.ImageFolderDataset(dataset_dir=DATA_DIR)
  495. decode_op = c_vision.Decode()
  496. ds1 = ds1.map(input_columns=["image"], operations=decode_op, cache=some_cache)
  497. ds1 = ds1.repeat(4)
  498. num_iter = 0
  499. for _ in ds1.create_dict_iterator():
  500. num_iter += 1
  501. logger.info("Number of data in ds1: {} ".format(num_iter))
  502. assert num_iter == 8
  503. num_iter = 0
  504. for _ in ds1.create_dict_iterator():
  505. num_iter += 1
  506. logger.info("Number of data in ds1: {} ".format(num_iter))
  507. assert num_iter == 8
  508. logger.info("test_cache_map_running_twice1 Ended.\n")
  509. @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
  510. def test_cache_map_running_twice2():
  511. """
  512. Executing the same pipeline for twice (from shell), with cache injected after leaf
  513. Repeat
  514. |
  515. Map(decode)
  516. |
  517. Cache
  518. |
  519. ImageFolder
  520. """
  521. logger.info("Test cache map running twice 2")
  522. if "SESSION_ID" in os.environ:
  523. session_id = int(os.environ['SESSION_ID'])
  524. else:
  525. raise RuntimeError("Testcase requires SESSION_ID environment variable")
  526. some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True)
  527. # This DATA_DIR only has 2 images in it
  528. ds1 = ds.ImageFolderDataset(dataset_dir=DATA_DIR, cache=some_cache)
  529. decode_op = c_vision.Decode()
  530. ds1 = ds1.map(input_columns=["image"], operations=decode_op)
  531. ds1 = ds1.repeat(4)
  532. num_iter = 0
  533. for _ in ds1.create_dict_iterator():
  534. num_iter += 1
  535. logger.info("Number of data in ds1: {} ".format(num_iter))
  536. assert num_iter == 8
  537. logger.info("test_cache_map_running_twice2 Ended.\n")
  538. @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
  539. def test_cache_map_extra_small_size1():
  540. """
  541. Test running pipeline with cache of extra small size and spilling true
  542. Repeat
  543. |
  544. Map(decode)
  545. |
  546. Cache
  547. |
  548. ImageFolder
  549. """
  550. logger.info("Test cache map extra small size 1")
  551. if "SESSION_ID" in os.environ:
  552. session_id = int(os.environ['SESSION_ID'])
  553. else:
  554. raise RuntimeError("Testcase requires SESSION_ID environment variable")
  555. some_cache = ds.DatasetCache(session_id=session_id, size=1, spilling=True)
  556. # This DATA_DIR only has 2 images in it
  557. ds1 = ds.ImageFolderDataset(dataset_dir=DATA_DIR, cache=some_cache)
  558. decode_op = c_vision.Decode()
  559. ds1 = ds1.map(input_columns=["image"], operations=decode_op)
  560. ds1 = ds1.repeat(4)
  561. num_iter = 0
  562. for _ in ds1.create_dict_iterator():
  563. num_iter += 1
  564. logger.info("Number of data in ds1: {} ".format(num_iter))
  565. assert num_iter == 8
  566. logger.info("test_cache_map_extra_small_size1 Ended.\n")
  567. @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
  568. def test_cache_map_extra_small_size2():
  569. """
  570. Test running pipeline with cache of extra small size and spilling false
  571. Repeat
  572. |
  573. Cache
  574. |
  575. Map(decode)
  576. |
  577. ImageFolder
  578. """
  579. logger.info("Test cache map extra small size 2")
  580. if "SESSION_ID" in os.environ:
  581. session_id = int(os.environ['SESSION_ID'])
  582. else:
  583. raise RuntimeError("Testcase requires SESSION_ID environment variable")
  584. some_cache = ds.DatasetCache(session_id=session_id, size=1, spilling=False)
  585. # This DATA_DIR only has 2 images in it
  586. ds1 = ds.ImageFolderDataset(dataset_dir=DATA_DIR)
  587. decode_op = c_vision.Decode()
  588. ds1 = ds1.map(input_columns=["image"], operations=decode_op, cache=some_cache)
  589. ds1 = ds1.repeat(4)
  590. num_iter = 0
  591. for _ in ds1.create_dict_iterator():
  592. num_iter += 1
  593. logger.info("Number of data in ds1: {} ".format(num_iter))
  594. assert num_iter == 8
  595. logger.info("test_cache_map_extra_small_size2 Ended.\n")
  596. @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
  597. def test_cache_map_no_image():
  598. """
  599. Test cache with no dataset existing in the path
  600. Repeat
  601. |
  602. Map(decode)
  603. |
  604. Cache
  605. |
  606. ImageFolder
  607. """
  608. logger.info("Test cache map no image")
  609. if "SESSION_ID" in os.environ:
  610. session_id = int(os.environ['SESSION_ID'])
  611. else:
  612. raise RuntimeError("Testcase requires SESSION_ID environment variable")
  613. some_cache = ds.DatasetCache(session_id=session_id, size=1, spilling=False)
  614. # This DATA_DIR only has 2 images in it
  615. ds1 = ds.ImageFolderDataset(dataset_dir=NO_IMAGE_DIR, cache=some_cache)
  616. decode_op = c_vision.Decode()
  617. ds1 = ds1.map(input_columns=["image"], operations=decode_op)
  618. ds1 = ds1.repeat(4)
  619. with pytest.raises(RuntimeError):
  620. num_iter = 0
  621. for _ in ds1.create_dict_iterator():
  622. num_iter += 1
  623. assert num_iter == 0
  624. logger.info("test_cache_map_no_image Ended.\n")
  625. @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
  626. def test_cache_map_parallel_pipeline1(shard):
  627. """
  628. Test running two parallel pipelines (sharing cache) with cache injected after leaf op
  629. Repeat
  630. |
  631. Map(decode)
  632. |
  633. Cache
  634. |
  635. ImageFolder
  636. """
  637. logger.info("Test cache map parallel pipeline 1")
  638. if "SESSION_ID" in os.environ:
  639. session_id = int(os.environ['SESSION_ID'])
  640. else:
  641. raise RuntimeError("Testcase requires SESSION_ID environment variable")
  642. some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True)
  643. # This DATA_DIR only has 2 images in it
  644. ds1 = ds.ImageFolderDataset(dataset_dir=DATA_DIR, num_shards=2, shard_id=int(shard), cache=some_cache)
  645. decode_op = c_vision.Decode()
  646. ds1 = ds1.map(input_columns=["image"], operations=decode_op)
  647. ds1 = ds1.repeat(4)
  648. num_iter = 0
  649. for _ in ds1.create_dict_iterator():
  650. num_iter += 1
  651. logger.info("Number of data in ds1: {} ".format(num_iter))
  652. assert num_iter == 4
  653. logger.info("test_cache_map_parallel_pipeline1 Ended.\n")
  654. @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
  655. def test_cache_map_parallel_pipeline2(shard):
  656. """
  657. Test running two parallel pipelines (sharing cache) with cache injected after map op
  658. Repeat
  659. |
  660. Cache
  661. |
  662. Map(decode)
  663. |
  664. ImageFolder
  665. """
  666. logger.info("Test cache map parallel pipeline 2")
  667. if "SESSION_ID" in os.environ:
  668. session_id = int(os.environ['SESSION_ID'])
  669. else:
  670. raise RuntimeError("Testcase requires SESSION_ID environment variable")
  671. some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True)
  672. # This DATA_DIR only has 2 images in it
  673. ds1 = ds.ImageFolderDataset(dataset_dir=DATA_DIR, num_shards=2, shard_id=int(shard))
  674. decode_op = c_vision.Decode()
  675. ds1 = ds1.map(input_columns=["image"], operations=decode_op, cache=some_cache)
  676. ds1 = ds1.repeat(4)
  677. num_iter = 0
  678. for _ in ds1.create_dict_iterator():
  679. num_iter += 1
  680. logger.info("Number of data in ds1: {} ".format(num_iter))
  681. assert num_iter == 4
  682. logger.info("test_cache_map_parallel_pipeline2 Ended.\n")
  683. @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
  684. def test_cache_map_parallel_workers():
  685. """
  686. Test cache with num_parallel_workers > 1 set for map op and leaf op
  687. Repeat
  688. |
  689. cache
  690. |
  691. Map(decode)
  692. |
  693. ImageFolder
  694. """
  695. logger.info("Test cache map parallel workers")
  696. if "SESSION_ID" in os.environ:
  697. session_id = int(os.environ['SESSION_ID'])
  698. else:
  699. raise RuntimeError("Testcase requires SESSION_ID environment variable")
  700. some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True)
  701. # This DATA_DIR only has 2 images in it
  702. ds1 = ds.ImageFolderDataset(dataset_dir=DATA_DIR, num_parallel_workers=4)
  703. decode_op = c_vision.Decode()
  704. ds1 = ds1.map(input_columns=["image"], operations=decode_op, num_parallel_workers=4, cache=some_cache)
  705. ds1 = ds1.repeat(4)
  706. num_iter = 0
  707. for _ in ds1.create_dict_iterator():
  708. num_iter += 1
  709. logger.info("Number of data in ds1: {} ".format(num_iter))
  710. assert num_iter == 8
  711. logger.info("test_cache_map_parallel_workers Ended.\n")
  712. @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
  713. def test_cache_map_server_workers_1():
  714. """
  715. start cache server with --workers 1 and then test cache function
  716. Repeat
  717. |
  718. cache
  719. |
  720. Map(decode)
  721. |
  722. ImageFolder
  723. """
  724. logger.info("Test cache map server workers 1")
  725. if "SESSION_ID" in os.environ:
  726. session_id = int(os.environ['SESSION_ID'])
  727. else:
  728. raise RuntimeError("Testcase requires SESSION_ID environment variable")
  729. some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True)
  730. # This DATA_DIR only has 2 images in it
  731. ds1 = ds.ImageFolderDataset(dataset_dir=DATA_DIR)
  732. decode_op = c_vision.Decode()
  733. ds1 = ds1.map(input_columns=["image"], operations=decode_op, cache=some_cache)
  734. ds1 = ds1.repeat(4)
  735. num_iter = 0
  736. for _ in ds1.create_dict_iterator():
  737. num_iter += 1
  738. logger.info("Number of data in ds1: {} ".format(num_iter))
  739. assert num_iter == 8
  740. logger.info("test_cache_map_server_workers_1 Ended.\n")
  741. @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
  742. def test_cache_map_server_workers_100():
  743. """
  744. start cache server with --workers 100 and then test cache function
  745. Repeat
  746. |
  747. Map(decode)
  748. |
  749. cache
  750. |
  751. ImageFolder
  752. """
  753. logger.info("Test cache map server workers 100")
  754. if "SESSION_ID" in os.environ:
  755. session_id = int(os.environ['SESSION_ID'])
  756. else:
  757. raise RuntimeError("Testcase requires SESSION_ID environment variable")
  758. some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True)
  759. # This DATA_DIR only has 2 images in it
  760. ds1 = ds.ImageFolderDataset(dataset_dir=DATA_DIR, cache=some_cache)
  761. decode_op = c_vision.Decode()
  762. ds1 = ds1.map(input_columns=["image"], operations=decode_op)
  763. ds1 = ds1.repeat(4)
  764. num_iter = 0
  765. for _ in ds1.create_dict_iterator():
  766. num_iter += 1
  767. logger.info("Number of data in ds1: {} ".format(num_iter))
  768. assert num_iter == 8
  769. logger.info("test_cache_map_server_workers_100 Ended.\n")
  770. @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
  771. def test_cache_map_num_connections_1():
  772. """
  773. Test setting num_connections=1 in DatasetCache
  774. Repeat
  775. |
  776. cache
  777. |
  778. Map(decode)
  779. |
  780. ImageFolder
  781. """
  782. logger.info("Test cache map num_connections 1")
  783. if "SESSION_ID" in os.environ:
  784. session_id = int(os.environ['SESSION_ID'])
  785. else:
  786. raise RuntimeError("Testcase requires SESSION_ID environment variable")
  787. some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True, num_connections=1)
  788. # This DATA_DIR only has 2 images in it
  789. ds1 = ds.ImageFolderDataset(dataset_dir=DATA_DIR)
  790. decode_op = c_vision.Decode()
  791. ds1 = ds1.map(input_columns=["image"], operations=decode_op, cache=some_cache)
  792. ds1 = ds1.repeat(4)
  793. num_iter = 0
  794. for _ in ds1.create_dict_iterator():
  795. num_iter += 1
  796. logger.info("Number of data in ds1: {} ".format(num_iter))
  797. assert num_iter == 8
  798. logger.info("test_cache_map_num_connections_1 Ended.\n")
  799. @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
  800. def test_cache_map_num_connections_100():
  801. """
  802. Test setting num_connections=100 in DatasetCache
  803. Repeat
  804. |
  805. Map(decode)
  806. |
  807. cache
  808. |
  809. ImageFolder
  810. """
  811. logger.info("Test cache map num_connections 100")
  812. if "SESSION_ID" in os.environ:
  813. session_id = int(os.environ['SESSION_ID'])
  814. else:
  815. raise RuntimeError("Testcase requires SESSION_ID environment variable")
  816. some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True, num_connections=100)
  817. # This DATA_DIR only has 2 images in it
  818. ds1 = ds.ImageFolderDataset(dataset_dir=DATA_DIR, cache=some_cache)
  819. decode_op = c_vision.Decode()
  820. ds1 = ds1.map(input_columns=["image"], operations=decode_op)
  821. ds1 = ds1.repeat(4)
  822. num_iter = 0
  823. for _ in ds1.create_dict_iterator():
  824. num_iter += 1
  825. logger.info("Number of data in ds1: {} ".format(num_iter))
  826. assert num_iter == 8
  827. logger.info("test_cache_map_num_connections_100 Ended.\n")
  828. @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
  829. def test_cache_map_prefetch_size_1():
  830. """
  831. Test setting prefetch_size=1 in DatasetCache
  832. Repeat
  833. |
  834. cache
  835. |
  836. Map(decode)
  837. |
  838. ImageFolder
  839. """
  840. logger.info("Test cache map prefetch_size 1")
  841. if "SESSION_ID" in os.environ:
  842. session_id = int(os.environ['SESSION_ID'])
  843. else:
  844. raise RuntimeError("Testcase requires SESSION_ID environment variable")
  845. some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True, prefetch_size=1)
  846. # This DATA_DIR only has 2 images in it
  847. ds1 = ds.ImageFolderDataset(dataset_dir=DATA_DIR)
  848. decode_op = c_vision.Decode()
  849. ds1 = ds1.map(input_columns=["image"], operations=decode_op, cache=some_cache)
  850. ds1 = ds1.repeat(4)
  851. num_iter = 0
  852. for _ in ds1.create_dict_iterator():
  853. num_iter += 1
  854. logger.info("Number of data in ds1: {} ".format(num_iter))
  855. assert num_iter == 8
  856. logger.info("test_cache_map_prefetch_size_1 Ended.\n")
  857. @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
  858. def test_cache_map_prefetch_size_100():
  859. """
  860. Test setting prefetch_size=100 in DatasetCache
  861. Repeat
  862. |
  863. Map(decode)
  864. |
  865. cache
  866. |
  867. ImageFolder
  868. """
  869. logger.info("Test cache map prefetch_size 100")
  870. if "SESSION_ID" in os.environ:
  871. session_id = int(os.environ['SESSION_ID'])
  872. else:
  873. raise RuntimeError("Testcase requires SESSION_ID environment variable")
  874. some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True, prefetch_size=100)
  875. # This DATA_DIR only has 2 images in it
  876. ds1 = ds.ImageFolderDataset(dataset_dir=DATA_DIR, cache=some_cache)
  877. decode_op = c_vision.Decode()
  878. ds1 = ds1.map(input_columns=["image"], operations=decode_op)
  879. ds1 = ds1.repeat(4)
  880. num_iter = 0
  881. for _ in ds1.create_dict_iterator():
  882. num_iter += 1
  883. logger.info("Number of data in ds1: {} ".format(num_iter))
  884. assert num_iter == 8
  885. logger.info("test_cache_map_prefetch_size_100 Ended.\n")
  886. @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
  887. def test_cache_map_to_device():
  888. """
  889. Test cache with to_device
  890. DeviceQueue
  891. |
  892. EpochCtrl
  893. |
  894. Repeat
  895. |
  896. Map(decode)
  897. |
  898. cache
  899. |
  900. ImageFolder
  901. """
  902. logger.info("Test cache map to_device")
  903. if "SESSION_ID" in os.environ:
  904. session_id = int(os.environ['SESSION_ID'])
  905. else:
  906. raise RuntimeError("Testcase requires SESSION_ID environment variable")
  907. some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True)
  908. # This DATA_DIR only has 2 images in it
  909. ds1 = ds.ImageFolderDataset(dataset_dir=DATA_DIR)
  910. decode_op = c_vision.Decode()
  911. ds1 = ds1.map(input_columns=["image"], operations=decode_op, cache=some_cache)
  912. ds1 = ds1.repeat(4)
  913. ds1 = ds1.to_device()
  914. ds1.send()
  915. logger.info("test_cache_map_to_device Ended.\n")
  916. @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
  917. def test_cache_map_epoch_ctrl1():
  918. """
  919. Test using two-loops method to run several epochs
  920. Map(decode)
  921. |
  922. cache
  923. |
  924. ImageFolder
  925. """
  926. logger.info("Test cache map epoch ctrl1")
  927. if "SESSION_ID" in os.environ:
  928. session_id = int(os.environ['SESSION_ID'])
  929. else:
  930. raise RuntimeError("Testcase requires SESSION_ID environment variable")
  931. some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True)
  932. # This DATA_DIR only has 2 images in it
  933. ds1 = ds.ImageFolderDataset(dataset_dir=DATA_DIR, cache=some_cache)
  934. decode_op = c_vision.Decode()
  935. ds1 = ds1.map(input_columns=["image"], operations=decode_op)
  936. num_epoch = 5
  937. iter1 = ds1.create_dict_iterator(num_epochs=num_epoch)
  938. epoch_count = 0
  939. for _ in range(num_epoch):
  940. row_count = 0
  941. for _ in iter1:
  942. row_count += 1
  943. logger.info("Number of data in ds1: {} ".format(row_count))
  944. assert row_count == 2
  945. epoch_count += 1
  946. assert epoch_count == num_epoch
  947. logger.info("test_cache_map_epoch_ctrl1 Ended.\n")
  948. @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
  949. def test_cache_map_epoch_ctrl2():
  950. """
  951. Test using two-loops method with infinite epochs
  952. cache
  953. |
  954. Map(decode)
  955. |
  956. ImageFolder
  957. """
  958. logger.info("Test cache map epoch ctrl2")
  959. if "SESSION_ID" in os.environ:
  960. session_id = int(os.environ['SESSION_ID'])
  961. else:
  962. raise RuntimeError("Testcase requires SESSION_ID environment variable")
  963. some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True)
  964. # This DATA_DIR only has 2 images in it
  965. ds1 = ds.ImageFolderDataset(dataset_dir=DATA_DIR)
  966. decode_op = c_vision.Decode()
  967. ds1 = ds1.map(input_columns=["image"], operations=decode_op, cache=some_cache)
  968. num_epoch = 5
  969. # iter1 will always assume there is a next epoch and never shutdown
  970. iter1 = ds1.create_dict_iterator()
  971. epoch_count = 0
  972. for _ in range(num_epoch):
  973. row_count = 0
  974. for _ in iter1:
  975. row_count += 1
  976. logger.info("Number of data in ds1: {} ".format(row_count))
  977. assert row_count == 2
  978. epoch_count += 1
  979. assert epoch_count == num_epoch
  980. # manually stop the iterator
  981. iter1.stop()
  982. logger.info("test_cache_map_epoch_ctrl2 Ended.\n")
  983. @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
  984. def test_cache_map_epoch_ctrl3():
  985. """
  986. Test using two-loops method with infinite epochs over repeat
  987. repeat
  988. |
  989. Map(decode)
  990. |
  991. cache
  992. |
  993. ImageFolder
  994. """
  995. logger.info("Test cache map epoch ctrl3")
  996. if "SESSION_ID" in os.environ:
  997. session_id = int(os.environ['SESSION_ID'])
  998. else:
  999. raise RuntimeError("Testcase requires SESSION_ID environment variable")
  1000. some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True)
  1001. # This DATA_DIR only has 2 images in it
  1002. ds1 = ds.ImageFolderDataset(dataset_dir=DATA_DIR, cache=some_cache)
  1003. decode_op = c_vision.Decode()
  1004. ds1 = ds1.map(input_columns=["image"], operations=decode_op)
  1005. ds1 = ds1.repeat(2)
  1006. num_epoch = 5
  1007. # iter1 will always assume there is a next epoch and never shutdown
  1008. iter1 = ds1.create_dict_iterator()
  1009. epoch_count = 0
  1010. for _ in range(num_epoch):
  1011. row_count = 0
  1012. for _ in iter1:
  1013. row_count += 1
  1014. logger.info("Number of data in ds1: {} ".format(row_count))
  1015. assert row_count == 4
  1016. epoch_count += 1
  1017. assert epoch_count == num_epoch
  1018. # reply on garbage collector to destroy iter1
  1019. logger.info("test_cache_map_epoch_ctrl3 Ended.\n")
  1020. @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
  1021. def test_cache_map_coco1():
  1022. """
  1023. Test mappable coco leaf with cache op right over the leaf
  1024. cache
  1025. |
  1026. Coco
  1027. """
  1028. logger.info("Test cache map coco1")
  1029. if "SESSION_ID" in os.environ:
  1030. session_id = int(os.environ['SESSION_ID'])
  1031. else:
  1032. raise RuntimeError("Testcase requires SESSION_ID environment variable")
  1033. some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True)
  1034. # This dataset has 6 records
  1035. ds1 = ds.CocoDataset(COCO_DATA_DIR, annotation_file=COCO_ANNOTATION_FILE, task="Detection", decode=True,
  1036. cache=some_cache)
  1037. num_epoch = 4
  1038. iter1 = ds1.create_dict_iterator(num_epochs=num_epoch)
  1039. epoch_count = 0
  1040. for _ in range(num_epoch):
  1041. assert sum([1 for _ in iter1]) == 6
  1042. epoch_count += 1
  1043. assert epoch_count == num_epoch
  1044. logger.info("test_cache_map_coco1 Ended.\n")
  1045. @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
  1046. def test_cache_map_coco2():
  1047. """
  1048. Test mappable coco leaf with the cache op later in the tree above the map(resize)
  1049. cache
  1050. |
  1051. Map(resize)
  1052. |
  1053. Coco
  1054. """
  1055. logger.info("Test cache map coco2")
  1056. if "SESSION_ID" in os.environ:
  1057. session_id = int(os.environ['SESSION_ID'])
  1058. else:
  1059. raise RuntimeError("Testcase requires SESSION_ID environment variable")
  1060. some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True)
  1061. # This dataset has 6 records
  1062. ds1 = ds.CocoDataset(COCO_DATA_DIR, annotation_file=COCO_ANNOTATION_FILE, task="Detection", decode=True)
  1063. resize_op = c_vision.Resize((224, 224))
  1064. ds1 = ds1.map(input_columns=["image"], operations=resize_op, cache=some_cache)
  1065. num_epoch = 4
  1066. iter1 = ds1.create_dict_iterator(num_epochs=num_epoch)
  1067. epoch_count = 0
  1068. for _ in range(num_epoch):
  1069. assert sum([1 for _ in iter1]) == 6
  1070. epoch_count += 1
  1071. assert epoch_count == num_epoch
  1072. logger.info("test_cache_map_coco2 Ended.\n")
  1073. @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
  1074. def test_cache_map_mnist1():
  1075. """
  1076. Test mappable mnist leaf with cache op right over the leaf
  1077. cache
  1078. |
  1079. Mnist
  1080. """
  1081. logger.info("Test cache map mnist1")
  1082. if "SESSION_ID" in os.environ:
  1083. session_id = int(os.environ['SESSION_ID'])
  1084. else:
  1085. raise RuntimeError("Testcase requires SESSION_ID environment variable")
  1086. some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True)
  1087. ds1 = ds.MnistDataset(MNIST_DATA_DIR, num_samples=10, cache=some_cache)
  1088. num_epoch = 4
  1089. iter1 = ds1.create_dict_iterator(num_epochs=num_epoch)
  1090. epoch_count = 0
  1091. for _ in range(num_epoch):
  1092. assert sum([1 for _ in iter1]) == 10
  1093. epoch_count += 1
  1094. assert epoch_count == num_epoch
  1095. logger.info("test_cache_map_mnist1 Ended.\n")
  1096. @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
  1097. def test_cache_map_mnist2():
  1098. """
  1099. Test mappable mnist leaf with the cache op later in the tree above the map(resize)
  1100. cache
  1101. |
  1102. Map(resize)
  1103. |
  1104. Mnist
  1105. """
  1106. logger.info("Test cache map mnist2")
  1107. if "SESSION_ID" in os.environ:
  1108. session_id = int(os.environ['SESSION_ID'])
  1109. else:
  1110. raise RuntimeError("Testcase requires SESSION_ID environment variable")
  1111. some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True)
  1112. ds1 = ds.MnistDataset(MNIST_DATA_DIR, num_samples=10)
  1113. resize_op = c_vision.Resize((224, 224))
  1114. ds1 = ds1.map(input_columns=["image"], operations=resize_op, cache=some_cache)
  1115. num_epoch = 4
  1116. iter1 = ds1.create_dict_iterator(num_epochs=num_epoch)
  1117. epoch_count = 0
  1118. for _ in range(num_epoch):
  1119. assert sum([1 for _ in iter1]) == 10
  1120. epoch_count += 1
  1121. assert epoch_count == num_epoch
  1122. logger.info("test_cache_map_mnist2 Ended.\n")
  1123. @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
  1124. def test_cache_map_celeba1():
  1125. """
  1126. Test mappable celeba leaf with cache op right over the leaf
  1127. cache
  1128. |
  1129. CelebA
  1130. """
  1131. logger.info("Test cache map celeba1")
  1132. if "SESSION_ID" in os.environ:
  1133. session_id = int(os.environ['SESSION_ID'])
  1134. else:
  1135. raise RuntimeError("Testcase requires SESSION_ID environment variable")
  1136. some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True)
  1137. # This dataset has 4 records
  1138. ds1 = ds.CelebADataset(CELEBA_DATA_DIR, shuffle=False, decode=True, cache=some_cache)
  1139. num_epoch = 4
  1140. iter1 = ds1.create_dict_iterator(num_epochs=num_epoch)
  1141. epoch_count = 0
  1142. for _ in range(num_epoch):
  1143. assert sum([1 for _ in iter1]) == 4
  1144. epoch_count += 1
  1145. assert epoch_count == num_epoch
  1146. logger.info("test_cache_map_celeba1 Ended.\n")
  1147. @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
  1148. def test_cache_map_celeba2():
  1149. """
  1150. Test mappable celeba leaf with the cache op later in the tree above the map(resize)
  1151. cache
  1152. |
  1153. Map(resize)
  1154. |
  1155. CelebA
  1156. """
  1157. logger.info("Test cache map celeba2")
  1158. if "SESSION_ID" in os.environ:
  1159. session_id = int(os.environ['SESSION_ID'])
  1160. else:
  1161. raise RuntimeError("Testcase requires SESSION_ID environment variable")
  1162. some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True)
  1163. # This dataset has 4 records
  1164. ds1 = ds.CelebADataset(CELEBA_DATA_DIR, shuffle=False, decode=True)
  1165. resize_op = c_vision.Resize((224, 224))
  1166. ds1 = ds1.map(input_columns=["image"], operations=resize_op, cache=some_cache)
  1167. num_epoch = 4
  1168. iter1 = ds1.create_dict_iterator(num_epochs=num_epoch)
  1169. epoch_count = 0
  1170. for _ in range(num_epoch):
  1171. assert sum([1 for _ in iter1]) == 4
  1172. epoch_count += 1
  1173. assert epoch_count == num_epoch
  1174. logger.info("test_cache_map_celeba2 Ended.\n")
  1175. @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
  1176. def test_cache_map_manifest1():
  1177. """
  1178. Test mappable manifest leaf with cache op right over the leaf
  1179. cache
  1180. |
  1181. Manifest
  1182. """
  1183. logger.info("Test cache map manifest1")
  1184. if "SESSION_ID" in os.environ:
  1185. session_id = int(os.environ['SESSION_ID'])
  1186. else:
  1187. raise RuntimeError("Testcase requires SESSION_ID environment variable")
  1188. some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True)
  1189. # This dataset has 4 records
  1190. ds1 = ds.ManifestDataset(MANIFEST_DATA_FILE, decode=True, cache=some_cache)
  1191. num_epoch = 4
  1192. iter1 = ds1.create_dict_iterator(num_epochs=num_epoch)
  1193. epoch_count = 0
  1194. for _ in range(num_epoch):
  1195. assert sum([1 for _ in iter1]) == 4
  1196. epoch_count += 1
  1197. assert epoch_count == num_epoch
  1198. logger.info("test_cache_map_manifest1 Ended.\n")
  1199. @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
  1200. def test_cache_map_manifest2():
  1201. """
  1202. Test mappable manifest leaf with the cache op later in the tree above the map(resize)
  1203. cache
  1204. |
  1205. Map(resize)
  1206. |
  1207. Manifest
  1208. """
  1209. logger.info("Test cache map manifest2")
  1210. if "SESSION_ID" in os.environ:
  1211. session_id = int(os.environ['SESSION_ID'])
  1212. else:
  1213. raise RuntimeError("Testcase requires SESSION_ID environment variable")
  1214. some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True)
  1215. # This dataset has 4 records
  1216. ds1 = ds.ManifestDataset(MANIFEST_DATA_FILE, decode=True)
  1217. resize_op = c_vision.Resize((224, 224))
  1218. ds1 = ds1.map(input_columns=["image"], operations=resize_op, cache=some_cache)
  1219. num_epoch = 4
  1220. iter1 = ds1.create_dict_iterator(num_epochs=num_epoch)
  1221. epoch_count = 0
  1222. for _ in range(num_epoch):
  1223. assert sum([1 for _ in iter1]) == 4
  1224. epoch_count += 1
  1225. assert epoch_count == num_epoch
  1226. logger.info("test_cache_map_manifest2 Ended.\n")
  1227. @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
  1228. def test_cache_map_cifar1():
  1229. """
  1230. Test mappable cifar10 leaf with cache op right over the leaf
  1231. cache
  1232. |
  1233. Cifar10
  1234. """
  1235. logger.info("Test cache map cifar1")
  1236. if "SESSION_ID" in os.environ:
  1237. session_id = int(os.environ['SESSION_ID'])
  1238. else:
  1239. raise RuntimeError("Testcase requires SESSION_ID environment variable")
  1240. some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True)
  1241. ds1 = ds.Cifar10Dataset(CIFAR10_DATA_DIR, num_samples=10, cache=some_cache)
  1242. num_epoch = 4
  1243. iter1 = ds1.create_dict_iterator(num_epochs=num_epoch)
  1244. epoch_count = 0
  1245. for _ in range(num_epoch):
  1246. assert sum([1 for _ in iter1]) == 10
  1247. epoch_count += 1
  1248. assert epoch_count == num_epoch
  1249. logger.info("test_cache_map_cifar1 Ended.\n")
  1250. @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
  1251. def test_cache_map_cifar2():
  1252. """
  1253. Test mappable cifar100 leaf with the cache op later in the tree above the map(resize)
  1254. cache
  1255. |
  1256. Map(resize)
  1257. |
  1258. Cifar100
  1259. """
  1260. logger.info("Test cache map cifar2")
  1261. if "SESSION_ID" in os.environ:
  1262. session_id = int(os.environ['SESSION_ID'])
  1263. else:
  1264. raise RuntimeError("Testcase requires SESSION_ID environment variable")
  1265. some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True)
  1266. ds1 = ds.Cifar100Dataset(CIFAR100_DATA_DIR, num_samples=10)
  1267. resize_op = c_vision.Resize((224, 224))
  1268. ds1 = ds1.map(input_columns=["image"], operations=resize_op, cache=some_cache)
  1269. num_epoch = 4
  1270. iter1 = ds1.create_dict_iterator(num_epochs=num_epoch)
  1271. epoch_count = 0
  1272. for _ in range(num_epoch):
  1273. assert sum([1 for _ in iter1]) == 10
  1274. epoch_count += 1
  1275. assert epoch_count == num_epoch
  1276. logger.info("test_cache_map_cifar2 Ended.\n")
  1277. @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
  1278. def test_cache_map_voc1():
  1279. """
  1280. Test mappable voc leaf with cache op right over the leaf
  1281. cache
  1282. |
  1283. VOC
  1284. """
  1285. logger.info("Test cache map voc1")
  1286. if "SESSION_ID" in os.environ:
  1287. session_id = int(os.environ['SESSION_ID'])
  1288. else:
  1289. raise RuntimeError("Testcase requires SESSION_ID environment variable")
  1290. some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True)
  1291. # This dataset has 9 records
  1292. ds1 = ds.VOCDataset(VOC_DATA_DIR, task="Detection", usage="train", shuffle=False, decode=True, cache=some_cache)
  1293. num_epoch = 4
  1294. iter1 = ds1.create_dict_iterator(num_epochs=num_epoch)
  1295. epoch_count = 0
  1296. for _ in range(num_epoch):
  1297. assert sum([1 for _ in iter1]) == 9
  1298. epoch_count += 1
  1299. assert epoch_count == num_epoch
  1300. logger.info("test_cache_map_voc1 Ended.\n")
  1301. @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
  1302. def test_cache_map_voc2():
  1303. """
  1304. Test mappable voc leaf with the cache op later in the tree above the map(resize)
  1305. cache
  1306. |
  1307. Map(resize)
  1308. |
  1309. VOC
  1310. """
  1311. logger.info("Test cache map voc2")
  1312. if "SESSION_ID" in os.environ:
  1313. session_id = int(os.environ['SESSION_ID'])
  1314. else:
  1315. raise RuntimeError("Testcase requires SESSION_ID environment variable")
  1316. some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True)
  1317. # This dataset has 9 records
  1318. ds1 = ds.VOCDataset(VOC_DATA_DIR, task="Detection", usage="train", shuffle=False, decode=True)
  1319. resize_op = c_vision.Resize((224, 224))
  1320. ds1 = ds1.map(input_columns=["image"], operations=resize_op, cache=some_cache)
  1321. num_epoch = 4
  1322. iter1 = ds1.create_dict_iterator(num_epochs=num_epoch)
  1323. epoch_count = 0
  1324. for _ in range(num_epoch):
  1325. assert sum([1 for _ in iter1]) == 9
  1326. epoch_count += 1
  1327. assert epoch_count == num_epoch
  1328. logger.info("test_cache_map_voc2 Ended.\n")
  1329. class ReverseSampler(ds.Sampler):
  1330. def __iter__(self):
  1331. for i in range(self.dataset_size - 1, -1, -1):
  1332. yield i
  1333. @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
  1334. def test_cache_map_python_sampler1():
  1335. """
  1336. Test using a python sampler, and cache after leaf
  1337. Repeat
  1338. |
  1339. Map(decode)
  1340. |
  1341. cache
  1342. |
  1343. ImageFolder
  1344. """
  1345. logger.info("Test cache map python sampler1")
  1346. if "SESSION_ID" in os.environ:
  1347. session_id = int(os.environ['SESSION_ID'])
  1348. else:
  1349. raise RuntimeError("Testcase requires SESSION_ID environment variable")
  1350. some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True)
  1351. # This DATA_DIR only has 2 images in it
  1352. ds1 = ds.ImageFolderDataset(dataset_dir=DATA_DIR, sampler=ReverseSampler(), cache=some_cache)
  1353. decode_op = c_vision.Decode()
  1354. ds1 = ds1.map(input_columns=["image"], operations=decode_op)
  1355. ds1 = ds1.repeat(4)
  1356. num_iter = 0
  1357. for _ in ds1.create_dict_iterator():
  1358. num_iter += 1
  1359. logger.info("Number of data in ds1: {} ".format(num_iter))
  1360. assert num_iter == 8
  1361. logger.info("test_cache_map_python_sampler1 Ended.\n")
  1362. @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
  1363. def test_cache_map_python_sampler2():
  1364. """
  1365. Test using a python sampler, and cache after map
  1366. Repeat
  1367. |
  1368. cache
  1369. |
  1370. Map(decode)
  1371. |
  1372. ImageFolder
  1373. """
  1374. logger.info("Test cache map python sampler2")
  1375. if "SESSION_ID" in os.environ:
  1376. session_id = int(os.environ['SESSION_ID'])
  1377. else:
  1378. raise RuntimeError("Testcase requires SESSION_ID environment variable")
  1379. some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True)
  1380. # This DATA_DIR only has 2 images in it
  1381. ds1 = ds.ImageFolderDataset(dataset_dir=DATA_DIR, sampler=ReverseSampler())
  1382. decode_op = c_vision.Decode()
  1383. ds1 = ds1.map(input_columns=["image"], operations=decode_op, cache=some_cache)
  1384. ds1 = ds1.repeat(4)
  1385. num_iter = 0
  1386. for _ in ds1.create_dict_iterator():
  1387. num_iter += 1
  1388. logger.info("Number of data in ds1: {} ".format(num_iter))
  1389. assert num_iter == 8
  1390. logger.info("test_cache_map_python_sampler2 Ended.\n")
  1391. @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
  1392. def test_cache_map_nested_repeat():
  1393. """
  1394. Test cache on pipeline with nested repeat ops
  1395. Repeat
  1396. |
  1397. Map(decode)
  1398. |
  1399. Repeat
  1400. |
  1401. Cache
  1402. |
  1403. ImageFolder
  1404. """
  1405. logger.info("Test cache map nested repeat")
  1406. if "SESSION_ID" in os.environ:
  1407. session_id = int(os.environ['SESSION_ID'])
  1408. else:
  1409. raise RuntimeError("Testcase requires SESSION_ID environment variable")
  1410. some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True)
  1411. # This DATA_DIR only has 2 images in it
  1412. ds1 = ds.ImageFolderDataset(dataset_dir=DATA_DIR, cache=some_cache)
  1413. decode_op = c_vision.Decode()
  1414. ds1 = ds1.repeat(4)
  1415. ds1 = ds1.map(operations=decode_op, input_columns=["image"])
  1416. ds1 = ds1.repeat(2)
  1417. num_iter = 0
  1418. for _ in ds1.create_dict_iterator(num_epochs=1):
  1419. logger.info("get data from dataset")
  1420. num_iter += 1
  1421. logger.info("Number of data in ds1: {} ".format(num_iter))
  1422. assert num_iter == 16
  1423. logger.info('test_cache_map_nested_repeat Ended.\n')
  1424. if __name__ == '__main__':
  1425. test_cache_map_basic1()
  1426. test_cache_map_basic2()
  1427. test_cache_map_basic3()
  1428. test_cache_map_basic4()
  1429. test_cache_map_failure1()
  1430. test_cache_map_failure2()
  1431. test_cache_map_failure3()
  1432. test_cache_map_failure4()