You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_cache_nomap.py 82 kB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516151715181519152015211522152315241525152615271528152915301531153215331534153515361537153815391540154115421543154415451546154715481549155015511552155315541555155615571558155915601561156215631564156515661567156815691570157115721573157415751576157715781579158015811582158315841585158615871588158915901591159215931594159515961597159815991600160116021603160416051606160716081609161016111612161316141615161616171618161916201621162216231624162516261627162816291630163116321633163416351636163716381639164016411642164316441645164616471648164916501651165216531654165516561657165816591660166116621663166416651666166716681669167016711672167316741675167616771678167916801681168216831684168516861687168816891690169116921693169416951696169716981699170017011702170317041705170617071708170917101711171217131714171517161717171817191720172117221723172417251726172717281729173017311732173317341735173617371738173917401741174217431744174517461747174817491750175117521753175417551756175717581759176017611762176317641765176617671768176917701771177217731774177517761777177817791780178117821783178417851786178717881789179017911792179317941795179617971798179918001801180218031804180518061807180818091810181118121813181418151816181718181819182018211822182318241825182618271828182918301831183218331834183518361837183818391840184118421843184418451846184718481849185018511852185318541855185618571858185918601861186218631864186518661867186818691870187118721873187418751876187718781879188018811882188318841885188618871888188918901891189218931894189518961897189818991900190119021903190419051906190719081909191019111912191319141915191619171918191919201921192219231924192519261927192819291930193119321933193419351936193719381939194019411942194319441945194619471948194919501951195219531954195519561957195819591960196119621963196419651966196719681969197019711972197319741975197619771978197919801981198219831984198519861987198819891990199119921993199419951996199719981999200020012002200320042005200620072008200920102011201220132014201520162017201820192020202120222023202420252026202720282029203020312032203320342035203620372038203920402041204220432044204520462047204820492050205120522053205420552056205720582059206020612062206320642065206620672068206920702071207220732074207520762077207820792080208120822083208420852086208720882089209020912092209320942095209620972098209921002101210221032104210521062107210821092110211121122113211421152116211721182119212021212122212321242125212621272128212921302131213221332134213521362137213821392140214121422143214421452146214721482149215021512152215321542155215621572158215921602161216221632164216521662167216821692170217121722173217421752176217721782179218021812182218321842185218621872188218921902191219221932194219521962197219821992200220122022203220422052206220722082209221022112212221322142215221622172218221922202221222222232224222522262227222822292230223122322233223422352236223722382239224022412242224322442245224622472248224922502251225222532254225522562257225822592260226122622263226422652266226722682269227022712272227322742275227622772278227922802281228222832284228522862287228822892290229122922293229422952296229722982299230023012302230323042305230623072308230923102311231223132314231523162317231823192320232123222323232423252326232723282329233023312332233323342335233623372338233923402341234223432344234523462347234823492350235123522353235423552356235723582359236023612362236323642365236623672368236923702371237223732374237523762377237823792380238123822383238423852386238723882389239023912392239323942395239623972398239924002401240224032404240524062407240824092410241124122413241424152416
  1. # Copyright 2020-2021 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. """
  16. Testing cache operator with non-mappable datasets
  17. """
  18. import os
  19. import itertools
  20. import numpy as np
  21. import pytest
  22. import mindspore.common.dtype as mstype
  23. import mindspore.dataset as ds
  24. import mindspore.dataset.text as text
  25. import mindspore.dataset.vision.c_transforms as c_vision
  26. import mindspore.dataset.vision.py_transforms as py_vision
  27. from mindspore import log as logger
  28. DATA_DIR = ["../data/dataset/test_tf_file_3_images/train-0000-of-0001.data"]
  29. SCHEMA_DIR = "../data/dataset/test_tf_file_3_images/datasetSchema.json"
  30. TEXT_TF_DATA_DIR = ["../data/dataset/testTextTFRecord/text.tfrecord"]
  31. SCHEMA_DIR2 = "../data/dataset/testTextTFRecord/datasetSchema.json"
  32. TRAIN_DATA_DIR = ["../data/dataset/test_tf_file_3_images2/train-0000-of-0001.data",
  33. "../data/dataset/test_tf_file_3_images2/train-0000-of-0002.data",
  34. "../data/dataset/test_tf_file_3_images2/train-0000-of-0003.data",
  35. "../data/dataset/test_tf_file_3_images2/train-0000-of-0004.data"]
  36. TRAIN_SCHEMA_DIR = "../data/dataset/test_tf_file_3_images2/datasetSchema.json"
  37. IMAGE_FOLDER_DATA_DIR = "../data/dataset/testImageNetData/train/"
  38. CLUE_DATA_DIR = '../data/dataset/testCLUE/afqmc/train.json'
  39. CSV_DATA_DIR = '../data/dataset/testCSV/1.csv'
  40. TEXT_FILE_DATA_DIR = "../data/dataset/testTextFileDataset/1.txt"
  41. PYFUNC_DATA_DIR = ["../data/dataset/testPyfuncMap/data.data"]
  42. PYFUNC_SCHEMA_DIR = "../data/dataset/testPyfuncMap/schema.json"
  43. GENERATE_GOLDEN = False
  44. @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
  45. def test_cache_nomap_basic1():
  46. """
  47. A random dataset (a non mappable dataset) with a cache over it just after the leaf
  48. """
  49. logger.info("Test cache nomap basic 1")
  50. if "SESSION_ID" in os.environ:
  51. session_id = int(os.environ['SESSION_ID'])
  52. else:
  53. raise RuntimeError("Testcase requires SESSION_ID environment variable")
  54. schema = ds.Schema()
  55. schema.add_column('image', de_type=mstype.uint8,
  56. shape=[640, 480, 3]) # 921600 bytes (a bit less than 1 MB per image)
  57. schema.add_column('label', de_type=mstype.uint8, shape=[1])
  58. # create a cache. arbitrary session_id for now
  59. some_cache = ds.DatasetCache(session_id=session_id, size=0)
  60. # User-created sampler here
  61. ds1 = ds.RandomDataset(schema=schema, total_rows=10, num_parallel_workers=4, cache=some_cache)
  62. ds1 = ds1.repeat(4)
  63. num_iter = 0
  64. for data in ds1.create_dict_iterator(num_epochs=1):
  65. logger.info("printing the label: {}".format(data["label"]))
  66. num_iter += 1
  67. logger.info("Number of data in ds1: {} ".format(num_iter))
  68. assert num_iter == 40
  69. logger.info("test_cache_nomap_basic1 Ended.\n")
  70. @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
  71. def test_cache_nomap_basic2():
  72. """
  73. A random dataset (a non mappable dataset) with a cache over it just after the leaf
  74. """
  75. logger.info("Test cache nomap basic 2")
  76. if "SESSION_ID" in os.environ:
  77. session_id = int(os.environ['SESSION_ID'])
  78. else:
  79. raise RuntimeError("Testcase requires SESSION_ID environment variable")
  80. schema = ds.Schema()
  81. schema.add_column('image', de_type=mstype.uint8,
  82. shape=[640, 480, 3]) # 921600 bytes (a bit less than 1 MB per image)
  83. schema.add_column('label', de_type=mstype.uint8, shape=[1])
  84. # create a cache. arbitrary session_id for now
  85. some_cache = ds.DatasetCache(session_id=session_id, size=0)
  86. # sampler arg not given directly, however any of these args will auto-generate an appropriate sampler:
  87. # num_samples, shuffle, num_shards, shard_id
  88. # In this case, the presence of num_samples chooses a sampler.
  89. ds1 = ds.RandomDataset(schema=schema, total_rows=20, num_samples=20, num_parallel_workers=4, cache=some_cache)
  90. ds1 = ds1.repeat(2)
  91. num_iter = 0
  92. for data in ds1.create_dict_iterator(num_epochs=1):
  93. logger.info("printing the label: {}".format(data["label"]))
  94. num_iter += 1
  95. logger.info("Number of data in ds1: {} ".format(num_iter))
  96. assert num_iter == 40
  97. logger.info("test_cache_nomap_basic2 Ended.\n")
  98. @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
  99. def test_cache_nomap_basic3():
  100. """
  101. A TF reader dataset (a non mappable dataset) with a cache over it just after the leaf
  102. Repeat
  103. |
  104. Map(decode)
  105. |
  106. Cache
  107. |
  108. TFReader
  109. """
  110. logger.info("Test cache nomap basic 3")
  111. if "SESSION_ID" in os.environ:
  112. session_id = int(os.environ['SESSION_ID'])
  113. else:
  114. raise RuntimeError("Testcase requires SESSION_ID environment variable")
  115. some_cache = ds.DatasetCache(session_id=session_id, size=0)
  116. ds1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False, cache=some_cache)
  117. decode_op = c_vision.Decode()
  118. ds1 = ds1.map(operations=decode_op, input_columns=["image"])
  119. ds1 = ds1.repeat(4)
  120. num_iter = 0
  121. for _ in ds1.create_dict_iterator(num_epochs=1):
  122. num_iter += 1
  123. logger.info("Number of data in ds1: {} ".format(num_iter))
  124. assert num_iter == 12
  125. # Contact the server to get the statistics
  126. stat = some_cache.GetStat()
  127. cache_sz = stat.avg_cache_sz
  128. num_mem_cached = stat.num_mem_cached
  129. num_disk_cached = stat.num_disk_cached
  130. logger.info("Number of rows cached in memory: {}".format(num_mem_cached))
  131. logger.info("Number of rows spilled to disk: {}".format(num_disk_cached))
  132. logger.info("Average row cache size: {}".format(cache_sz))
  133. logger.info("test_cache_nomap_basic3 Ended.\n")
  134. @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
  135. def test_cache_nomap_basic4():
  136. """
  137. A TF reader dataset (a non mappable dataset) with a map decode and cache after it
  138. Since a global shuffle is used for the tf reader, it will inject a shuffle op over the tf.
  139. But, if there's a cache later, that shuffle becomes invalid and should be removed.
  140. Repeat
  141. |
  142. Cache
  143. |
  144. Map(decode)
  145. |
  146. TFReader
  147. """
  148. logger.info("Test cache nomap basic 4")
  149. if "SESSION_ID" in os.environ:
  150. session_id = int(os.environ['SESSION_ID'])
  151. else:
  152. raise RuntimeError("Testcase requires SESSION_ID environment variable")
  153. # This dataset has 3 records in it only
  154. some_cache = ds.DatasetCache(session_id=session_id, size=0)
  155. # With shuffle not being set, TF defaults to a "global" shuffle when there is no cache
  156. # in the picture. This causes a shuffle-injection over the TF. For clarify, this test will
  157. # explicitly give the global option, even though it's the default in python.
  158. # But, when caching is added in the ascendent tree above TF, we do global shuffling
  159. # through the sampler over the cache, not by the shuffle op. In that case, tree prepare
  160. # will remove the shuffle op that got injected by the initial tree creation.
  161. ds1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=ds.Shuffle.GLOBAL)
  162. decode_op = c_vision.Decode()
  163. ds1 = ds1.map(operations=decode_op, input_columns=["image"], cache=some_cache)
  164. ds1 = ds1.repeat(4)
  165. num_iter = 0
  166. for _ in ds1.create_dict_iterator(num_epochs=1):
  167. num_iter += 1
  168. logger.info("Number of data in ds1: {} ".format(num_iter))
  169. assert num_iter == 12
  170. logger.info("test_cache_nomap_basic4 Ended.\n")
  171. @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
  172. def test_cache_nomap_basic5():
  173. """
  174. A TF reader dataset (a non mappable dataset) with a cache over it just after the leaf
  175. Same as test 3, but this one does not have shuffle arg, causing tf to default to global
  176. shuffle which attempts to inject a shuffle operator. However, since there is a cache
  177. we do not need global shuffle, so the shuffle will not be built. It ends up being
  178. identical to test basic 3, however we arrive at the same tree in different codepaths
  179. (if there was no cache, then the shuffle IS built)
  180. Repeat
  181. |
  182. Map(decode)
  183. |
  184. Cache
  185. |
  186. TFReader
  187. """
  188. logger.info("Test cache nomap basic 5")
  189. if "SESSION_ID" in os.environ:
  190. session_id = int(os.environ['SESSION_ID'])
  191. else:
  192. raise RuntimeError("Testcase requires SESSION_ID environment variable")
  193. # This dataset has 3 records in it only
  194. some_cache = ds.DatasetCache(session_id=session_id, size=0)
  195. ds1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], cache=some_cache)
  196. decode_op = c_vision.Decode()
  197. ds1 = ds1.map(operations=decode_op, input_columns=["image"])
  198. ds1 = ds1.repeat(4)
  199. num_iter = 0
  200. for _ in ds1.create_dict_iterator(num_epochs=1):
  201. num_iter += 1
  202. logger.info("Number of data in ds1: {} ".format(num_iter))
  203. assert num_iter == 12
  204. logger.info("test_cache_nomap_basic5 Ended.\n")
  205. @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
  206. def test_cache_nomap_basic6():
  207. """
  208. A TF reader dataset (a non mappable dataset) with a cache over it just after the leaf
  209. In this one, the tf dataset will be given sharding configuration, however since a cache is
  210. used, the tree prepare should undo the sharding configuration and instead, a distributed
  211. sampler will be chosen with the same shard config.
  212. Repeat
  213. |
  214. Map(decode)
  215. |
  216. Cache
  217. |
  218. TFReader
  219. """
  220. logger.info("Test cache nomap basic 6")
  221. if "SESSION_ID" in os.environ:
  222. session_id = int(os.environ['SESSION_ID'])
  223. else:
  224. raise RuntimeError("Testcase requires SESSION_ID environment variable")
  225. # This dataset has 3 records in it only
  226. some_cache = ds.DatasetCache(session_id=session_id, size=0)
  227. # With only 3 records shard into 3, we expect only 1 record returned for this shard
  228. # However, the sharding will be done by the sampler, not by the tf record leaf node
  229. # In this case, it is a row-based sharding, not the file-based sharding that would happen if
  230. # there was not any cache.
  231. ds1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], num_shards=3, shard_id=1, cache=some_cache)
  232. decode_op = c_vision.Decode()
  233. ds1 = ds1.map(operations=decode_op, input_columns=["image"])
  234. ds1 = ds1.repeat(4)
  235. num_iter = 0
  236. for _ in ds1.create_dict_iterator(num_epochs=1):
  237. num_iter += 1
  238. logger.info("Number of data in ds1: {} ".format(num_iter))
  239. assert num_iter == 4
  240. logger.info("test_cache_nomap_basic6 Ended.\n")
  241. @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
  242. def test_cache_nomap_basic7():
  243. """
  244. A TF reader dataset (a non mappable dataset) that uses global shuffle, and is cached followed by
  245. map.
  246. In this one, the tf dataset with global shuffle might want to inject a shuffle op over top of the
  247. tf reader, but since a cache is given, it will choose not to.
  248. Repeat
  249. |
  250. Map(decode)
  251. |
  252. cache
  253. |
  254. TFReader
  255. """
  256. logger.info("Test cache nomap basic 7")
  257. if "SESSION_ID" in os.environ:
  258. session_id = int(os.environ['SESSION_ID'])
  259. else:
  260. raise RuntimeError("Testcase requires SESSION_ID environment variable")
  261. some_cache = ds.DatasetCache(session_id=session_id, size=0)
  262. # This dataset has 3 records in it only
  263. ds1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=ds.Shuffle.GLOBAL, cache=some_cache)
  264. decode_op = c_vision.Decode()
  265. ds1 = ds1.map(operations=decode_op, input_columns=["image"])
  266. ds1 = ds1.repeat(4)
  267. num_iter = 0
  268. for _ in ds1.create_dict_iterator(num_epochs=1):
  269. num_iter += 1
  270. logger.info("Number of data in ds1: {} ".format(num_iter))
  271. assert num_iter == 12
  272. logger.info("test_cache_nomap_basic7 Ended.\n")
  273. @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
  274. def test_cache_nomap_basic8():
  275. """
  276. Test cache as root node
  277. cache
  278. |
  279. TFReader
  280. """
  281. logger.info("Test cache basic 8")
  282. if "SESSION_ID" in os.environ:
  283. session_id = int(os.environ['SESSION_ID'])
  284. else:
  285. raise RuntimeError("Testcase requires SESSION_ID environment variable")
  286. some_cache = ds.DatasetCache(session_id=session_id, size=0)
  287. # This dataset has 3 records in it only
  288. ds1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, cache=some_cache)
  289. num_iter = 0
  290. for _ in ds1.create_dict_iterator(num_epochs=1):
  291. logger.info("get data from dataset")
  292. num_iter += 1
  293. logger.info("Number of data in ds1: {} ".format(num_iter))
  294. assert num_iter == 3
  295. logger.info('test_cache_basic8 Ended.\n')
  296. @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
  297. def test_cache_nomap_basic9():
  298. """
  299. Testing the GetStat interface for getting some info from server, but this should fail if the cache is not created
  300. in a pipeline.
  301. """
  302. logger.info("Test cache nomap basic 9")
  303. if "SESSION_ID" in os.environ:
  304. session_id = int(os.environ['SESSION_ID'])
  305. else:
  306. raise RuntimeError("Testcase requires SESSION_ID environment variable")
  307. some_cache = ds.DatasetCache(session_id=session_id, size=0)
  308. # Contact the server to get the statistics, this should fail because we have not used this cache in any pipeline
  309. # so there will not be any cache to get stats on.
  310. with pytest.raises(RuntimeError) as e:
  311. stat = some_cache.GetStat()
  312. cache_sz = stat.avg_cache_sz
  313. logger.info("Average row cache size: {}".format(cache_sz))
  314. assert "Unexpected error" in str(e.value)
  315. logger.info("test_cache_nomap_basic9 Ended.\n")
  316. @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
  317. def test_cache_nomap_allowed_share1():
  318. """
  319. It is allowed to share the cache between the following two trees:
  320. Repeat Shuffle
  321. | |
  322. Cache Cache
  323. | |
  324. TFReader TFReader
  325. """
  326. logger.info("Test cache nomap allowed share 1")
  327. if "SESSION_ID" in os.environ:
  328. session_id = int(os.environ['SESSION_ID'])
  329. else:
  330. raise RuntimeError("Testcase requires SESSION_ID environment variable")
  331. ds.config.set_seed(1)
  332. # This dataset has 3 records in it only
  333. some_cache = ds.DatasetCache(session_id=session_id, size=0, prefetch_size=32)
  334. ds1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False, cache=some_cache)
  335. ds1 = ds1.repeat(4)
  336. ds2 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False, cache=some_cache)
  337. ds2 = ds2.shuffle(buffer_size=2)
  338. num_iter = 0
  339. for _ in ds1.create_dict_iterator(num_epochs=1):
  340. num_iter += 1
  341. assert num_iter == 12
  342. logger.info("Number of data in ds1: {} ".format(num_iter))
  343. num_iter = 0
  344. for _ in ds2.create_dict_iterator(num_epochs=1):
  345. num_iter += 1
  346. assert num_iter == 3
  347. logger.info("test_cache_nomap_allowed_share1 Ended.\n")
  348. @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
  349. def test_cache_nomap_allowed_share2():
  350. """
  351. It is allowed to share the cache between the following two trees (with map decode):
  352. Repeat Shuffle
  353. | |
  354. Cache Cache
  355. | |
  356. Map(decode) Map(decode)
  357. | |
  358. TFReader TFReader
  359. """
  360. logger.info("Test cache nomap allowed share 2")
  361. if "SESSION_ID" in os.environ:
  362. session_id = int(os.environ['SESSION_ID'])
  363. else:
  364. raise RuntimeError("Testcase requires SESSION_ID environment variable")
  365. ds.config.set_seed(1)
  366. # This dataset has 3 records in it only
  367. some_cache = ds.DatasetCache(session_id=session_id, size=0)
  368. decode_op = c_vision.Decode()
  369. ds1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False)
  370. ds1 = ds1.map(operations=decode_op, input_columns=["image"], cache=some_cache)
  371. ds1 = ds1.repeat(4)
  372. ds2 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False)
  373. ds2 = ds2.map(operations=decode_op, input_columns=["image"], cache=some_cache)
  374. ds2 = ds2.shuffle(buffer_size=2)
  375. num_iter = 0
  376. for _ in ds1.create_dict_iterator(num_epochs=1):
  377. num_iter += 1
  378. logger.info("Number of data in ds1: {} ".format(num_iter))
  379. assert num_iter == 12
  380. num_iter = 0
  381. for _ in ds2.create_dict_iterator(num_epochs=1):
  382. num_iter += 1
  383. assert num_iter == 3
  384. logger.info("test_cache_nomap_allowed_share2 Ended.\n")
  385. @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
  386. def test_cache_nomap_allowed_share3():
  387. """
  388. It is allowed to share the cache between the following two trees (different shard ids):
  389. Repeat Repeat
  390. | |
  391. Cache Cache
  392. | |
  393. TFReader(shard_id = 0) TFReader(shard_id = 1)
  394. """
  395. logger.info("Test cache nomap allowed share 3")
  396. if "SESSION_ID" in os.environ:
  397. session_id = int(os.environ['SESSION_ID'])
  398. else:
  399. raise RuntimeError("Testcase requires SESSION_ID environment variable")
  400. some_cache = ds.DatasetCache(session_id=session_id, size=0)
  401. tf_files = ["../data/dataset/tf_file_dataset/test1.data", "../data/dataset/tf_file_dataset/test2.data"]
  402. ds1 = ds.TFRecordDataset(tf_files, num_shards=2, shard_id=0, num_samples=3, shuffle=False, cache=some_cache)
  403. ds1 = ds1.repeat(4)
  404. ds2 = ds.TFRecordDataset(tf_files, num_shards=2, shard_id=1, num_samples=3, shuffle=False, cache=some_cache)
  405. ds2 = ds2.repeat(4)
  406. num_iter = 0
  407. for _ in ds1.create_dict_iterator(num_epochs=1):
  408. num_iter += 1
  409. logger.info("Number of data in ds1: {} ".format(num_iter))
  410. assert num_iter == 12
  411. num_iter = 0
  412. for _ in ds2.create_dict_iterator(num_epochs=1):
  413. num_iter += 1
  414. assert num_iter == 12
  415. logger.info("test_cache_nomap_allowed_share3 Ended.\n")
  416. @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
  417. def test_cache_nomap_allowed_share4():
  418. """
  419. It is allowed to share the cache between the following two trees:
  420. Cache Cache
  421. | |
  422. Map(decode, num_parallel_workers=1) Map(decode, num_parallel_workers=2)
  423. | |
  424. TFReader TFReader
  425. """
  426. logger.info("Test cache nomap allowed share 4")
  427. if "SESSION_ID" in os.environ:
  428. session_id = int(os.environ['SESSION_ID'])
  429. else:
  430. raise RuntimeError("Testcase requires SESSION_ID environment variable")
  431. # This dataset has 3 records in it only
  432. some_cache = ds.DatasetCache(session_id=session_id, size=0)
  433. decode_op = c_vision.Decode()
  434. ds1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False)
  435. ds1 = ds1.map(operations=decode_op, input_columns=["image"], cache=some_cache, num_parallel_workers=1)
  436. ds2 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False)
  437. ds2 = ds2.map(operations=decode_op, input_columns=["image"], cache=some_cache, num_parallel_workers=2)
  438. num_iter = 0
  439. for _ in ds1.create_dict_iterator(num_epochs=1):
  440. num_iter += 1
  441. logger.info("Number of data in ds1: {} ".format(num_iter))
  442. assert num_iter == 3
  443. num_iter = 0
  444. for _ in ds2.create_dict_iterator(num_epochs=1):
  445. num_iter += 1
  446. logger.info("Number of data in ds2: {} ".format(num_iter))
  447. assert num_iter == 3
  448. logger.info("test_cache_nomap_allowed_share4 Ended.\n")
  449. @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
  450. def test_cache_nomap_disallowed_share1():
  451. """
  452. It is not allowed to share the cache between the following two trees:
  453. Cache Cache
  454. | |
  455. Map(decode) Map(rescale)
  456. | |
  457. TFReader TFReader
  458. """
  459. logger.info("Test cache nomap disallowed share1")
  460. if "SESSION_ID" in os.environ:
  461. session_id = int(os.environ['SESSION_ID'])
  462. else:
  463. raise RuntimeError("Testcase requires SESSION_ID environment variable")
  464. # This dataset has 3 records in it only
  465. some_cache = ds.DatasetCache(session_id=session_id, size=0)
  466. decode_op = c_vision.Decode()
  467. rescale_op = c_vision.Rescale(1.0 / 255.0, -1.0)
  468. ds1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False)
  469. ds1 = ds1.map(operations=decode_op, input_columns=["image"], cache=some_cache)
  470. ds2 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False)
  471. ds2 = ds2.map(operations=rescale_op, input_columns=["image"], cache=some_cache)
  472. num_iter = 0
  473. for _ in ds1.create_dict_iterator(num_epochs=1):
  474. num_iter += 1
  475. logger.info("Number of data in ds1: {} ".format(num_iter))
  476. assert num_iter == 3
  477. with pytest.raises(RuntimeError) as e:
  478. sum([1 for _ in ds2])
  479. assert "Cannot re-use a cache for a different tree!" in str(e.value)
  480. logger.info("test_cache_nomap_disallowed_share1 Ended.\n")
  481. @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
  482. def test_cache_nomap_running_twice1():
  483. """
  484. Executing the same pipeline for twice (from python), with cache injected after map
  485. Repeat
  486. |
  487. Cache
  488. |
  489. Map(decode)
  490. |
  491. TFRecord
  492. """
  493. logger.info("Test cache nomap running twice 1")
  494. if "SESSION_ID" in os.environ:
  495. session_id = int(os.environ['SESSION_ID'])
  496. else:
  497. raise RuntimeError("Testcase requires SESSION_ID environment variable")
  498. some_cache = ds.DatasetCache(session_id=session_id, size=0)
  499. # This dataset has 3 records in it only
  500. ds1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR)
  501. decode_op = c_vision.Decode()
  502. ds1 = ds1.map(input_columns=["image"], operations=decode_op, cache=some_cache)
  503. ds1 = ds1.repeat(4)
  504. num_iter = 0
  505. for _ in ds1.create_dict_iterator():
  506. num_iter += 1
  507. logger.info("Number of data in ds1: {} ".format(num_iter))
  508. assert num_iter == 12
  509. num_iter = 0
  510. for _ in ds1.create_dict_iterator():
  511. num_iter += 1
  512. logger.info("Number of data in ds1: {} ".format(num_iter))
  513. assert num_iter == 12
  514. logger.info("test_cache_nomap_running_twice1 Ended.\n")
  515. @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
  516. def test_cache_nomap_running_twice2():
  517. """
  518. Executing the same pipeline for twice (from shell), with cache injected after leaf
  519. Repeat
  520. |
  521. Map(decode)
  522. |
  523. Cache
  524. |
  525. TFRecord
  526. """
  527. logger.info("Test cache nomap running twice 2")
  528. if "SESSION_ID" in os.environ:
  529. session_id = int(os.environ['SESSION_ID'])
  530. else:
  531. raise RuntimeError("Testcase requires SESSION_ID environment variable")
  532. some_cache = ds.DatasetCache(session_id=session_id, size=0)
  533. # This dataset has 3 records in it only
  534. ds1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, cache=some_cache)
  535. decode_op = c_vision.Decode()
  536. ds1 = ds1.map(input_columns=["image"], operations=decode_op)
  537. ds1 = ds1.repeat(4)
  538. num_iter = 0
  539. for _ in ds1.create_dict_iterator():
  540. num_iter += 1
  541. logger.info("Number of data in ds1: {} ".format(num_iter))
  542. assert num_iter == 12
  543. logger.info("test_cache_nomap_running_twice2 Ended.\n")
  544. @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
  545. def test_cache_nomap_extra_small_size1():
  546. """
  547. Test running pipeline with cache of extra small size and spilling true
  548. Repeat
  549. |
  550. Map(decode)
  551. |
  552. Cache
  553. |
  554. TFRecord
  555. """
  556. logger.info("Test cache nomap extra small size 1")
  557. if "SESSION_ID" in os.environ:
  558. session_id = int(os.environ['SESSION_ID'])
  559. else:
  560. raise RuntimeError("Testcase requires SESSION_ID environment variable")
  561. some_cache = ds.DatasetCache(session_id=session_id, size=1, spilling=True)
  562. # This dataset has 3 records in it only
  563. ds1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, cache=some_cache)
  564. decode_op = c_vision.Decode()
  565. ds1 = ds1.map(input_columns=["image"], operations=decode_op)
  566. ds1 = ds1.repeat(4)
  567. num_iter = 0
  568. for _ in ds1.create_dict_iterator():
  569. num_iter += 1
  570. logger.info("Number of data in ds1: {} ".format(num_iter))
  571. assert num_iter == 12
  572. logger.info("test_cache_nomap_extra_small_size1 Ended.\n")
  573. @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
  574. def test_cache_nomap_extra_small_size2():
  575. """
  576. Test running pipeline with cache of extra small size and spilling false (failure)
  577. Repeat
  578. |
  579. Cache
  580. |
  581. Map(decode)
  582. |
  583. TFRecord
  584. """
  585. logger.info("Test cache nomap extra small size 2")
  586. if "SESSION_ID" in os.environ:
  587. session_id = int(os.environ['SESSION_ID'])
  588. else:
  589. raise RuntimeError("Testcase requires SESSION_ID environment variable")
  590. some_cache = ds.DatasetCache(session_id=session_id, size=1, spilling=False)
  591. # This dataset has 3 records in it only
  592. ds1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR)
  593. decode_op = c_vision.Decode()
  594. ds1 = ds1.map(input_columns=["image"], operations=decode_op, cache=some_cache)
  595. ds1 = ds1.repeat(4)
  596. with pytest.raises(RuntimeError) as e:
  597. sum([1 for _ in ds1])
  598. assert "Out of memory" in str(e.value)
  599. logger.info("test_cache_nomap_extra_small_size2 Ended.\n")
  600. @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
  601. def test_cache_nomap_parallel_pipeline1(shard):
  602. """
  603. Test running two parallel pipelines (sharing cache) with cache injected after leaf op
  604. Repeat
  605. |
  606. Map(decode)
  607. |
  608. cache
  609. |
  610. TFReader
  611. """
  612. logger.info("Test cache nomap parallel pipeline 1")
  613. if "SESSION_ID" in os.environ:
  614. session_id = int(os.environ['SESSION_ID'])
  615. else:
  616. raise RuntimeError("Testcase requires SESSION_ID environment variable")
  617. some_cache = ds.DatasetCache(session_id=session_id, size=0)
  618. # This dataset has 3 records in it only
  619. ds1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, num_shards=3, shard_id=int(shard), cache=some_cache)
  620. decode_op = c_vision.Decode()
  621. ds1 = ds1.map(input_columns=["image"], operations=decode_op)
  622. ds1 = ds1.repeat(4)
  623. num_iter = 0
  624. for _ in ds1.create_dict_iterator(num_epochs=1):
  625. num_iter += 1
  626. logger.info("Number of data in ds1: {} ".format(num_iter))
  627. assert num_iter == 4
  628. logger.info("test_cache_nomap_parallel_pipeline1 Ended.\n")
  629. @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
  630. def test_cache_nomap_parallel_pipeline2(shard):
  631. """
  632. Test running two parallel pipelines (sharing cache) with cache injected after map op
  633. Repeat
  634. |
  635. cache
  636. |
  637. Map(decode)
  638. |
  639. TFReader
  640. """
  641. logger.info("Test cache nomap parallel pipeline 2")
  642. if "SESSION_ID" in os.environ:
  643. session_id = int(os.environ['SESSION_ID'])
  644. else:
  645. raise RuntimeError("Testcase requires SESSION_ID environment variable")
  646. some_cache = ds.DatasetCache(session_id=session_id, size=0)
  647. # This dataset has 3 records in it only
  648. ds1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, num_shards=3, shard_id=int(shard))
  649. decode_op = c_vision.Decode()
  650. ds1 = ds1.map(input_columns=["image"], operations=decode_op, cache=some_cache)
  651. ds1 = ds1.repeat(4)
  652. num_iter = 0
  653. for _ in ds1.create_dict_iterator(num_epochs=1):
  654. num_iter += 1
  655. logger.info("Number of data in ds1: {} ".format(num_iter))
  656. assert num_iter == 4
  657. logger.info("test_cache_nomap_parallel_pipeline2 Ended.\n")
  658. @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
  659. def test_cache_nomap_parallel_workers():
  660. """
  661. Test cache with num_parallel_workers > 1 set for map op and leaf op
  662. Repeat
  663. |
  664. Map(decode)
  665. |
  666. cache
  667. |
  668. TFReader
  669. """
  670. logger.info("Test cache nomap parallel workers")
  671. if "SESSION_ID" in os.environ:
  672. session_id = int(os.environ['SESSION_ID'])
  673. else:
  674. raise RuntimeError("Testcase requires SESSION_ID environment variable")
  675. some_cache = ds.DatasetCache(session_id=session_id, size=0)
  676. # This dataset has 3 records in it only
  677. ds1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, num_parallel_workers=4)
  678. decode_op = c_vision.Decode()
  679. ds1 = ds1.map(input_columns=["image"], operations=decode_op, num_parallel_workers=4, cache=some_cache)
  680. ds1 = ds1.repeat(4)
  681. num_iter = 0
  682. for _ in ds1.create_dict_iterator(num_epochs=1):
  683. num_iter += 1
  684. logger.info("Number of data in ds1: {} ".format(num_iter))
  685. assert num_iter == 12
  686. logger.info("test_cache_nomap_parallel_workers Ended.\n")
  687. @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
  688. def test_cache_nomap_server_workers_1():
  689. """
  690. start cache server with --workers 1 and then test cache function
  691. Repeat
  692. |
  693. cache
  694. |
  695. Map(decode)
  696. |
  697. TFRecord
  698. """
  699. logger.info("Test cache nomap server workers 1")
  700. if "SESSION_ID" in os.environ:
  701. session_id = int(os.environ['SESSION_ID'])
  702. else:
  703. raise RuntimeError("Testcase requires SESSION_ID environment variable")
  704. some_cache = ds.DatasetCache(session_id=session_id, size=0)
  705. # This dataset has 3 records in it only
  706. ds1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR)
  707. decode_op = c_vision.Decode()
  708. ds1 = ds1.map(input_columns=["image"], operations=decode_op, cache=some_cache)
  709. ds1 = ds1.repeat(4)
  710. num_iter = 0
  711. for _ in ds1.create_dict_iterator():
  712. num_iter += 1
  713. logger.info("Number of data in ds1: {} ".format(num_iter))
  714. assert num_iter == 12
  715. logger.info("test_cache_nomap_server_workers_1 Ended.\n")
  716. @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
  717. def test_cache_nomap_server_workers_100():
  718. """
  719. start cache server with --workers 100 and then test cache function
  720. Repeat
  721. |
  722. Map(decode)
  723. |
  724. cache
  725. |
  726. TFRecord
  727. """
  728. logger.info("Test cache nomap server workers 100")
  729. if "SESSION_ID" in os.environ:
  730. session_id = int(os.environ['SESSION_ID'])
  731. else:
  732. raise RuntimeError("Testcase requires SESSION_ID environment variable")
  733. some_cache = ds.DatasetCache(session_id=session_id, size=0)
  734. # This dataset has 3 records in it only
  735. ds1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, cache=some_cache)
  736. decode_op = c_vision.Decode()
  737. ds1 = ds1.map(input_columns=["image"], operations=decode_op)
  738. ds1 = ds1.repeat(4)
  739. num_iter = 0
  740. for _ in ds1.create_dict_iterator():
  741. num_iter += 1
  742. logger.info("Number of data in ds1: {} ".format(num_iter))
  743. assert num_iter == 12
  744. logger.info("test_cache_nomap_server_workers_100 Ended.\n")
  745. @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
  746. def test_cache_nomap_num_connections_1():
  747. """
  748. Test setting num_connections=1 in DatasetCache
  749. Repeat
  750. |
  751. cache
  752. |
  753. Map(decode)
  754. |
  755. TFRecord
  756. """
  757. logger.info("Test cache nomap num_connections 1")
  758. if "SESSION_ID" in os.environ:
  759. session_id = int(os.environ['SESSION_ID'])
  760. else:
  761. raise RuntimeError("Testcase requires SESSION_ID environment variable")
  762. some_cache = ds.DatasetCache(session_id=session_id, size=0, num_connections=1)
  763. # This dataset has 3 records in it only
  764. ds1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR)
  765. decode_op = c_vision.Decode()
  766. ds1 = ds1.map(input_columns=["image"], operations=decode_op, cache=some_cache)
  767. ds1 = ds1.repeat(4)
  768. num_iter = 0
  769. for _ in ds1.create_dict_iterator():
  770. num_iter += 1
  771. logger.info("Number of data in ds1: {} ".format(num_iter))
  772. assert num_iter == 12
  773. logger.info("test_cache_nomap_num_connections_1 Ended.\n")
  774. @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
  775. def test_cache_nomap_num_connections_100():
  776. """
  777. Test setting num_connections=100 in DatasetCache
  778. Repeat
  779. |
  780. Map(decode)
  781. |
  782. cache
  783. |
  784. TFRecord
  785. """
  786. logger.info("Test cache nomap num_connections 100")
  787. if "SESSION_ID" in os.environ:
  788. session_id = int(os.environ['SESSION_ID'])
  789. else:
  790. raise RuntimeError("Testcase requires SESSION_ID environment variable")
  791. some_cache = ds.DatasetCache(session_id=session_id, size=0, num_connections=100)
  792. # This dataset has 3 records in it only
  793. ds1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, cache=some_cache)
  794. decode_op = c_vision.Decode()
  795. ds1 = ds1.map(input_columns=["image"], operations=decode_op)
  796. ds1 = ds1.repeat(4)
  797. num_iter = 0
  798. for _ in ds1.create_dict_iterator():
  799. num_iter += 1
  800. logger.info("Number of data in ds1: {} ".format(num_iter))
  801. assert num_iter == 12
  802. logger.info("test_cache_nomap_num_connections_100 Ended.\n")
  803. @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
  804. def test_cache_nomap_prefetch_size_1():
  805. """
  806. Test setting prefetch_size=1 in DatasetCache
  807. Repeat
  808. |
  809. cache
  810. |
  811. Map(decode)
  812. |
  813. TFRecord
  814. """
  815. logger.info("Test cache nomap prefetch_size 1")
  816. if "SESSION_ID" in os.environ:
  817. session_id = int(os.environ['SESSION_ID'])
  818. else:
  819. raise RuntimeError("Testcase requires SESSION_ID environment variable")
  820. some_cache = ds.DatasetCache(session_id=session_id, size=0, prefetch_size=1)
  821. # This dataset has 3 records in it only
  822. ds1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR)
  823. decode_op = c_vision.Decode()
  824. ds1 = ds1.map(input_columns=["image"], operations=decode_op, cache=some_cache)
  825. ds1 = ds1.repeat(4)
  826. num_iter = 0
  827. for _ in ds1.create_dict_iterator():
  828. num_iter += 1
  829. logger.info("Number of data in ds1: {} ".format(num_iter))
  830. assert num_iter == 12
  831. logger.info("test_cache_nomap_prefetch_size_1 Ended.\n")
  832. @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
  833. def test_cache_nomap_prefetch_size_100():
  834. """
  835. Test setting prefetch_size=100 in DatasetCache
  836. Repeat
  837. |
  838. Map(decode)
  839. |
  840. cache
  841. |
  842. TFRecord
  843. """
  844. logger.info("Test cache nomap prefetch_size 100")
  845. if "SESSION_ID" in os.environ:
  846. session_id = int(os.environ['SESSION_ID'])
  847. else:
  848. raise RuntimeError("Testcase requires SESSION_ID environment variable")
  849. some_cache = ds.DatasetCache(session_id=session_id, size=0, prefetch_size=100)
  850. # This dataset has 3 records in it only
  851. ds1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, cache=some_cache)
  852. decode_op = c_vision.Decode()
  853. ds1 = ds1.map(input_columns=["image"], operations=decode_op)
  854. ds1 = ds1.repeat(4)
  855. num_iter = 0
  856. for _ in ds1.create_dict_iterator():
  857. num_iter += 1
  858. logger.info("Number of data in ds1: {} ".format(num_iter))
  859. assert num_iter == 12
  860. logger.info("test_cache_nomap_prefetch_size_100 Ended.\n")
  861. @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
  862. def test_cache_nomap_to_device():
  863. """
  864. Test cache with to_device
  865. DeviceQueue
  866. |
  867. EpochCtrl
  868. |
  869. Repeat
  870. |
  871. Map(decode)
  872. |
  873. cache
  874. |
  875. TFReader
  876. """
  877. logger.info("Test cache nomap to_device")
  878. if "SESSION_ID" in os.environ:
  879. session_id = int(os.environ['SESSION_ID'])
  880. else:
  881. raise RuntimeError("Testcase requires SESSION_ID environment variable")
  882. some_cache = ds.DatasetCache(session_id=session_id, size=0)
  883. # This dataset has 3 records in it only
  884. ds1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR)
  885. decode_op = c_vision.Decode()
  886. ds1 = ds1.map(input_columns=["image"], operations=decode_op, cache=some_cache)
  887. ds1 = ds1.repeat(4)
  888. ds1 = ds1.to_device()
  889. ds1.send()
  890. logger.info("test_cache_nomap_to_device Ended.\n")
  891. @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
  892. def test_cache_nomap_session_destroy():
  893. """
  894. Test executing cache_admin -d while the pipeline is running
  895. Repeat
  896. |
  897. Cache
  898. |
  899. RandomDataset
  900. """
  901. logger.info("Test cache nomap session destroy")
  902. if "SESSION_ID" in os.environ:
  903. session_id = int(os.environ['SESSION_ID'])
  904. else:
  905. raise RuntimeError("Testcase requires SESSION_ID environment variable")
  906. schema = ds.Schema()
  907. schema.add_column('image', de_type=mstype.uint8,
  908. shape=[640, 480, 3]) # 921600 bytes (a bit less than 1 MB per image)
  909. schema.add_column('label', de_type=mstype.uint8, shape=[1])
  910. some_cache = ds.DatasetCache(session_id=session_id, size=0)
  911. # User-created sampler here
  912. ds1 = ds.RandomDataset(schema=schema, num_parallel_workers=4, cache=some_cache)
  913. ds1 = ds1.repeat()
  914. with pytest.raises(RuntimeError) as e:
  915. num_iter = 0
  916. for _ in ds1.create_dict_iterator():
  917. num_iter += 1
  918. assert "Unexpected error" in str(e.value)
  919. logger.info("test_cache_nomap_session_destroy Ended.\n")
  920. @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
  921. def test_cache_nomap_server_stop():
  922. """
  923. Test executing cache_admin --stop while the pipeline is running
  924. Repeat
  925. |
  926. Cache
  927. |
  928. RandomDataset
  929. """
  930. logger.info("Test cache nomap server stop")
  931. if "SESSION_ID" in os.environ:
  932. session_id = int(os.environ['SESSION_ID'])
  933. else:
  934. raise RuntimeError("Testcase requires SESSION_ID environment variable")
  935. schema = ds.Schema()
  936. schema.add_column('image', de_type=mstype.uint8,
  937. shape=[640, 480, 3]) # 921600 bytes (a bit less than 1 MB per image)
  938. schema.add_column('label', de_type=mstype.uint8, shape=[1])
  939. some_cache = ds.DatasetCache(session_id=session_id, size=0)
  940. # User-created sampler here
  941. ds1 = ds.RandomDataset(schema=schema, num_parallel_workers=4, cache=some_cache)
  942. ds1 = ds1.repeat()
  943. with pytest.raises(RuntimeError) as e:
  944. num_iter = 0
  945. for _ in ds1.create_dict_iterator():
  946. num_iter += 1
  947. assert "Network error. Cache server with port 50052 is unreachable. Make sure the server is running." in \
  948. str(e.value)
  949. logger.info("test_cache_nomap_server_stop Ended.\n")
  950. @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
  951. def test_cache_nomap_interrupt_and_rerun():
  952. """
  953. Test interrupt a running pipeline and then re-use the same cache to run another pipeline
  954. Cache
  955. |
  956. RandomDataset
  957. """
  958. logger.info("Test cache nomap interrupt and rerun")
  959. if "SESSION_ID" in os.environ:
  960. session_id = int(os.environ['SESSION_ID'])
  961. else:
  962. raise RuntimeError("Testcase requires SESSION_ID environment variable")
  963. schema = ds.Schema()
  964. schema.add_column('image', de_type=mstype.uint8,
  965. shape=[640, 480, 3]) # 921600 bytes (a bit less than 1 MB per image)
  966. schema.add_column('label', de_type=mstype.uint8, shape=[1])
  967. some_cache = ds.DatasetCache(session_id=session_id, size=0)
  968. # User-created sampler here
  969. ds1 = ds.RandomDataset(schema=schema, total_rows=10000, num_parallel_workers=4, cache=some_cache)
  970. iter1 = ds1.create_dict_iterator()
  971. num_iter = 0
  972. with pytest.raises(AttributeError) as e:
  973. for _ in iter1:
  974. num_iter += 1
  975. if num_iter == 10:
  976. iter1.stop()
  977. assert "'DictIterator' object has no attribute '_runtime_context'" in str(e.value)
  978. num_epoch = 2
  979. iter2 = ds1.create_dict_iterator(num_epochs=num_epoch)
  980. epoch_count = 0
  981. for _ in range(num_epoch):
  982. num_iter = 0
  983. for _ in iter2:
  984. num_iter += 1
  985. logger.info("Number of data in ds1: {} ".format(num_iter))
  986. assert num_iter == 10000
  987. epoch_count += 1
  988. cache_stat = some_cache.GetStat()
  989. assert cache_stat.num_mem_cached == 10000
  990. logger.info("test_cache_nomap_interrupt_and_rerun Ended.\n")
  991. @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
  992. def test_cache_nomap_epoch_ctrl1():
  993. """
  994. Test using two-loops method to run several epochs
  995. Map(decode)
  996. |
  997. cache
  998. |
  999. TFRecord
  1000. """
  1001. logger.info("Test cache nomap epoch ctrl1")
  1002. if "SESSION_ID" in os.environ:
  1003. session_id = int(os.environ['SESSION_ID'])
  1004. else:
  1005. raise RuntimeError("Testcase requires SESSION_ID environment variable")
  1006. some_cache = ds.DatasetCache(session_id=session_id, size=0)
  1007. # This dataset has 3 records in it only
  1008. ds1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, cache=some_cache)
  1009. decode_op = c_vision.Decode()
  1010. ds1 = ds1.map(input_columns=["image"], operations=decode_op)
  1011. num_epoch = 5
  1012. iter1 = ds1.create_dict_iterator(num_epochs=num_epoch)
  1013. epoch_count = 0
  1014. for _ in range(num_epoch):
  1015. row_count = 0
  1016. for _ in iter1:
  1017. row_count += 1
  1018. logger.info("Number of data in ds1: {} ".format(row_count))
  1019. assert row_count == 3
  1020. epoch_count += 1
  1021. assert epoch_count == num_epoch
  1022. logger.info("test_cache_nomap_epoch_ctrl1 Ended.\n")
  1023. @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
  1024. def test_cache_nomap_epoch_ctrl2():
  1025. """
  1026. Test using two-loops method with infinite epochs
  1027. cache
  1028. |
  1029. Map(decode)
  1030. |
  1031. TFRecord
  1032. """
  1033. logger.info("Test cache nomap epoch ctrl2")
  1034. if "SESSION_ID" in os.environ:
  1035. session_id = int(os.environ['SESSION_ID'])
  1036. else:
  1037. raise RuntimeError("Testcase requires SESSION_ID environment variable")
  1038. some_cache = ds.DatasetCache(session_id=session_id, size=0)
  1039. # This dataset has 3 records in it only
  1040. ds1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR)
  1041. decode_op = c_vision.Decode()
  1042. ds1 = ds1.map(input_columns=["image"], operations=decode_op, cache=some_cache)
  1043. num_epoch = 5
  1044. # iter1 will always assume there is a next epoch and never shutdown
  1045. iter1 = ds1.create_dict_iterator()
  1046. epoch_count = 0
  1047. for _ in range(num_epoch):
  1048. row_count = 0
  1049. for _ in iter1:
  1050. row_count += 1
  1051. logger.info("Number of data in ds1: {} ".format(row_count))
  1052. assert row_count == 3
  1053. epoch_count += 1
  1054. assert epoch_count == num_epoch
  1055. # manually stop the iterator
  1056. iter1.stop()
  1057. logger.info("test_cache_nomap_epoch_ctrl2 Ended.\n")
  1058. @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
  1059. def test_cache_nomap_epoch_ctrl3():
  1060. """
  1061. Test using two-loops method with infinite epochs over repeat
  1062. repeat
  1063. |
  1064. Map(decode)
  1065. |
  1066. cache
  1067. |
  1068. TFRecord
  1069. """
  1070. logger.info("Test cache nomap epoch ctrl3")
  1071. if "SESSION_ID" in os.environ:
  1072. session_id = int(os.environ['SESSION_ID'])
  1073. else:
  1074. raise RuntimeError("Testcase requires SESSION_ID environment variable")
  1075. some_cache = ds.DatasetCache(session_id=session_id, size=0)
  1076. # This dataset has 3 records in it only
  1077. ds1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, cache=some_cache)
  1078. decode_op = c_vision.Decode()
  1079. ds1 = ds1.map(input_columns=["image"], operations=decode_op)
  1080. ds1 = ds1.repeat(2)
  1081. num_epoch = 5
  1082. # iter1 will always assume there is a next epoch and never shutdown
  1083. iter1 = ds1.create_dict_iterator()
  1084. epoch_count = 0
  1085. for _ in range(num_epoch):
  1086. row_count = 0
  1087. for _ in iter1:
  1088. row_count += 1
  1089. logger.info("Number of data in ds1: {} ".format(row_count))
  1090. assert row_count == 6
  1091. epoch_count += 1
  1092. assert epoch_count == num_epoch
  1093. # reply on garbage collector to destroy iter1
  1094. logger.info("test_cache_nomap_epoch_ctrl3 Ended.\n")
  1095. @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
  1096. def test_cache_nomap_epoch_ctrl4():
  1097. """
  1098. Test using two-loops method with repeat under cache
  1099. cache
  1100. |
  1101. Map(decode)
  1102. |
  1103. repeat
  1104. |
  1105. TFRecord
  1106. """
  1107. logger.info("Test cache nomap epoch ctrl4")
  1108. if "SESSION_ID" in os.environ:
  1109. session_id = int(os.environ['SESSION_ID'])
  1110. else:
  1111. raise RuntimeError("Testcase requires SESSION_ID environment variable")
  1112. some_cache = ds.DatasetCache(session_id=session_id, size=0)
  1113. # This dataset has 3 records in it only
  1114. ds1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR)
  1115. ds1 = ds1.repeat(2)
  1116. decode_op = c_vision.Decode()
  1117. ds1 = ds1.map(input_columns=["image"], operations=decode_op, cache=some_cache)
  1118. num_epoch = 5
  1119. iter1 = ds1.create_dict_iterator(num_epochs=num_epoch)
  1120. epoch_count = 0
  1121. for _ in range(num_epoch):
  1122. row_count = 0
  1123. for _ in iter1:
  1124. row_count += 1
  1125. logger.info("Number of data in ds1: {} ".format(row_count))
  1126. assert row_count == 6
  1127. epoch_count += 1
  1128. assert epoch_count == num_epoch
  1129. logger.info("test_cache_nomap_epoch_ctrl4 Ended.\n")
  1130. @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
  1131. def test_cache_nomap_multiple_cache1():
  1132. """
  1133. Test multiple cache in the same python script
  1134. cache cache
  1135. | |
  1136. Map(decode) Map(decode)
  1137. | |
  1138. TFRecord(train) TFRecord(eval)
  1139. """
  1140. logger.info("Test cache nomap multiple cache 1")
  1141. if "SESSION_ID" in os.environ:
  1142. session_id = int(os.environ['SESSION_ID'])
  1143. else:
  1144. raise RuntimeError("Testcase requires SESSION_ID environment variable")
  1145. train_cache = ds.DatasetCache(session_id=session_id, size=0)
  1146. eval_cache = ds.DatasetCache(session_id=session_id, size=0)
  1147. # This dataset has 12 records in it
  1148. train_dataset = ds.TFRecordDataset(TRAIN_DATA_DIR, TRAIN_SCHEMA_DIR)
  1149. decode_op = c_vision.Decode()
  1150. train_dataset = train_dataset.map(input_columns=["image"], operations=decode_op, cache=train_cache)
  1151. # This dataset has 3 records in it only
  1152. eval_dataset = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR)
  1153. eval_dataset = eval_dataset.map(input_columns=["image"], operations=decode_op, cache=eval_cache)
  1154. num_epoch = 5
  1155. train_iter = train_dataset.create_dict_iterator(num_epochs=num_epoch)
  1156. eval_iter = eval_dataset.create_dict_iterator(num_epochs=num_epoch)
  1157. epoch_count = 0
  1158. for _ in range(num_epoch):
  1159. assert sum([1 for _ in train_iter]) == 12
  1160. assert sum([1 for _ in eval_iter]) == 3
  1161. epoch_count += 1
  1162. assert epoch_count == num_epoch
  1163. logger.info("test_cache_nomap_multiple_cache1 Ended.\n")
  1164. @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
  1165. def test_cache_nomap_multiple_cache2():
  1166. """
  1167. Test multiple cache in the same python script
  1168. cache
  1169. |
  1170. Map(decode) cache
  1171. | |
  1172. TFRecord(image) TFRecord(text)
  1173. """
  1174. logger.info("Test cache nomap multiple cache 2")
  1175. if "SESSION_ID" in os.environ:
  1176. session_id = int(os.environ['SESSION_ID'])
  1177. else:
  1178. raise RuntimeError("Testcase requires SESSION_ID environment variable")
  1179. image_cache = ds.DatasetCache(session_id=session_id, size=0)
  1180. text_cache = ds.DatasetCache(session_id=session_id, size=0)
  1181. # This dataset has 3 records in it only
  1182. image_dataset = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR)
  1183. decode_op = c_vision.Decode()
  1184. image_dataset = image_dataset.map(input_columns=["image"], operations=decode_op, cache=image_cache)
  1185. # This dataset has 3 records in it only
  1186. text_dataset = ds.TFRecordDataset(TEXT_TF_DATA_DIR, SCHEMA_DIR2, cache=text_cache)
  1187. num_epoch = 5
  1188. image_iter = image_dataset.create_dict_iterator(num_epochs=num_epoch)
  1189. text_iter = text_dataset.create_dict_iterator(num_epochs=num_epoch, output_numpy=True)
  1190. epoch_count = 0
  1191. for _ in range(num_epoch):
  1192. row_count = 0
  1193. for _, _ in itertools.zip_longest(image_iter, text_iter):
  1194. row_count += 1
  1195. assert row_count == 3
  1196. epoch_count += 1
  1197. assert epoch_count == num_epoch
  1198. logger.info("test_cache_nomap_multiple_cache2 Ended.\n")
  1199. @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
  1200. def test_cache_nomap_multiple_cache3():
  1201. """
  1202. Test multiple cache in the same python script
  1203. cache cache
  1204. | |
  1205. Map(decode) Map(decode)
  1206. | |
  1207. TFRecord ImageFolder
  1208. """
  1209. logger.info("Test cache nomap multiple cache 3")
  1210. if "SESSION_ID" in os.environ:
  1211. session_id = int(os.environ['SESSION_ID'])
  1212. else:
  1213. raise RuntimeError("Testcase requires SESSION_ID environment variable")
  1214. tf_cache = ds.DatasetCache(session_id=session_id, size=0)
  1215. image_cache = ds.DatasetCache(session_id=session_id, size=0)
  1216. # This dataset has 3 records in it only
  1217. tf_dataset = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR)
  1218. decode_op = c_vision.Decode()
  1219. tf_dataset = tf_dataset.map(input_columns=["image"], operations=decode_op, cache=tf_cache)
  1220. # This DATA_DIR only has 2 images in it
  1221. image_dataset = ds.ImageFolderDataset(dataset_dir=IMAGE_FOLDER_DATA_DIR)
  1222. image_dataset = image_dataset.map(input_columns=["image"], operations=decode_op, cache=image_cache)
  1223. num_epoch = 5
  1224. tf_iter = tf_dataset.create_dict_iterator(num_epochs=num_epoch)
  1225. image_iter = image_dataset.create_dict_iterator(num_epochs=num_epoch)
  1226. epoch_count = 0
  1227. for _ in range(num_epoch):
  1228. assert sum([1 for _ in tf_iter]) == 3
  1229. assert sum([1 for _ in image_iter]) == 2
  1230. epoch_count += 1
  1231. assert epoch_count == num_epoch
  1232. logger.info("test_cache_nomap_multiple_cache3 Ended.\n")
  1233. @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
  1234. def test_cache_nomap_multiple_cache_train():
  1235. """
  1236. Test multiple cache in different python scripts. This test case is going to run concurrently with
  1237. test_cache_nomap_multiple_cache_eval.
  1238. cache
  1239. |
  1240. Map(decode)
  1241. |
  1242. TFRecord(train)
  1243. """
  1244. logger.info("Test cache nomap multiple cache train")
  1245. if "SESSION_ID" in os.environ:
  1246. session_id = int(os.environ['SESSION_ID'])
  1247. else:
  1248. raise RuntimeError("Testcase requires SESSION_ID environment variable")
  1249. train_cache = ds.DatasetCache(session_id=session_id, size=0)
  1250. # This dataset has 12 records in it
  1251. train_dataset = ds.TFRecordDataset(TRAIN_DATA_DIR, TRAIN_SCHEMA_DIR)
  1252. decode_op = c_vision.Decode()
  1253. train_dataset = train_dataset.map(input_columns=["image"], operations=decode_op, cache=train_cache)
  1254. num_epoch = 5
  1255. train_iter = train_dataset.create_dict_iterator(num_epochs=num_epoch)
  1256. epoch_count = 0
  1257. for _ in range(num_epoch):
  1258. assert sum([1 for _ in train_iter]) == 12
  1259. epoch_count += 1
  1260. assert epoch_count == num_epoch
  1261. logger.info("test_cache_nomap_multiple_cache_train Ended.\n")
  1262. @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
  1263. def test_cache_nomap_multiple_cache_eval():
  1264. """
  1265. Test multiple cache in different python scripts. This test case is going to run concurrently with
  1266. test_cache_nomap_multiple_cache_train.
  1267. cache
  1268. |
  1269. Map(decode)
  1270. |
  1271. TFRecord(eval)
  1272. """
  1273. logger.info("Test cache nomap multiple cache eval")
  1274. if "SESSION_ID" in os.environ:
  1275. session_id = int(os.environ['SESSION_ID'])
  1276. else:
  1277. raise RuntimeError("Testcase requires SESSION_ID environment variable")
  1278. eval_cache = ds.DatasetCache(session_id=session_id, size=0)
  1279. # This dataset only has 3 records in it
  1280. eval_dataset = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR)
  1281. decode_op = c_vision.Decode()
  1282. eval_dataset = eval_dataset.map(input_columns=["image"], operations=decode_op, cache=eval_cache)
  1283. num_epoch = 5
  1284. eval_iter = eval_dataset.create_dict_iterator(num_epochs=num_epoch)
  1285. epoch_count = 0
  1286. for _ in range(num_epoch):
  1287. assert sum([1 for _ in eval_iter]) == 3
  1288. epoch_count += 1
  1289. assert epoch_count == num_epoch
  1290. logger.info("test_cache_nomap_multiple_cache_eval Ended.\n")
  1291. @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
  1292. def test_cache_nomap_clue1():
  1293. """
  1294. A clue dataset (a non mappable dataset) with a cache over it just after the leaf
  1295. In this one, the clue dataset will be given sharding configuration, however since a cache is
  1296. used, the tree prepare should undo the sharding configuration and instead, a distributed
  1297. sampler will be chosen with the same shard config.
  1298. Cache
  1299. |
  1300. CLUE
  1301. """
  1302. logger.info("Test cache nomap clue 1")
  1303. if "SESSION_ID" in os.environ:
  1304. session_id = int(os.environ['SESSION_ID'])
  1305. else:
  1306. raise RuntimeError("Testcase requires SESSION_ID environment variable")
  1307. some_cache = ds.DatasetCache(session_id=session_id, size=0)
  1308. # With only 3 records shard into 3, we expect only 1 record returned for this shard
  1309. # However, the sharding will be done by the sampler, not by the clue leaf node
  1310. # In this case, it is a row-based sharding, not the file-based sharding that would happen if
  1311. # there was not any cache.
  1312. ds1 = ds.CLUEDataset(CLUE_DATA_DIR, task='AFQMC', usage='train', num_shards=3, shard_id=1, cache=some_cache)
  1313. num_epoch = 4
  1314. iter1 = ds1.create_dict_iterator(num_epochs=num_epoch, output_numpy=True)
  1315. epoch_count = 0
  1316. for _ in range(num_epoch):
  1317. assert sum([1 for _ in iter1]) == 1
  1318. epoch_count += 1
  1319. assert epoch_count == num_epoch
  1320. logger.info("test_cache_nomap_clue1 Ended.\n")
  1321. @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
  1322. def test_cache_nomap_clue2():
  1323. """
  1324. A clue dataset (a non mappable dataset) with a cache over it after map
  1325. In this one, a num_samples argument is given
  1326. Cache
  1327. |
  1328. map(lambda x: x)
  1329. |
  1330. CLUE
  1331. """
  1332. logger.info("Test cache nomap clue 2")
  1333. if "SESSION_ID" in os.environ:
  1334. session_id = int(os.environ['SESSION_ID'])
  1335. else:
  1336. raise RuntimeError("Testcase requires SESSION_ID environment variable")
  1337. some_cache = ds.DatasetCache(session_id=session_id, size=0)
  1338. ds1 = ds.CLUEDataset(CLUE_DATA_DIR, task='AFQMC', usage='train', num_samples=2)
  1339. ds1 = ds1.map(py_vision.not_random(lambda x: x), ["label"], cache=some_cache)
  1340. num_epoch = 4
  1341. iter1 = ds1.create_dict_iterator(num_epochs=num_epoch, output_numpy=True)
  1342. epoch_count = 0
  1343. for _ in range(num_epoch):
  1344. assert sum([1 for _ in iter1]) == 2
  1345. epoch_count += 1
  1346. assert epoch_count == num_epoch
  1347. logger.info("test_cache_nomap_clue2 Ended.\n")
  1348. @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
  1349. def test_cache_nomap_csv1():
  1350. """
  1351. A csv dataset (a non mappable dataset) with a cache over it just after the leaf
  1352. In this one, the csv dataset will be given sharding configuration, however since a cache is
  1353. used, the tree prepare should undo the sharding configuration and instead, a distributed
  1354. sampler will be chosen with the same shard config.
  1355. Cache
  1356. |
  1357. CSV
  1358. """
  1359. logger.info("Test cache nomap csv 1")
  1360. if "SESSION_ID" in os.environ:
  1361. session_id = int(os.environ['SESSION_ID'])
  1362. else:
  1363. raise RuntimeError("Testcase requires SESSION_ID environment variable")
  1364. some_cache = ds.DatasetCache(session_id=session_id, size=0)
  1365. # With only 3 records shard into 3, we expect only 1 record returned for this shard
  1366. # However, the sharding will be done by the sampler, not by the clue leaf node
  1367. # In this case, it is a row-based sharding, not the file-based sharding that would happen if
  1368. # there was not any cache.
  1369. ds1 = ds.CSVDataset(CSV_DATA_DIR, column_defaults=["1", "2", "3", "4"],
  1370. column_names=['col1', 'col2', 'col3', 'col4'], num_shards=3, shard_id=1, cache=some_cache)
  1371. num_epoch = 4
  1372. iter1 = ds1.create_dict_iterator(num_epochs=num_epoch, output_numpy=True)
  1373. epoch_count = 0
  1374. for _ in range(num_epoch):
  1375. assert sum([1 for _ in iter1]) == 1
  1376. epoch_count += 1
  1377. assert epoch_count == num_epoch
  1378. logger.info("test_cache_nomap_csv1 Ended.\n")
  1379. @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
  1380. def test_cache_nomap_csv2():
  1381. """
  1382. A csv dataset (a non mappable dataset) with a cache over it after map
  1383. In this one, a num_samples argument is given
  1384. Cache
  1385. |
  1386. map(lambda x: x)
  1387. |
  1388. CSV
  1389. """
  1390. logger.info("Test cache nomap csv 2")
  1391. if "SESSION_ID" in os.environ:
  1392. session_id = int(os.environ['SESSION_ID'])
  1393. else:
  1394. raise RuntimeError("Testcase requires SESSION_ID environment variable")
  1395. some_cache = ds.DatasetCache(session_id=session_id, size=0)
  1396. ds1 = ds.CSVDataset(CSV_DATA_DIR, column_defaults=["1", "2", "3", "4"],
  1397. column_names=['col1', 'col2', 'col3', 'col4'], num_samples=2)
  1398. ds1 = ds1.map(py_vision.not_random(lambda x: x), ["col1"], cache=some_cache)
  1399. num_epoch = 4
  1400. iter1 = ds1.create_dict_iterator(num_epochs=num_epoch, output_numpy=True)
  1401. epoch_count = 0
  1402. for _ in range(num_epoch):
  1403. assert sum([1 for _ in iter1]) == 2
  1404. epoch_count += 1
  1405. assert epoch_count == num_epoch
  1406. logger.info("test_cache_nomap_csv2 Ended.\n")
  1407. @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
  1408. def test_cache_nomap_textfile1():
  1409. """
  1410. A text file dataset (a non mappable dataset) with a cache over it just after the leaf
  1411. In this one, the text file dataset will be given sharding configuration, however since a cache is
  1412. used, the tree prepare should undo the sharding configuration and instead, a distributed
  1413. sampler will be chosen with the same shard config.
  1414. Cache
  1415. |
  1416. TextFile
  1417. """
  1418. logger.info("Test cache nomap textfile 1")
  1419. if "SESSION_ID" in os.environ:
  1420. session_id = int(os.environ['SESSION_ID'])
  1421. else:
  1422. raise RuntimeError("Testcase requires SESSION_ID environment variable")
  1423. some_cache = ds.DatasetCache(session_id=session_id, size=0)
  1424. # With only 3 records shard into 3, we expect only 1 record returned for this shard
  1425. # However, the sharding will be done by the sampler, not by the clue leaf node
  1426. # In this case, it is a row-based sharding, not the file-based sharding that would happen if
  1427. # there was not any cache.
  1428. ds1 = ds.TextFileDataset(TEXT_FILE_DATA_DIR, num_shards=3, shard_id=1, cache=some_cache)
  1429. num_epoch = 4
  1430. iter1 = ds1.create_dict_iterator(num_epochs=num_epoch, output_numpy=True)
  1431. epoch_count = 0
  1432. for _ in range(num_epoch):
  1433. assert sum([1 for _ in iter1]) == 1
  1434. epoch_count += 1
  1435. assert epoch_count == num_epoch
  1436. logger.info("test_cache_nomap_textfile1 Ended.\n")
  1437. @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
  1438. def test_cache_nomap_textfile2():
  1439. """
  1440. A text file dataset (a non mappable dataset) with a cache over it after map
  1441. In this one, a num_samples argument is given
  1442. Cache
  1443. |
  1444. Map(tokenizer)
  1445. |
  1446. TextFile
  1447. """
  1448. def my_tokenizer(line):
  1449. words = line.split()
  1450. if not words:
  1451. return [""]
  1452. return words
  1453. logger.info("Test cache nomap textfile 2")
  1454. if "SESSION_ID" in os.environ:
  1455. session_id = int(os.environ['SESSION_ID'])
  1456. else:
  1457. raise RuntimeError("Testcase requires SESSION_ID environment variable")
  1458. some_cache = ds.DatasetCache(session_id=session_id, size=0)
  1459. ds1 = ds.TextFileDataset(TEXT_FILE_DATA_DIR, num_samples=2)
  1460. tokenizer = text.PythonTokenizer(my_tokenizer)
  1461. ds1 = ds1.map(operations=tokenizer, cache=some_cache)
  1462. num_epoch = 4
  1463. iter1 = ds1.create_dict_iterator(num_epochs=num_epoch, output_numpy=True)
  1464. epoch_count = 0
  1465. for _ in range(num_epoch):
  1466. assert sum([1 for _ in iter1]) == 2
  1467. epoch_count += 1
  1468. assert epoch_count == num_epoch
  1469. logger.info("test_cache_nomap_textfile2 Ended.\n")
  1470. @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
  1471. def test_cache_nomap_nested_repeat():
  1472. """
  1473. Test cache on pipeline with nested repeat ops
  1474. Repeat
  1475. |
  1476. Cache
  1477. |
  1478. Map(decode)
  1479. |
  1480. Repeat
  1481. |
  1482. TFRecord
  1483. """
  1484. logger.info("Test cache nomap nested repeat")
  1485. if "SESSION_ID" in os.environ:
  1486. session_id = int(os.environ['SESSION_ID'])
  1487. else:
  1488. raise RuntimeError("Testcase requires SESSION_ID environment variable")
  1489. some_cache = ds.DatasetCache(session_id=session_id, size=0)
  1490. # This dataset has 3 records in it only
  1491. ds1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR)
  1492. decode_op = c_vision.Decode()
  1493. ds1 = ds1.repeat(4)
  1494. ds1 = ds1.map(operations=decode_op, input_columns=["image"], cache=some_cache)
  1495. ds1 = ds1.repeat(2)
  1496. num_iter = 0
  1497. for _ in ds1.create_dict_iterator(num_epochs=1):
  1498. logger.info("get data from dataset")
  1499. num_iter += 1
  1500. logger.info("Number of data in ds1: {} ".format(num_iter))
  1501. assert num_iter == 24
  1502. logger.info('test_cache_nomap_nested_repeat Ended.\n')
  1503. @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
  1504. def test_cache_nomap_get_repeat_count():
  1505. """
  1506. Test get_repeat_count() for a pipeline with cache and nested repeat ops
  1507. Cache
  1508. |
  1509. Map(decode)
  1510. |
  1511. Repeat
  1512. |
  1513. TFRecord
  1514. """
  1515. logger.info("Test cache nomap get_repeat_count")
  1516. if "SESSION_ID" in os.environ:
  1517. session_id = int(os.environ['SESSION_ID'])
  1518. else:
  1519. raise RuntimeError("Testcase requires SESSION_ID environment variable")
  1520. some_cache = ds.DatasetCache(session_id=session_id, size=0)
  1521. # This dataset has 3 records in it only
  1522. ds1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False)
  1523. ds1 = ds1.repeat(4)
  1524. decode_op = c_vision.Decode()
  1525. ds1 = ds1.map(operations=decode_op, input_columns=["image"], cache=some_cache)
  1526. repeat_count = ds1.get_repeat_count()
  1527. logger.info("repeat_count: {}".format(repeat_count))
  1528. assert repeat_count == 4
  1529. num_iter = 0
  1530. for _ in ds1.create_dict_iterator(num_epochs=1):
  1531. logger.info("get data from dataset")
  1532. num_iter += 1
  1533. assert num_iter == 12
  1534. @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
  1535. def test_cache_nomap_long_file_list():
  1536. """
  1537. Test cache after TFRecord with a long list of files as arguments
  1538. Cache
  1539. |
  1540. TFRecord
  1541. """
  1542. logger.info("Test cache nomap long file list")
  1543. if "SESSION_ID" in os.environ:
  1544. session_id = int(os.environ['SESSION_ID'])
  1545. else:
  1546. raise RuntimeError("Testcase requires SESSION_ID environment variable")
  1547. some_cache = ds.DatasetCache(session_id=session_id, size=1)
  1548. ds1 = ds.TFRecordDataset([DATA_DIR[0] for _ in range(0, 1000)], SCHEMA_DIR, columns_list=["image"],
  1549. cache=some_cache)
  1550. with pytest.raises(RuntimeError) as e:
  1551. sum([1 for _ in ds1])
  1552. assert "Out of memory" in str(e.value)
  1553. logger.info("test_cache_nomap_long_file_list Ended.\n")
  1554. @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
  1555. def test_cache_nomap_failure1():
  1556. """
  1557. Test nested cache (failure)
  1558. Repeat
  1559. |
  1560. Cache
  1561. |
  1562. Map(decode)
  1563. |
  1564. Cache
  1565. |
  1566. TFRecord
  1567. """
  1568. logger.info("Test cache nomap failure 1")
  1569. if "SESSION_ID" in os.environ:
  1570. session_id = int(os.environ['SESSION_ID'])
  1571. else:
  1572. raise RuntimeError("Testcase requires SESSION_ID environment variable")
  1573. some_cache = ds.DatasetCache(session_id=session_id, size=0)
  1574. # This dataset has 3 records in it only
  1575. ds1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, cache=some_cache)
  1576. decode_op = c_vision.Decode()
  1577. ds1 = ds1.map(operations=decode_op, input_columns=["image"], cache=some_cache)
  1578. ds1 = ds1.repeat(4)
  1579. with pytest.raises(RuntimeError) as e:
  1580. ds1.get_batch_size()
  1581. assert "Nested cache operations" in str(e.value)
  1582. with pytest.raises(RuntimeError) as e:
  1583. num_iter = 0
  1584. for _ in ds1.create_dict_iterator(num_epochs=1):
  1585. num_iter += 1
  1586. assert "Nested cache operations" in str(e.value)
  1587. assert num_iter == 0
  1588. logger.info('test_cache_nomap_failure1 Ended.\n')
  1589. @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
  1590. def test_cache_nomap_failure2():
  1591. """
  1592. Test zip under cache (failure)
  1593. repeat
  1594. |
  1595. Cache
  1596. |
  1597. Map(decode)
  1598. |
  1599. Zip
  1600. | |
  1601. Random Random
  1602. """
  1603. logger.info("Test cache nomap failure 2")
  1604. if "SESSION_ID" in os.environ:
  1605. session_id = int(os.environ['SESSION_ID'])
  1606. else:
  1607. raise RuntimeError("Testcase requires SESSION_ID environment variable")
  1608. some_cache = ds.DatasetCache(session_id=session_id, size=0)
  1609. schema = ds.Schema()
  1610. schema.add_column('image', de_type=mstype.uint8,
  1611. shape=[640, 480, 3]) # 921600 bytes (a bit less than 1 MB per image)
  1612. schema.add_column('label', de_type=mstype.uint8, shape=[1])
  1613. ds1 = ds.RandomDataset(schema=schema)
  1614. ds2 = ds.RandomDataset(schema=schema)
  1615. dsz = ds.zip((ds1, ds2))
  1616. decode_op = c_vision.Decode()
  1617. dsz = dsz.map(input_columns=["image"], operations=decode_op, cache=some_cache)
  1618. dsz = dsz.repeat(4)
  1619. with pytest.raises(RuntimeError) as e:
  1620. num_iter = 0
  1621. for _ in dsz.create_dict_iterator():
  1622. num_iter += 1
  1623. assert "ZipNode is not supported as a descendant operator under a cache" in str(e.value)
  1624. assert num_iter == 0
  1625. logger.info('test_cache_nomap_failure2 Ended.\n')
  1626. @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
  1627. def test_cache_nomap_failure3():
  1628. """
  1629. Test batch under cache (failure)
  1630. repeat
  1631. |
  1632. Cache
  1633. |
  1634. Map(resize)
  1635. |
  1636. Batch
  1637. |
  1638. Clue
  1639. """
  1640. logger.info("Test cache nomap failure 3")
  1641. if "SESSION_ID" in os.environ:
  1642. session_id = int(os.environ['SESSION_ID'])
  1643. else:
  1644. raise RuntimeError("Testcase requires SESSION_ID environment variable")
  1645. some_cache = ds.DatasetCache(session_id=session_id, size=0)
  1646. ds1 = ds.CLUEDataset(CLUE_DATA_DIR, task='AFQMC', usage='train')
  1647. ds1 = ds1.batch(2)
  1648. resize_op = c_vision.Resize((224, 224))
  1649. ds1 = ds1.map(input_columns=["image"], operations=resize_op, cache=some_cache)
  1650. ds1 = ds1.repeat(4)
  1651. with pytest.raises(RuntimeError) as e:
  1652. num_iter = 0
  1653. for _ in ds1.create_dict_iterator():
  1654. num_iter += 1
  1655. assert "BatchNode is not supported as a descendant operator under a cache" in str(e.value)
  1656. assert num_iter == 0
  1657. logger.info('test_cache_nomap_failure3 Ended.\n')
  1658. @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
  1659. def test_cache_nomap_failure4():
  1660. """
  1661. Test filter under cache (failure)
  1662. repeat
  1663. |
  1664. Cache
  1665. |
  1666. Map(decode)
  1667. |
  1668. Filter
  1669. |
  1670. CSV
  1671. """
  1672. logger.info("Test cache nomap failure 4")
  1673. if "SESSION_ID" in os.environ:
  1674. session_id = int(os.environ['SESSION_ID'])
  1675. else:
  1676. raise RuntimeError("Testcase requires SESSION_ID environment variable")
  1677. some_cache = ds.DatasetCache(session_id=session_id, size=0)
  1678. ds1 = ds.CSVDataset(CSV_DATA_DIR, column_defaults=["1", "2", "3", "4"],
  1679. column_names=['col1', 'col2', 'col3', 'col4'])
  1680. ds1 = ds1.filter(predicate=lambda data: data < 11, input_columns=["label"])
  1681. decode_op = c_vision.Decode()
  1682. ds1 = ds1.map(input_columns=["image"], operations=decode_op, cache=some_cache)
  1683. ds1 = ds1.repeat(4)
  1684. with pytest.raises(RuntimeError) as e:
  1685. num_iter = 0
  1686. for _ in ds1.create_dict_iterator():
  1687. num_iter += 1
  1688. assert "FilterNode is not supported as a descendant operator under a cache" in str(e.value)
  1689. assert num_iter == 0
  1690. logger.info('test_cache_nomap_failure4 Ended.\n')
  1691. @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
  1692. def test_cache_nomap_failure5():
  1693. """
  1694. Test Map containing random operation under cache (failure)
  1695. repeat
  1696. |
  1697. Cache
  1698. |
  1699. Map(decode, randomCrop)
  1700. |
  1701. TextFile
  1702. """
  1703. logger.info("Test cache nomap failure 5")
  1704. if "SESSION_ID" in os.environ:
  1705. session_id = int(os.environ['SESSION_ID'])
  1706. else:
  1707. raise RuntimeError("Testcase requires SESSION_ID environment variable")
  1708. some_cache = ds.DatasetCache(session_id=session_id, size=0)
  1709. data = ds.TextFileDataset(TEXT_FILE_DATA_DIR)
  1710. random_crop_op = c_vision.RandomCrop([512, 512], [200, 200, 200, 200])
  1711. decode_op = c_vision.Decode()
  1712. data = data.map(input_columns=["image"], operations=decode_op)
  1713. data = data.map(input_columns=["image"], operations=random_crop_op, cache=some_cache)
  1714. data = data.repeat(4)
  1715. with pytest.raises(RuntimeError) as e:
  1716. num_iter = 0
  1717. for _ in data.create_dict_iterator():
  1718. num_iter += 1
  1719. assert "MapNode containing random operation is not supported as a descendant of cache" in str(e.value)
  1720. assert num_iter == 0
  1721. logger.info('test_cache_nomap_failure5 Ended.\n')
  1722. @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
  1723. def test_cache_nomap_pyfunc_lambda():
  1724. """
  1725. Test cache after map op with a python lambda function.
  1726. Only allowed if the lambda function is wrapped by 'pyvision.not_random', otherwise an error will be raised.
  1727. Cache
  1728. |
  1729. Map(lambda function1, lambda function2)
  1730. |
  1731. TFRecord
  1732. """
  1733. logger.info("Test cache nomap pyfunc lambda")
  1734. if "SESSION_ID" in os.environ:
  1735. session_id = int(os.environ['SESSION_ID'])
  1736. else:
  1737. raise RuntimeError("Testcase requires SESSION_ID environment variable")
  1738. some_cache = ds.DatasetCache(session_id=session_id, size=0)
  1739. # This dataset has 12 records in it
  1740. data1 = ds.TFRecordDataset(PYFUNC_DATA_DIR, PYFUNC_SCHEMA_DIR, shuffle=False)
  1741. transforms = [py_vision.not_random(lambda x: x + x), py_vision.not_random(lambda x: x - 1)]
  1742. data1 = data1.map(operations=transforms, input_columns="col0", cache=some_cache)
  1743. num_iter = 0
  1744. for _ in data1.create_dict_iterator(num_epochs=1):
  1745. num_iter += 1
  1746. assert num_iter == 12
  1747. other_cache = ds.DatasetCache(session_id=session_id, size=0)
  1748. ds2 = ds.TFRecordDataset(PYFUNC_DATA_DIR, PYFUNC_SCHEMA_DIR, shuffle=False)
  1749. ds2 = ds2.map(operations=[(lambda x: x + x)], input_columns=["col0"], cache=other_cache)
  1750. with pytest.raises(RuntimeError) as e:
  1751. num_iter = 0
  1752. for _ in ds2.create_dict_iterator():
  1753. num_iter += 1
  1754. assert "MapNode containing random operation is not supported as a descendant of cache" in str(e.value)
  1755. logger.info("test_cache_nomap_pyfunc_lambda Ended.\n")
  1756. @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
  1757. def test_cache_nomap_pyfunc_builtin():
  1758. """
  1759. Test cache after map op with a python builtin PyFunc.
  1760. An error will be raised if the builtin pyfunc containing random operation.
  1761. Cache
  1762. |
  1763. Map([builtin pyfunc1, builtin pyfunc2])
  1764. |
  1765. TFRecord
  1766. """
  1767. logger.info("Test cache nomap pyfunc builtin")
  1768. if "SESSION_ID" in os.environ:
  1769. session_id = int(os.environ['SESSION_ID'])
  1770. else:
  1771. raise RuntimeError("Testcase requires SESSION_ID environment variable")
  1772. some_cache = ds.DatasetCache(session_id=session_id, size=0)
  1773. # This dataset has 3 records in it only
  1774. ds1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"])
  1775. ds1 = ds1.map(operations=[py_vision.Decode(), py_vision.ToTensor()], input_columns=["image"], cache=some_cache)
  1776. num_iter = 0
  1777. for _ in ds1.create_dict_iterator(num_epochs=1):
  1778. num_iter += 1
  1779. assert num_iter == 3
  1780. other_cache = ds.DatasetCache(session_id=session_id, size=0)
  1781. # This dataset has 3 records in it only
  1782. ds2 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"])
  1783. ds2 = ds2.map(operations=[py_vision.Decode(), py_vision.RandomCrop(224), py_vision.ToTensor()],
  1784. input_columns=["image"], cache=other_cache)
  1785. with pytest.raises(RuntimeError) as e:
  1786. num_iter = 0
  1787. for _ in ds2.create_dict_iterator():
  1788. num_iter += 1
  1789. assert "MapNode containing random operation is not supported as a descendant of cache" in str(e.value)
  1790. logger.info("test_cache_nomap_pyfunc_builtin Ended.\n")
  1791. @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
  1792. def test_cache_nomap_pyfunc_function():
  1793. """
  1794. Test cache after map op with a python customized function.
  1795. Only allowed if the function is decorated with 'py_vision.not_random', otherwise an error will be raised.
  1796. Cache
  1797. |
  1798. Map([function1, function2])
  1799. |
  1800. TFRecord
  1801. """
  1802. @py_vision.not_random
  1803. def not_random_func(x):
  1804. return np.ones(x.shape, dtype=x.dtype)
  1805. def normal_func(x):
  1806. return np.ones(x.shape, dtype=x.dtype)
  1807. logger.info("Test cache nomap pyfunc function")
  1808. if "SESSION_ID" in os.environ:
  1809. session_id = int(os.environ['SESSION_ID'])
  1810. else:
  1811. raise RuntimeError("Testcase requires SESSION_ID environment variable")
  1812. some_cache = ds.DatasetCache(session_id=session_id, size=0)
  1813. # This dataset has 3 records in it only
  1814. ds1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"])
  1815. ds1 = ds1.map(operations=[not_random_func, not_random_func], input_columns=["image"], cache=some_cache)
  1816. num_iter = 0
  1817. for _ in ds1.create_dict_iterator(num_epochs=1):
  1818. num_iter += 1
  1819. assert num_iter == 3
  1820. other_cache = ds.DatasetCache(session_id=session_id, size=0)
  1821. # This dataset has 3 records in it only
  1822. ds2 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"])
  1823. ds2 = ds2.map(operations=[not_random_func, normal_func], input_columns=["image"], cache=other_cache)
  1824. with pytest.raises(RuntimeError) as e:
  1825. num_iter = 0
  1826. for _ in ds2.create_dict_iterator():
  1827. num_iter += 1
  1828. assert "MapNode containing random operation is not supported as a descendant of cache" in str(e.value)
  1829. logger.info("test_cache_nomap_pyfunc_function Ended.\n")
  1830. @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
  1831. def test_cache_nomap_all_rows_cached():
  1832. """
  1833. Make sure all rows are cached before we switch to the fetching phase
  1834. Cache
  1835. |
  1836. RandomDataset
  1837. """
  1838. logger.info("Test cache nomap all rows cached")
  1839. if "SESSION_ID" in os.environ:
  1840. session_id = int(os.environ['SESSION_ID'])
  1841. else:
  1842. raise RuntimeError("Testcase requires SESSION_ID environment variable")
  1843. schema = ds.Schema()
  1844. schema.add_column('image', de_type=mstype.uint8,
  1845. shape=[450, 450, 3])
  1846. schema.add_column('label', de_type=mstype.uint8, shape=[1])
  1847. some_cache = ds.DatasetCache(session_id=session_id, size=0)
  1848. # easier to reproduce the problem with 271 total rows
  1849. num_total_rows = 271
  1850. # User-created sampler here
  1851. ds1 = ds.RandomDataset(schema=schema, total_rows=num_total_rows, num_parallel_workers=4, cache=some_cache)
  1852. iter1 = ds1.create_dict_iterator()
  1853. num_iter = 0
  1854. for _ in iter1:
  1855. num_iter += 1
  1856. logger.info("Number of data in ds1: {} ".format(num_iter))
  1857. assert num_iter == num_total_rows
  1858. cache_stat = some_cache.GetStat()
  1859. assert cache_stat.num_mem_cached == num_total_rows
  1860. logger.info("test_cache_nomap_all_rows_cached Ended.\n")
  1861. if __name__ == '__main__':
  1862. # This is just a list of tests, don't try to run these tests with 'python test_cache_nomap.py'
  1863. # since cache server is required to be brought up first
  1864. test_cache_nomap_basic1()
  1865. test_cache_nomap_basic2()
  1866. test_cache_nomap_basic3()
  1867. test_cache_nomap_basic4()
  1868. test_cache_nomap_basic5()
  1869. test_cache_nomap_basic6()
  1870. test_cache_nomap_basic7()
  1871. test_cache_nomap_basic8()
  1872. test_cache_nomap_basic9()
  1873. test_cache_nomap_allowed_share1()
  1874. test_cache_nomap_allowed_share2()
  1875. test_cache_nomap_allowed_share3()
  1876. test_cache_nomap_allowed_share4()
  1877. test_cache_nomap_disallowed_share1()
  1878. test_cache_nomap_running_twice1()
  1879. test_cache_nomap_running_twice2()
  1880. test_cache_nomap_extra_small_size1()
  1881. test_cache_nomap_extra_small_size2()
  1882. test_cache_nomap_parallel_pipeline1(shard=0)
  1883. test_cache_nomap_parallel_pipeline2(shard=1)
  1884. test_cache_nomap_parallel_workers()
  1885. test_cache_nomap_server_workers_1()
  1886. test_cache_nomap_server_workers_100()
  1887. test_cache_nomap_num_connections_1()
  1888. test_cache_nomap_num_connections_100()
  1889. test_cache_nomap_prefetch_size_1()
  1890. test_cache_nomap_prefetch_size_100()
  1891. test_cache_nomap_to_device()
  1892. test_cache_nomap_session_destroy()
  1893. test_cache_nomap_server_stop()
  1894. test_cache_nomap_epoch_ctrl1()
  1895. test_cache_nomap_epoch_ctrl2()
  1896. test_cache_nomap_epoch_ctrl3()
  1897. test_cache_nomap_epoch_ctrl4()
  1898. test_cache_nomap_multiple_cache1()
  1899. test_cache_nomap_multiple_cache2()
  1900. test_cache_nomap_multiple_cache3()
  1901. test_cache_nomap_multiple_cache_train()
  1902. test_cache_nomap_multiple_cache_eval()
  1903. test_cache_nomap_clue1()
  1904. test_cache_nomap_clue2()
  1905. test_cache_nomap_csv1()
  1906. test_cache_nomap_csv2()
  1907. test_cache_nomap_textfile1()
  1908. test_cache_nomap_textfile2()
  1909. test_cache_nomap_nested_repeat()
  1910. test_cache_nomap_get_repeat_count()
  1911. test_cache_nomap_long_file_list()
  1912. test_cache_nomap_failure1()
  1913. test_cache_nomap_failure2()
  1914. test_cache_nomap_failure3()
  1915. test_cache_nomap_failure4()
  1916. test_cache_nomap_failure5()
  1917. test_cache_nomap_pyfunc_lambda()
  1918. test_cache_nomap_pyfunc_builtin()
  1919. test_cache_nomap_pyfunc_function()