You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_config.py 16 kB

5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402
  1. # Copyright 2019 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. """
  16. Testing configuration manager
  17. """
  18. import os
  19. import filecmp
  20. import glob
  21. import numpy as np
  22. import mindspore.dataset as ds
  23. import mindspore.dataset.engine.iterators as it
  24. import mindspore.dataset.transforms.py_transforms
  25. import mindspore.dataset.vision.c_transforms as c_vision
  26. import mindspore.dataset.vision.py_transforms as py_vision
  27. from mindspore import log as logger
  28. from util import dataset_equal
  29. DATA_DIR = ["../data/dataset/test_tf_file_3_images/train-0000-of-0001.data"]
  30. SCHEMA_DIR = "../data/dataset/test_tf_file_3_images/datasetSchema.json"
  31. def test_basic():
  32. """
  33. Test basic configuration functions
  34. """
  35. # Save original configuration values
  36. num_parallel_workers_original = ds.config.get_num_parallel_workers()
  37. prefetch_size_original = ds.config.get_prefetch_size()
  38. seed_original = ds.config.get_seed()
  39. monitor_sampling_interval_original = ds.config.get_monitor_sampling_interval()
  40. ds.config.load('../data/dataset/declient.cfg')
  41. assert ds.config.get_num_parallel_workers() == 8
  42. # assert ds.config.get_worker_connector_size() == 16
  43. assert ds.config.get_prefetch_size() == 16
  44. assert ds.config.get_seed() == 5489
  45. assert ds.config.get_monitor_sampling_interval() == 15
  46. ds.config.set_num_parallel_workers(2)
  47. # ds.config.set_worker_connector_size(3)
  48. ds.config.set_prefetch_size(4)
  49. ds.config.set_seed(5)
  50. ds.config.set_monitor_sampling_interval(45)
  51. assert ds.config.get_num_parallel_workers() == 2
  52. # assert ds.config.get_worker_connector_size() == 3
  53. assert ds.config.get_prefetch_size() == 4
  54. assert ds.config.get_seed() == 5
  55. assert ds.config.get_monitor_sampling_interval() == 45
  56. # Restore original configuration values
  57. ds.config.set_num_parallel_workers(num_parallel_workers_original)
  58. ds.config.set_prefetch_size(prefetch_size_original)
  59. ds.config.set_seed(seed_original)
  60. ds.config.set_monitor_sampling_interval(monitor_sampling_interval_original)
  61. def test_get_seed():
  62. """
  63. This gets the seed value without explicitly setting a default, expect int.
  64. """
  65. assert isinstance(ds.config.get_seed(), int)
  66. def test_pipeline():
  67. """
  68. Test that our configuration pipeline works when we set parameters at different locations in dataset code
  69. """
  70. # Save original configuration values
  71. num_parallel_workers_original = ds.config.get_num_parallel_workers()
  72. data1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, shuffle=False)
  73. data1 = data1.map(operations=[c_vision.Decode(True)], input_columns=["image"])
  74. ds.serialize(data1, "testpipeline.json")
  75. data2 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, num_parallel_workers=num_parallel_workers_original,
  76. shuffle=False)
  77. data2 = data2.map(operations=[c_vision.Decode(True)], input_columns=["image"])
  78. ds.serialize(data2, "testpipeline2.json")
  79. # check that the generated output is different
  80. assert filecmp.cmp('testpipeline.json', 'testpipeline2.json')
  81. # this test passes currently because our num_parallel_workers don't get updated.
  82. # remove generated jason files
  83. file_list = glob.glob('*.json')
  84. for f in file_list:
  85. try:
  86. os.remove(f)
  87. except IOError:
  88. logger.info("Error while deleting: {}".format(f))
  89. # Restore original configuration values
  90. ds.config.set_num_parallel_workers(num_parallel_workers_original)
  91. def test_deterministic_run_fail():
  92. """
  93. Test RandomCrop with seed, expected to fail
  94. """
  95. logger.info("test_deterministic_run_fail")
  96. # Save original configuration values
  97. num_parallel_workers_original = ds.config.get_num_parallel_workers()
  98. seed_original = ds.config.get_seed()
  99. # when we set the seed all operations within our dataset should be deterministic
  100. ds.config.set_seed(0)
  101. ds.config.set_num_parallel_workers(1)
  102. # First dataset
  103. data1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False)
  104. # Assuming we get the same seed on calling constructor, if this op is re-used then result won't be
  105. # the same in between the two datasets. For example, RandomCrop constructor takes seed (0)
  106. # outputs a deterministic series of numbers, e,g "a" = [1, 2, 3, 4, 5, 6] <- pretend these are random
  107. random_crop_op = c_vision.RandomCrop([512, 512], [200, 200, 200, 200])
  108. decode_op = c_vision.Decode()
  109. data1 = data1.map(operations=decode_op, input_columns=["image"])
  110. data1 = data1.map(operations=random_crop_op, input_columns=["image"])
  111. # Second dataset
  112. data2 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False)
  113. data2 = data2.map(operations=decode_op, input_columns=["image"])
  114. # If seed is set up on constructor
  115. data2 = data2.map(operations=random_crop_op, input_columns=["image"])
  116. try:
  117. dataset_equal(data1, data2, 0)
  118. except Exception as e:
  119. # two datasets split the number out of the sequence a
  120. logger.info("Got an exception in DE: {}".format(str(e)))
  121. assert "Array" in str(e)
  122. # Restore original configuration values
  123. ds.config.set_num_parallel_workers(num_parallel_workers_original)
  124. ds.config.set_seed(seed_original)
  125. def test_seed_undeterministic():
  126. """
  127. Test seed with num parallel workers in c, this test is expected to fail some of the time
  128. """
  129. logger.info("test_seed_undeterministic")
  130. # Save original configuration values
  131. num_parallel_workers_original = ds.config.get_num_parallel_workers()
  132. seed_original = ds.config.get_seed()
  133. ds.config.set_seed(0)
  134. ds.config.set_num_parallel_workers(3)
  135. # First dataset
  136. data1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False)
  137. # We get the seed when constructor is called
  138. random_crop_op = c_vision.RandomCrop([512, 512], [200, 200, 200, 200])
  139. decode_op = c_vision.Decode()
  140. data1 = data1.map(operations=decode_op, input_columns=["image"])
  141. data1 = data1.map(operations=random_crop_op, input_columns=["image"])
  142. # Second dataset
  143. data2 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False)
  144. data2 = data2.map(operations=decode_op, input_columns=["image"])
  145. # Since seed is set up on constructor, so the two ops output deterministic sequence.
  146. # Assume the generated random sequence "a" = [1, 2, 3, 4, 5, 6] <- pretend these are random
  147. random_crop_op2 = c_vision.RandomCrop([512, 512], [200, 200, 200, 200])
  148. data2 = data2.map(operations=random_crop_op2, input_columns=["image"])
  149. try:
  150. dataset_equal(data1, data2, 0)
  151. except Exception as e:
  152. # two datasets both use numbers from the generated sequence "a"
  153. logger.info("Got an exception in DE: {}".format(str(e)))
  154. assert "Array" in str(e)
  155. # Restore original configuration values
  156. ds.config.set_num_parallel_workers(num_parallel_workers_original)
  157. ds.config.set_seed(seed_original)
  158. def test_seed_deterministic():
  159. """
  160. Test deterministic run with setting the seed, only works with num_parallel worker = 1
  161. """
  162. logger.info("test_seed_deterministic")
  163. # Save original configuration values
  164. num_parallel_workers_original = ds.config.get_num_parallel_workers()
  165. seed_original = ds.config.get_seed()
  166. ds.config.set_seed(0)
  167. ds.config.set_num_parallel_workers(1)
  168. # First dataset
  169. data1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False)
  170. # seed will be read in during constructor call
  171. random_crop_op = c_vision.RandomCrop([512, 512], [200, 200, 200, 200])
  172. decode_op = c_vision.Decode()
  173. data1 = data1.map(operations=decode_op, input_columns=["image"])
  174. data1 = data1.map(operations=random_crop_op, input_columns=["image"])
  175. # Second dataset
  176. data2 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False)
  177. data2 = data2.map(operations=decode_op, input_columns=["image"])
  178. # If seed is set up on constructor, so the two ops output deterministic sequence
  179. random_crop_op2 = c_vision.RandomCrop([512, 512], [200, 200, 200, 200])
  180. data2 = data2.map(operations=random_crop_op2, input_columns=["image"])
  181. dataset_equal(data1, data2, 0)
  182. # Restore original configuration values
  183. ds.config.set_num_parallel_workers(num_parallel_workers_original)
  184. ds.config.set_seed(seed_original)
  185. def test_deterministic_run_distribution():
  186. """
  187. Test deterministic run with with setting the seed being used in a distribution
  188. """
  189. logger.info("test_deterministic_run_distribution")
  190. # Save original configuration values
  191. num_parallel_workers_original = ds.config.get_num_parallel_workers()
  192. seed_original = ds.config.get_seed()
  193. # when we set the seed all operations within our dataset should be deterministic
  194. ds.config.set_seed(0)
  195. ds.config.set_num_parallel_workers(1)
  196. # First dataset
  197. data1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False)
  198. random_horizontal_flip_op = c_vision.RandomHorizontalFlip(0.1)
  199. decode_op = c_vision.Decode()
  200. data1 = data1.map(operations=decode_op, input_columns=["image"])
  201. data1 = data1.map(operations=random_horizontal_flip_op, input_columns=["image"])
  202. # Second dataset
  203. data2 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False)
  204. data2 = data2.map(operations=decode_op, input_columns=["image"])
  205. # If seed is set up on constructor, so the two ops output deterministic sequence
  206. random_horizontal_flip_op2 = c_vision.RandomHorizontalFlip(0.1)
  207. data2 = data2.map(operations=random_horizontal_flip_op2, input_columns=["image"])
  208. dataset_equal(data1, data2, 0)
  209. # Restore original configuration values
  210. ds.config.set_num_parallel_workers(num_parallel_workers_original)
  211. ds.config.set_seed(seed_original)
  212. def test_deterministic_python_seed():
  213. """
  214. Test deterministic execution with seed in python
  215. """
  216. logger.info("test_deterministic_python_seed")
  217. # Save original configuration values
  218. num_parallel_workers_original = ds.config.get_num_parallel_workers()
  219. seed_original = ds.config.get_seed()
  220. ds.config.set_seed(0)
  221. ds.config.set_num_parallel_workers(1)
  222. # First dataset
  223. data1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False)
  224. transforms = [
  225. py_vision.Decode(),
  226. py_vision.RandomCrop([512, 512], [200, 200, 200, 200]),
  227. py_vision.ToTensor(),
  228. ]
  229. transform = mindspore.dataset.transforms.py_transforms.Compose(transforms)
  230. data1 = data1.map(operations=transform, input_columns=["image"])
  231. data1_output = []
  232. # config.set_seed() calls random.seed()
  233. for data_one in data1.create_dict_iterator(num_epochs=1, output_numpy=True):
  234. data1_output.append(data_one["image"])
  235. # Second dataset
  236. data2 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False)
  237. data2 = data2.map(operations=transform, input_columns=["image"])
  238. # config.set_seed() calls random.seed(), resets seed for next dataset iterator
  239. ds.config.set_seed(0)
  240. data2_output = []
  241. for data_two in data2.create_dict_iterator(num_epochs=1, output_numpy=True):
  242. data2_output.append(data_two["image"])
  243. np.testing.assert_equal(data1_output, data2_output)
  244. # Restore original configuration values
  245. ds.config.set_num_parallel_workers(num_parallel_workers_original)
  246. ds.config.set_seed(seed_original)
  247. def test_deterministic_python_seed_multi_thread():
  248. """
  249. Test deterministic execution with seed in python, this fails with multi-thread pyfunc run
  250. """
  251. logger.info("test_deterministic_python_seed_multi_thread")
  252. # Sometimes there are some ITERATORS left in ITERATORS_LIST when run all UTs together,
  253. # and cause core dump and blocking in this UT. Add cleanup() here to fix it.
  254. it._cleanup() # pylint: disable=W0212
  255. # Save original configuration values
  256. num_parallel_workers_original = ds.config.get_num_parallel_workers()
  257. seed_original = ds.config.get_seed()
  258. ds.config.set_num_parallel_workers(3)
  259. ds.config.set_seed(0)
  260. # when we set the seed all operations within our dataset should be deterministic
  261. # First dataset
  262. data1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False)
  263. transforms = [
  264. py_vision.Decode(),
  265. py_vision.RandomCrop([512, 512], [200, 200, 200, 200]),
  266. py_vision.ToTensor(),
  267. ]
  268. transform = mindspore.dataset.transforms.py_transforms.Compose(transforms)
  269. data1 = data1.map(operations=transform, input_columns=["image"], python_multiprocessing=True)
  270. data1_output = []
  271. # config.set_seed() calls random.seed()
  272. for data_one in data1.create_dict_iterator(num_epochs=1, output_numpy=True):
  273. data1_output.append(data_one["image"])
  274. # Second dataset
  275. data2 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False)
  276. # If seed is set up on constructor
  277. data2 = data2.map(operations=transform, input_columns=["image"], python_multiprocessing=True)
  278. # config.set_seed() calls random.seed()
  279. ds.config.set_seed(0)
  280. data2_output = []
  281. for data_two in data2.create_dict_iterator(num_epochs=1, output_numpy=True):
  282. data2_output.append(data_two["image"])
  283. try:
  284. np.testing.assert_equal(data1_output, data2_output)
  285. except Exception as e:
  286. # expect output to not match during multi-threaded execution
  287. logger.info("Got an exception in DE: {}".format(str(e)))
  288. assert "Array" in str(e)
  289. # Restore original configuration values
  290. ds.config.set_num_parallel_workers(num_parallel_workers_original)
  291. ds.config.set_seed(seed_original)
  292. def test_auto_num_workers_error():
  293. """
  294. Test auto_num_workers error
  295. """
  296. err_msg = ""
  297. try:
  298. ds.config.set_auto_num_workers([1, 2])
  299. except TypeError as e:
  300. err_msg = str(e)
  301. assert "must be of type bool" in err_msg
  302. def test_auto_num_workers():
  303. """
  304. Test auto_num_workers can be set.
  305. """
  306. saved_config = ds.config.get_auto_num_workers()
  307. assert isinstance(saved_config, bool)
  308. # change to a different config
  309. flipped_config = not saved_config
  310. ds.config.set_auto_num_workers(flipped_config)
  311. assert flipped_config == ds.config.get_auto_num_workers()
  312. # now flip this back
  313. ds.config.set_auto_num_workers(saved_config)
  314. assert saved_config == ds.config.get_auto_num_workers()
  315. if __name__ == '__main__':
  316. test_basic()
  317. test_get_seed()
  318. test_pipeline()
  319. test_deterministic_run_fail()
  320. test_seed_undeterministic()
  321. test_seed_deterministic()
  322. test_deterministic_run_distribution()
  323. test_deterministic_python_seed()
  324. test_deterministic_python_seed_multi_thread()
  325. test_auto_num_workers_error()
  326. test_auto_num_workers()