You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_cache_nomap.py 63 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509151015111512151315141515151615171518151915201521152215231524152515261527152815291530153115321533153415351536153715381539154015411542154315441545154615471548154915501551155215531554155515561557155815591560156115621563156415651566156715681569157015711572157315741575157615771578157915801581158215831584158515861587158815891590159115921593159415951596159715981599160016011602160316041605160616071608160916101611161216131614161516161617161816191620162116221623162416251626162716281629163016311632163316341635163616371638163916401641164216431644164516461647164816491650165116521653165416551656165716581659166016611662166316641665166616671668166916701671167216731674167516761677167816791680168116821683168416851686168716881689169016911692169316941695169616971698169917001701170217031704170517061707170817091710171117121713171417151716171717181719172017211722172317241725172617271728172917301731173217331734173517361737173817391740174117421743174417451746174717481749175017511752175317541755175617571758175917601761176217631764176517661767176817691770177117721773177417751776177717781779178017811782178317841785178617871788178917901791179217931794179517961797179817991800180118021803180418051806180718081809181018111812181318141815181618171818181918201821182218231824182518261827182818291830183118321833183418351836183718381839184018411842184318441845184618471848184918501851185218531854185518561857185818591860
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. """
  16. Testing cache operator with non-mappable datasets
  17. """
  18. import os
  19. import itertools
  20. import pytest
  21. import mindspore.common.dtype as mstype
  22. import mindspore.dataset as ds
  23. import mindspore.dataset.text as text
  24. import mindspore.dataset.vision.c_transforms as c_vision
  25. from mindspore import log as logger
  26. DATA_DIR = ["../data/dataset/test_tf_file_3_images/train-0000-of-0001.data"]
  27. SCHEMA_DIR = "../data/dataset/test_tf_file_3_images/datasetSchema.json"
  28. TEXT_TF_DATA_DIR = ["../data/dataset/testTextTFRecord/text.tfrecord"]
  29. SCHEMA_DIR2 = "../data/dataset/testTextTFRecord/datasetSchema.json"
  30. TRAIN_DATA_DIR = ["../data/dataset/test_tf_file_3_images2/train-0000-of-0001.data",
  31. "../data/dataset/test_tf_file_3_images2/train-0000-of-0002.data",
  32. "../data/dataset/test_tf_file_3_images2/train-0000-of-0003.data",
  33. "../data/dataset/test_tf_file_3_images2/train-0000-of-0004.data"]
  34. TRAIN_SCHEMA_DIR = "../data/dataset/test_tf_file_3_images2/datasetSchema.json"
  35. IMAGE_FOLDER_DATA_DIR = "../data/dataset/testImageNetData/train/"
  36. CLUE_DATA_DIR = '../data/dataset/testCLUE/afqmc/train.json'
  37. CSV_DATA_DIR = '../data/dataset/testCSV/1.csv'
  38. TEXT_FILE_DATA_DIR = "../data/dataset/testTextFileDataset/1.txt"
  39. GENERATE_GOLDEN = False
  40. @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
  41. def test_cache_nomap_basic1():
  42. """
  43. A random dataset (a non mappable dataset) with a cache over it just after the leaf
  44. """
  45. logger.info("Test cache nomap basic 1")
  46. if "SESSION_ID" in os.environ:
  47. session_id = int(os.environ['SESSION_ID'])
  48. else:
  49. raise RuntimeError("Testcase requires SESSION_ID environment variable")
  50. schema = ds.Schema()
  51. schema.add_column('image', de_type=mstype.uint8,
  52. shape=[640, 480, 3]) # 921600 bytes (a bit less than 1 MB per image)
  53. schema.add_column('label', de_type=mstype.uint8, shape=[1])
  54. # create a cache. arbitrary session_id for now
  55. some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True)
  56. # User-created sampler here
  57. ds1 = ds.RandomDataset(schema=schema, total_rows=10, num_parallel_workers=4, cache=some_cache)
  58. ds1 = ds1.repeat(4)
  59. num_iter = 0
  60. for data in ds1.create_dict_iterator(num_epochs=1):
  61. logger.info("printing the label: {}".format(data["label"]))
  62. num_iter += 1
  63. logger.info("Number of data in ds1: {} ".format(num_iter))
  64. assert num_iter == 40
  65. logger.info("test_cache_nomap_basic1 Ended.\n")
  66. @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
  67. def test_cache_nomap_basic2():
  68. """
  69. A random dataset (a non mappable dataset) with a cache over it just after the leaf
  70. """
  71. logger.info("Test cache nomap basic 2")
  72. if "SESSION_ID" in os.environ:
  73. session_id = int(os.environ['SESSION_ID'])
  74. else:
  75. raise RuntimeError("Testcase requires SESSION_ID environment variable")
  76. schema = ds.Schema()
  77. schema.add_column('image', de_type=mstype.uint8,
  78. shape=[640, 480, 3]) # 921600 bytes (a bit less than 1 MB per image)
  79. schema.add_column('label', de_type=mstype.uint8, shape=[1])
  80. # create a cache. arbitrary session_id for now
  81. some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True)
  82. # sampler arg not given directly, however any of these args will auto-generate an appropriate sampler:
  83. # num_samples, shuffle, num_shards, shard_id
  84. # In this case, the presence of num_samples chooses a sampler.
  85. ds1 = ds.RandomDataset(schema=schema, total_rows=20, num_samples=20, num_parallel_workers=4, cache=some_cache)
  86. ds1 = ds1.repeat(2)
  87. num_iter = 0
  88. for data in ds1.create_dict_iterator(num_epochs=1):
  89. logger.info("printing the label: {}".format(data["label"]))
  90. num_iter += 1
  91. logger.info("Number of data in ds1: {} ".format(num_iter))
  92. assert num_iter == 40
  93. logger.info("test_cache_nomap_basic2 Ended.\n")
  94. @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
  95. def test_cache_nomap_basic3():
  96. """
  97. A TF reader dataset (a non mappable dataset) with a cache over it just after the leaf
  98. Repeat
  99. |
  100. Map(decode)
  101. |
  102. Cache
  103. |
  104. TFReader
  105. """
  106. logger.info("Test cache nomap basic 3")
  107. if "SESSION_ID" in os.environ:
  108. session_id = int(os.environ['SESSION_ID'])
  109. else:
  110. raise RuntimeError("Testcase requires SESSION_ID environment variable")
  111. some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True)
  112. ds1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False, cache=some_cache)
  113. decode_op = c_vision.Decode()
  114. ds1 = ds1.map(operations=decode_op, input_columns=["image"])
  115. ds1 = ds1.repeat(4)
  116. num_iter = 0
  117. for _ in ds1.create_dict_iterator(num_epochs=1):
  118. num_iter += 1
  119. logger.info("Number of data in ds1: {} ".format(num_iter))
  120. assert num_iter == 12
  121. # Contact the server to get the statistics
  122. stat = some_cache.GetStat()
  123. cache_sz = stat.avg_cache_sz
  124. num_mem_cached = stat.num_mem_cached
  125. num_disk_cached = stat.num_disk_cached
  126. logger.info("Number of rows cached in memory: {}".format(num_mem_cached))
  127. logger.info("Number of rows spilled to disk: {}".format(num_disk_cached))
  128. logger.info("Average row cache size: {}".format(cache_sz))
  129. logger.info("test_cache_nomap_basic3 Ended.\n")
  130. @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
  131. def test_cache_nomap_basic4():
  132. """
  133. A TF reader dataset (a non mappable dataset) with a map decode and cache after it
  134. Since a global shuffle is used for the tf reader, it will inject a shuffle op over the tf.
  135. But, if there's a cache later, that shuffle becomes invalid and should be removed.
  136. Repeat
  137. |
  138. Cache
  139. |
  140. Map(decode)
  141. |
  142. TFReader
  143. """
  144. logger.info("Test cache nomap basic 4")
  145. if "SESSION_ID" in os.environ:
  146. session_id = int(os.environ['SESSION_ID'])
  147. else:
  148. raise RuntimeError("Testcase requires SESSION_ID environment variable")
  149. # This dataset has 3 records in it only
  150. some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True)
  151. # With shuffle not being set, TF defaults to a "global" shuffle when there is no cache
  152. # in the picture. This causes a shuffle-injection over the TF. For clarify, this test will
  153. # explicitly give the global option, even though it's the default in python.
  154. # But, when caching is added in the ascendent tree above TF, we do global shuffling
  155. # through the sampler over the cache, not by the shuffle op. In that case, tree prepare
  156. # will remove the shuffle op that got injected by the initial tree creation.
  157. ds1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=ds.Shuffle.GLOBAL)
  158. decode_op = c_vision.Decode()
  159. ds1 = ds1.map(operations=decode_op, input_columns=["image"], cache=some_cache)
  160. ds1 = ds1.repeat(4)
  161. num_iter = 0
  162. for _ in ds1.create_dict_iterator(num_epochs=1):
  163. num_iter += 1
  164. logger.info("Number of data in ds1: {} ".format(num_iter))
  165. assert num_iter == 12
  166. logger.info("test_cache_nomap_basic4 Ended.\n")
  167. @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
  168. def test_cache_nomap_basic5():
  169. """
  170. A TF reader dataset (a non mappable dataset) with a cache over it just after the leaf
  171. Same as test 3, but this one does not have shuffle arg, causing tf to default to global
  172. shuffle which attempts to inject a shuffle operator. However, since there is a cache
  173. we do not need global shuffle, so the shuffle will not be built. It ends up being
  174. identical to test basic 3, however we arrive at the same tree in different codepaths
  175. (if there was no cache, then the shuffle IS built)
  176. Repeat
  177. |
  178. Map(decode)
  179. |
  180. Cache
  181. |
  182. TFReader
  183. """
  184. logger.info("Test cache nomap basic 5")
  185. if "SESSION_ID" in os.environ:
  186. session_id = int(os.environ['SESSION_ID'])
  187. else:
  188. raise RuntimeError("Testcase requires SESSION_ID environment variable")
  189. # This dataset has 3 records in it only
  190. some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True)
  191. ds1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], cache=some_cache)
  192. decode_op = c_vision.Decode()
  193. ds1 = ds1.map(operations=decode_op, input_columns=["image"])
  194. ds1 = ds1.repeat(4)
  195. num_iter = 0
  196. for _ in ds1.create_dict_iterator(num_epochs=1):
  197. num_iter += 1
  198. logger.info("Number of data in ds1: {} ".format(num_iter))
  199. assert num_iter == 12
  200. logger.info("test_cache_nomap_basic5 Ended.\n")
  201. @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
  202. def test_cache_nomap_basic6():
  203. """
  204. A TF reader dataset (a non mappable dataset) with a cache over it just after the leaf
  205. In this one, the tf dataset will be given sharding configuration, however since a cache is
  206. used, the tree prepare should undo the sharding configuration and instead, a distributed
  207. sampler will be chosen with the same shard config.
  208. Repeat
  209. |
  210. Map(decode)
  211. |
  212. Cache
  213. |
  214. TFReader
  215. """
  216. logger.info("Test cache nomap basic 6")
  217. if "SESSION_ID" in os.environ:
  218. session_id = int(os.environ['SESSION_ID'])
  219. else:
  220. raise RuntimeError("Testcase requires SESSION_ID environment variable")
  221. # This dataset has 3 records in it only
  222. some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True)
  223. # With only 3 records shard into 3, we expect only 1 record returned for this shard
  224. # However, the sharding will be done by the sampler, not by the tf record leaf node
  225. # In this case, it is a row-based sharding, not the file-based sharding that would happen if
  226. # there was not any cache.
  227. ds1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], num_shards=3, shard_id=1, cache=some_cache)
  228. decode_op = c_vision.Decode()
  229. ds1 = ds1.map(operations=decode_op, input_columns=["image"])
  230. ds1 = ds1.repeat(4)
  231. num_iter = 0
  232. for _ in ds1.create_dict_iterator(num_epochs=1):
  233. num_iter += 1
  234. logger.info("Number of data in ds1: {} ".format(num_iter))
  235. assert num_iter == 4
  236. logger.info("test_cache_nomap_basic6 Ended.\n")
  237. @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
  238. def test_cache_nomap_basic7():
  239. """
  240. A TF reader dataset (a non mappable dataset) that uses global shuffle, and is cached followed by
  241. map.
  242. In this one, the tf dataset with global shuffle might want to inject a shuffle op over top of the
  243. tf reader, but since a cache is given, it will choose not to.
  244. Repeat
  245. |
  246. Map(decode)
  247. |
  248. cache
  249. |
  250. TFReader
  251. """
  252. logger.info("Test cache nomap basic 7")
  253. if "SESSION_ID" in os.environ:
  254. session_id = int(os.environ['SESSION_ID'])
  255. else:
  256. raise RuntimeError("Testcase requires SESSION_ID environment variable")
  257. some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True)
  258. # This dataset has 3 records in it only
  259. ds1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=ds.Shuffle.GLOBAL, cache=some_cache)
  260. decode_op = c_vision.Decode()
  261. ds1 = ds1.map(operations=decode_op, input_columns=["image"])
  262. ds1 = ds1.repeat(4)
  263. num_iter = 0
  264. for _ in ds1.create_dict_iterator(num_epochs=1):
  265. num_iter += 1
  266. logger.info("Number of data in ds1: {} ".format(num_iter))
  267. assert num_iter == 12
  268. logger.info("test_cache_nomap_basic7 Ended.\n")
  269. @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
  270. def test_cache_nomap_basic8():
  271. """
  272. Test cache as root node
  273. cache
  274. |
  275. TFReader
  276. """
  277. logger.info("Test cache basic 8")
  278. if "SESSION_ID" in os.environ:
  279. session_id = int(os.environ['SESSION_ID'])
  280. else:
  281. raise RuntimeError("Testcase requires SESSION_ID environment variable")
  282. some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True)
  283. # This dataset has 3 records in it only
  284. ds1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, cache=some_cache)
  285. num_iter = 0
  286. for _ in ds1.create_dict_iterator(num_epochs=1):
  287. logger.info("get data from dataset")
  288. num_iter += 1
  289. logger.info("Number of data in ds1: {} ".format(num_iter))
  290. assert num_iter == 3
  291. logger.info('test_cache_basic8 Ended.\n')
  292. @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
  293. def test_cache_nomap_basic9():
  294. """
  295. Testing the GetStat interface for getting some info from server, but this should fail if the cache is not created
  296. in a pipeline.
  297. """
  298. logger.info("Test cache nomap basic 9")
  299. if "SESSION_ID" in os.environ:
  300. session_id = int(os.environ['SESSION_ID'])
  301. else:
  302. raise RuntimeError("Testcase requires SESSION_ID environment variable")
  303. some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True)
  304. # Contact the server to get the statistics, this should fail because we have not used this cache in any pipeline
  305. # so there will not be any cache to get stats on.
  306. with pytest.raises(RuntimeError) as e:
  307. stat = some_cache.GetStat()
  308. cache_sz = stat.avg_cache_sz
  309. logger.info("Average row cache size: {}".format(cache_sz))
  310. assert "Unexpected error" in str(e.value)
  311. logger.info("test_cache_nomap_basic9 Ended.\n")
  312. @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
  313. def test_cache_nomap_allowed_share1():
  314. """
  315. It is allowed to share the cache between the following two trees:
  316. Repeat Shuffle
  317. | |
  318. Cache Cache
  319. | |
  320. TFReader TFReader
  321. """
  322. logger.info("Test cache nomap allowed share 1")
  323. if "SESSION_ID" in os.environ:
  324. session_id = int(os.environ['SESSION_ID'])
  325. else:
  326. raise RuntimeError("Testcase requires SESSION_ID environment variable")
  327. ds.config.set_seed(1)
  328. # This dataset has 3 records in it only
  329. some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True, prefetch_size=32)
  330. ds1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False, cache=some_cache)
  331. ds1 = ds1.repeat(4)
  332. ds2 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False, cache=some_cache)
  333. ds2 = ds2.shuffle(buffer_size=2)
  334. num_iter = 0
  335. for _ in ds1.create_dict_iterator(num_epochs=1):
  336. num_iter += 1
  337. assert num_iter == 12
  338. logger.info("Number of data in ds1: {} ".format(num_iter))
  339. num_iter = 0
  340. for _ in ds2.create_dict_iterator(num_epochs=1):
  341. num_iter += 1
  342. assert num_iter == 3
  343. logger.info("test_cache_nomap_allowed_share1 Ended.\n")
  344. @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
  345. def test_cache_nomap_allowed_share2():
  346. """
  347. It is allowed to share the cache between the following two trees (with map decode):
  348. Repeat Shuffle
  349. | |
  350. Cache Cache
  351. | |
  352. Map(decode) Map(decode)
  353. | |
  354. TFReader TFReader
  355. """
  356. logger.info("Test cache nomap allowed share 2")
  357. if "SESSION_ID" in os.environ:
  358. session_id = int(os.environ['SESSION_ID'])
  359. else:
  360. raise RuntimeError("Testcase requires SESSION_ID environment variable")
  361. ds.config.set_seed(1)
  362. # This dataset has 3 records in it only
  363. some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True)
  364. decode_op = c_vision.Decode()
  365. ds1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False)
  366. ds1 = ds1.map(operations=decode_op, input_columns=["image"], cache=some_cache)
  367. ds1 = ds1.repeat(4)
  368. ds2 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False)
  369. ds2 = ds2.map(operations=decode_op, input_columns=["image"], cache=some_cache)
  370. ds2 = ds2.shuffle(buffer_size=2)
  371. num_iter = 0
  372. for _ in ds1.create_dict_iterator(num_epochs=1):
  373. num_iter += 1
  374. logger.info("Number of data in ds1: {} ".format(num_iter))
  375. assert num_iter == 12
  376. num_iter = 0
  377. for _ in ds2.create_dict_iterator(num_epochs=1):
  378. num_iter += 1
  379. assert num_iter == 3
  380. logger.info("test_cache_nomap_allowed_share2 Ended.\n")
  381. @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
  382. def test_cache_nomap_allowed_share3():
  383. """
  384. It is allowed to share the cache between the following two trees (different shard ids):
  385. Repeat Repeat
  386. | |
  387. Cache Cache
  388. | |
  389. TFReader(shard_id = 0) TFReader(shard_id = 1)
  390. """
  391. logger.info("Test cache nomap allowed share 3")
  392. if "SESSION_ID" in os.environ:
  393. session_id = int(os.environ['SESSION_ID'])
  394. else:
  395. raise RuntimeError("Testcase requires SESSION_ID environment variable")
  396. some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True)
  397. tf_files = ["../data/dataset/tf_file_dataset/test1.data", "../data/dataset/tf_file_dataset/test2.data"]
  398. ds1 = ds.TFRecordDataset(tf_files, num_shards=2, shard_id=0, num_samples=3, shuffle=False, cache=some_cache)
  399. ds1 = ds1.repeat(4)
  400. ds2 = ds.TFRecordDataset(tf_files, num_shards=2, shard_id=1, num_samples=3, shuffle=False, cache=some_cache)
  401. ds2 = ds2.repeat(4)
  402. num_iter = 0
  403. for _ in ds1.create_dict_iterator(num_epochs=1):
  404. num_iter += 1
  405. logger.info("Number of data in ds1: {} ".format(num_iter))
  406. assert num_iter == 12
  407. num_iter = 0
  408. for _ in ds2.create_dict_iterator(num_epochs=1):
  409. num_iter += 1
  410. assert num_iter == 12
  411. logger.info("test_cache_nomap_allowed_share3 Ended.\n")
  412. @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
  413. def test_cache_nomap_allowed_share4():
  414. """
  415. It is allowed to share the cache between the following two trees:
  416. Cache Cache
  417. | |
  418. Map(decode, num_parallel_workers=1) Map(decode, num_parallel_workers=2)
  419. | |
  420. TFReader TFReader
  421. """
  422. logger.info("Test cache nomap allowed share 4")
  423. if "SESSION_ID" in os.environ:
  424. session_id = int(os.environ['SESSION_ID'])
  425. else:
  426. raise RuntimeError("Testcase requires SESSION_ID environment variable")
  427. # This dataset has 3 records in it only
  428. some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True)
  429. decode_op = c_vision.Decode()
  430. ds1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False)
  431. ds1 = ds1.map(operations=decode_op, input_columns=["image"], cache=some_cache, num_parallel_workers=1)
  432. ds2 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False)
  433. ds2 = ds2.map(operations=decode_op, input_columns=["image"], cache=some_cache, num_parallel_workers=2)
  434. num_iter = 0
  435. for _ in ds1.create_dict_iterator(num_epochs=1):
  436. num_iter += 1
  437. logger.info("Number of data in ds1: {} ".format(num_iter))
  438. assert num_iter == 3
  439. num_iter = 0
  440. for _ in ds2.create_dict_iterator(num_epochs=1):
  441. num_iter += 1
  442. logger.info("Number of data in ds2: {} ".format(num_iter))
  443. assert num_iter == 3
  444. logger.info("test_cache_nomap_allowed_share4 Ended.\n")
  445. @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
  446. def test_cache_nomap_disallowed_share1():
  447. """
  448. It is not allowed to share the cache between the following two trees:
  449. Cache Cache
  450. | |
  451. Map(decode) Map(rescale)
  452. | |
  453. TFReader TFReader
  454. """
  455. logger.info("Test cache nomap disallowed share1")
  456. if "SESSION_ID" in os.environ:
  457. session_id = int(os.environ['SESSION_ID'])
  458. else:
  459. raise RuntimeError("Testcase requires SESSION_ID environment variable")
  460. # This dataset has 3 records in it only
  461. some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True)
  462. decode_op = c_vision.Decode()
  463. rescale_op = c_vision.Rescale(1.0 / 255.0, -1.0)
  464. ds1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False)
  465. ds1 = ds1.map(operations=decode_op, input_columns=["image"], cache=some_cache)
  466. ds2 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False)
  467. ds2 = ds2.map(operations=rescale_op, input_columns=["image"], cache=some_cache)
  468. num_iter = 0
  469. for _ in ds1.create_dict_iterator(num_epochs=1):
  470. num_iter += 1
  471. logger.info("Number of data in ds1: {} ".format(num_iter))
  472. assert num_iter == 3
  473. with pytest.raises(RuntimeError) as e:
  474. sum([1 for _ in ds2])
  475. assert "Attempt to re-use a cache for a different tree!" in str(e.value)
  476. logger.info("test_cache_nomap_disallowed_share1 Ended.\n")
  477. @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
  478. def test_cache_nomap_running_twice1():
  479. """
  480. Executing the same pipeline for twice (from python), with cache injected after map
  481. Repeat
  482. |
  483. Cache
  484. |
  485. Map(decode)
  486. |
  487. TFRecord
  488. """
  489. logger.info("Test cache nomap running twice 1")
  490. if "SESSION_ID" in os.environ:
  491. session_id = int(os.environ['SESSION_ID'])
  492. else:
  493. raise RuntimeError("Testcase requires SESSION_ID environment variable")
  494. some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True)
  495. # This dataset has 3 records in it only
  496. ds1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR)
  497. decode_op = c_vision.Decode()
  498. ds1 = ds1.map(input_columns=["image"], operations=decode_op, cache=some_cache)
  499. ds1 = ds1.repeat(4)
  500. num_iter = 0
  501. for _ in ds1.create_dict_iterator():
  502. num_iter += 1
  503. logger.info("Number of data in ds1: {} ".format(num_iter))
  504. assert num_iter == 12
  505. num_iter = 0
  506. for _ in ds1.create_dict_iterator():
  507. num_iter += 1
  508. logger.info("Number of data in ds1: {} ".format(num_iter))
  509. assert num_iter == 12
  510. logger.info("test_cache_nomap_running_twice1 Ended.\n")
  511. @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
  512. def test_cache_nomap_running_twice2():
  513. """
  514. Executing the same pipeline for twice (from shell), with cache injected after leaf
  515. Repeat
  516. |
  517. Map(decode)
  518. |
  519. Cache
  520. |
  521. TFRecord
  522. """
  523. logger.info("Test cache nomap running twice 2")
  524. if "SESSION_ID" in os.environ:
  525. session_id = int(os.environ['SESSION_ID'])
  526. else:
  527. raise RuntimeError("Testcase requires SESSION_ID environment variable")
  528. some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True)
  529. # This dataset has 3 records in it only
  530. ds1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, cache=some_cache)
  531. decode_op = c_vision.Decode()
  532. ds1 = ds1.map(input_columns=["image"], operations=decode_op)
  533. ds1 = ds1.repeat(4)
  534. num_iter = 0
  535. for _ in ds1.create_dict_iterator():
  536. num_iter += 1
  537. logger.info("Number of data in ds1: {} ".format(num_iter))
  538. assert num_iter == 12
  539. logger.info("test_cache_nomap_running_twice2 Ended.\n")
  540. @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
  541. def test_cache_nomap_extra_small_size1():
  542. """
  543. Test running pipeline with cache of extra small size and spilling true
  544. Repeat
  545. |
  546. Map(decode)
  547. |
  548. Cache
  549. |
  550. TFRecord
  551. """
  552. logger.info("Test cache nomap extra small size 1")
  553. if "SESSION_ID" in os.environ:
  554. session_id = int(os.environ['SESSION_ID'])
  555. else:
  556. raise RuntimeError("Testcase requires SESSION_ID environment variable")
  557. some_cache = ds.DatasetCache(session_id=session_id, size=1, spilling=True)
  558. # This dataset has 3 records in it only
  559. ds1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, cache=some_cache)
  560. decode_op = c_vision.Decode()
  561. ds1 = ds1.map(input_columns=["image"], operations=decode_op)
  562. ds1 = ds1.repeat(4)
  563. num_iter = 0
  564. for _ in ds1.create_dict_iterator():
  565. num_iter += 1
  566. logger.info("Number of data in ds1: {} ".format(num_iter))
  567. assert num_iter == 12
  568. logger.info("test_cache_nomap_extra_small_size1 Ended.\n")
  569. @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
  570. def test_cache_nomap_extra_small_size2():
  571. """
  572. Test running pipeline with cache of extra small size and spilling false (failure)
  573. Repeat
  574. |
  575. Cache
  576. |
  577. Map(decode)
  578. |
  579. TFRecord
  580. """
  581. logger.info("Test cache nomap extra small size 2")
  582. if "SESSION_ID" in os.environ:
  583. session_id = int(os.environ['SESSION_ID'])
  584. else:
  585. raise RuntimeError("Testcase requires SESSION_ID environment variable")
  586. some_cache = ds.DatasetCache(session_id=session_id, size=1, spilling=False)
  587. # This dataset has 3 records in it only
  588. ds1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR)
  589. decode_op = c_vision.Decode()
  590. ds1 = ds1.map(input_columns=["image"], operations=decode_op, cache=some_cache)
  591. ds1 = ds1.repeat(4)
  592. with pytest.raises(RuntimeError) as e:
  593. sum([1 for _ in ds1])
  594. assert "Out of memory" in str(e.value)
  595. logger.info("test_cache_nomap_extra_small_size2 Ended.\n")
  596. @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
  597. def test_cache_nomap_parallel_pipeline1(shard):
  598. """
  599. Test running two parallel pipelines (sharing cache) with cache injected after leaf op
  600. Repeat
  601. |
  602. Map(decode)
  603. |
  604. cache
  605. |
  606. TFReader
  607. """
  608. logger.info("Test cache nomap parallel pipeline 1")
  609. if "SESSION_ID" in os.environ:
  610. session_id = int(os.environ['SESSION_ID'])
  611. else:
  612. raise RuntimeError("Testcase requires SESSION_ID environment variable")
  613. some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True)
  614. # This dataset has 3 records in it only
  615. ds1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, num_shards=3, shard_id=int(shard), cache=some_cache)
  616. decode_op = c_vision.Decode()
  617. ds1 = ds1.map(input_columns=["image"], operations=decode_op)
  618. ds1 = ds1.repeat(4)
  619. num_iter = 0
  620. for _ in ds1.create_dict_iterator(num_epochs=1):
  621. num_iter += 1
  622. logger.info("Number of data in ds1: {} ".format(num_iter))
  623. assert num_iter == 4
  624. logger.info("test_cache_nomap_parallel_pipeline1 Ended.\n")
  625. @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
  626. def test_cache_nomap_parallel_pipeline2(shard):
  627. """
  628. Test running two parallel pipelines (sharing cache) with cache injected after map op
  629. Repeat
  630. |
  631. cache
  632. |
  633. Map(decode)
  634. |
  635. TFReader
  636. """
  637. logger.info("Test cache nomap parallel pipeline 2")
  638. if "SESSION_ID" in os.environ:
  639. session_id = int(os.environ['SESSION_ID'])
  640. else:
  641. raise RuntimeError("Testcase requires SESSION_ID environment variable")
  642. some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True)
  643. # This dataset has 3 records in it only
  644. ds1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, num_shards=3, shard_id=int(shard))
  645. decode_op = c_vision.Decode()
  646. ds1 = ds1.map(input_columns=["image"], operations=decode_op, cache=some_cache)
  647. ds1 = ds1.repeat(4)
  648. num_iter = 0
  649. for _ in ds1.create_dict_iterator(num_epochs=1):
  650. num_iter += 1
  651. logger.info("Number of data in ds1: {} ".format(num_iter))
  652. assert num_iter == 4
  653. logger.info("test_cache_nomap_parallel_pipeline2 Ended.\n")
  654. @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
  655. def test_cache_nomap_parallel_workers():
  656. """
  657. Test cache with num_parallel_workers > 1 set for map op and leaf op
  658. Repeat
  659. |
  660. Map(decode)
  661. |
  662. cache
  663. |
  664. TFReader
  665. """
  666. logger.info("Test cache nomap parallel workers")
  667. if "SESSION_ID" in os.environ:
  668. session_id = int(os.environ['SESSION_ID'])
  669. else:
  670. raise RuntimeError("Testcase requires SESSION_ID environment variable")
  671. some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True)
  672. # This dataset has 3 records in it only
  673. ds1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, num_parallel_workers=4)
  674. decode_op = c_vision.Decode()
  675. ds1 = ds1.map(input_columns=["image"], operations=decode_op, num_parallel_workers=4, cache=some_cache)
  676. ds1 = ds1.repeat(4)
  677. num_iter = 0
  678. for _ in ds1.create_dict_iterator(num_epochs=1):
  679. num_iter += 1
  680. logger.info("Number of data in ds1: {} ".format(num_iter))
  681. assert num_iter == 12
  682. logger.info("test_cache_nomap_parallel_workers Ended.\n")
  683. @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
  684. def test_cache_nomap_server_workers_1():
  685. """
  686. start cache server with --workers 1 and then test cache function
  687. Repeat
  688. |
  689. cache
  690. |
  691. Map(decode)
  692. |
  693. TFRecord
  694. """
  695. logger.info("Test cache nomap server workers 1")
  696. if "SESSION_ID" in os.environ:
  697. session_id = int(os.environ['SESSION_ID'])
  698. else:
  699. raise RuntimeError("Testcase requires SESSION_ID environment variable")
  700. some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True)
  701. # This dataset has 3 records in it only
  702. ds1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR)
  703. decode_op = c_vision.Decode()
  704. ds1 = ds1.map(input_columns=["image"], operations=decode_op, cache=some_cache)
  705. ds1 = ds1.repeat(4)
  706. num_iter = 0
  707. for _ in ds1.create_dict_iterator():
  708. num_iter += 1
  709. logger.info("Number of data in ds1: {} ".format(num_iter))
  710. assert num_iter == 12
  711. logger.info("test_cache_nomap_server_workers_1 Ended.\n")
  712. @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
  713. def test_cache_nomap_server_workers_100():
  714. """
  715. start cache server with --workers 100 and then test cache function
  716. Repeat
  717. |
  718. Map(decode)
  719. |
  720. cache
  721. |
  722. TFRecord
  723. """
  724. logger.info("Test cache nomap server workers 100")
  725. if "SESSION_ID" in os.environ:
  726. session_id = int(os.environ['SESSION_ID'])
  727. else:
  728. raise RuntimeError("Testcase requires SESSION_ID environment variable")
  729. some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True)
  730. # This dataset has 3 records in it only
  731. ds1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, cache=some_cache)
  732. decode_op = c_vision.Decode()
  733. ds1 = ds1.map(input_columns=["image"], operations=decode_op)
  734. ds1 = ds1.repeat(4)
  735. num_iter = 0
  736. for _ in ds1.create_dict_iterator():
  737. num_iter += 1
  738. logger.info("Number of data in ds1: {} ".format(num_iter))
  739. assert num_iter == 12
  740. logger.info("test_cache_nomap_server_workers_100 Ended.\n")
  741. @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
  742. def test_cache_nomap_num_connections_1():
  743. """
  744. Test setting num_connections=1 in DatasetCache
  745. Repeat
  746. |
  747. cache
  748. |
  749. Map(decode)
  750. |
  751. TFRecord
  752. """
  753. logger.info("Test cache nomap num_connections 1")
  754. if "SESSION_ID" in os.environ:
  755. session_id = int(os.environ['SESSION_ID'])
  756. else:
  757. raise RuntimeError("Testcase requires SESSION_ID environment variable")
  758. some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True, num_connections=1)
  759. # This dataset has 3 records in it only
  760. ds1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR)
  761. decode_op = c_vision.Decode()
  762. ds1 = ds1.map(input_columns=["image"], operations=decode_op, cache=some_cache)
  763. ds1 = ds1.repeat(4)
  764. num_iter = 0
  765. for _ in ds1.create_dict_iterator():
  766. num_iter += 1
  767. logger.info("Number of data in ds1: {} ".format(num_iter))
  768. assert num_iter == 12
  769. logger.info("test_cache_nomap_num_connections_1 Ended.\n")
  770. @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
  771. def test_cache_nomap_num_connections_100():
  772. """
  773. Test setting num_connections=100 in DatasetCache
  774. Repeat
  775. |
  776. Map(decode)
  777. |
  778. cache
  779. |
  780. TFRecord
  781. """
  782. logger.info("Test cache nomap num_connections 100")
  783. if "SESSION_ID" in os.environ:
  784. session_id = int(os.environ['SESSION_ID'])
  785. else:
  786. raise RuntimeError("Testcase requires SESSION_ID environment variable")
  787. some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True, num_connections=100)
  788. # This dataset has 3 records in it only
  789. ds1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, cache=some_cache)
  790. decode_op = c_vision.Decode()
  791. ds1 = ds1.map(input_columns=["image"], operations=decode_op)
  792. ds1 = ds1.repeat(4)
  793. num_iter = 0
  794. for _ in ds1.create_dict_iterator():
  795. num_iter += 1
  796. logger.info("Number of data in ds1: {} ".format(num_iter))
  797. assert num_iter == 12
  798. logger.info("test_cache_nomap_num_connections_100 Ended.\n")
  799. @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
  800. def test_cache_nomap_prefetch_size_1():
  801. """
  802. Test setting prefetch_size=1 in DatasetCache
  803. Repeat
  804. |
  805. cache
  806. |
  807. Map(decode)
  808. |
  809. TFRecord
  810. """
  811. logger.info("Test cache nomap prefetch_size 1")
  812. if "SESSION_ID" in os.environ:
  813. session_id = int(os.environ['SESSION_ID'])
  814. else:
  815. raise RuntimeError("Testcase requires SESSION_ID environment variable")
  816. some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True, prefetch_size=1)
  817. # This dataset has 3 records in it only
  818. ds1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR)
  819. decode_op = c_vision.Decode()
  820. ds1 = ds1.map(input_columns=["image"], operations=decode_op, cache=some_cache)
  821. ds1 = ds1.repeat(4)
  822. num_iter = 0
  823. for _ in ds1.create_dict_iterator():
  824. num_iter += 1
  825. logger.info("Number of data in ds1: {} ".format(num_iter))
  826. assert num_iter == 12
  827. logger.info("test_cache_nomap_prefetch_size_1 Ended.\n")
  828. @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
  829. def test_cache_nomap_prefetch_size_100():
  830. """
  831. Test setting prefetch_size=100 in DatasetCache
  832. Repeat
  833. |
  834. Map(decode)
  835. |
  836. cache
  837. |
  838. TFRecord
  839. """
  840. logger.info("Test cache nomap prefetch_size 100")
  841. if "SESSION_ID" in os.environ:
  842. session_id = int(os.environ['SESSION_ID'])
  843. else:
  844. raise RuntimeError("Testcase requires SESSION_ID environment variable")
  845. some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True, prefetch_size=100)
  846. # This dataset has 3 records in it only
  847. ds1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, cache=some_cache)
  848. decode_op = c_vision.Decode()
  849. ds1 = ds1.map(input_columns=["image"], operations=decode_op)
  850. ds1 = ds1.repeat(4)
  851. num_iter = 0
  852. for _ in ds1.create_dict_iterator():
  853. num_iter += 1
  854. logger.info("Number of data in ds1: {} ".format(num_iter))
  855. assert num_iter == 12
  856. logger.info("test_cache_nomap_prefetch_size_100 Ended.\n")
  857. @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
  858. def test_cache_nomap_to_device():
  859. """
  860. Test cache with to_device
  861. DeviceQueue
  862. |
  863. EpochCtrl
  864. |
  865. Repeat
  866. |
  867. Map(decode)
  868. |
  869. cache
  870. |
  871. TFReader
  872. """
  873. logger.info("Test cache nomap to_device")
  874. if "SESSION_ID" in os.environ:
  875. session_id = int(os.environ['SESSION_ID'])
  876. else:
  877. raise RuntimeError("Testcase requires SESSION_ID environment variable")
  878. some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True)
  879. # This dataset has 3 records in it only
  880. ds1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR)
  881. decode_op = c_vision.Decode()
  882. ds1 = ds1.map(input_columns=["image"], operations=decode_op, cache=some_cache)
  883. ds1 = ds1.repeat(4)
  884. ds1 = ds1.to_device()
  885. ds1.send()
  886. logger.info("test_cache_nomap_to_device Ended.\n")
  887. @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
  888. def test_cache_nomap_session_destroy():
  889. """
  890. Test executing cache_admin -d while the pipeline is running
  891. Repeat
  892. |
  893. Cache
  894. |
  895. RandomDataset
  896. """
  897. logger.info("Test cache nomap session destroy")
  898. if "SESSION_ID" in os.environ:
  899. session_id = int(os.environ['SESSION_ID'])
  900. else:
  901. raise RuntimeError("Testcase requires SESSION_ID environment variable")
  902. schema = ds.Schema()
  903. schema.add_column('image', de_type=mstype.uint8,
  904. shape=[640, 480, 3]) # 921600 bytes (a bit less than 1 MB per image)
  905. schema.add_column('label', de_type=mstype.uint8, shape=[1])
  906. some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True)
  907. # User-created sampler here
  908. ds1 = ds.RandomDataset(schema=schema, num_parallel_workers=4, cache=some_cache)
  909. ds1 = ds1.repeat()
  910. with pytest.raises(RuntimeError) as e:
  911. num_iter = 0
  912. for _ in ds1.create_dict_iterator():
  913. num_iter += 1
  914. assert "Unexpected error" in str(e.value)
  915. logger.info("test_cache_nomap_session_destroy Ended.\n")
  916. @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
  917. def test_cache_nomap_server_stop():
  918. """
  919. Test executing cache_admin --stop while the pipeline is running
  920. Repeat
  921. |
  922. Cache
  923. |
  924. RandomDataset
  925. """
  926. logger.info("Test cache nomap server stop")
  927. if "SESSION_ID" in os.environ:
  928. session_id = int(os.environ['SESSION_ID'])
  929. else:
  930. raise RuntimeError("Testcase requires SESSION_ID environment variable")
  931. schema = ds.Schema()
  932. schema.add_column('image', de_type=mstype.uint8,
  933. shape=[640, 480, 3]) # 921600 bytes (a bit less than 1 MB per image)
  934. schema.add_column('label', de_type=mstype.uint8, shape=[1])
  935. some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True)
  936. # User-created sampler here
  937. ds1 = ds.RandomDataset(schema=schema, num_parallel_workers=4, cache=some_cache)
  938. ds1 = ds1.repeat()
  939. with pytest.raises(RuntimeError) as e:
  940. num_iter = 0
  941. for _ in ds1.create_dict_iterator():
  942. num_iter += 1
  943. assert "Network error. Cache server is unreachable. Make sure the server is running." in str(e.value)
  944. logger.info("test_cache_nomap_server_stop Ended.\n")
  945. @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
  946. def test_cache_nomap_epoch_ctrl1():
  947. """
  948. Test using two-loops method to run several epochs
  949. Map(decode)
  950. |
  951. cache
  952. |
  953. TFRecord
  954. """
  955. logger.info("Test cache nomap epoch ctrl1")
  956. if "SESSION_ID" in os.environ:
  957. session_id = int(os.environ['SESSION_ID'])
  958. else:
  959. raise RuntimeError("Testcase requires SESSION_ID environment variable")
  960. some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True)
  961. # This dataset has 3 records in it only
  962. ds1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, cache=some_cache)
  963. decode_op = c_vision.Decode()
  964. ds1 = ds1.map(input_columns=["image"], operations=decode_op)
  965. num_epoch = 5
  966. iter1 = ds1.create_dict_iterator(num_epochs=num_epoch)
  967. epoch_count = 0
  968. for _ in range(num_epoch):
  969. row_count = 0
  970. for _ in iter1:
  971. row_count += 1
  972. logger.info("Number of data in ds1: {} ".format(row_count))
  973. assert row_count == 3
  974. epoch_count += 1
  975. assert epoch_count == num_epoch
  976. logger.info("test_cache_nomap_epoch_ctrl1 Ended.\n")
  977. @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
  978. def test_cache_nomap_epoch_ctrl2():
  979. """
  980. Test using two-loops method with infinite epochs
  981. cache
  982. |
  983. Map(decode)
  984. |
  985. TFRecord
  986. """
  987. logger.info("Test cache nomap epoch ctrl2")
  988. if "SESSION_ID" in os.environ:
  989. session_id = int(os.environ['SESSION_ID'])
  990. else:
  991. raise RuntimeError("Testcase requires SESSION_ID environment variable")
  992. some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True)
  993. # This dataset has 3 records in it only
  994. ds1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR)
  995. decode_op = c_vision.Decode()
  996. ds1 = ds1.map(input_columns=["image"], operations=decode_op, cache=some_cache)
  997. num_epoch = 5
  998. # iter1 will always assume there is a next epoch and never shutdown
  999. iter1 = ds1.create_dict_iterator()
  1000. epoch_count = 0
  1001. for _ in range(num_epoch):
  1002. row_count = 0
  1003. for _ in iter1:
  1004. row_count += 1
  1005. logger.info("Number of data in ds1: {} ".format(row_count))
  1006. assert row_count == 3
  1007. epoch_count += 1
  1008. assert epoch_count == num_epoch
  1009. # manually stop the iterator
  1010. iter1.stop()
  1011. logger.info("test_cache_nomap_epoch_ctrl2 Ended.\n")
  1012. @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
  1013. def test_cache_nomap_epoch_ctrl3():
  1014. """
  1015. Test using two-loops method with infinite epochs over repeat
  1016. repeat
  1017. |
  1018. Map(decode)
  1019. |
  1020. cache
  1021. |
  1022. TFRecord
  1023. """
  1024. logger.info("Test cache nomap epoch ctrl3")
  1025. if "SESSION_ID" in os.environ:
  1026. session_id = int(os.environ['SESSION_ID'])
  1027. else:
  1028. raise RuntimeError("Testcase requires SESSION_ID environment variable")
  1029. some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True)
  1030. # This dataset has 3 records in it only
  1031. ds1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, cache=some_cache)
  1032. decode_op = c_vision.Decode()
  1033. ds1 = ds1.map(input_columns=["image"], operations=decode_op)
  1034. ds1 = ds1.repeat(2)
  1035. num_epoch = 5
  1036. # iter1 will always assume there is a next epoch and never shutdown
  1037. iter1 = ds1.create_dict_iterator()
  1038. epoch_count = 0
  1039. for _ in range(num_epoch):
  1040. row_count = 0
  1041. for _ in iter1:
  1042. row_count += 1
  1043. logger.info("Number of data in ds1: {} ".format(row_count))
  1044. assert row_count == 6
  1045. epoch_count += 1
  1046. assert epoch_count == num_epoch
  1047. # reply on garbage collector to destroy iter1
  1048. logger.info("test_cache_nomap_epoch_ctrl3 Ended.\n")
  1049. @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
  1050. def test_cache_nomap_epoch_ctrl4():
  1051. """
  1052. Test using two-loops method with repeat under cache
  1053. cache
  1054. |
  1055. Map(decode)
  1056. |
  1057. repeat
  1058. |
  1059. TFRecord
  1060. """
  1061. logger.info("Test cache nomap epoch ctrl4")
  1062. if "SESSION_ID" in os.environ:
  1063. session_id = int(os.environ['SESSION_ID'])
  1064. else:
  1065. raise RuntimeError("Testcase requires SESSION_ID environment variable")
  1066. some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True)
  1067. # This dataset has 3 records in it only
  1068. ds1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR)
  1069. ds1 = ds1.repeat(2)
  1070. decode_op = c_vision.Decode()
  1071. ds1 = ds1.map(input_columns=["image"], operations=decode_op, cache=some_cache)
  1072. num_epoch = 5
  1073. iter1 = ds1.create_dict_iterator(num_epochs=num_epoch)
  1074. epoch_count = 0
  1075. for _ in range(num_epoch):
  1076. row_count = 0
  1077. for _ in iter1:
  1078. row_count += 1
  1079. logger.info("Number of data in ds1: {} ".format(row_count))
  1080. assert row_count == 6
  1081. epoch_count += 1
  1082. assert epoch_count == num_epoch
  1083. logger.info("test_cache_nomap_epoch_ctrl4 Ended.\n")
  1084. @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
  1085. def test_cache_nomap_multiple_cache1():
  1086. """
  1087. Test multiple cache in the same python script
  1088. cache cache
  1089. | |
  1090. Map(decode) Map(decode)
  1091. | |
  1092. TFRecord(train) TFRecord(eval)
  1093. """
  1094. logger.info("Test cache nomap multiple cache 1")
  1095. if "SESSION_ID" in os.environ:
  1096. session_id = int(os.environ['SESSION_ID'])
  1097. else:
  1098. raise RuntimeError("Testcase requires SESSION_ID environment variable")
  1099. train_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True)
  1100. eval_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True)
  1101. # This dataset has 12 records in it
  1102. train_dataset = ds.TFRecordDataset(TRAIN_DATA_DIR, TRAIN_SCHEMA_DIR)
  1103. decode_op = c_vision.Decode()
  1104. train_dataset = train_dataset.map(input_columns=["image"], operations=decode_op, cache=train_cache)
  1105. # This dataset has 3 records in it only
  1106. eval_dataset = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR)
  1107. eval_dataset = eval_dataset.map(input_columns=["image"], operations=decode_op, cache=eval_cache)
  1108. num_epoch = 5
  1109. train_iter = train_dataset.create_dict_iterator(num_epochs=num_epoch)
  1110. eval_iter = eval_dataset.create_dict_iterator(num_epochs=num_epoch)
  1111. epoch_count = 0
  1112. for _ in range(num_epoch):
  1113. assert sum([1 for _ in train_iter]) == 12
  1114. assert sum([1 for _ in eval_iter]) == 3
  1115. epoch_count += 1
  1116. assert epoch_count == num_epoch
  1117. logger.info("test_cache_nomap_multiple_cache1 Ended.\n")
  1118. @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
  1119. def test_cache_nomap_multiple_cache2():
  1120. """
  1121. Test multiple cache in the same python script
  1122. cache
  1123. |
  1124. Map(decode) cache
  1125. | |
  1126. TFRecord(image) TFRecord(text)
  1127. """
  1128. logger.info("Test cache nomap multiple cache 2")
  1129. if "SESSION_ID" in os.environ:
  1130. session_id = int(os.environ['SESSION_ID'])
  1131. else:
  1132. raise RuntimeError("Testcase requires SESSION_ID environment variable")
  1133. image_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True)
  1134. text_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True)
  1135. # This dataset has 3 records in it only
  1136. image_dataset = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR)
  1137. decode_op = c_vision.Decode()
  1138. image_dataset = image_dataset.map(input_columns=["image"], operations=decode_op, cache=image_cache)
  1139. # This dataset has 3 records in it only
  1140. text_dataset = ds.TFRecordDataset(TEXT_TF_DATA_DIR, SCHEMA_DIR2, cache=text_cache)
  1141. num_epoch = 5
  1142. image_iter = image_dataset.create_dict_iterator(num_epochs=num_epoch)
  1143. text_iter = text_dataset.create_dict_iterator(num_epochs=num_epoch, output_numpy=True)
  1144. epoch_count = 0
  1145. for _ in range(num_epoch):
  1146. row_count = 0
  1147. for _, _ in itertools.zip_longest(image_iter, text_iter):
  1148. row_count += 1
  1149. assert row_count == 3
  1150. epoch_count += 1
  1151. assert epoch_count == num_epoch
  1152. logger.info("test_cache_nomap_multiple_cache2 Ended.\n")
  1153. @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
  1154. def test_cache_nomap_multiple_cache3():
  1155. """
  1156. Test multiple cache in the same python script
  1157. cache cache
  1158. | |
  1159. Map(decode) Map(decode)
  1160. | |
  1161. TFRecord ImageFolder
  1162. """
  1163. logger.info("Test cache nomap multiple cache 3")
  1164. if "SESSION_ID" in os.environ:
  1165. session_id = int(os.environ['SESSION_ID'])
  1166. else:
  1167. raise RuntimeError("Testcase requires SESSION_ID environment variable")
  1168. tf_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True)
  1169. image_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True)
  1170. # This dataset has 3 records in it only
  1171. tf_dataset = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR)
  1172. decode_op = c_vision.Decode()
  1173. tf_dataset = tf_dataset.map(input_columns=["image"], operations=decode_op, cache=tf_cache)
  1174. # This DATA_DIR only has 2 images in it
  1175. image_dataset = ds.ImageFolderDataset(dataset_dir=IMAGE_FOLDER_DATA_DIR)
  1176. image_dataset = image_dataset.map(input_columns=["image"], operations=decode_op, cache=image_cache)
  1177. num_epoch = 5
  1178. tf_iter = tf_dataset.create_dict_iterator(num_epochs=num_epoch)
  1179. image_iter = image_dataset.create_dict_iterator(num_epochs=num_epoch)
  1180. epoch_count = 0
  1181. for _ in range(num_epoch):
  1182. assert sum([1 for _ in tf_iter]) == 3
  1183. assert sum([1 for _ in image_iter]) == 2
  1184. epoch_count += 1
  1185. assert epoch_count == num_epoch
  1186. logger.info("test_cache_nomap_multiple_cache3 Ended.\n")
  1187. @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
  1188. def test_cache_nomap_multiple_cache_train():
  1189. """
  1190. Test multiple cache in different python scripts. This test case is going to run concurrently with
  1191. test_cache_nomap_multiple_cache_eval.
  1192. cache
  1193. |
  1194. Map(decode)
  1195. |
  1196. TFRecord(train)
  1197. """
  1198. logger.info("Test cache nomap multiple cache train")
  1199. if "SESSION_ID" in os.environ:
  1200. session_id = int(os.environ['SESSION_ID'])
  1201. else:
  1202. raise RuntimeError("Testcase requires SESSION_ID environment variable")
  1203. train_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True)
  1204. # This dataset has 12 records in it
  1205. train_dataset = ds.TFRecordDataset(TRAIN_DATA_DIR, TRAIN_SCHEMA_DIR)
  1206. decode_op = c_vision.Decode()
  1207. train_dataset = train_dataset.map(input_columns=["image"], operations=decode_op, cache=train_cache)
  1208. num_epoch = 5
  1209. train_iter = train_dataset.create_dict_iterator(num_epochs=num_epoch)
  1210. epoch_count = 0
  1211. for _ in range(num_epoch):
  1212. assert sum([1 for _ in train_iter]) == 12
  1213. epoch_count += 1
  1214. assert epoch_count == num_epoch
  1215. logger.info("test_cache_nomap_multiple_cache_train Ended.\n")
  1216. @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
  1217. def test_cache_nomap_multiple_cache_eval():
  1218. """
  1219. Test multiple cache in different python scripts. This test case is going to run concurrently with
  1220. test_cache_nomap_multiple_cache_train.
  1221. cache
  1222. |
  1223. Map(decode)
  1224. |
  1225. TFRecord(eval)
  1226. """
  1227. logger.info("Test cache nomap multiple cache eval")
  1228. if "SESSION_ID" in os.environ:
  1229. session_id = int(os.environ['SESSION_ID'])
  1230. else:
  1231. raise RuntimeError("Testcase requires SESSION_ID environment variable")
  1232. eval_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True)
  1233. # This dataset only has 3 records in it
  1234. eval_dataset = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR)
  1235. decode_op = c_vision.Decode()
  1236. eval_dataset = eval_dataset.map(input_columns=["image"], operations=decode_op, cache=eval_cache)
  1237. num_epoch = 5
  1238. eval_iter = eval_dataset.create_dict_iterator(num_epochs=num_epoch)
  1239. epoch_count = 0
  1240. for _ in range(num_epoch):
  1241. assert sum([1 for _ in eval_iter]) == 3
  1242. epoch_count += 1
  1243. assert epoch_count == num_epoch
  1244. logger.info("test_cache_nomap_multiple_cache_eval Ended.\n")
  1245. @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
  1246. def test_cache_nomap_clue1():
  1247. """
  1248. A clue dataset (a non mappable dataset) with a cache over it just after the leaf
  1249. In this one, the clue dataset will be given sharding configuration, however since a cache is
  1250. used, the tree prepare should undo the sharding configuration and instead, a distributed
  1251. sampler will be chosen with the same shard config.
  1252. Cache
  1253. |
  1254. CLUE
  1255. """
  1256. logger.info("Test cache nomap clue 1")
  1257. if "SESSION_ID" in os.environ:
  1258. session_id = int(os.environ['SESSION_ID'])
  1259. else:
  1260. raise RuntimeError("Testcase requires SESSION_ID environment variable")
  1261. some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True)
  1262. # With only 3 records shard into 3, we expect only 1 record returned for this shard
  1263. # However, the sharding will be done by the sampler, not by the clue leaf node
  1264. # In this case, it is a row-based sharding, not the file-based sharding that would happen if
  1265. # there was not any cache.
  1266. ds1 = ds.CLUEDataset(CLUE_DATA_DIR, task='AFQMC', usage='train', num_shards=3, shard_id=1, cache=some_cache)
  1267. num_epoch = 4
  1268. iter1 = ds1.create_dict_iterator(num_epochs=num_epoch, output_numpy=True)
  1269. epoch_count = 0
  1270. for _ in range(num_epoch):
  1271. assert sum([1 for _ in iter1]) == 1
  1272. epoch_count += 1
  1273. assert epoch_count == num_epoch
  1274. logger.info("test_cache_nomap_clue1 Ended.\n")
  1275. @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
  1276. def test_cache_nomap_clue2():
  1277. """
  1278. A clue dataset (a non mappable dataset) with a cache over it after map
  1279. In this one, a num_samples argument is given
  1280. Cache
  1281. |
  1282. map(lambda x: x)
  1283. |
  1284. CLUE
  1285. """
  1286. logger.info("Test cache nomap clue 2")
  1287. if "SESSION_ID" in os.environ:
  1288. session_id = int(os.environ['SESSION_ID'])
  1289. else:
  1290. raise RuntimeError("Testcase requires SESSION_ID environment variable")
  1291. some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True)
  1292. ds1 = ds.CLUEDataset(CLUE_DATA_DIR, task='AFQMC', usage='train', num_samples=2)
  1293. ds1 = ds1.map((lambda x: x), ["label"], cache=some_cache)
  1294. num_epoch = 4
  1295. iter1 = ds1.create_dict_iterator(num_epochs=num_epoch, output_numpy=True)
  1296. epoch_count = 0
  1297. for _ in range(num_epoch):
  1298. assert sum([1 for _ in iter1]) == 2
  1299. epoch_count += 1
  1300. assert epoch_count == num_epoch
  1301. logger.info("test_cache_nomap_clue2 Ended.\n")
  1302. @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
  1303. def test_cache_nomap_csv1():
  1304. """
  1305. A csv dataset (a non mappable dataset) with a cache over it just after the leaf
  1306. In this one, the csv dataset will be given sharding configuration, however since a cache is
  1307. used, the tree prepare should undo the sharding configuration and instead, a distributed
  1308. sampler will be chosen with the same shard config.
  1309. Cache
  1310. |
  1311. CSV
  1312. """
  1313. logger.info("Test cache nomap csv 1")
  1314. if "SESSION_ID" in os.environ:
  1315. session_id = int(os.environ['SESSION_ID'])
  1316. else:
  1317. raise RuntimeError("Testcase requires SESSION_ID environment variable")
  1318. some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True)
  1319. # With only 3 records shard into 3, we expect only 1 record returned for this shard
  1320. # However, the sharding will be done by the sampler, not by the clue leaf node
  1321. # In this case, it is a row-based sharding, not the file-based sharding that would happen if
  1322. # there was not any cache.
  1323. ds1 = ds.CSVDataset(CSV_DATA_DIR, column_defaults=["1", "2", "3", "4"],
  1324. column_names=['col1', 'col2', 'col3', 'col4'], num_shards=3, shard_id=1, cache=some_cache)
  1325. num_epoch = 4
  1326. iter1 = ds1.create_dict_iterator(num_epochs=num_epoch, output_numpy=True)
  1327. epoch_count = 0
  1328. for _ in range(num_epoch):
  1329. assert sum([1 for _ in iter1]) == 1
  1330. epoch_count += 1
  1331. assert epoch_count == num_epoch
  1332. logger.info("test_cache_nomap_csv1 Ended.\n")
  1333. @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
  1334. def test_cache_nomap_csv2():
  1335. """
  1336. A csv dataset (a non mappable dataset) with a cache over it after map
  1337. In this one, a num_samples argument is given
  1338. Cache
  1339. |
  1340. map(lambda x: x)
  1341. |
  1342. CSV
  1343. """
  1344. logger.info("Test cache nomap csv 2")
  1345. if "SESSION_ID" in os.environ:
  1346. session_id = int(os.environ['SESSION_ID'])
  1347. else:
  1348. raise RuntimeError("Testcase requires SESSION_ID environment variable")
  1349. some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True)
  1350. ds1 = ds.CSVDataset(CSV_DATA_DIR, column_defaults=["1", "2", "3", "4"],
  1351. column_names=['col1', 'col2', 'col3', 'col4'], num_samples=2)
  1352. ds1 = ds1.map((lambda x: x), ["col1"], cache=some_cache)
  1353. num_epoch = 4
  1354. iter1 = ds1.create_dict_iterator(num_epochs=num_epoch, output_numpy=True)
  1355. epoch_count = 0
  1356. for _ in range(num_epoch):
  1357. assert sum([1 for _ in iter1]) == 2
  1358. epoch_count += 1
  1359. assert epoch_count == num_epoch
  1360. logger.info("test_cache_nomap_csv2 Ended.\n")
  1361. @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
  1362. def test_cache_nomap_textfile1():
  1363. """
  1364. A text file dataset (a non mappable dataset) with a cache over it just after the leaf
  1365. In this one, the text file dataset will be given sharding configuration, however since a cache is
  1366. used, the tree prepare should undo the sharding configuration and instead, a distributed
  1367. sampler will be chosen with the same shard config.
  1368. Cache
  1369. |
  1370. TextFile
  1371. """
  1372. logger.info("Test cache nomap textfile 1")
  1373. if "SESSION_ID" in os.environ:
  1374. session_id = int(os.environ['SESSION_ID'])
  1375. else:
  1376. raise RuntimeError("Testcase requires SESSION_ID environment variable")
  1377. some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True)
  1378. # With only 3 records shard into 3, we expect only 1 record returned for this shard
  1379. # However, the sharding will be done by the sampler, not by the clue leaf node
  1380. # In this case, it is a row-based sharding, not the file-based sharding that would happen if
  1381. # there was not any cache.
  1382. ds1 = ds.CSVDataset(TEXT_FILE_DATA_DIR, num_shards=3, shard_id=1, cache=some_cache)
  1383. num_epoch = 4
  1384. iter1 = ds1.create_dict_iterator(num_epochs=num_epoch, output_numpy=True)
  1385. epoch_count = 0
  1386. for _ in range(num_epoch):
  1387. assert sum([1 for _ in iter1]) == 1
  1388. epoch_count += 1
  1389. assert epoch_count == num_epoch
  1390. logger.info("test_cache_nomap_textfile1 Ended.\n")
  1391. @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
  1392. def test_cache_nomap_textfile2():
  1393. """
  1394. A text file dataset (a non mappable dataset) with a cache over it after map
  1395. In this one, a num_samples argument is given
  1396. Cache
  1397. |
  1398. Map(tokenizer)
  1399. |
  1400. TextFile
  1401. """
  1402. def my_tokenizer(line):
  1403. words = line.split()
  1404. if not words:
  1405. return [""]
  1406. return words
  1407. logger.info("Test cache nomap textfile 2")
  1408. if "SESSION_ID" in os.environ:
  1409. session_id = int(os.environ['SESSION_ID'])
  1410. else:
  1411. raise RuntimeError("Testcase requires SESSION_ID environment variable")
  1412. some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True)
  1413. ds1 = ds.TextFileDataset(TEXT_FILE_DATA_DIR, num_samples=2)
  1414. tokenizer = text.PythonTokenizer(my_tokenizer)
  1415. ds1 = ds1.map(operations=tokenizer, cache=some_cache)
  1416. num_epoch = 4
  1417. iter1 = ds1.create_dict_iterator(num_epochs=num_epoch, output_numpy=True)
  1418. epoch_count = 0
  1419. for _ in range(num_epoch):
  1420. assert sum([1 for _ in iter1]) == 2
  1421. epoch_count += 1
  1422. assert epoch_count == num_epoch
  1423. logger.info("test_cache_nomap_textfile2 Ended.\n")
  1424. @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
  1425. def test_cache_nomap_nested_repeat():
  1426. """
  1427. Test cache on pipeline with nested repeat ops
  1428. Repeat
  1429. |
  1430. Cache
  1431. |
  1432. Map(decode)
  1433. |
  1434. Repeat
  1435. |
  1436. TFRecord
  1437. """
  1438. logger.info("Test cache nomap nested repeat")
  1439. if "SESSION_ID" in os.environ:
  1440. session_id = int(os.environ['SESSION_ID'])
  1441. else:
  1442. raise RuntimeError("Testcase requires SESSION_ID environment variable")
  1443. some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True)
  1444. # This dataset has 3 records in it only
  1445. ds1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR)
  1446. decode_op = c_vision.Decode()
  1447. ds1 = ds1.repeat(4)
  1448. ds1 = ds1.map(operations=decode_op, input_columns=["image"], cache=some_cache)
  1449. ds1 = ds1.repeat(2)
  1450. num_iter = 0
  1451. for _ in ds1.create_dict_iterator(num_epochs=1):
  1452. logger.info("get data from dataset")
  1453. num_iter += 1
  1454. logger.info("Number of data in ds1: {} ".format(num_iter))
  1455. assert num_iter == 24
  1456. logger.info('test_cache_nomap_nested_repeat Ended.\n')
  1457. if __name__ == '__main__':
  1458. test_cache_nomap_basic1()
  1459. test_cache_nomap_basic2()
  1460. test_cache_nomap_basic3()
  1461. test_cache_nomap_basic4()
  1462. test_cache_nomap_basic5()
  1463. test_cache_nomap_basic6()
  1464. test_cache_nomap_basic7()
  1465. test_cache_nomap_allowed_share1()
  1466. test_cache_nomap_allowed_share2()
  1467. test_cache_nomap_allowed_share3()
  1468. test_cache_nomap_allowed_share4()
  1469. test_cache_nomap_disallowed_share1()