You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_cache_nomap.py 60 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509151015111512151315141515151615171518151915201521152215231524152515261527152815291530153115321533153415351536153715381539154015411542154315441545154615471548154915501551155215531554155515561557155815591560156115621563156415651566156715681569157015711572157315741575157615771578157915801581158215831584158515861587158815891590159115921593159415951596159715981599160016011602160316041605160616071608160916101611161216131614161516161617161816191620162116221623162416251626162716281629163016311632163316341635163616371638163916401641164216431644164516461647164816491650165116521653165416551656165716581659166016611662166316641665166616671668166916701671167216731674167516761677167816791680168116821683168416851686168716881689169016911692169316941695169616971698169917001701170217031704170517061707170817091710171117121713171417151716171717181719172017211722172317241725172617271728172917301731173217331734173517361737173817391740174117421743174417451746174717481749
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. """
  16. Testing cache operator with non-mappable datasets
  17. """
  18. import os
  19. import itertools
  20. import pytest
  21. import mindspore.common.dtype as mstype
  22. import mindspore.dataset as ds
  23. import mindspore.dataset.text as text
  24. import mindspore.dataset.vision.c_transforms as c_vision
  25. from mindspore import log as logger
  26. DATA_DIR = ["../data/dataset/test_tf_file_3_images/train-0000-of-0001.data"]
  27. SCHEMA_DIR = "../data/dataset/test_tf_file_3_images/datasetSchema.json"
  28. TEXT_TF_DATA_DIR = ["../data/dataset/testTextTFRecord/text.tfrecord"]
  29. SCHEMA_DIR2 = "../data/dataset/testTextTFRecord/datasetSchema.json"
  30. TRAIN_DATA_DIR = ["../data/dataset/test_tf_file_3_images2/train-0000-of-0001.data",
  31. "../data/dataset/test_tf_file_3_images2/train-0000-of-0002.data",
  32. "../data/dataset/test_tf_file_3_images2/train-0000-of-0003.data",
  33. "../data/dataset/test_tf_file_3_images2/train-0000-of-0004.data"]
  34. TRAIN_SCHEMA_DIR = "../data/dataset/test_tf_file_3_images2/datasetSchema.json"
  35. IMAGE_FOLDER_DATA_DIR = "../data/dataset/testImageNetData/train/"
  36. CLUE_DATA_DIR = '../data/dataset/testCLUE/afqmc/train.json'
  37. CSV_DATA_DIR = '../data/dataset/testCSV/1.csv'
  38. TEXT_FILE_DATA_DIR = "../data/dataset/testTextFileDataset/1.txt"
  39. GENERATE_GOLDEN = False
  40. @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
  41. def test_cache_nomap_basic1():
  42. """
  43. A random dataset (a non mappable dataset) with a cache over it just after the leaf
  44. """
  45. logger.info("Test cache nomap basic 1")
  46. if "SESSION_ID" in os.environ:
  47. session_id = int(os.environ['SESSION_ID'])
  48. else:
  49. raise RuntimeError("Testcase requires SESSION_ID environment variable")
  50. schema = ds.Schema()
  51. schema.add_column('image', de_type=mstype.uint8,
  52. shape=[640, 480, 3]) # 921600 bytes (a bit less than 1 MB per image)
  53. schema.add_column('label', de_type=mstype.uint8, shape=[1])
  54. # create a cache. arbitrary session_id for now
  55. some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True)
  56. # User-created sampler here
  57. ds1 = ds.RandomDataset(schema=schema, total_rows=10, num_parallel_workers=4, cache=some_cache)
  58. ds1 = ds1.repeat(4)
  59. num_iter = 0
  60. for data in ds1.create_dict_iterator(num_epochs=1):
  61. logger.info("printing the label: {}".format(data["label"]))
  62. num_iter += 1
  63. logger.info("Number of data in ds1: {} ".format(num_iter))
  64. assert num_iter == 40
  65. logger.info("test_cache_nomap_basic1 Ended.\n")
  66. @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
  67. def test_cache_nomap_basic2():
  68. """
  69. A random dataset (a non mappable dataset) with a cache over it just after the leaf
  70. """
  71. logger.info("Test cache nomap basic 2")
  72. if "SESSION_ID" in os.environ:
  73. session_id = int(os.environ['SESSION_ID'])
  74. else:
  75. raise RuntimeError("Testcase requires SESSION_ID environment variable")
  76. schema = ds.Schema()
  77. schema.add_column('image', de_type=mstype.uint8,
  78. shape=[640, 480, 3]) # 921600 bytes (a bit less than 1 MB per image)
  79. schema.add_column('label', de_type=mstype.uint8, shape=[1])
  80. # create a cache. arbitrary session_id for now
  81. some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True)
  82. # sampler arg not given directly, however any of these args will auto-generate an appropriate sampler:
  83. # num_samples, shuffle, num_shards, shard_id
  84. # In this case, the presence of num_samples chooses a sampler.
  85. ds1 = ds.RandomDataset(schema=schema, total_rows=20, num_samples=20, num_parallel_workers=4, cache=some_cache)
  86. ds1 = ds1.repeat(2)
  87. num_iter = 0
  88. for data in ds1.create_dict_iterator(num_epochs=1):
  89. logger.info("printing the label: {}".format(data["label"]))
  90. num_iter += 1
  91. logger.info("Number of data in ds1: {} ".format(num_iter))
  92. assert num_iter == 40
  93. logger.info("test_cache_nomap_basic2 Ended.\n")
  94. @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
  95. def test_cache_nomap_basic3():
  96. """
  97. A TF reader dataset (a non mappable dataset) with a cache over it just after the leaf
  98. Repeat
  99. |
  100. Map(decode)
  101. |
  102. Cache
  103. |
  104. TFReader
  105. """
  106. logger.info("Test cache nomap basic 3")
  107. if "SESSION_ID" in os.environ:
  108. session_id = int(os.environ['SESSION_ID'])
  109. else:
  110. raise RuntimeError("Testcase requires SESSION_ID environment variable")
  111. some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True)
  112. ds1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False, cache=some_cache)
  113. decode_op = c_vision.Decode()
  114. ds1 = ds1.map(operations=decode_op, input_columns=["image"])
  115. ds1 = ds1.repeat(4)
  116. num_iter = 0
  117. for _ in ds1.create_dict_iterator(num_epochs=1):
  118. num_iter += 1
  119. logger.info("Number of data in ds1: {} ".format(num_iter))
  120. assert num_iter == 12
  121. # Contact the server to get the statistics
  122. stat = some_cache.GetStat()
  123. cache_sz = stat.avg_cache_sz
  124. num_mem_cached = stat.num_mem_cached
  125. num_disk_cached = stat.num_disk_cached
  126. logger.info("Number of rows cached in memory: {}".format(num_mem_cached))
  127. logger.info("Number of rows spilled to disk: {}".format(num_disk_cached))
  128. logger.info("Average row cache size: {}".format(cache_sz))
  129. logger.info("test_cache_nomap_basic3 Ended.\n")
  130. @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
  131. def test_cache_nomap_basic4():
  132. """
  133. A TF reader dataset (a non mappable dataset) with a map decode and cache after it
  134. Since a global shuffle is used for the tf reader, it will inject a shuffle op over the tf.
  135. But, if there's a cache later, that shuffle becomes invalid and should be removed.
  136. Repeat
  137. |
  138. Cache
  139. |
  140. Map(decode)
  141. |
  142. TFReader
  143. """
  144. logger.info("Test cache nomap basic 4")
  145. if "SESSION_ID" in os.environ:
  146. session_id = int(os.environ['SESSION_ID'])
  147. else:
  148. raise RuntimeError("Testcase requires SESSION_ID environment variable")
  149. # This dataset has 3 records in it only
  150. some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True)
  151. # With shuffle not being set, TF defaults to a "global" shuffle when there is no cache
  152. # in the picture. This causes a shuffle-injection over the TF. For clarify, this test will
  153. # explicitly give the global option, even though it's the default in python.
  154. # But, when caching is added in the ascendent tree above TF, we do global shuffling
  155. # through the sampler over the cache, not by the shuffle op. In that case, tree prepare
  156. # will remove the shuffle op that got injected by the initial tree creation.
  157. ds1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=ds.Shuffle.GLOBAL)
  158. decode_op = c_vision.Decode()
  159. ds1 = ds1.map(operations=decode_op, input_columns=["image"], cache=some_cache)
  160. ds1 = ds1.repeat(4)
  161. num_iter = 0
  162. for _ in ds1.create_dict_iterator(num_epochs=1):
  163. num_iter += 1
  164. logger.info("Number of data in ds1: {} ".format(num_iter))
  165. assert num_iter == 12
  166. logger.info("test_cache_nomap_basic4 Ended.\n")
  167. @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
  168. def test_cache_nomap_basic5():
  169. """
  170. A TF reader dataset (a non mappable dataset) with a cache over it just after the leaf
  171. Same as test 3, but this one does not have shuffle arg, causing tf to default to global
  172. shuffle which attempts to inject a shuffle operator. However, since there is a cache
  173. we do not need global shuffle, so the shuffle will not be built. It ends up being
  174. identical to test basic 3, however we arrive at the same tree in different codepaths
  175. (if there was no cache, then the shuffle IS built)
  176. Repeat
  177. |
  178. Map(decode)
  179. |
  180. Cache
  181. |
  182. TFReader
  183. """
  184. logger.info("Test cache nomap basic 5")
  185. if "SESSION_ID" in os.environ:
  186. session_id = int(os.environ['SESSION_ID'])
  187. else:
  188. raise RuntimeError("Testcase requires SESSION_ID environment variable")
  189. # This dataset has 3 records in it only
  190. some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True)
  191. ds1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], cache=some_cache)
  192. decode_op = c_vision.Decode()
  193. ds1 = ds1.map(operations=decode_op, input_columns=["image"])
  194. ds1 = ds1.repeat(4)
  195. num_iter = 0
  196. for _ in ds1.create_dict_iterator(num_epochs=1):
  197. num_iter += 1
  198. logger.info("Number of data in ds1: {} ".format(num_iter))
  199. assert num_iter == 12
  200. logger.info("test_cache_nomap_basic5 Ended.\n")
  201. @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
  202. def test_cache_nomap_basic6():
  203. """
  204. A TF reader dataset (a non mappable dataset) with a cache over it just after the leaf
  205. In this one, the tf dataset will be given sharding configuration, however since a cache is
  206. used, the tree prepare should undo the sharding configuration and instead, a distributed
  207. sampler will be chosen with the same shard config.
  208. Repeat
  209. |
  210. Map(decode)
  211. |
  212. Cache
  213. |
  214. TFReader
  215. """
  216. logger.info("Test cache nomap basic 6")
  217. if "SESSION_ID" in os.environ:
  218. session_id = int(os.environ['SESSION_ID'])
  219. else:
  220. raise RuntimeError("Testcase requires SESSION_ID environment variable")
  221. # This dataset has 3 records in it only
  222. some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True)
  223. # With only 3 records shard into 3, we expect only 1 record returned for this shard
  224. # However, the sharding will be done by the sampler, not by the tf record leaf node
  225. # In this case, it is a row-based sharding, not the file-based sharding that would happen if
  226. # there was not any cache.
  227. ds1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], num_shards=3, shard_id=1, cache=some_cache)
  228. decode_op = c_vision.Decode()
  229. ds1 = ds1.map(operations=decode_op, input_columns=["image"])
  230. ds1 = ds1.repeat(4)
  231. num_iter = 0
  232. for _ in ds1.create_dict_iterator(num_epochs=1):
  233. num_iter += 1
  234. logger.info("Number of data in ds1: {} ".format(num_iter))
  235. assert num_iter == 4
  236. logger.info("test_cache_nomap_basic6 Ended.\n")
  237. @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
  238. def test_cache_nomap_basic7():
  239. """
  240. A TF reader dataset (a non mappable dataset) that uses global shuffle, and is cached followed by
  241. map.
  242. In this one, the tf dataset with global shuffle might want to inject a shuffle op over top of the
  243. tf reader, but since a cache is given, it will choose not to.
  244. Repeat
  245. |
  246. Map(decode)
  247. |
  248. cache
  249. |
  250. TFReader
  251. """
  252. logger.info("Test cache nomap basic 7")
  253. if "SESSION_ID" in os.environ:
  254. session_id = int(os.environ['SESSION_ID'])
  255. else:
  256. raise RuntimeError("Testcase requires SESSION_ID environment variable")
  257. some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True)
  258. # This dataset has 3 records in it only
  259. ds1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=ds.Shuffle.GLOBAL, cache=some_cache)
  260. decode_op = c_vision.Decode()
  261. ds1 = ds1.map(operations=decode_op, input_columns=["image"])
  262. ds1 = ds1.repeat(4)
  263. num_iter = 0
  264. for _ in ds1.create_dict_iterator(num_epochs=1):
  265. num_iter += 1
  266. logger.info("Number of data in ds1: {} ".format(num_iter))
  267. assert num_iter == 12
  268. logger.info("test_cache_nomap_basic7 Ended.\n")
  269. @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
  270. def test_cache_nomap_basic8():
  271. """
  272. Test cache as root node
  273. cache
  274. |
  275. TFReader
  276. """
  277. logger.info("Test cache basic 4")
  278. if "SESSION_ID" in os.environ:
  279. session_id = int(os.environ['SESSION_ID'])
  280. else:
  281. raise RuntimeError("Testcase requires SESSION_ID environment variable")
  282. some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True)
  283. # This dataset has 3 records in it only
  284. ds1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, cache=some_cache)
  285. num_iter = 0
  286. for _ in ds1.create_dict_iterator(num_epochs=1):
  287. logger.info("get data from dataset")
  288. num_iter += 1
  289. logger.info("Number of data in ds1: {} ".format(num_iter))
  290. assert num_iter == 3
  291. logger.info('test_cache_basic3 Ended.\n')
  292. @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
  293. def test_cache_nomap_allowed_share1():
  294. """
  295. It is allowed to share the cache between the following two trees:
  296. Repeat Shuffle
  297. | |
  298. Cache Cache
  299. | |
  300. TFReader TFReader
  301. """
  302. logger.info("Test cache nomap allowed share 1")
  303. if "SESSION_ID" in os.environ:
  304. session_id = int(os.environ['SESSION_ID'])
  305. else:
  306. raise RuntimeError("Testcase requires SESSION_ID environment variable")
  307. ds.config.set_seed(1)
  308. # This dataset has 3 records in it only
  309. some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True, prefetch_size=32)
  310. ds1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False, cache=some_cache)
  311. ds1 = ds1.repeat(4)
  312. ds2 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False, cache=some_cache)
  313. ds2 = ds2.shuffle(buffer_size=2)
  314. num_iter = 0
  315. for _ in ds1.create_dict_iterator(num_epochs=1):
  316. num_iter += 1
  317. assert num_iter == 12
  318. logger.info("Number of data in ds1: {} ".format(num_iter))
  319. num_iter = 0
  320. for _ in ds2.create_dict_iterator(num_epochs=1):
  321. num_iter += 1
  322. assert num_iter == 3
  323. logger.info("test_cache_nomap_allowed_share1 Ended.\n")
  324. @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
  325. def test_cache_nomap_allowed_share2():
  326. """
  327. It is allowed to share the cache between the following two trees (with map decode):
  328. Repeat Shuffle
  329. | |
  330. Cache Cache
  331. | |
  332. Map(decode) Map(decode)
  333. | |
  334. TFReader TFReader
  335. """
  336. logger.info("Test cache nomap allowed share 2")
  337. if "SESSION_ID" in os.environ:
  338. session_id = int(os.environ['SESSION_ID'])
  339. else:
  340. raise RuntimeError("Testcase requires SESSION_ID environment variable")
  341. ds.config.set_seed(1)
  342. # This dataset has 3 records in it only
  343. some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True)
  344. decode_op = c_vision.Decode()
  345. ds1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False)
  346. ds1 = ds1.map(operations=decode_op, input_columns=["image"], cache=some_cache)
  347. ds1 = ds1.repeat(4)
  348. ds2 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False)
  349. ds2 = ds2.map(operations=decode_op, input_columns=["image"], cache=some_cache)
  350. ds2 = ds2.shuffle(buffer_size=2)
  351. num_iter = 0
  352. for _ in ds1.create_dict_iterator(num_epochs=1):
  353. num_iter += 1
  354. logger.info("Number of data in ds1: {} ".format(num_iter))
  355. assert num_iter == 12
  356. num_iter = 0
  357. for _ in ds2.create_dict_iterator(num_epochs=1):
  358. num_iter += 1
  359. assert num_iter == 3
  360. logger.info("test_cache_nomap_allowed_share2 Ended.\n")
  361. @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
  362. def test_cache_nomap_allowed_share3():
  363. """
  364. It is allowed to share the cache between the following two trees (different shard ids):
  365. Repeat Repeat
  366. | |
  367. Cache Cache
  368. | |
  369. TFReader(shard_id = 0) TFReader(shard_id = 1)
  370. """
  371. logger.info("Test cache nomap allowed share 3")
  372. if "SESSION_ID" in os.environ:
  373. session_id = int(os.environ['SESSION_ID'])
  374. else:
  375. raise RuntimeError("Testcase requires SESSION_ID environment variable")
  376. some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True)
  377. tf_files = ["../data/dataset/tf_file_dataset/test1.data", "../data/dataset/tf_file_dataset/test2.data"]
  378. ds1 = ds.TFRecordDataset(tf_files, num_shards=2, shard_id=0, num_samples=3, shuffle=False, cache=some_cache)
  379. ds1 = ds1.repeat(4)
  380. ds2 = ds.TFRecordDataset(tf_files, num_shards=2, shard_id=1, num_samples=3, shuffle=False, cache=some_cache)
  381. ds2 = ds2.repeat(4)
  382. num_iter = 0
  383. for _ in ds1.create_dict_iterator(num_epochs=1):
  384. num_iter += 1
  385. logger.info("Number of data in ds1: {} ".format(num_iter))
  386. assert num_iter == 12
  387. num_iter = 0
  388. for _ in ds2.create_dict_iterator(num_epochs=1):
  389. num_iter += 1
  390. assert num_iter == 12
  391. logger.info("test_cache_nomap_allowed_share3 Ended.\n")
  392. @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
  393. def test_cache_nomap_allowed_share4():
  394. """
  395. It is allowed to share the cache between the following two trees:
  396. Cache Cache
  397. | |
  398. Map(decode, num_parallel_workers=1) Map(decode, num_parallel_workers=2)
  399. | |
  400. TFReader TFReader
  401. """
  402. logger.info("Test cache nomap allowed share 4")
  403. if "SESSION_ID" in os.environ:
  404. session_id = int(os.environ['SESSION_ID'])
  405. else:
  406. raise RuntimeError("Testcase requires SESSION_ID environment variable")
  407. # This dataset has 3 records in it only
  408. some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True)
  409. decode_op = c_vision.Decode()
  410. ds1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False)
  411. ds1 = ds1.map(operations=decode_op, input_columns=["image"], cache=some_cache, num_parallel_workers=1)
  412. ds2 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False)
  413. ds2 = ds2.map(operations=decode_op, input_columns=["image"], cache=some_cache, num_parallel_workers=2)
  414. num_iter = 0
  415. for _ in ds1.create_dict_iterator(num_epochs=1):
  416. num_iter += 1
  417. logger.info("Number of data in ds1: {} ".format(num_iter))
  418. assert num_iter == 3
  419. num_iter = 0
  420. for _ in ds2.create_dict_iterator(num_epochs=1):
  421. num_iter += 1
  422. logger.info("Number of data in ds2: {} ".format(num_iter))
  423. assert num_iter == 3
  424. logger.info("test_cache_nomap_allowed_share4 Ended.\n")
  425. @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
  426. def test_cache_nomap_disallowed_share1():
  427. """
  428. It is not allowed to share the cache between the following two trees:
  429. Cache Cache
  430. | |
  431. Map(decode) Map(rescale)
  432. | |
  433. TFReader TFReader
  434. """
  435. logger.info("Test cache nomap disallowed share1")
  436. if "SESSION_ID" in os.environ:
  437. session_id = int(os.environ['SESSION_ID'])
  438. else:
  439. raise RuntimeError("Testcase requires SESSION_ID environment variable")
  440. # This dataset has 3 records in it only
  441. some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True)
  442. decode_op = c_vision.Decode()
  443. rescale_op = c_vision.Rescale(1.0 / 255.0, -1.0)
  444. ds1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False)
  445. ds1 = ds1.map(operations=decode_op, input_columns=["image"], cache=some_cache)
  446. ds2 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False)
  447. ds2 = ds2.map(operations=rescale_op, input_columns=["image"], cache=some_cache)
  448. num_iter = 0
  449. for _ in ds1.create_dict_iterator(num_epochs=1):
  450. num_iter += 1
  451. logger.info("Number of data in ds1: {} ".format(num_iter))
  452. assert num_iter == 3
  453. with pytest.raises(RuntimeError) as e:
  454. sum([1 for _ in ds2])
  455. assert "Attempt to re-use a cache for a different tree!" in str(e.value)
  456. logger.info("test_cache_nomap_disallowed_share1 Ended.\n")
  457. @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
  458. def test_cache_nomap_running_twice1():
  459. """
  460. Executing the same pipeline for twice (from python), with cache injected after map
  461. Repeat
  462. |
  463. Cache
  464. |
  465. Map(decode)
  466. |
  467. TFRecord
  468. """
  469. logger.info("Test cache nomap running twice 1")
  470. if "SESSION_ID" in os.environ:
  471. session_id = int(os.environ['SESSION_ID'])
  472. else:
  473. raise RuntimeError("Testcase requires SESSION_ID environment variable")
  474. some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True)
  475. # This dataset has 3 records in it only
  476. ds1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR)
  477. decode_op = c_vision.Decode()
  478. ds1 = ds1.map(input_columns=["image"], operations=decode_op, cache=some_cache)
  479. ds1 = ds1.repeat(4)
  480. num_iter = 0
  481. for _ in ds1.create_dict_iterator():
  482. num_iter += 1
  483. logger.info("Number of data in ds1: {} ".format(num_iter))
  484. assert num_iter == 12
  485. num_iter = 0
  486. for _ in ds1.create_dict_iterator():
  487. num_iter += 1
  488. logger.info("Number of data in ds1: {} ".format(num_iter))
  489. assert num_iter == 12
  490. logger.info("test_cache_nomap_running_twice1 Ended.\n")
  491. @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
  492. def test_cache_nomap_running_twice2():
  493. """
  494. Executing the same pipeline for twice (from shell), with cache injected after leaf
  495. Repeat
  496. |
  497. Map(decode)
  498. |
  499. Cache
  500. |
  501. TFRecord
  502. """
  503. logger.info("Test cache nomap running twice 2")
  504. if "SESSION_ID" in os.environ:
  505. session_id = int(os.environ['SESSION_ID'])
  506. else:
  507. raise RuntimeError("Testcase requires SESSION_ID environment variable")
  508. some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True)
  509. # This dataset has 3 records in it only
  510. ds1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, cache=some_cache)
  511. decode_op = c_vision.Decode()
  512. ds1 = ds1.map(input_columns=["image"], operations=decode_op)
  513. ds1 = ds1.repeat(4)
  514. num_iter = 0
  515. for _ in ds1.create_dict_iterator():
  516. num_iter += 1
  517. logger.info("Number of data in ds1: {} ".format(num_iter))
  518. assert num_iter == 12
  519. logger.info("test_cache_nomap_running_twice2 Ended.\n")
  520. @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
  521. def test_cache_nomap_extra_small_size1():
  522. """
  523. Test running pipeline with cache of extra small size and spilling true
  524. Repeat
  525. |
  526. Map(decode)
  527. |
  528. Cache
  529. |
  530. TFRecord
  531. """
  532. logger.info("Test cache nomap extra small size 1")
  533. if "SESSION_ID" in os.environ:
  534. session_id = int(os.environ['SESSION_ID'])
  535. else:
  536. raise RuntimeError("Testcase requires SESSION_ID environment variable")
  537. some_cache = ds.DatasetCache(session_id=session_id, size=1, spilling=True)
  538. # This dataset has 3 records in it only
  539. ds1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, cache=some_cache)
  540. decode_op = c_vision.Decode()
  541. ds1 = ds1.map(input_columns=["image"], operations=decode_op)
  542. ds1 = ds1.repeat(4)
  543. num_iter = 0
  544. for _ in ds1.create_dict_iterator():
  545. num_iter += 1
  546. logger.info("Number of data in ds1: {} ".format(num_iter))
  547. assert num_iter == 12
  548. logger.info("test_cache_nomap_extra_small_size1 Ended.\n")
  549. @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
  550. def test_cache_nomap_extra_small_size2():
  551. """
  552. Test running pipeline with cache of extra small size and spilling false (failure)
  553. Repeat
  554. |
  555. Cache
  556. |
  557. Map(decode)
  558. |
  559. TFRecord
  560. """
  561. logger.info("Test cache nomap extra small size 2")
  562. if "SESSION_ID" in os.environ:
  563. session_id = int(os.environ['SESSION_ID'])
  564. else:
  565. raise RuntimeError("Testcase requires SESSION_ID environment variable")
  566. some_cache = ds.DatasetCache(session_id=session_id, size=1, spilling=False)
  567. # This dataset has 3 records in it only
  568. ds1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR)
  569. decode_op = c_vision.Decode()
  570. ds1 = ds1.map(input_columns=["image"], operations=decode_op, cache=some_cache)
  571. ds1 = ds1.repeat(4)
  572. with pytest.raises(RuntimeError) as e:
  573. sum([1 for _ in ds1])
  574. assert "Out of memory" in str(e.value)
  575. logger.info("test_cache_nomap_extra_small_size2 Ended.\n")
  576. @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
  577. def test_cache_nomap_parallel_pipeline1(shard):
  578. """
  579. Test running two parallel pipelines (sharing cache) with cache injected after leaf op
  580. Repeat
  581. |
  582. Map(decode)
  583. |
  584. cache
  585. |
  586. TFReader
  587. """
  588. logger.info("Test cache nomap parallel pipeline 1")
  589. if "SESSION_ID" in os.environ:
  590. session_id = int(os.environ['SESSION_ID'])
  591. else:
  592. raise RuntimeError("Testcase requires SESSION_ID environment variable")
  593. some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True)
  594. # This dataset has 3 records in it only
  595. ds1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, num_shards=3, shard_id=int(shard), cache=some_cache)
  596. decode_op = c_vision.Decode()
  597. ds1 = ds1.map(input_columns=["image"], operations=decode_op)
  598. ds1 = ds1.repeat(4)
  599. num_iter = 0
  600. for _ in ds1.create_dict_iterator(num_epochs=1):
  601. num_iter += 1
  602. logger.info("Number of data in ds1: {} ".format(num_iter))
  603. assert num_iter == 4
  604. logger.info("test_cache_nomap_parallel_pipeline1 Ended.\n")
  605. @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
  606. def test_cache_nomap_parallel_pipeline2(shard):
  607. """
  608. Test running two parallel pipelines (sharing cache) with cache injected after map op
  609. Repeat
  610. |
  611. cache
  612. |
  613. Map(decode)
  614. |
  615. TFReader
  616. """
  617. logger.info("Test cache nomap parallel pipeline 2")
  618. if "SESSION_ID" in os.environ:
  619. session_id = int(os.environ['SESSION_ID'])
  620. else:
  621. raise RuntimeError("Testcase requires SESSION_ID environment variable")
  622. some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True)
  623. # This dataset has 3 records in it only
  624. ds1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, num_shards=3, shard_id=int(shard))
  625. decode_op = c_vision.Decode()
  626. ds1 = ds1.map(input_columns=["image"], operations=decode_op, cache=some_cache)
  627. ds1 = ds1.repeat(4)
  628. num_iter = 0
  629. for _ in ds1.create_dict_iterator(num_epochs=1):
  630. num_iter += 1
  631. logger.info("Number of data in ds1: {} ".format(num_iter))
  632. assert num_iter == 4
  633. logger.info("test_cache_nomap_parallel_pipeline2 Ended.\n")
  634. @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
  635. def test_cache_nomap_parallel_workers():
  636. """
  637. Test cache with num_parallel_workers > 1 set for map op and leaf op
  638. Repeat
  639. |
  640. Map(decode)
  641. |
  642. cache
  643. |
  644. TFReader
  645. """
  646. logger.info("Test cache nomap parallel workers")
  647. if "SESSION_ID" in os.environ:
  648. session_id = int(os.environ['SESSION_ID'])
  649. else:
  650. raise RuntimeError("Testcase requires SESSION_ID environment variable")
  651. some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True)
  652. # This dataset has 3 records in it only
  653. ds1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, num_parallel_workers=4)
  654. decode_op = c_vision.Decode()
  655. ds1 = ds1.map(input_columns=["image"], operations=decode_op, num_parallel_workers=4, cache=some_cache)
  656. ds1 = ds1.repeat(4)
  657. num_iter = 0
  658. for _ in ds1.create_dict_iterator(num_epochs=1):
  659. num_iter += 1
  660. logger.info("Number of data in ds1: {} ".format(num_iter))
  661. assert num_iter == 12
  662. logger.info("test_cache_nomap_parallel_workers Ended.\n")
  663. @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
  664. def test_cache_nomap_server_workers_1():
  665. """
  666. start cache server with --workers 1 and then test cache function
  667. Repeat
  668. |
  669. cache
  670. |
  671. Map(decode)
  672. |
  673. TFRecord
  674. """
  675. logger.info("Test cache nomap server workers 1")
  676. if "SESSION_ID" in os.environ:
  677. session_id = int(os.environ['SESSION_ID'])
  678. else:
  679. raise RuntimeError("Testcase requires SESSION_ID environment variable")
  680. some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True)
  681. # This dataset has 3 records in it only
  682. ds1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR)
  683. decode_op = c_vision.Decode()
  684. ds1 = ds1.map(input_columns=["image"], operations=decode_op, cache=some_cache)
  685. ds1 = ds1.repeat(4)
  686. num_iter = 0
  687. for _ in ds1.create_dict_iterator():
  688. num_iter += 1
  689. logger.info("Number of data in ds1: {} ".format(num_iter))
  690. assert num_iter == 12
  691. logger.info("test_cache_nomap_server_workers_1 Ended.\n")
  692. @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
  693. def test_cache_nomap_server_workers_100():
  694. """
  695. start cache server with --workers 100 and then test cache function
  696. Repeat
  697. |
  698. Map(decode)
  699. |
  700. cache
  701. |
  702. TFRecord
  703. """
  704. logger.info("Test cache nomap server workers 100")
  705. if "SESSION_ID" in os.environ:
  706. session_id = int(os.environ['SESSION_ID'])
  707. else:
  708. raise RuntimeError("Testcase requires SESSION_ID environment variable")
  709. some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True)
  710. # This dataset has 3 records in it only
  711. ds1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, cache=some_cache)
  712. decode_op = c_vision.Decode()
  713. ds1 = ds1.map(input_columns=["image"], operations=decode_op)
  714. ds1 = ds1.repeat(4)
  715. num_iter = 0
  716. for _ in ds1.create_dict_iterator():
  717. num_iter += 1
  718. logger.info("Number of data in ds1: {} ".format(num_iter))
  719. assert num_iter == 12
  720. logger.info("test_cache_nomap_server_workers_100 Ended.\n")
  721. @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
  722. def test_cache_nomap_num_connections_1():
  723. """
  724. Test setting num_connections=1 in DatasetCache
  725. Repeat
  726. |
  727. cache
  728. |
  729. Map(decode)
  730. |
  731. TFRecord
  732. """
  733. logger.info("Test cache nomap num_connections 1")
  734. if "SESSION_ID" in os.environ:
  735. session_id = int(os.environ['SESSION_ID'])
  736. else:
  737. raise RuntimeError("Testcase requires SESSION_ID environment variable")
  738. some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True, num_connections=1)
  739. # This dataset has 3 records in it only
  740. ds1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR)
  741. decode_op = c_vision.Decode()
  742. ds1 = ds1.map(input_columns=["image"], operations=decode_op, cache=some_cache)
  743. ds1 = ds1.repeat(4)
  744. num_iter = 0
  745. for _ in ds1.create_dict_iterator():
  746. num_iter += 1
  747. logger.info("Number of data in ds1: {} ".format(num_iter))
  748. assert num_iter == 12
  749. logger.info("test_cache_nomap_num_connections_1 Ended.\n")
  750. @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
  751. def test_cache_nomap_num_connections_100():
  752. """
  753. Test setting num_connections=100 in DatasetCache
  754. Repeat
  755. |
  756. Map(decode)
  757. |
  758. cache
  759. |
  760. TFRecord
  761. """
  762. logger.info("Test cache nomap num_connections 100")
  763. if "SESSION_ID" in os.environ:
  764. session_id = int(os.environ['SESSION_ID'])
  765. else:
  766. raise RuntimeError("Testcase requires SESSION_ID environment variable")
  767. some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True, num_connections=100)
  768. # This dataset has 3 records in it only
  769. ds1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, cache=some_cache)
  770. decode_op = c_vision.Decode()
  771. ds1 = ds1.map(input_columns=["image"], operations=decode_op)
  772. ds1 = ds1.repeat(4)
  773. num_iter = 0
  774. for _ in ds1.create_dict_iterator():
  775. num_iter += 1
  776. logger.info("Number of data in ds1: {} ".format(num_iter))
  777. assert num_iter == 12
  778. logger.info("test_cache_nomap_num_connections_100 Ended.\n")
  779. @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
  780. def test_cache_nomap_prefetch_size_1():
  781. """
  782. Test setting prefetch_size=1 in DatasetCache
  783. Repeat
  784. |
  785. cache
  786. |
  787. Map(decode)
  788. |
  789. TFRecord
  790. """
  791. logger.info("Test cache nomap prefetch_size 1")
  792. if "SESSION_ID" in os.environ:
  793. session_id = int(os.environ['SESSION_ID'])
  794. else:
  795. raise RuntimeError("Testcase requires SESSION_ID environment variable")
  796. some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True, prefetch_size=1)
  797. # This dataset has 3 records in it only
  798. ds1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR)
  799. decode_op = c_vision.Decode()
  800. ds1 = ds1.map(input_columns=["image"], operations=decode_op, cache=some_cache)
  801. ds1 = ds1.repeat(4)
  802. num_iter = 0
  803. for _ in ds1.create_dict_iterator():
  804. num_iter += 1
  805. logger.info("Number of data in ds1: {} ".format(num_iter))
  806. assert num_iter == 12
  807. logger.info("test_cache_nomap_prefetch_size_1 Ended.\n")
  808. @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
  809. def test_cache_nomap_prefetch_size_100():
  810. """
  811. Test setting prefetch_size=100 in DatasetCache
  812. Repeat
  813. |
  814. Map(decode)
  815. |
  816. cache
  817. |
  818. TFRecord
  819. """
  820. logger.info("Test cache nomap prefetch_size 100")
  821. if "SESSION_ID" in os.environ:
  822. session_id = int(os.environ['SESSION_ID'])
  823. else:
  824. raise RuntimeError("Testcase requires SESSION_ID environment variable")
  825. some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True, prefetch_size=100)
  826. # This dataset has 3 records in it only
  827. ds1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, cache=some_cache)
  828. decode_op = c_vision.Decode()
  829. ds1 = ds1.map(input_columns=["image"], operations=decode_op)
  830. ds1 = ds1.repeat(4)
  831. num_iter = 0
  832. for _ in ds1.create_dict_iterator():
  833. num_iter += 1
  834. logger.info("Number of data in ds1: {} ".format(num_iter))
  835. assert num_iter == 12
  836. logger.info("test_cache_nomap_prefetch_size_100 Ended.\n")
  837. @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
  838. def test_cache_nomap_to_device():
  839. """
  840. Test cache with to_device
  841. DeviceQueue
  842. |
  843. EpochCtrl
  844. |
  845. Repeat
  846. |
  847. Map(decode)
  848. |
  849. cache
  850. |
  851. TFReader
  852. """
  853. logger.info("Test cache nomap to_device")
  854. if "SESSION_ID" in os.environ:
  855. session_id = int(os.environ['SESSION_ID'])
  856. else:
  857. raise RuntimeError("Testcase requires SESSION_ID environment variable")
  858. some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True)
  859. # This dataset has 3 records in it only
  860. ds1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR)
  861. decode_op = c_vision.Decode()
  862. ds1 = ds1.map(input_columns=["image"], operations=decode_op, cache=some_cache)
  863. ds1 = ds1.repeat(4)
  864. ds1 = ds1.to_device()
  865. ds1.send()
  866. logger.info("test_cache_nomap_to_device Ended.\n")
  867. @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
  868. def test_cache_nomap_session_destroy():
  869. """
  870. Test executing cache_admin -d while the pipeline is running
  871. Repeat
  872. |
  873. Cache
  874. |
  875. RandomDataset
  876. """
  877. logger.info("Test cache nomap session destroy")
  878. if "SESSION_ID" in os.environ:
  879. session_id = int(os.environ['SESSION_ID'])
  880. else:
  881. raise RuntimeError("Testcase requires SESSION_ID environment variable")
  882. schema = ds.Schema()
  883. schema.add_column('image', de_type=mstype.uint8,
  884. shape=[640, 480, 3]) # 921600 bytes (a bit less than 1 MB per image)
  885. schema.add_column('label', de_type=mstype.uint8, shape=[1])
  886. some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True)
  887. # User-created sampler here
  888. ds1 = ds.RandomDataset(schema=schema, num_parallel_workers=4, cache=some_cache)
  889. ds1 = ds1.repeat()
  890. with pytest.raises(RuntimeError) as e:
  891. num_iter = 0
  892. for _ in ds1.create_dict_iterator():
  893. num_iter += 1
  894. assert "Unexpected error" in str(e.value)
  895. logger.info("test_cache_nomap_session_destroy Ended.\n")
  896. @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
  897. def test_cache_nomap_server_stop():
  898. """
  899. Test executing cache_admin --stop while the pipeline is running
  900. Repeat
  901. |
  902. Cache
  903. |
  904. RandomDataset
  905. """
  906. logger.info("Test cache nomap server stop")
  907. if "SESSION_ID" in os.environ:
  908. session_id = int(os.environ['SESSION_ID'])
  909. else:
  910. raise RuntimeError("Testcase requires SESSION_ID environment variable")
  911. schema = ds.Schema()
  912. schema.add_column('image', de_type=mstype.uint8,
  913. shape=[640, 480, 3]) # 921600 bytes (a bit less than 1 MB per image)
  914. schema.add_column('label', de_type=mstype.uint8, shape=[1])
  915. some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True)
  916. # User-created sampler here
  917. ds1 = ds.RandomDataset(schema=schema, num_parallel_workers=4, cache=some_cache)
  918. ds1 = ds1.repeat()
  919. with pytest.raises(RuntimeError) as e:
  920. num_iter = 0
  921. for _ in ds1.create_dict_iterator():
  922. num_iter += 1
  923. assert "Network error. Cache server is unreachable. Make sure the server is running." in str(e.value)
  924. logger.info("test_cache_nomap_server_stop Ended.\n")
  925. @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
  926. def test_cache_nomap_epoch_ctrl1():
  927. """
  928. Test using two-loops method to run several epochs
  929. Map(decode)
  930. |
  931. cache
  932. |
  933. TFRecord
  934. """
  935. logger.info("Test cache nomap epoch ctrl1")
  936. if "SESSION_ID" in os.environ:
  937. session_id = int(os.environ['SESSION_ID'])
  938. else:
  939. raise RuntimeError("Testcase requires SESSION_ID environment variable")
  940. some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True)
  941. # This dataset has 3 records in it only
  942. ds1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, cache=some_cache)
  943. decode_op = c_vision.Decode()
  944. ds1 = ds1.map(input_columns=["image"], operations=decode_op)
  945. num_epoch = 5
  946. iter1 = ds1.create_dict_iterator(num_epochs=num_epoch)
  947. epoch_count = 0
  948. for _ in range(num_epoch):
  949. row_count = 0
  950. for _ in iter1:
  951. row_count += 1
  952. logger.info("Number of data in ds1: {} ".format(row_count))
  953. assert row_count == 3
  954. epoch_count += 1
  955. assert epoch_count == num_epoch
  956. logger.info("test_cache_nomap_epoch_ctrl1 Ended.\n")
  957. @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
  958. def test_cache_nomap_epoch_ctrl2():
  959. """
  960. Test using two-loops method with infinite epochs
  961. cache
  962. |
  963. Map(decode)
  964. |
  965. TFRecord
  966. """
  967. logger.info("Test cache nomap epoch ctrl2")
  968. if "SESSION_ID" in os.environ:
  969. session_id = int(os.environ['SESSION_ID'])
  970. else:
  971. raise RuntimeError("Testcase requires SESSION_ID environment variable")
  972. some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True)
  973. # This dataset has 3 records in it only
  974. ds1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR)
  975. decode_op = c_vision.Decode()
  976. ds1 = ds1.map(input_columns=["image"], operations=decode_op, cache=some_cache)
  977. num_epoch = 5
  978. # iter1 will always assume there is a next epoch and never shutdown
  979. iter1 = ds1.create_dict_iterator()
  980. epoch_count = 0
  981. for _ in range(num_epoch):
  982. row_count = 0
  983. for _ in iter1:
  984. row_count += 1
  985. logger.info("Number of data in ds1: {} ".format(row_count))
  986. assert row_count == 3
  987. epoch_count += 1
  988. assert epoch_count == num_epoch
  989. # manually stop the iterator
  990. iter1.stop()
  991. logger.info("test_cache_nomap_epoch_ctrl2 Ended.\n")
  992. @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
  993. def test_cache_nomap_epoch_ctrl3():
  994. """
  995. Test using two-loops method with infinite epochs over repeat
  996. repeat
  997. |
  998. Map(decode)
  999. |
  1000. cache
  1001. |
  1002. TFRecord
  1003. """
  1004. logger.info("Test cache nomap epoch ctrl3")
  1005. if "SESSION_ID" in os.environ:
  1006. session_id = int(os.environ['SESSION_ID'])
  1007. else:
  1008. raise RuntimeError("Testcase requires SESSION_ID environment variable")
  1009. some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True)
  1010. # This dataset has 3 records in it only
  1011. ds1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, cache=some_cache)
  1012. decode_op = c_vision.Decode()
  1013. ds1 = ds1.map(input_columns=["image"], operations=decode_op)
  1014. ds1 = ds1.repeat(2)
  1015. num_epoch = 5
  1016. # iter1 will always assume there is a next epoch and never shutdown
  1017. iter1 = ds1.create_dict_iterator()
  1018. epoch_count = 0
  1019. for _ in range(num_epoch):
  1020. row_count = 0
  1021. for _ in iter1:
  1022. row_count += 1
  1023. logger.info("Number of data in ds1: {} ".format(row_count))
  1024. assert row_count == 6
  1025. epoch_count += 1
  1026. assert epoch_count == num_epoch
  1027. # reply on garbage collector to destroy iter1
  1028. logger.info("test_cache_nomap_epoch_ctrl3 Ended.\n")
  1029. @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
  1030. def test_cache_nomap_multiple_cache1():
  1031. """
  1032. Test multiple cache in the same python script
  1033. cache cache
  1034. | |
  1035. Map(decode) Map(decode)
  1036. | |
  1037. TFRecord(train) TFRecord(eval)
  1038. """
  1039. logger.info("Test cache nomap multiple cache 1")
  1040. if "SESSION_ID" in os.environ:
  1041. session_id = int(os.environ['SESSION_ID'])
  1042. else:
  1043. raise RuntimeError("Testcase requires SESSION_ID environment variable")
  1044. train_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True)
  1045. eval_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True)
  1046. # This dataset has 12 records in it
  1047. train_dataset = ds.TFRecordDataset(TRAIN_DATA_DIR, TRAIN_SCHEMA_DIR)
  1048. decode_op = c_vision.Decode()
  1049. train_dataset = train_dataset.map(input_columns=["image"], operations=decode_op, cache=train_cache)
  1050. # This dataset has 3 records in it only
  1051. eval_dataset = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR)
  1052. eval_dataset = eval_dataset.map(input_columns=["image"], operations=decode_op, cache=eval_cache)
  1053. num_epoch = 5
  1054. train_iter = train_dataset.create_dict_iterator(num_epochs=num_epoch)
  1055. eval_iter = eval_dataset.create_dict_iterator(num_epochs=num_epoch)
  1056. epoch_count = 0
  1057. for _ in range(num_epoch):
  1058. assert sum([1 for _ in train_iter]) == 12
  1059. assert sum([1 for _ in eval_iter]) == 3
  1060. epoch_count += 1
  1061. assert epoch_count == num_epoch
  1062. logger.info("test_cache_nomap_multiple_cache1 Ended.\n")
  1063. @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
  1064. def test_cache_nomap_multiple_cache2():
  1065. """
  1066. Test multiple cache in the same python script
  1067. cache
  1068. |
  1069. Map(decode) cache
  1070. | |
  1071. TFRecord(image) TFRecord(text)
  1072. """
  1073. logger.info("Test cache nomap multiple cache 2")
  1074. if "SESSION_ID" in os.environ:
  1075. session_id = int(os.environ['SESSION_ID'])
  1076. else:
  1077. raise RuntimeError("Testcase requires SESSION_ID environment variable")
  1078. image_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True)
  1079. text_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True)
  1080. # This dataset has 3 records in it only
  1081. image_dataset = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR)
  1082. decode_op = c_vision.Decode()
  1083. image_dataset = image_dataset.map(input_columns=["image"], operations=decode_op, cache=image_cache)
  1084. # This dataset has 3 records in it only
  1085. text_dataset = ds.TFRecordDataset(TEXT_TF_DATA_DIR, SCHEMA_DIR2, cache=text_cache)
  1086. num_epoch = 5
  1087. image_iter = image_dataset.create_dict_iterator(num_epochs=num_epoch)
  1088. text_iter = text_dataset.create_dict_iterator(num_epochs=num_epoch, output_numpy=True)
  1089. epoch_count = 0
  1090. for _ in range(num_epoch):
  1091. row_count = 0
  1092. for _, _ in itertools.zip_longest(image_iter, text_iter):
  1093. row_count += 1
  1094. assert row_count == 3
  1095. epoch_count += 1
  1096. assert epoch_count == num_epoch
  1097. logger.info("test_cache_nomap_multiple_cache2 Ended.\n")
  1098. @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
  1099. def test_cache_nomap_multiple_cache3():
  1100. """
  1101. Test multiple cache in the same python script
  1102. cache cache
  1103. | |
  1104. Map(decode) Map(decode)
  1105. | |
  1106. TFRecord ImageFolder
  1107. """
  1108. logger.info("Test cache nomap multiple cache 3")
  1109. if "SESSION_ID" in os.environ:
  1110. session_id = int(os.environ['SESSION_ID'])
  1111. else:
  1112. raise RuntimeError("Testcase requires SESSION_ID environment variable")
  1113. tf_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True)
  1114. image_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True)
  1115. # This dataset has 3 records in it only
  1116. tf_dataset = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR)
  1117. decode_op = c_vision.Decode()
  1118. tf_dataset = tf_dataset.map(input_columns=["image"], operations=decode_op, cache=tf_cache)
  1119. # This DATA_DIR only has 2 images in it
  1120. image_dataset = ds.ImageFolderDataset(dataset_dir=IMAGE_FOLDER_DATA_DIR)
  1121. image_dataset = image_dataset.map(input_columns=["image"], operations=decode_op, cache=image_cache)
  1122. num_epoch = 5
  1123. tf_iter = tf_dataset.create_dict_iterator(num_epochs=num_epoch)
  1124. image_iter = image_dataset.create_dict_iterator(num_epochs=num_epoch)
  1125. epoch_count = 0
  1126. for _ in range(num_epoch):
  1127. assert sum([1 for _ in tf_iter]) == 3
  1128. assert sum([1 for _ in image_iter]) == 2
  1129. epoch_count += 1
  1130. assert epoch_count == num_epoch
  1131. logger.info("test_cache_nomap_multiple_cache3 Ended.\n")
  1132. @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
  1133. def test_cache_nomap_multiple_cache_train():
  1134. """
  1135. Test multiple cache in different python scripts. This test case is going to run concurrently with
  1136. test_cache_nomap_multiple_cache_eval.
  1137. cache
  1138. |
  1139. Map(decode)
  1140. |
  1141. TFRecord(train)
  1142. """
  1143. logger.info("Test cache nomap multiple cache train")
  1144. if "SESSION_ID" in os.environ:
  1145. session_id = int(os.environ['SESSION_ID'])
  1146. else:
  1147. raise RuntimeError("Testcase requires SESSION_ID environment variable")
  1148. train_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True)
  1149. # This dataset has 12 records in it
  1150. train_dataset = ds.TFRecordDataset(TRAIN_DATA_DIR, TRAIN_SCHEMA_DIR)
  1151. decode_op = c_vision.Decode()
  1152. train_dataset = train_dataset.map(input_columns=["image"], operations=decode_op, cache=train_cache)
  1153. num_epoch = 5
  1154. train_iter = train_dataset.create_dict_iterator(num_epochs=num_epoch)
  1155. epoch_count = 0
  1156. for _ in range(num_epoch):
  1157. assert sum([1 for _ in train_iter]) == 12
  1158. epoch_count += 1
  1159. assert epoch_count == num_epoch
  1160. logger.info("test_cache_nomap_multiple_cache_train Ended.\n")
  1161. @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
  1162. def test_cache_nomap_multiple_cache_eval():
  1163. """
  1164. Test multiple cache in different python scripts. This test case is going to run concurrently with
  1165. test_cache_nomap_multiple_cache_train.
  1166. cache
  1167. |
  1168. Map(decode)
  1169. |
  1170. TFRecord(eval)
  1171. """
  1172. logger.info("Test cache nomap multiple cache eval")
  1173. if "SESSION_ID" in os.environ:
  1174. session_id = int(os.environ['SESSION_ID'])
  1175. else:
  1176. raise RuntimeError("Testcase requires SESSION_ID environment variable")
  1177. eval_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True)
  1178. # This dataset only has 3 records in it
  1179. eval_dataset = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR)
  1180. decode_op = c_vision.Decode()
  1181. eval_dataset = eval_dataset.map(input_columns=["image"], operations=decode_op, cache=eval_cache)
  1182. num_epoch = 5
  1183. eval_iter = eval_dataset.create_dict_iterator(num_epochs=num_epoch)
  1184. epoch_count = 0
  1185. for _ in range(num_epoch):
  1186. assert sum([1 for _ in eval_iter]) == 3
  1187. epoch_count += 1
  1188. assert epoch_count == num_epoch
  1189. logger.info("test_cache_nomap_multiple_cache_eval Ended.\n")
  1190. @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
  1191. def test_cache_nomap_clue1():
  1192. """
  1193. A clue dataset (a non mappable dataset) with a cache over it just after the leaf
  1194. In this one, the clue dataset will be given sharding configuration, however since a cache is
  1195. used, the tree prepare should undo the sharding configuration and instead, a distributed
  1196. sampler will be chosen with the same shard config.
  1197. Cache
  1198. |
  1199. CLUE
  1200. """
  1201. logger.info("Test cache nomap clue 1")
  1202. if "SESSION_ID" in os.environ:
  1203. session_id = int(os.environ['SESSION_ID'])
  1204. else:
  1205. raise RuntimeError("Testcase requires SESSION_ID environment variable")
  1206. some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True)
  1207. # With only 3 records shard into 3, we expect only 1 record returned for this shard
  1208. # However, the sharding will be done by the sampler, not by the clue leaf node
  1209. # In this case, it is a row-based sharding, not the file-based sharding that would happen if
  1210. # there was not any cache.
  1211. ds1 = ds.CLUEDataset(CLUE_DATA_DIR, task='AFQMC', usage='train', num_shards=3, shard_id=1, cache=some_cache)
  1212. num_epoch = 4
  1213. iter1 = ds1.create_dict_iterator(num_epochs=num_epoch, output_numpy=True)
  1214. epoch_count = 0
  1215. for _ in range(num_epoch):
  1216. assert sum([1 for _ in iter1]) == 1
  1217. epoch_count += 1
  1218. assert epoch_count == num_epoch
  1219. logger.info("test_cache_nomap_clue1 Ended.\n")
  1220. @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
  1221. def test_cache_nomap_clue2():
  1222. """
  1223. A clue dataset (a non mappable dataset) with a cache over it after map
  1224. In this one, a num_samples argument is given
  1225. Cache
  1226. |
  1227. map(lambda x: x)
  1228. |
  1229. CLUE
  1230. """
  1231. logger.info("Test cache nomap clue 2")
  1232. if "SESSION_ID" in os.environ:
  1233. session_id = int(os.environ['SESSION_ID'])
  1234. else:
  1235. raise RuntimeError("Testcase requires SESSION_ID environment variable")
  1236. some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True)
  1237. ds1 = ds.CLUEDataset(CLUE_DATA_DIR, task='AFQMC', usage='train', num_samples=2)
  1238. ds1 = ds1.map((lambda x: x), ["label"], cache=some_cache)
  1239. num_epoch = 4
  1240. iter1 = ds1.create_dict_iterator(num_epochs=num_epoch, output_numpy=True)
  1241. epoch_count = 0
  1242. for _ in range(num_epoch):
  1243. assert sum([1 for _ in iter1]) == 2
  1244. epoch_count += 1
  1245. assert epoch_count == num_epoch
  1246. logger.info("test_cache_nomap_clue2 Ended.\n")
  1247. @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
  1248. def test_cache_nomap_csv1():
  1249. """
  1250. A csv dataset (a non mappable dataset) with a cache over it just after the leaf
  1251. In this one, the csv dataset will be given sharding configuration, however since a cache is
  1252. used, the tree prepare should undo the sharding configuration and instead, a distributed
  1253. sampler will be chosen with the same shard config.
  1254. Cache
  1255. |
  1256. CSV
  1257. """
  1258. logger.info("Test cache nomap csv 1")
  1259. if "SESSION_ID" in os.environ:
  1260. session_id = int(os.environ['SESSION_ID'])
  1261. else:
  1262. raise RuntimeError("Testcase requires SESSION_ID environment variable")
  1263. some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True)
  1264. # With only 3 records shard into 3, we expect only 1 record returned for this shard
  1265. # However, the sharding will be done by the sampler, not by the clue leaf node
  1266. # In this case, it is a row-based sharding, not the file-based sharding that would happen if
  1267. # there was not any cache.
  1268. ds1 = ds.CSVDataset(CSV_DATA_DIR, column_defaults=["1", "2", "3", "4"],
  1269. column_names=['col1', 'col2', 'col3', 'col4'], num_shards=3, shard_id=1, cache=some_cache)
  1270. num_epoch = 4
  1271. iter1 = ds1.create_dict_iterator(num_epochs=num_epoch, output_numpy=True)
  1272. epoch_count = 0
  1273. for _ in range(num_epoch):
  1274. assert sum([1 for _ in iter1]) == 1
  1275. epoch_count += 1
  1276. assert epoch_count == num_epoch
  1277. logger.info("test_cache_nomap_csv1 Ended.\n")
  1278. @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
  1279. def test_cache_nomap_csv2():
  1280. """
  1281. A csv dataset (a non mappable dataset) with a cache over it after map
  1282. In this one, a num_samples argument is given
  1283. Cache
  1284. |
  1285. map(lambda x: x)
  1286. |
  1287. CSV
  1288. """
  1289. logger.info("Test cache nomap csv 2")
  1290. if "SESSION_ID" in os.environ:
  1291. session_id = int(os.environ['SESSION_ID'])
  1292. else:
  1293. raise RuntimeError("Testcase requires SESSION_ID environment variable")
  1294. some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True)
  1295. ds1 = ds.CSVDataset(CSV_DATA_DIR, column_defaults=["1", "2", "3", "4"],
  1296. column_names=['col1', 'col2', 'col3', 'col4'], num_samples=2)
  1297. ds1 = ds1.map((lambda x: x), ["col1"], cache=some_cache)
  1298. num_epoch = 4
  1299. iter1 = ds1.create_dict_iterator(num_epochs=num_epoch, output_numpy=True)
  1300. epoch_count = 0
  1301. for _ in range(num_epoch):
  1302. assert sum([1 for _ in iter1]) == 2
  1303. epoch_count += 1
  1304. assert epoch_count == num_epoch
  1305. logger.info("test_cache_nomap_csv2 Ended.\n")
  1306. @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
  1307. def test_cache_nomap_textfile1():
  1308. """
  1309. A text file dataset (a non mappable dataset) with a cache over it just after the leaf
  1310. In this one, the text file dataset will be given sharding configuration, however since a cache is
  1311. used, the tree prepare should undo the sharding configuration and instead, a distributed
  1312. sampler will be chosen with the same shard config.
  1313. Cache
  1314. |
  1315. TextFile
  1316. """
  1317. logger.info("Test cache nomap textfile 1")
  1318. if "SESSION_ID" in os.environ:
  1319. session_id = int(os.environ['SESSION_ID'])
  1320. else:
  1321. raise RuntimeError("Testcase requires SESSION_ID environment variable")
  1322. some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True)
  1323. # With only 3 records shard into 3, we expect only 1 record returned for this shard
  1324. # However, the sharding will be done by the sampler, not by the clue leaf node
  1325. # In this case, it is a row-based sharding, not the file-based sharding that would happen if
  1326. # there was not any cache.
  1327. ds1 = ds.CSVDataset(TEXT_FILE_DATA_DIR, num_shards=3, shard_id=1, cache=some_cache)
  1328. num_epoch = 4
  1329. iter1 = ds1.create_dict_iterator(num_epochs=num_epoch, output_numpy=True)
  1330. epoch_count = 0
  1331. for _ in range(num_epoch):
  1332. assert sum([1 for _ in iter1]) == 1
  1333. epoch_count += 1
  1334. assert epoch_count == num_epoch
  1335. logger.info("test_cache_nomap_textfile1 Ended.\n")
  1336. @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
  1337. def test_cache_nomap_textfile2():
  1338. """
  1339. A text file dataset (a non mappable dataset) with a cache over it after map
  1340. In this one, a num_samples argument is given
  1341. Cache
  1342. |
  1343. Map(tokenizer)
  1344. |
  1345. TextFile
  1346. """
  1347. def my_tokenizer(line):
  1348. words = line.split()
  1349. if not words:
  1350. return [""]
  1351. return words
  1352. logger.info("Test cache nomap textfile 2")
  1353. if "SESSION_ID" in os.environ:
  1354. session_id = int(os.environ['SESSION_ID'])
  1355. else:
  1356. raise RuntimeError("Testcase requires SESSION_ID environment variable")
  1357. some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True)
  1358. ds1 = ds.TextFileDataset(TEXT_FILE_DATA_DIR, num_samples=2)
  1359. tokenizer = text.PythonTokenizer(my_tokenizer)
  1360. ds1 = ds1.map(operations=tokenizer, cache=some_cache)
  1361. num_epoch = 4
  1362. iter1 = ds1.create_dict_iterator(num_epochs=num_epoch, output_numpy=True)
  1363. epoch_count = 0
  1364. for _ in range(num_epoch):
  1365. assert sum([1 for _ in iter1]) == 2
  1366. epoch_count += 1
  1367. assert epoch_count == num_epoch
  1368. logger.info("test_cache_nomap_textfile2 Ended.\n")
  1369. if __name__ == '__main__':
  1370. test_cache_nomap_basic1()
  1371. test_cache_nomap_basic2()
  1372. test_cache_nomap_basic3()
  1373. test_cache_nomap_basic4()
  1374. test_cache_nomap_basic5()
  1375. test_cache_nomap_basic6()
  1376. test_cache_nomap_basic7()
  1377. test_cache_nomap_allowed_share1()
  1378. test_cache_nomap_allowed_share2()
  1379. test_cache_nomap_allowed_share3()
  1380. test_cache_nomap_allowed_share4()
  1381. test_cache_nomap_disallowed_share1()