You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_cache_nomap.py 18 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. """
  16. Testing cache operator with non-mappable datasets
  17. """
  18. import os
  19. import pytest
  20. import mindspore.common.dtype as mstype
  21. import mindspore.dataset as ds
  22. import mindspore.dataset.transforms.vision.c_transforms as c_vision
  23. from mindspore import log as logger
  24. DATA_DIR = ["../data/dataset/test_tf_file_3_images/train-0000-of-0001.data"]
  25. SCHEMA_DIR = "../data/dataset/test_tf_file_3_images/datasetSchema.json"
  26. GENERATE_GOLDEN = False
  27. @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
  28. def test_cache_nomap_basic1():
  29. """
  30. A random dataset (a non mappable dataset) with a cache over it just after the leaf
  31. """
  32. logger.info("Test cache nomap basic 1")
  33. schema = ds.Schema()
  34. schema.add_column('image', de_type=mstype.uint8,
  35. shape=[640, 480, 3]) # 921600 bytes (a bit less than 1 MB per image)
  36. schema.add_column('label', de_type=mstype.uint8, shape=[1])
  37. # create a cache. arbitrary session_id for now
  38. some_cache = ds.DatasetCache(session_id=1, size=0, spilling=True)
  39. # User-created sampler here
  40. ds1 = ds.RandomDataset(schema=schema, total_rows=10, num_parallel_workers=4, cache=some_cache)
  41. ds1 = ds1.repeat(4)
  42. num_iter = 0
  43. for data in ds1.create_dict_iterator(num_epochs=1):
  44. logger.info("printing the label: {}".format(data["label"]))
  45. num_iter += 1
  46. logger.info("Number of data in ds1: {} ".format(num_iter))
  47. assert num_iter == 40
  48. logger.info("test_cache_nomap_basic1 Ended.\n")
  49. @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
  50. def test_cache_nomap_basic2():
  51. """
  52. A random dataset (a non mappable dataset) with a cache over it just after the leaf
  53. """
  54. logger.info("Test cache nomap basic 2")
  55. schema = ds.Schema()
  56. schema.add_column('image', de_type=mstype.uint8,
  57. shape=[640, 480, 3]) # 921600 bytes (a bit less than 1 MB per image)
  58. schema.add_column('label', de_type=mstype.uint8, shape=[1])
  59. # create a cache. arbitrary session_id for now
  60. some_cache = ds.DatasetCache(session_id=1, size=0, spilling=True)
  61. # sampler arg not given directly, however any of these args will auto-generate an appropriate sampler:
  62. # num_samples, shuffle, num_shards, shard_id
  63. # In this case, the presence of num_samples chooses a sampler.
  64. ds1 = ds.RandomDataset(schema=schema, total_rows=20, num_samples=20, num_parallel_workers=4, cache=some_cache)
  65. ds1 = ds1.repeat(2)
  66. num_iter = 0
  67. for data in ds1.create_dict_iterator(num_epochs=1):
  68. logger.info("printing the label: {}".format(data["label"]))
  69. num_iter += 1
  70. logger.info("Number of data in ds1: {} ".format(num_iter))
  71. assert num_iter == 40
  72. logger.info("test_cache_nomap_basic2 Ended.\n")
  73. @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
  74. def test_cache_nomap_basic3():
  75. """
  76. A TF reader dataset (a non mappable dataset) with a cache over it just after the leaf
  77. Repeat
  78. |
  79. Map(decode)
  80. |
  81. Cache
  82. |
  83. TFReader
  84. """
  85. logger.info("Test cache nomap basic 3")
  86. some_cache = ds.DatasetCache(session_id=1, size=0, spilling=True)
  87. ds1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False, cache=some_cache)
  88. decode_op = c_vision.Decode()
  89. ds1 = ds1.map(input_columns=["image"], operations=decode_op)
  90. ds1 = ds1.repeat(4)
  91. num_iter = 0
  92. for _ in ds1.create_dict_iterator(num_epochs=1):
  93. num_iter += 1
  94. logger.info("Number of data in ds1: {} ".format(num_iter))
  95. assert num_iter == 12
  96. # Contact the server to get the statistics
  97. stat = some_cache.GetStat()
  98. cache_sz = stat.avg_cache_sz
  99. num_mem_cached = stat.num_mem_cached
  100. num_disk_cached = stat.num_disk_cached
  101. logger.info("Number of rows cached in memory: {}".format(num_mem_cached))
  102. logger.info("Number of rows spilled to disk: {}".format(num_disk_cached))
  103. logger.info("Average row cache size: {}".format(cache_sz))
  104. logger.info("test_cache_nomap_basic3 Ended.\n")
  105. @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
  106. def test_cache_nomap_basic4():
  107. """
  108. A TF reader dataset (a non mappable dataset) with a map decode and cache after it
  109. Since a global shuffle is used for the tf reader, it will inject a shuffle op over the tf.
  110. But, if there's a cache later, that shuffle becomes invalid and should be removed.
  111. Repeat
  112. |
  113. Cache
  114. |
  115. Map(decode)
  116. |
  117. TFReader
  118. """
  119. logger.info("Test cache nomap basic 4")
  120. # This dataset has 3 records in it only
  121. some_cache = ds.DatasetCache(session_id=1, size=0, spilling=True)
  122. # With shuffle not being set, TF defaults to a "global" shuffle when there is no cache
  123. # in the picture. This causes a shuffle-injection over the TF. For clarify, this test will
  124. # explicitly give the global option, even though it's the default in python.
  125. # But, when caching is added in the ascendent tree above TF, we do global shuffling
  126. # through the sampler over the cache, not by the shuffle op. In that case, tree prepare
  127. # will remove the shuffle op that got injected by the initial tree creation.
  128. ds1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=ds.Shuffle.GLOBAL)
  129. decode_op = c_vision.Decode()
  130. ds1 = ds1.map(input_columns=["image"], operations=decode_op, cache=some_cache)
  131. ds1 = ds1.repeat(4)
  132. num_iter = 0
  133. for _ in ds1.create_dict_iterator(num_epochs=1):
  134. num_iter += 1
  135. logger.info("Number of data in ds1: {} ".format(num_iter))
  136. assert num_iter == 12
  137. logger.info("test_cache_nomap_basic4 Ended.\n")
  138. @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
  139. def test_cache_nomap_basic5():
  140. """
  141. A TF reader dataset (a non mappable dataset) with a cache over it just after the leaf
  142. Same as test 3, but this one does not have shuffle arg, causing tf to default to global
  143. shuffle which attempts to inject a shuffle operator. However, since there is a cache
  144. we do not need global shuffle, so the shuffle will not be built. It ends up being
  145. identical to test basic 3, however we arrive at the same tree in different codepaths
  146. (if there was no cache, then the shuffle IS built)
  147. Repeat
  148. |
  149. Map(decode)
  150. |
  151. Cache
  152. |
  153. TFReader
  154. """
  155. logger.info("Test cache nomap basic 5")
  156. # This dataset has 3 records in it only
  157. some_cache = ds.DatasetCache(session_id=1, size=0, spilling=True)
  158. ds1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], cache=some_cache)
  159. decode_op = c_vision.Decode()
  160. ds1 = ds1.map(input_columns=["image"], operations=decode_op)
  161. ds1 = ds1.repeat(4)
  162. num_iter = 0
  163. for _ in ds1.create_dict_iterator(num_epochs=1):
  164. num_iter += 1
  165. logger.info("Number of data in ds1: {} ".format(num_iter))
  166. assert num_iter == 12
  167. logger.info("test_cache_nomap_basic5 Ended.\n")
  168. @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
  169. def test_cache_nomap_basic6():
  170. """
  171. A TF reader dataset (a non mappable dataset) with a cache over it just after the leaf
  172. In this one, the tf dataset will be given sharding configuration, however since a cache is
  173. used, the tree prepare should undo the sharding configuration and instead, a distributed
  174. sampler will be chosen with the same shard config.
  175. Repeat
  176. |
  177. Map(decode)
  178. |
  179. Cache
  180. |
  181. TFReader
  182. """
  183. logger.info("Test cache nomap basic 6")
  184. # This dataset has 3 records in it only
  185. some_cache = ds.DatasetCache(session_id=1, size=0, spilling=True)
  186. # With only 3 records shard into 3, we expect only 1 record returned for this shard
  187. # However, the sharding will be done by the sampler, not by the tf record leaf node
  188. # In this case, it is a row-based sharding, not the file-based sharding that would happen if
  189. # there was not any cache.
  190. ds1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], num_shards=3, shard_id=1, cache=some_cache)
  191. decode_op = c_vision.Decode()
  192. ds1 = ds1.map(input_columns=["image"], operations=decode_op)
  193. ds1 = ds1.repeat(4)
  194. num_iter = 0
  195. for _ in ds1.create_dict_iterator(num_epochs=1):
  196. num_iter += 1
  197. logger.info("Number of data in ds1: {} ".format(num_iter))
  198. assert num_iter == 4
  199. logger.info("test_cache_nomap_basic6 Ended.\n")
  200. @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
  201. def test_cache_nomap_basic7():
  202. """
  203. A TF reader dataset (a non mappable dataset) that uses global shuffle, and is cached followed by
  204. map.
  205. In this one, the tf dataset with global shuffle might want to inject a shuffle op over top of the
  206. tf reader, but since a cache is given, it will choose not to.
  207. Repeat
  208. |
  209. Map(decode)
  210. |
  211. cache
  212. |
  213. TFReader
  214. """
  215. logger.info("Test cache nomap basic 7")
  216. # This dataset has 3 records in it only
  217. some_cache = ds.DatasetCache(session_id=1, size=0, spilling=True)
  218. ds1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=ds.Shuffle.GLOBAL, cache=some_cache)
  219. decode_op = c_vision.Decode()
  220. ds1 = ds1.map(input_columns=["image"], operations=decode_op)
  221. ds1 = ds1.repeat(4)
  222. num_iter = 0
  223. for _ in ds1.create_dict_iterator(num_epochs=1):
  224. num_iter += 1
  225. logger.info("Number of data in ds1: {} ".format(num_iter))
  226. assert num_iter == 12
  227. logger.info("test_cache_nomap_basic7 Ended.\n")
  228. @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
  229. def test_cache_nomap_allowed_share1():
  230. """
  231. It is allowed to share the cache between the following two trees:
  232. Repeat Shuffle
  233. | |
  234. Cache Cache
  235. | |
  236. TFReader TFReader
  237. """
  238. logger.info("Test cache nomap allowed share 1")
  239. ds.config.set_seed(1)
  240. # This dataset has 3 records in it only
  241. some_cache = ds.DatasetCache(session_id=1, size=0, spilling=True, prefetch_size=32)
  242. ds1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False, cache=some_cache)
  243. ds1 = ds1.repeat(4)
  244. ds2 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False, cache=some_cache)
  245. ds2 = ds2.shuffle(buffer_size=2)
  246. num_iter = 0
  247. for _ in ds1.create_dict_iterator(num_epochs=1):
  248. num_iter += 1
  249. assert num_iter == 12
  250. logger.info("Number of data in ds1: {} ".format(num_iter))
  251. num_iter = 0
  252. for _ in ds2.create_dict_iterator(num_epochs=1):
  253. num_iter += 1
  254. assert num_iter == 3
  255. logger.info("test_cache_nomap_allowed_share1 Ended.\n")
  256. @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
  257. def test_cache_nomap_allowed_share2():
  258. """
  259. It is allowed to share the cache between the following two trees (with map decode):
  260. Repeat Shuffle
  261. | |
  262. Cache Cache
  263. | |
  264. Map(decode) Map(decode)
  265. | |
  266. TFReader TFReader
  267. """
  268. logger.info("Test cache nomap allowed share 2")
  269. ds.config.set_seed(1)
  270. # This dataset has 3 records in it only
  271. some_cache = ds.DatasetCache(session_id=2, size=0, spilling=True)
  272. decode_op = c_vision.Decode()
  273. ds1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False)
  274. ds1 = ds1.map(input_columns=["image"], operations=decode_op, cache=some_cache)
  275. ds1 = ds1.repeat(4)
  276. ds2 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False)
  277. ds2 = ds2.map(input_columns=["image"], operations=decode_op, cache=some_cache)
  278. ds2 = ds2.shuffle(buffer_size=2)
  279. num_iter = 0
  280. for _ in ds1.create_dict_iterator(num_epochs=1):
  281. num_iter += 1
  282. logger.info("Number of data in ds1: {} ".format(num_iter))
  283. assert num_iter == 12
  284. num_iter = 0
  285. for _ in ds2.create_dict_iterator(num_epochs=1):
  286. num_iter += 1
  287. assert num_iter == 3
  288. logger.info("test_cache_nomap_allowed_share2 Ended.\n")
  289. @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
  290. def test_cache_nomap_allowed_share3():
  291. """
  292. It is allowed to share the cache between the following two trees (different shard ids):
  293. Repeat Repeat
  294. | |
  295. Cache Cache
  296. | |
  297. TFReader(shard_id = 0) TFReader(shard_id = 1)
  298. """
  299. logger.info("Test cache nomap allowed share 3")
  300. some_cache = ds.DatasetCache(session_id=1, size=0, spilling=True)
  301. tf_files = ["../data/dataset/tf_file_dataset/test1.data", "../data/dataset/tf_file_dataset/test2.data"]
  302. ds1 = ds.TFRecordDataset(tf_files, num_shards=2, shard_id=0, num_samples=3, shuffle=False, cache=some_cache)
  303. ds1 = ds1.repeat(4)
  304. ds2 = ds.TFRecordDataset(tf_files, num_shards=2, shard_id=1, num_samples=3, shuffle=False, cache=some_cache)
  305. ds2 = ds2.repeat(4)
  306. num_iter = 0
  307. for _ in ds1.create_dict_iterator(num_epochs=1):
  308. num_iter += 1
  309. logger.info("Number of data in ds1: {} ".format(num_iter))
  310. assert num_iter == 12
  311. num_iter = 0
  312. for _ in ds2.create_dict_iterator(num_epochs=1):
  313. num_iter += 1
  314. assert num_iter == 12
  315. logger.info("test_cache_nomap_allowed_share3 Ended.\n")
  316. @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
  317. def test_cache_nomap_allowed_share4():
  318. """
  319. It is allowed to share the cache between the following two trees:
  320. Cache Cache
  321. | |
  322. Map(decode, num_parallel_workers=1) Map(decode, num_parallel_workers=2)
  323. | |
  324. TFReader TFReader
  325. """
  326. logger.info("Test cache nomap allowed share 4")
  327. # This dataset has 3 records in it only
  328. some_cache = ds.DatasetCache(session_id=2, size=0, spilling=True)
  329. decode_op = c_vision.Decode()
  330. ds1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False)
  331. ds1 = ds1.map(input_columns=["image"], operations=decode_op, cache=some_cache, num_parallel_workers=1)
  332. ds2 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False)
  333. ds2 = ds2.map(input_columns=["image"], operations=decode_op, cache=some_cache, num_parallel_workers=2)
  334. num_iter = 0
  335. for _ in ds1.create_dict_iterator(num_epochs=1):
  336. num_iter += 1
  337. logger.info("Number of data in ds1: {} ".format(num_iter))
  338. assert num_iter == 3
  339. num_iter = 0
  340. for _ in ds2.create_dict_iterator(num_epochs=1):
  341. num_iter += 1
  342. logger.info("Number of data in ds2: {} ".format(num_iter))
  343. assert num_iter == 3
  344. logger.info("test_cache_nomap_allowed_share4 Ended.\n")
  345. @pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
  346. def test_cache_nomap_disallowed_share1():
  347. """
  348. It is not allowed to share the cache between the following two trees:
  349. Cache Cache
  350. | |
  351. Map(decode) Map(rescale)
  352. | |
  353. TFReader TFReader
  354. """
  355. logger.info("Test cache nomap disallowed share1")
  356. # This dataset has 3 records in it only
  357. some_cache = ds.DatasetCache(session_id=1, size=0, spilling=True)
  358. decode_op = c_vision.Decode()
  359. rescale_op = c_vision.Rescale(1.0 / 255.0, -1.0)
  360. ds1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False)
  361. ds1 = ds1.map(input_columns=["image"], operations=decode_op, cache=some_cache)
  362. ds2 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False)
  363. ds2 = ds2.map(input_columns=["image"], operations=rescale_op, cache=some_cache)
  364. num_iter = 0
  365. for _ in ds1.create_dict_iterator(num_epochs=1):
  366. num_iter += 1
  367. logger.info("Number of data in ds1: {} ".format(num_iter))
  368. assert num_iter == 3
  369. try:
  370. sum([1 for _ in ds2])
  371. except RuntimeError as e:
  372. logger.info("Got an exception in DE: {}".format(str(e)))
  373. assert "Attempt to re-use a cache for a different tree!" in str(e)
  374. logger.info("test_cache_nomap_disallowed_share1 Ended.\n")
  375. if __name__ == '__main__':
  376. test_cache_nomap_basic1()
  377. test_cache_nomap_basic2()
  378. test_cache_nomap_basic3()
  379. test_cache_nomap_basic4()
  380. test_cache_nomap_basic5()
  381. test_cache_nomap_basic6()
  382. test_cache_nomap_basic7()
  383. test_cache_nomap_allowed_share1()
  384. test_cache_nomap_allowed_share2()
  385. test_cache_nomap_allowed_share3()
  386. test_cache_nomap_allowed_share4()
  387. test_cache_nomap_disallowed_share1()