You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_cache_nomap.py 51 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509151015111512
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. """
  16. Testing cache operator with non-mappable datasets
  17. """
  18. import os
  19. import itertools
  20. import pytest
  21. import mindspore.common.dtype as mstype
  22. import mindspore.dataset as ds
  23. import mindspore.dataset.vision.c_transforms as c_vision
  24. from mindspore import log as logger
  25. DATA_DIR = ["../data/dataset/test_tf_file_3_images/train-0000-of-0001.data"]
  26. SCHEMA_DIR = "../data/dataset/test_tf_file_3_images/datasetSchema.json"
  27. DATA_DIR2 = ["../data/dataset/testTextTFRecord/text.tfrecord"]
  28. SCHEMA_DIR2 = "../data/dataset/testTextTFRecord/datasetSchema.json"
  29. DATA_DIR3 = ["../data/dataset/test_tf_file_3_images2/train-0000-of-0001.data",
  30. "../data/dataset/test_tf_file_3_images2/train-0000-of-0002.data",
  31. "../data/dataset/test_tf_file_3_images2/train-0000-of-0003.data",
  32. "../data/dataset/test_tf_file_3_images2/train-0000-of-0004.data"]
  33. SCHEMA_DIR3 = "../data/dataset/test_tf_file_3_images2/datasetSchema.json"
  34. DATA_DIR4 = "../data/dataset/testImageNetData/train/"
  35. GENERATE_GOLDEN = False
  36. @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
  37. def test_cache_nomap_basic1():
  38. """
  39. A random dataset (a non mappable dataset) with a cache over it just after the leaf
  40. """
  41. logger.info("Test cache nomap basic 1")
  42. if "SESSION_ID" in os.environ:
  43. session_id = int(os.environ['SESSION_ID'])
  44. else:
  45. raise RuntimeError("Testcase requires SESSION_ID environment variable")
  46. schema = ds.Schema()
  47. schema.add_column('image', de_type=mstype.uint8,
  48. shape=[640, 480, 3]) # 921600 bytes (a bit less than 1 MB per image)
  49. schema.add_column('label', de_type=mstype.uint8, shape=[1])
  50. # create a cache. arbitrary session_id for now
  51. some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True)
  52. # User-created sampler here
  53. ds1 = ds.RandomDataset(schema=schema, total_rows=10, num_parallel_workers=4, cache=some_cache)
  54. ds1 = ds1.repeat(4)
  55. num_iter = 0
  56. for data in ds1.create_dict_iterator(num_epochs=1):
  57. logger.info("printing the label: {}".format(data["label"]))
  58. num_iter += 1
  59. logger.info("Number of data in ds1: {} ".format(num_iter))
  60. assert num_iter == 40
  61. logger.info("test_cache_nomap_basic1 Ended.\n")
  62. @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
  63. def test_cache_nomap_basic2():
  64. """
  65. A random dataset (a non mappable dataset) with a cache over it just after the leaf
  66. """
  67. logger.info("Test cache nomap basic 2")
  68. if "SESSION_ID" in os.environ:
  69. session_id = int(os.environ['SESSION_ID'])
  70. else:
  71. raise RuntimeError("Testcase requires SESSION_ID environment variable")
  72. schema = ds.Schema()
  73. schema.add_column('image', de_type=mstype.uint8,
  74. shape=[640, 480, 3]) # 921600 bytes (a bit less than 1 MB per image)
  75. schema.add_column('label', de_type=mstype.uint8, shape=[1])
  76. # create a cache. arbitrary session_id for now
  77. some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True)
  78. # sampler arg not given directly, however any of these args will auto-generate an appropriate sampler:
  79. # num_samples, shuffle, num_shards, shard_id
  80. # In this case, the presence of num_samples chooses a sampler.
  81. ds1 = ds.RandomDataset(schema=schema, total_rows=20, num_samples=20, num_parallel_workers=4, cache=some_cache)
  82. ds1 = ds1.repeat(2)
  83. num_iter = 0
  84. for data in ds1.create_dict_iterator(num_epochs=1):
  85. logger.info("printing the label: {}".format(data["label"]))
  86. num_iter += 1
  87. logger.info("Number of data in ds1: {} ".format(num_iter))
  88. assert num_iter == 40
  89. logger.info("test_cache_nomap_basic2 Ended.\n")
  90. @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
  91. def test_cache_nomap_basic3():
  92. """
  93. A TF reader dataset (a non mappable dataset) with a cache over it just after the leaf
  94. Repeat
  95. |
  96. Map(decode)
  97. |
  98. Cache
  99. |
  100. TFReader
  101. """
  102. logger.info("Test cache nomap basic 3")
  103. if "SESSION_ID" in os.environ:
  104. session_id = int(os.environ['SESSION_ID'])
  105. else:
  106. raise RuntimeError("Testcase requires SESSION_ID environment variable")
  107. some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True)
  108. ds1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False, cache=some_cache)
  109. decode_op = c_vision.Decode()
  110. ds1 = ds1.map(operations=decode_op, input_columns=["image"])
  111. ds1 = ds1.repeat(4)
  112. num_iter = 0
  113. for _ in ds1.create_dict_iterator(num_epochs=1):
  114. num_iter += 1
  115. logger.info("Number of data in ds1: {} ".format(num_iter))
  116. assert num_iter == 12
  117. # Contact the server to get the statistics
  118. stat = some_cache.GetStat()
  119. cache_sz = stat.avg_cache_sz
  120. num_mem_cached = stat.num_mem_cached
  121. num_disk_cached = stat.num_disk_cached
  122. logger.info("Number of rows cached in memory: {}".format(num_mem_cached))
  123. logger.info("Number of rows spilled to disk: {}".format(num_disk_cached))
  124. logger.info("Average row cache size: {}".format(cache_sz))
  125. logger.info("test_cache_nomap_basic3 Ended.\n")
  126. @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
  127. def test_cache_nomap_basic4():
  128. """
  129. A TF reader dataset (a non mappable dataset) with a map decode and cache after it
  130. Since a global shuffle is used for the tf reader, it will inject a shuffle op over the tf.
  131. But, if there's a cache later, that shuffle becomes invalid and should be removed.
  132. Repeat
  133. |
  134. Cache
  135. |
  136. Map(decode)
  137. |
  138. TFReader
  139. """
  140. logger.info("Test cache nomap basic 4")
  141. if "SESSION_ID" in os.environ:
  142. session_id = int(os.environ['SESSION_ID'])
  143. else:
  144. raise RuntimeError("Testcase requires SESSION_ID environment variable")
  145. # This dataset has 3 records in it only
  146. some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True)
  147. # With shuffle not being set, TF defaults to a "global" shuffle when there is no cache
  148. # in the picture. This causes a shuffle-injection over the TF. For clarify, this test will
  149. # explicitly give the global option, even though it's the default in python.
  150. # But, when caching is added in the ascendent tree above TF, we do global shuffling
  151. # through the sampler over the cache, not by the shuffle op. In that case, tree prepare
  152. # will remove the shuffle op that got injected by the initial tree creation.
  153. ds1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=ds.Shuffle.GLOBAL)
  154. decode_op = c_vision.Decode()
  155. ds1 = ds1.map(operations=decode_op, input_columns=["image"], cache=some_cache)
  156. ds1 = ds1.repeat(4)
  157. num_iter = 0
  158. for _ in ds1.create_dict_iterator(num_epochs=1):
  159. num_iter += 1
  160. logger.info("Number of data in ds1: {} ".format(num_iter))
  161. assert num_iter == 12
  162. logger.info("test_cache_nomap_basic4 Ended.\n")
  163. @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
  164. def test_cache_nomap_basic5():
  165. """
  166. A TF reader dataset (a non mappable dataset) with a cache over it just after the leaf
  167. Same as test 3, but this one does not have shuffle arg, causing tf to default to global
  168. shuffle which attempts to inject a shuffle operator. However, since there is a cache
  169. we do not need global shuffle, so the shuffle will not be built. It ends up being
  170. identical to test basic 3, however we arrive at the same tree in different codepaths
  171. (if there was no cache, then the shuffle IS built)
  172. Repeat
  173. |
  174. Map(decode)
  175. |
  176. Cache
  177. |
  178. TFReader
  179. """
  180. logger.info("Test cache nomap basic 5")
  181. if "SESSION_ID" in os.environ:
  182. session_id = int(os.environ['SESSION_ID'])
  183. else:
  184. raise RuntimeError("Testcase requires SESSION_ID environment variable")
  185. # This dataset has 3 records in it only
  186. some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True)
  187. ds1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], cache=some_cache)
  188. decode_op = c_vision.Decode()
  189. ds1 = ds1.map(operations=decode_op, input_columns=["image"])
  190. ds1 = ds1.repeat(4)
  191. num_iter = 0
  192. for _ in ds1.create_dict_iterator(num_epochs=1):
  193. num_iter += 1
  194. logger.info("Number of data in ds1: {} ".format(num_iter))
  195. assert num_iter == 12
  196. logger.info("test_cache_nomap_basic5 Ended.\n")
  197. @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
  198. def test_cache_nomap_basic6():
  199. """
  200. A TF reader dataset (a non mappable dataset) with a cache over it just after the leaf
  201. In this one, the tf dataset will be given sharding configuration, however since a cache is
  202. used, the tree prepare should undo the sharding configuration and instead, a distributed
  203. sampler will be chosen with the same shard config.
  204. Repeat
  205. |
  206. Map(decode)
  207. |
  208. Cache
  209. |
  210. TFReader
  211. """
  212. logger.info("Test cache nomap basic 6")
  213. if "SESSION_ID" in os.environ:
  214. session_id = int(os.environ['SESSION_ID'])
  215. else:
  216. raise RuntimeError("Testcase requires SESSION_ID environment variable")
  217. # This dataset has 3 records in it only
  218. some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True)
  219. # With only 3 records shard into 3, we expect only 1 record returned for this shard
  220. # However, the sharding will be done by the sampler, not by the tf record leaf node
  221. # In this case, it is a row-based sharding, not the file-based sharding that would happen if
  222. # there was not any cache.
  223. ds1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], num_shards=3, shard_id=1, cache=some_cache)
  224. decode_op = c_vision.Decode()
  225. ds1 = ds1.map(operations=decode_op, input_columns=["image"])
  226. ds1 = ds1.repeat(4)
  227. num_iter = 0
  228. for _ in ds1.create_dict_iterator(num_epochs=1):
  229. num_iter += 1
  230. logger.info("Number of data in ds1: {} ".format(num_iter))
  231. assert num_iter == 4
  232. logger.info("test_cache_nomap_basic6 Ended.\n")
  233. @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
  234. def test_cache_nomap_basic7():
  235. """
  236. A TF reader dataset (a non mappable dataset) that uses global shuffle, and is cached followed by
  237. map.
  238. In this one, the tf dataset with global shuffle might want to inject a shuffle op over top of the
  239. tf reader, but since a cache is given, it will choose not to.
  240. Repeat
  241. |
  242. Map(decode)
  243. |
  244. cache
  245. |
  246. TFReader
  247. """
  248. logger.info("Test cache nomap basic 7")
  249. if "SESSION_ID" in os.environ:
  250. session_id = int(os.environ['SESSION_ID'])
  251. else:
  252. raise RuntimeError("Testcase requires SESSION_ID environment variable")
  253. some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True)
  254. # This dataset has 3 records in it only
  255. ds1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=ds.Shuffle.GLOBAL, cache=some_cache)
  256. decode_op = c_vision.Decode()
  257. ds1 = ds1.map(operations=decode_op, input_columns=["image"])
  258. ds1 = ds1.repeat(4)
  259. num_iter = 0
  260. for _ in ds1.create_dict_iterator(num_epochs=1):
  261. num_iter += 1
  262. logger.info("Number of data in ds1: {} ".format(num_iter))
  263. assert num_iter == 12
  264. logger.info("test_cache_nomap_basic7 Ended.\n")
  265. @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
  266. def test_cache_nomap_basic8():
  267. """
  268. Test cache as root node
  269. cache
  270. |
  271. TFReader
  272. """
  273. logger.info("Test cache basic 4")
  274. if "SESSION_ID" in os.environ:
  275. session_id = int(os.environ['SESSION_ID'])
  276. else:
  277. raise RuntimeError("Testcase requires SESSION_ID environment variable")
  278. some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True)
  279. # This dataset has 3 records in it only
  280. ds1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, cache=some_cache)
  281. num_iter = 0
  282. for _ in ds1.create_dict_iterator(num_epochs=1):
  283. logger.info("get data from dataset")
  284. num_iter += 1
  285. logger.info("Number of data in ds1: {} ".format(num_iter))
  286. assert num_iter == 3
  287. logger.info('test_cache_basic3 Ended.\n')
  288. @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
  289. def test_cache_nomap_allowed_share1():
  290. """
  291. It is allowed to share the cache between the following two trees:
  292. Repeat Shuffle
  293. | |
  294. Cache Cache
  295. | |
  296. TFReader TFReader
  297. """
  298. logger.info("Test cache nomap allowed share 1")
  299. if "SESSION_ID" in os.environ:
  300. session_id = int(os.environ['SESSION_ID'])
  301. else:
  302. raise RuntimeError("Testcase requires SESSION_ID environment variable")
  303. ds.config.set_seed(1)
  304. # This dataset has 3 records in it only
  305. some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True, prefetch_size=32)
  306. ds1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False, cache=some_cache)
  307. ds1 = ds1.repeat(4)
  308. ds2 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False, cache=some_cache)
  309. ds2 = ds2.shuffle(buffer_size=2)
  310. num_iter = 0
  311. for _ in ds1.create_dict_iterator(num_epochs=1):
  312. num_iter += 1
  313. assert num_iter == 12
  314. logger.info("Number of data in ds1: {} ".format(num_iter))
  315. num_iter = 0
  316. for _ in ds2.create_dict_iterator(num_epochs=1):
  317. num_iter += 1
  318. assert num_iter == 3
  319. logger.info("test_cache_nomap_allowed_share1 Ended.\n")
  320. @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
  321. def test_cache_nomap_allowed_share2():
  322. """
  323. It is allowed to share the cache between the following two trees (with map decode):
  324. Repeat Shuffle
  325. | |
  326. Cache Cache
  327. | |
  328. Map(decode) Map(decode)
  329. | |
  330. TFReader TFReader
  331. """
  332. logger.info("Test cache nomap allowed share 2")
  333. if "SESSION_ID" in os.environ:
  334. session_id = int(os.environ['SESSION_ID'])
  335. else:
  336. raise RuntimeError("Testcase requires SESSION_ID environment variable")
  337. ds.config.set_seed(1)
  338. # This dataset has 3 records in it only
  339. some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True)
  340. decode_op = c_vision.Decode()
  341. ds1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False)
  342. ds1 = ds1.map(operations=decode_op, input_columns=["image"], cache=some_cache)
  343. ds1 = ds1.repeat(4)
  344. ds2 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False)
  345. ds2 = ds2.map(operations=decode_op, input_columns=["image"], cache=some_cache)
  346. ds2 = ds2.shuffle(buffer_size=2)
  347. num_iter = 0
  348. for _ in ds1.create_dict_iterator(num_epochs=1):
  349. num_iter += 1
  350. logger.info("Number of data in ds1: {} ".format(num_iter))
  351. assert num_iter == 12
  352. num_iter = 0
  353. for _ in ds2.create_dict_iterator(num_epochs=1):
  354. num_iter += 1
  355. assert num_iter == 3
  356. logger.info("test_cache_nomap_allowed_share2 Ended.\n")
  357. @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
  358. def test_cache_nomap_allowed_share3():
  359. """
  360. It is allowed to share the cache between the following two trees (different shard ids):
  361. Repeat Repeat
  362. | |
  363. Cache Cache
  364. | |
  365. TFReader(shard_id = 0) TFReader(shard_id = 1)
  366. """
  367. logger.info("Test cache nomap allowed share 3")
  368. if "SESSION_ID" in os.environ:
  369. session_id = int(os.environ['SESSION_ID'])
  370. else:
  371. raise RuntimeError("Testcase requires SESSION_ID environment variable")
  372. some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True)
  373. tf_files = ["../data/dataset/tf_file_dataset/test1.data", "../data/dataset/tf_file_dataset/test2.data"]
  374. ds1 = ds.TFRecordDataset(tf_files, num_shards=2, shard_id=0, num_samples=3, shuffle=False, cache=some_cache)
  375. ds1 = ds1.repeat(4)
  376. ds2 = ds.TFRecordDataset(tf_files, num_shards=2, shard_id=1, num_samples=3, shuffle=False, cache=some_cache)
  377. ds2 = ds2.repeat(4)
  378. num_iter = 0
  379. for _ in ds1.create_dict_iterator(num_epochs=1):
  380. num_iter += 1
  381. logger.info("Number of data in ds1: {} ".format(num_iter))
  382. assert num_iter == 12
  383. num_iter = 0
  384. for _ in ds2.create_dict_iterator(num_epochs=1):
  385. num_iter += 1
  386. assert num_iter == 12
  387. logger.info("test_cache_nomap_allowed_share3 Ended.\n")
  388. @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
  389. def test_cache_nomap_allowed_share4():
  390. """
  391. It is allowed to share the cache between the following two trees:
  392. Cache Cache
  393. | |
  394. Map(decode, num_parallel_workers=1) Map(decode, num_parallel_workers=2)
  395. | |
  396. TFReader TFReader
  397. """
  398. logger.info("Test cache nomap allowed share 4")
  399. if "SESSION_ID" in os.environ:
  400. session_id = int(os.environ['SESSION_ID'])
  401. else:
  402. raise RuntimeError("Testcase requires SESSION_ID environment variable")
  403. # This dataset has 3 records in it only
  404. some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True)
  405. decode_op = c_vision.Decode()
  406. ds1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False)
  407. ds1 = ds1.map(operations=decode_op, input_columns=["image"], cache=some_cache, num_parallel_workers=1)
  408. ds2 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False)
  409. ds2 = ds2.map(operations=decode_op, input_columns=["image"], cache=some_cache, num_parallel_workers=2)
  410. num_iter = 0
  411. for _ in ds1.create_dict_iterator(num_epochs=1):
  412. num_iter += 1
  413. logger.info("Number of data in ds1: {} ".format(num_iter))
  414. assert num_iter == 3
  415. num_iter = 0
  416. for _ in ds2.create_dict_iterator(num_epochs=1):
  417. num_iter += 1
  418. logger.info("Number of data in ds2: {} ".format(num_iter))
  419. assert num_iter == 3
  420. logger.info("test_cache_nomap_allowed_share4 Ended.\n")
  421. @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
  422. def test_cache_nomap_disallowed_share1():
  423. """
  424. It is not allowed to share the cache between the following two trees:
  425. Cache Cache
  426. | |
  427. Map(decode) Map(rescale)
  428. | |
  429. TFReader TFReader
  430. """
  431. logger.info("Test cache nomap disallowed share1")
  432. if "SESSION_ID" in os.environ:
  433. session_id = int(os.environ['SESSION_ID'])
  434. else:
  435. raise RuntimeError("Testcase requires SESSION_ID environment variable")
  436. # This dataset has 3 records in it only
  437. some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True)
  438. decode_op = c_vision.Decode()
  439. rescale_op = c_vision.Rescale(1.0 / 255.0, -1.0)
  440. ds1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False)
  441. ds1 = ds1.map(operations=decode_op, input_columns=["image"], cache=some_cache)
  442. ds2 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False)
  443. ds2 = ds2.map(operations=rescale_op, input_columns=["image"], cache=some_cache)
  444. num_iter = 0
  445. for _ in ds1.create_dict_iterator(num_epochs=1):
  446. num_iter += 1
  447. logger.info("Number of data in ds1: {} ".format(num_iter))
  448. assert num_iter == 3
  449. with pytest.raises(RuntimeError) as e:
  450. sum([1 for _ in ds2])
  451. assert "Attempt to re-use a cache for a different tree!" in str(e.value)
  452. logger.info("test_cache_nomap_disallowed_share1 Ended.\n")
  453. @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
  454. def test_cache_nomap_running_twice1():
  455. """
  456. Executing the same pipeline for twice (from python), with cache injected after map
  457. Repeat
  458. |
  459. Cache
  460. |
  461. Map(decode)
  462. |
  463. TFRecord
  464. """
  465. logger.info("Test cache nomap running twice 1")
  466. if "SESSION_ID" in os.environ:
  467. session_id = int(os.environ['SESSION_ID'])
  468. else:
  469. raise RuntimeError("Testcase requires SESSION_ID environment variable")
  470. some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True)
  471. # This dataset has 3 records in it only
  472. ds1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR)
  473. decode_op = c_vision.Decode()
  474. ds1 = ds1.map(input_columns=["image"], operations=decode_op, cache=some_cache)
  475. ds1 = ds1.repeat(4)
  476. num_iter = 0
  477. for _ in ds1.create_dict_iterator():
  478. num_iter += 1
  479. logger.info("Number of data in ds1: {} ".format(num_iter))
  480. assert num_iter == 12
  481. num_iter = 0
  482. for _ in ds1.create_dict_iterator():
  483. num_iter += 1
  484. logger.info("Number of data in ds1: {} ".format(num_iter))
  485. assert num_iter == 12
  486. logger.info("test_cache_nomap_running_twice1 Ended.\n")
  487. @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
  488. def test_cache_nomap_running_twice2():
  489. """
  490. Executing the same pipeline for twice (from shell), with cache injected after leaf
  491. Repeat
  492. |
  493. Map(decode)
  494. |
  495. Cache
  496. |
  497. TFRecord
  498. """
  499. logger.info("Test cache nomap running twice 2")
  500. if "SESSION_ID" in os.environ:
  501. session_id = int(os.environ['SESSION_ID'])
  502. else:
  503. raise RuntimeError("Testcase requires SESSION_ID environment variable")
  504. some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True)
  505. # This dataset has 3 records in it only
  506. ds1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, cache=some_cache)
  507. decode_op = c_vision.Decode()
  508. ds1 = ds1.map(input_columns=["image"], operations=decode_op)
  509. ds1 = ds1.repeat(4)
  510. num_iter = 0
  511. for _ in ds1.create_dict_iterator():
  512. num_iter += 1
  513. logger.info("Number of data in ds1: {} ".format(num_iter))
  514. assert num_iter == 12
  515. logger.info("test_cache_nomap_running_twice2 Ended.\n")
  516. @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
  517. def test_cache_nomap_extra_small_size1():
  518. """
  519. Test running pipeline with cache of extra small size and spilling true
  520. Repeat
  521. |
  522. Map(decode)
  523. |
  524. Cache
  525. |
  526. TFRecord
  527. """
  528. logger.info("Test cache nomap extra small size 1")
  529. if "SESSION_ID" in os.environ:
  530. session_id = int(os.environ['SESSION_ID'])
  531. else:
  532. raise RuntimeError("Testcase requires SESSION_ID environment variable")
  533. some_cache = ds.DatasetCache(session_id=session_id, size=1, spilling=True)
  534. # This dataset has 3 records in it only
  535. ds1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, cache=some_cache)
  536. decode_op = c_vision.Decode()
  537. ds1 = ds1.map(input_columns=["image"], operations=decode_op)
  538. ds1 = ds1.repeat(4)
  539. num_iter = 0
  540. for _ in ds1.create_dict_iterator():
  541. num_iter += 1
  542. logger.info("Number of data in ds1: {} ".format(num_iter))
  543. assert num_iter == 12
  544. logger.info("test_cache_nomap_extra_small_size1 Ended.\n")
  545. @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
  546. def test_cache_nomap_extra_small_size2():
  547. """
  548. Test running pipeline with cache of extra small size and spilling false (failure)
  549. Repeat
  550. |
  551. Cache
  552. |
  553. Map(decode)
  554. |
  555. TFRecord
  556. """
  557. logger.info("Test cache nomap extra small size 2")
  558. if "SESSION_ID" in os.environ:
  559. session_id = int(os.environ['SESSION_ID'])
  560. else:
  561. raise RuntimeError("Testcase requires SESSION_ID environment variable")
  562. some_cache = ds.DatasetCache(session_id=session_id, size=1, spilling=False)
  563. # This dataset has 3 records in it only
  564. ds1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR)
  565. decode_op = c_vision.Decode()
  566. ds1 = ds1.map(input_columns=["image"], operations=decode_op, cache=some_cache)
  567. ds1 = ds1.repeat(4)
  568. with pytest.raises(RuntimeError) as e:
  569. sum([1 for _ in ds1])
  570. assert "Out of memory" in str(e.value)
  571. logger.info("test_cache_nomap_extra_small_size2 Ended.\n")
  572. @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
  573. def test_cache_nomap_parallel_pipeline1(shard):
  574. """
  575. Test running two parallel pipelines (sharing cache) with cache injected after leaf op
  576. Repeat
  577. |
  578. Map(decode)
  579. |
  580. cache
  581. |
  582. TFReader
  583. """
  584. logger.info("Test cache nomap parallel pipeline 1")
  585. if "SESSION_ID" in os.environ:
  586. session_id = int(os.environ['SESSION_ID'])
  587. else:
  588. raise RuntimeError("Testcase requires SESSION_ID environment variable")
  589. some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True)
  590. # This dataset has 3 records in it only
  591. ds1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, num_shards=3, shard_id=int(shard), cache=some_cache)
  592. decode_op = c_vision.Decode()
  593. ds1 = ds1.map(input_columns=["image"], operations=decode_op)
  594. ds1 = ds1.repeat(4)
  595. num_iter = 0
  596. for _ in ds1.create_dict_iterator(num_epochs=1):
  597. num_iter += 1
  598. logger.info("Number of data in ds1: {} ".format(num_iter))
  599. assert num_iter == 4
  600. logger.info("test_cache_nomap_parallel_pipeline1 Ended.\n")
  601. @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
  602. def test_cache_nomap_parallel_pipeline2(shard):
  603. """
  604. Test running two parallel pipelines (sharing cache) with cache injected after map op
  605. Repeat
  606. |
  607. cache
  608. |
  609. Map(decode)
  610. |
  611. TFReader
  612. """
  613. logger.info("Test cache nomap parallel pipeline 2")
  614. if "SESSION_ID" in os.environ:
  615. session_id = int(os.environ['SESSION_ID'])
  616. else:
  617. raise RuntimeError("Testcase requires SESSION_ID environment variable")
  618. some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True)
  619. # This dataset has 3 records in it only
  620. ds1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, num_shards=3, shard_id=int(shard))
  621. decode_op = c_vision.Decode()
  622. ds1 = ds1.map(input_columns=["image"], operations=decode_op, cache=some_cache)
  623. ds1 = ds1.repeat(4)
  624. num_iter = 0
  625. for _ in ds1.create_dict_iterator(num_epochs=1):
  626. num_iter += 1
  627. logger.info("Number of data in ds1: {} ".format(num_iter))
  628. assert num_iter == 4
  629. logger.info("test_cache_nomap_parallel_pipeline2 Ended.\n")
  630. @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
  631. def test_cache_nomap_parallel_workers():
  632. """
  633. Test cache with num_parallel_workers > 1 set for map op and leaf op
  634. Repeat
  635. |
  636. Map(decode)
  637. |
  638. cache
  639. |
  640. TFReader
  641. """
  642. logger.info("Test cache nomap parallel workers")
  643. if "SESSION_ID" in os.environ:
  644. session_id = int(os.environ['SESSION_ID'])
  645. else:
  646. raise RuntimeError("Testcase requires SESSION_ID environment variable")
  647. some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True)
  648. # This dataset has 3 records in it only
  649. ds1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, num_parallel_workers=4)
  650. decode_op = c_vision.Decode()
  651. ds1 = ds1.map(input_columns=["image"], operations=decode_op, num_parallel_workers=4, cache=some_cache)
  652. ds1 = ds1.repeat(4)
  653. num_iter = 0
  654. for _ in ds1.create_dict_iterator(num_epochs=1):
  655. num_iter += 1
  656. logger.info("Number of data in ds1: {} ".format(num_iter))
  657. assert num_iter == 12
  658. logger.info("test_cache_nomap_parallel_workers Ended.\n")
  659. @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
  660. def test_cache_nomap_server_workers_1():
  661. """
  662. start cache server with --workers 1 and then test cache function
  663. Repeat
  664. |
  665. cache
  666. |
  667. Map(decode)
  668. |
  669. TFRecord
  670. """
  671. logger.info("Test cache nomap server workers 1")
  672. if "SESSION_ID" in os.environ:
  673. session_id = int(os.environ['SESSION_ID'])
  674. else:
  675. raise RuntimeError("Testcase requires SESSION_ID environment variable")
  676. some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True)
  677. # This dataset has 3 records in it only
  678. ds1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR)
  679. decode_op = c_vision.Decode()
  680. ds1 = ds1.map(input_columns=["image"], operations=decode_op, cache=some_cache)
  681. ds1 = ds1.repeat(4)
  682. num_iter = 0
  683. for _ in ds1.create_dict_iterator():
  684. num_iter += 1
  685. logger.info("Number of data in ds1: {} ".format(num_iter))
  686. assert num_iter == 12
  687. logger.info("test_cache_nomap_server_workers_1 Ended.\n")
  688. @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
  689. def test_cache_nomap_server_workers_100():
  690. """
  691. start cache server with --workers 100 and then test cache function
  692. Repeat
  693. |
  694. Map(decode)
  695. |
  696. cache
  697. |
  698. TFRecord
  699. """
  700. logger.info("Test cache nomap server workers 100")
  701. if "SESSION_ID" in os.environ:
  702. session_id = int(os.environ['SESSION_ID'])
  703. else:
  704. raise RuntimeError("Testcase requires SESSION_ID environment variable")
  705. some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True)
  706. # This dataset has 3 records in it only
  707. ds1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, cache=some_cache)
  708. decode_op = c_vision.Decode()
  709. ds1 = ds1.map(input_columns=["image"], operations=decode_op)
  710. ds1 = ds1.repeat(4)
  711. num_iter = 0
  712. for _ in ds1.create_dict_iterator():
  713. num_iter += 1
  714. logger.info("Number of data in ds1: {} ".format(num_iter))
  715. assert num_iter == 12
  716. logger.info("test_cache_nomap_server_workers_100 Ended.\n")
  717. @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
  718. def test_cache_nomap_num_connections_1():
  719. """
  720. Test setting num_connections=1 in DatasetCache
  721. Repeat
  722. |
  723. cache
  724. |
  725. Map(decode)
  726. |
  727. TFRecord
  728. """
  729. logger.info("Test cache nomap num_connections 1")
  730. if "SESSION_ID" in os.environ:
  731. session_id = int(os.environ['SESSION_ID'])
  732. else:
  733. raise RuntimeError("Testcase requires SESSION_ID environment variable")
  734. some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True, num_connections=1)
  735. # This dataset has 3 records in it only
  736. ds1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR)
  737. decode_op = c_vision.Decode()
  738. ds1 = ds1.map(input_columns=["image"], operations=decode_op, cache=some_cache)
  739. ds1 = ds1.repeat(4)
  740. num_iter = 0
  741. for _ in ds1.create_dict_iterator():
  742. num_iter += 1
  743. logger.info("Number of data in ds1: {} ".format(num_iter))
  744. assert num_iter == 12
  745. logger.info("test_cache_nomap_num_connections_1 Ended.\n")
  746. @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
  747. def test_cache_nomap_num_connections_100():
  748. """
  749. Test setting num_connections=100 in DatasetCache
  750. Repeat
  751. |
  752. Map(decode)
  753. |
  754. cache
  755. |
  756. TFRecord
  757. """
  758. logger.info("Test cache nomap num_connections 100")
  759. if "SESSION_ID" in os.environ:
  760. session_id = int(os.environ['SESSION_ID'])
  761. else:
  762. raise RuntimeError("Testcase requires SESSION_ID environment variable")
  763. some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True, num_connections=100)
  764. # This dataset has 3 records in it only
  765. ds1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, cache=some_cache)
  766. decode_op = c_vision.Decode()
  767. ds1 = ds1.map(input_columns=["image"], operations=decode_op)
  768. ds1 = ds1.repeat(4)
  769. num_iter = 0
  770. for _ in ds1.create_dict_iterator():
  771. num_iter += 1
  772. logger.info("Number of data in ds1: {} ".format(num_iter))
  773. assert num_iter == 12
  774. logger.info("test_cache_nomap_num_connections_100 Ended.\n")
  775. @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
  776. def test_cache_nomap_prefetch_size_1():
  777. """
  778. Test setting prefetch_size=1 in DatasetCache
  779. Repeat
  780. |
  781. cache
  782. |
  783. Map(decode)
  784. |
  785. TFRecord
  786. """
  787. logger.info("Test cache nomap prefetch_size 1")
  788. if "SESSION_ID" in os.environ:
  789. session_id = int(os.environ['SESSION_ID'])
  790. else:
  791. raise RuntimeError("Testcase requires SESSION_ID environment variable")
  792. some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True, prefetch_size=1)
  793. # This dataset has 3 records in it only
  794. ds1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR)
  795. decode_op = c_vision.Decode()
  796. ds1 = ds1.map(input_columns=["image"], operations=decode_op, cache=some_cache)
  797. ds1 = ds1.repeat(4)
  798. num_iter = 0
  799. for _ in ds1.create_dict_iterator():
  800. num_iter += 1
  801. logger.info("Number of data in ds1: {} ".format(num_iter))
  802. assert num_iter == 12
  803. logger.info("test_cache_nomap_prefetch_size_1 Ended.\n")
  804. @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
  805. def test_cache_nomap_prefetch_size_100():
  806. """
  807. Test setting prefetch_size=100 in DatasetCache
  808. Repeat
  809. |
  810. Map(decode)
  811. |
  812. cache
  813. |
  814. TFRecord
  815. """
  816. logger.info("Test cache nomap prefetch_size 100")
  817. if "SESSION_ID" in os.environ:
  818. session_id = int(os.environ['SESSION_ID'])
  819. else:
  820. raise RuntimeError("Testcase requires SESSION_ID environment variable")
  821. some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True, prefetch_size=100)
  822. # This dataset has 3 records in it only
  823. ds1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, cache=some_cache)
  824. decode_op = c_vision.Decode()
  825. ds1 = ds1.map(input_columns=["image"], operations=decode_op)
  826. ds1 = ds1.repeat(4)
  827. num_iter = 0
  828. for _ in ds1.create_dict_iterator():
  829. num_iter += 1
  830. logger.info("Number of data in ds1: {} ".format(num_iter))
  831. assert num_iter == 12
  832. logger.info("test_cache_nomap_prefetch_size_100 Ended.\n")
  833. @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
  834. def test_cache_nomap_to_device():
  835. """
  836. Test cache with to_device
  837. DeviceQueue
  838. |
  839. EpochCtrl
  840. |
  841. Repeat
  842. |
  843. Map(decode)
  844. |
  845. cache
  846. |
  847. TFReader
  848. """
  849. logger.info("Test cache nomap to_device")
  850. if "SESSION_ID" in os.environ:
  851. session_id = int(os.environ['SESSION_ID'])
  852. else:
  853. raise RuntimeError("Testcase requires SESSION_ID environment variable")
  854. some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True)
  855. # This dataset has 3 records in it only
  856. ds1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR)
  857. decode_op = c_vision.Decode()
  858. ds1 = ds1.map(input_columns=["image"], operations=decode_op, cache=some_cache)
  859. ds1 = ds1.repeat(4)
  860. ds1 = ds1.to_device()
  861. ds1.send()
  862. logger.info("test_cache_nomap_to_device Ended.\n")
  863. @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
  864. def test_cache_nomap_session_destroy():
  865. """
  866. Test executing cache_admin -d while the pipeline is running
  867. Repeat
  868. |
  869. Cache
  870. |
  871. RandomDataset
  872. """
  873. logger.info("Test cache nomap session destroy")
  874. if "SESSION_ID" in os.environ:
  875. session_id = int(os.environ['SESSION_ID'])
  876. else:
  877. raise RuntimeError("Testcase requires SESSION_ID environment variable")
  878. schema = ds.Schema()
  879. schema.add_column('image', de_type=mstype.uint8,
  880. shape=[640, 480, 3]) # 921600 bytes (a bit less than 1 MB per image)
  881. schema.add_column('label', de_type=mstype.uint8, shape=[1])
  882. some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True)
  883. # User-created sampler here
  884. ds1 = ds.RandomDataset(schema=schema, num_parallel_workers=4, cache=some_cache)
  885. ds1 = ds1.repeat()
  886. with pytest.raises(RuntimeError) as e:
  887. num_iter = 0
  888. for _ in ds1.create_dict_iterator():
  889. num_iter += 1
  890. assert "Unexpected error" in str(e.value)
  891. logger.info("test_cache_nomap_session_destroy Ended.\n")
  892. @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
  893. def test_cache_nomap_server_stop():
  894. """
  895. Test executing cache_admin --stop while the pipeline is running
  896. Repeat
  897. |
  898. Cache
  899. |
  900. RandomDataset
  901. """
  902. logger.info("Test cache nomap server stop")
  903. if "SESSION_ID" in os.environ:
  904. session_id = int(os.environ['SESSION_ID'])
  905. else:
  906. raise RuntimeError("Testcase requires SESSION_ID environment variable")
  907. schema = ds.Schema()
  908. schema.add_column('image', de_type=mstype.uint8,
  909. shape=[640, 480, 3]) # 921600 bytes (a bit less than 1 MB per image)
  910. schema.add_column('label', de_type=mstype.uint8, shape=[1])
  911. some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True)
  912. # User-created sampler here
  913. ds1 = ds.RandomDataset(schema=schema, num_parallel_workers=4, cache=some_cache)
  914. ds1 = ds1.repeat()
  915. with pytest.raises(RuntimeError) as e:
  916. num_iter = 0
  917. for _ in ds1.create_dict_iterator():
  918. num_iter += 1
  919. assert "Network error. Cache server is unreachable. Make sure the server is running." in str(e.value)
  920. logger.info("test_cache_nomap_server_stop Ended.\n")
  921. @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
  922. def test_cache_nomap_epoch_ctrl1():
  923. """
  924. Test using two-loops method to run several epochs
  925. Map(decode)
  926. |
  927. cache
  928. |
  929. TFRecord
  930. """
  931. logger.info("Test cache nomap epoch ctrl1")
  932. if "SESSION_ID" in os.environ:
  933. session_id = int(os.environ['SESSION_ID'])
  934. else:
  935. raise RuntimeError("Testcase requires SESSION_ID environment variable")
  936. some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True)
  937. # This dataset has 3 records in it only
  938. ds1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, cache=some_cache)
  939. decode_op = c_vision.Decode()
  940. ds1 = ds1.map(input_columns=["image"], operations=decode_op)
  941. num_epoch = 5
  942. iter1 = ds1.create_dict_iterator(num_epochs=num_epoch)
  943. epoch_count = 0
  944. for _ in range(num_epoch):
  945. row_count = 0
  946. for _ in iter1:
  947. row_count += 1
  948. logger.info("Number of data in ds1: {} ".format(row_count))
  949. assert row_count == 3
  950. epoch_count += 1
  951. assert epoch_count == num_epoch
  952. logger.info("test_cache_nomap_epoch_ctrl1 Ended.\n")
  953. @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
  954. def test_cache_nomap_epoch_ctrl2():
  955. """
  956. Test using two-loops method with infinite epochs
  957. cache
  958. |
  959. Map(decode)
  960. |
  961. TFRecord
  962. """
  963. logger.info("Test cache nomap epoch ctrl2")
  964. if "SESSION_ID" in os.environ:
  965. session_id = int(os.environ['SESSION_ID'])
  966. else:
  967. raise RuntimeError("Testcase requires SESSION_ID environment variable")
  968. some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True)
  969. # This dataset has 3 records in it only
  970. ds1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR)
  971. decode_op = c_vision.Decode()
  972. ds1 = ds1.map(input_columns=["image"], operations=decode_op, cache=some_cache)
  973. num_epoch = 5
  974. # iter1 will always assume there is a next epoch and never shutdown
  975. iter1 = ds1.create_dict_iterator()
  976. epoch_count = 0
  977. for _ in range(num_epoch):
  978. row_count = 0
  979. for _ in iter1:
  980. row_count += 1
  981. logger.info("Number of data in ds1: {} ".format(row_count))
  982. assert row_count == 3
  983. epoch_count += 1
  984. assert epoch_count == num_epoch
  985. # manually stop the iterator
  986. iter1.stop()
  987. logger.info("test_cache_nomap_epoch_ctrl2 Ended.\n")
  988. @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
  989. def test_cache_nomap_epoch_ctrl3():
  990. """
  991. Test using two-loops method with infinite epochs over repeat
  992. repeat
  993. |
  994. Map(decode)
  995. |
  996. cache
  997. |
  998. TFRecord
  999. """
  1000. logger.info("Test cache nomap epoch ctrl3")
  1001. if "SESSION_ID" in os.environ:
  1002. session_id = int(os.environ['SESSION_ID'])
  1003. else:
  1004. raise RuntimeError("Testcase requires SESSION_ID environment variable")
  1005. some_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True)
  1006. # This dataset has 3 records in it only
  1007. ds1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, cache=some_cache)
  1008. decode_op = c_vision.Decode()
  1009. ds1 = ds1.map(input_columns=["image"], operations=decode_op)
  1010. ds1 = ds1.repeat(2)
  1011. num_epoch = 5
  1012. # iter1 will always assume there is a next epoch and never shutdown
  1013. iter1 = ds1.create_dict_iterator()
  1014. epoch_count = 0
  1015. for _ in range(num_epoch):
  1016. row_count = 0
  1017. for _ in iter1:
  1018. row_count += 1
  1019. logger.info("Number of data in ds1: {} ".format(row_count))
  1020. assert row_count == 6
  1021. epoch_count += 1
  1022. assert epoch_count == num_epoch
  1023. # reply on garbage collector to destroy iter1
  1024. logger.info("test_cache_nomap_epoch_ctrl3 Ended.\n")
  1025. @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
  1026. def test_cache_nomap_multiple_cache1():
  1027. """
  1028. Test multiple cache in the same python script
  1029. cache cache
  1030. | |
  1031. Map(decode) Map(decode)
  1032. | |
  1033. TFRecord(train) TFRecord(eval)
  1034. """
  1035. logger.info("Test cache nomap multiple cache 1")
  1036. if "SESSION_ID" in os.environ:
  1037. session_id = int(os.environ['SESSION_ID'])
  1038. else:
  1039. raise RuntimeError("Testcase requires SESSION_ID environment variable")
  1040. train_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True)
  1041. eval_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True)
  1042. # This dataset has 12 records in it
  1043. train_dataset = ds.TFRecordDataset(DATA_DIR3, SCHEMA_DIR3)
  1044. decode_op = c_vision.Decode()
  1045. train_dataset = train_dataset.map(input_columns=["image"], operations=decode_op, cache=train_cache)
  1046. # This dataset has 3 records in it only
  1047. eval_dataset = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR)
  1048. eval_dataset = eval_dataset.map(input_columns=["image"], operations=decode_op, cache=eval_cache)
  1049. num_epoch = 5
  1050. train_iter = train_dataset.create_dict_iterator(num_epochs=num_epoch)
  1051. eval_iter = eval_dataset.create_dict_iterator(num_epochs=num_epoch)
  1052. epoch_count = 0
  1053. for _ in range(num_epoch):
  1054. assert sum([1 for _ in train_iter]) == 12
  1055. assert sum([1 for _ in eval_iter]) == 3
  1056. epoch_count += 1
  1057. assert epoch_count == num_epoch
  1058. logger.info("test_cache_nomap_multiple_cache1 Ended.\n")
  1059. @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
  1060. def test_cache_nomap_multiple_cache2():
  1061. """
  1062. Test multiple cache in the same python script
  1063. cache
  1064. |
  1065. Map(decode) cache
  1066. | |
  1067. TFRecord(image) TFRecord(text)
  1068. """
  1069. logger.info("Test cache nomap multiple cache 2")
  1070. if "SESSION_ID" in os.environ:
  1071. session_id = int(os.environ['SESSION_ID'])
  1072. else:
  1073. raise RuntimeError("Testcase requires SESSION_ID environment variable")
  1074. image_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True)
  1075. text_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True)
  1076. # This dataset has 3 records in it only
  1077. image_dataset = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR)
  1078. decode_op = c_vision.Decode()
  1079. image_dataset = image_dataset.map(input_columns=["image"], operations=decode_op, cache=image_cache)
  1080. # This dataset has 3 records in it only
  1081. text_dataset = ds.TFRecordDataset(DATA_DIR2, SCHEMA_DIR2, cache=text_cache)
  1082. num_epoch = 5
  1083. image_iter = image_dataset.create_dict_iterator(num_epochs=num_epoch)
  1084. text_iter = text_dataset.create_dict_iterator(num_epochs=num_epoch, output_numpy=True)
  1085. epoch_count = 0
  1086. for _ in range(num_epoch):
  1087. row_count = 0
  1088. for _, _ in itertools.zip_longest(image_iter, text_iter):
  1089. row_count += 1
  1090. assert row_count == 3
  1091. epoch_count += 1
  1092. assert epoch_count == num_epoch
  1093. logger.info("test_cache_nomap_multiple_cache2 Ended.\n")
  1094. @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
  1095. def test_cache_nomap_multiple_cache3():
  1096. """
  1097. Test multiple cache in the same python script
  1098. cache cache
  1099. | |
  1100. Map(decode) Map(decode)
  1101. | |
  1102. TFRecord ImageFolder
  1103. """
  1104. logger.info("Test cache nomap multiple cache 3")
  1105. if "SESSION_ID" in os.environ:
  1106. session_id = int(os.environ['SESSION_ID'])
  1107. else:
  1108. raise RuntimeError("Testcase requires SESSION_ID environment variable")
  1109. tf_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True)
  1110. image_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True)
  1111. # This dataset has 3 records in it only
  1112. tf_dataset = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR)
  1113. decode_op = c_vision.Decode()
  1114. tf_dataset = tf_dataset.map(input_columns=["image"], operations=decode_op, cache=tf_cache)
  1115. # This DATA_DIR only has 2 images in it
  1116. image_dataset = ds.ImageFolderDataset(dataset_dir=DATA_DIR4)
  1117. image_dataset = image_dataset.map(input_columns=["image"], operations=decode_op, cache=image_cache)
  1118. num_epoch = 5
  1119. tf_iter = tf_dataset.create_dict_iterator(num_epochs=num_epoch)
  1120. image_iter = image_dataset.create_dict_iterator(num_epochs=num_epoch)
  1121. epoch_count = 0
  1122. for _ in range(num_epoch):
  1123. assert sum([1 for _ in tf_iter]) == 3
  1124. assert sum([1 for _ in image_iter]) == 2
  1125. epoch_count += 1
  1126. assert epoch_count == num_epoch
  1127. logger.info("test_cache_nomap_multiple_cache3 Ended.\n")
  1128. @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
  1129. def test_cache_nomap_multiple_cache_train():
  1130. """
  1131. Test multiple cache in different python scripts. This test case is going to run concurrently with
  1132. test_cache_nomap_multiple_cache_eval.
  1133. cache
  1134. |
  1135. Map(decode)
  1136. |
  1137. TFRecord(train)
  1138. """
  1139. logger.info("Test cache nomap multiple cache train")
  1140. if "SESSION_ID" in os.environ:
  1141. session_id = int(os.environ['SESSION_ID'])
  1142. else:
  1143. raise RuntimeError("Testcase requires SESSION_ID environment variable")
  1144. train_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True)
  1145. # This dataset has 12 records in it
  1146. train_dataset = ds.TFRecordDataset(DATA_DIR3, SCHEMA_DIR3)
  1147. decode_op = c_vision.Decode()
  1148. train_dataset = train_dataset.map(input_columns=["image"], operations=decode_op, cache=train_cache)
  1149. num_epoch = 5
  1150. train_iter = train_dataset.create_dict_iterator(num_epochs=num_epoch)
  1151. epoch_count = 0
  1152. for _ in range(num_epoch):
  1153. assert sum([1 for _ in train_iter]) == 12
  1154. epoch_count += 1
  1155. assert epoch_count == num_epoch
  1156. logger.info("test_cache_nomap_multiple_cache_train Ended.\n")
  1157. @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
  1158. def test_cache_nomap_multiple_cache_eval():
  1159. """
  1160. Test multiple cache in different python scripts. This test case is going to run concurrently with
  1161. test_cache_nomap_multiple_cache_train.
  1162. cache
  1163. |
  1164. Map(decode)
  1165. |
  1166. TFRecord(eval)
  1167. """
  1168. logger.info("Test cache nomap multiple cache eval")
  1169. if "SESSION_ID" in os.environ:
  1170. session_id = int(os.environ['SESSION_ID'])
  1171. else:
  1172. raise RuntimeError("Testcase requires SESSION_ID environment variable")
  1173. eval_cache = ds.DatasetCache(session_id=session_id, size=0, spilling=True)
  1174. # This dataset only has 3 records in it
  1175. eval_dataset = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR)
  1176. decode_op = c_vision.Decode()
  1177. eval_dataset = eval_dataset.map(input_columns=["image"], operations=decode_op, cache=eval_cache)
  1178. num_epoch = 5
  1179. eval_iter = eval_dataset.create_dict_iterator(num_epochs=num_epoch)
  1180. epoch_count = 0
  1181. for _ in range(num_epoch):
  1182. assert sum([1 for _ in eval_iter]) == 3
  1183. epoch_count += 1
  1184. assert epoch_count == num_epoch
  1185. logger.info("test_cache_nomap_multiple_cache_eval Ended.\n")
  1186. if __name__ == '__main__':
  1187. test_cache_nomap_basic1()
  1188. test_cache_nomap_basic2()
  1189. test_cache_nomap_basic3()
  1190. test_cache_nomap_basic4()
  1191. test_cache_nomap_basic5()
  1192. test_cache_nomap_basic6()
  1193. test_cache_nomap_basic7()
  1194. test_cache_nomap_allowed_share1()
  1195. test_cache_nomap_allowed_share2()
  1196. test_cache_nomap_allowed_share3()
  1197. test_cache_nomap_allowed_share4()
  1198. test_cache_nomap_disallowed_share1()