You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_config.py 19 kB

6 years ago
6 years ago
6 years ago
6 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487
  1. # Copyright 2019 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. """
  16. Testing configuration manager
  17. """
  18. import os
  19. import filecmp
  20. import glob
  21. import numpy as np
  22. import mindspore.dataset as ds
  23. import mindspore.dataset.engine.iterators as it
  24. import mindspore.dataset.transforms.py_transforms
  25. import mindspore.dataset.vision.c_transforms as c_vision
  26. import mindspore.dataset.vision.py_transforms as py_vision
  27. from mindspore import log as logger
  28. from util import dataset_equal
  29. DATA_DIR = ["../data/dataset/test_tf_file_3_images/train-0000-of-0001.data"]
  30. SCHEMA_DIR = "../data/dataset/test_tf_file_3_images/datasetSchema.json"
  31. def config_error_func(config_interface, input_args, err_type, except_err_msg):
  32. err_msg = ""
  33. try:
  34. config_interface(input_args)
  35. except err_type as e:
  36. err_msg = str(e)
  37. assert except_err_msg in err_msg
  38. def test_basic():
  39. """
  40. Test basic configuration functions
  41. """
  42. # Save original configuration values
  43. num_parallel_workers_original = ds.config.get_num_parallel_workers()
  44. prefetch_size_original = ds.config.get_prefetch_size()
  45. seed_original = ds.config.get_seed()
  46. monitor_sampling_interval_original = ds.config.get_monitor_sampling_interval()
  47. ds.config.load('../data/dataset/declient.cfg')
  48. assert ds.config.get_num_parallel_workers() == 8
  49. # assert ds.config.get_worker_connector_size() == 16
  50. assert ds.config.get_prefetch_size() == 16
  51. assert ds.config.get_seed() == 5489
  52. assert ds.config.get_monitor_sampling_interval() == 15
  53. ds.config.set_num_parallel_workers(2)
  54. # ds.config.set_worker_connector_size(3)
  55. ds.config.set_prefetch_size(4)
  56. ds.config.set_seed(5)
  57. ds.config.set_monitor_sampling_interval(45)
  58. assert ds.config.get_num_parallel_workers() == 2
  59. # assert ds.config.get_worker_connector_size() == 3
  60. assert ds.config.get_prefetch_size() == 4
  61. assert ds.config.get_seed() == 5
  62. assert ds.config.get_monitor_sampling_interval() == 45
  63. # Restore original configuration values
  64. ds.config.set_num_parallel_workers(num_parallel_workers_original)
  65. ds.config.set_prefetch_size(prefetch_size_original)
  66. ds.config.set_seed(seed_original)
  67. ds.config.set_monitor_sampling_interval(monitor_sampling_interval_original)
  68. def test_get_seed():
  69. """
  70. This gets the seed value without explicitly setting a default, expect int.
  71. """
  72. assert isinstance(ds.config.get_seed(), int)
  73. def test_pipeline():
  74. """
  75. Test that our configuration pipeline works when we set parameters at different locations in dataset code
  76. """
  77. # Save original configuration values
  78. num_parallel_workers_original = ds.config.get_num_parallel_workers()
  79. data1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, shuffle=False)
  80. data1 = data1.map(operations=[c_vision.Decode(True)], input_columns=["image"])
  81. ds.serialize(data1, "testpipeline.json")
  82. data2 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, num_parallel_workers=num_parallel_workers_original,
  83. shuffle=False)
  84. data2 = data2.map(operations=[c_vision.Decode(True)], input_columns=["image"])
  85. ds.serialize(data2, "testpipeline2.json")
  86. # check that the generated output is different
  87. assert filecmp.cmp('testpipeline.json', 'testpipeline2.json')
  88. # this test passes currently because our num_parallel_workers don't get updated.
  89. # remove generated jason files
  90. file_list = glob.glob('*.json')
  91. for f in file_list:
  92. try:
  93. os.remove(f)
  94. except IOError:
  95. logger.info("Error while deleting: {}".format(f))
  96. # Restore original configuration values
  97. ds.config.set_num_parallel_workers(num_parallel_workers_original)
  98. def test_deterministic_run_fail():
  99. """
  100. Test RandomCrop with seed, expected to fail
  101. """
  102. logger.info("test_deterministic_run_fail")
  103. # Save original configuration values
  104. num_parallel_workers_original = ds.config.get_num_parallel_workers()
  105. seed_original = ds.config.get_seed()
  106. # when we set the seed all operations within our dataset should be deterministic
  107. ds.config.set_seed(0)
  108. ds.config.set_num_parallel_workers(1)
  109. # First dataset
  110. data1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False)
  111. # Assuming we get the same seed on calling constructor, if this op is re-used then result won't be
  112. # the same in between the two datasets. For example, RandomCrop constructor takes seed (0)
  113. # outputs a deterministic series of numbers, e,g "a" = [1, 2, 3, 4, 5, 6] <- pretend these are random
  114. random_crop_op = c_vision.RandomCrop([512, 512], [200, 200, 200, 200])
  115. decode_op = c_vision.Decode()
  116. data1 = data1.map(operations=decode_op, input_columns=["image"])
  117. data1 = data1.map(operations=random_crop_op, input_columns=["image"])
  118. # Second dataset
  119. data2 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False)
  120. data2 = data2.map(operations=decode_op, input_columns=["image"])
  121. # If seed is set up on constructor
  122. data2 = data2.map(operations=random_crop_op, input_columns=["image"])
  123. try:
  124. dataset_equal(data1, data2, 0)
  125. except Exception as e:
  126. # two datasets split the number out of the sequence a
  127. logger.info("Got an exception in DE: {}".format(str(e)))
  128. assert "Array" in str(e)
  129. # Restore original configuration values
  130. ds.config.set_num_parallel_workers(num_parallel_workers_original)
  131. ds.config.set_seed(seed_original)
  132. def test_seed_undeterministic():
  133. """
  134. Test seed with num parallel workers in c, this test is expected to fail some of the time
  135. """
  136. logger.info("test_seed_undeterministic")
  137. # Save original configuration values
  138. num_parallel_workers_original = ds.config.get_num_parallel_workers()
  139. seed_original = ds.config.get_seed()
  140. ds.config.set_seed(0)
  141. ds.config.set_num_parallel_workers(3)
  142. # First dataset
  143. data1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False)
  144. # We get the seed when constructor is called
  145. random_crop_op = c_vision.RandomCrop([512, 512], [200, 200, 200, 200])
  146. decode_op = c_vision.Decode()
  147. data1 = data1.map(operations=decode_op, input_columns=["image"])
  148. data1 = data1.map(operations=random_crop_op, input_columns=["image"])
  149. # Second dataset
  150. data2 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False)
  151. data2 = data2.map(operations=decode_op, input_columns=["image"])
  152. # Since seed is set up on constructor, so the two ops output deterministic sequence.
  153. # Assume the generated random sequence "a" = [1, 2, 3, 4, 5, 6] <- pretend these are random
  154. random_crop_op2 = c_vision.RandomCrop([512, 512], [200, 200, 200, 200])
  155. data2 = data2.map(operations=random_crop_op2, input_columns=["image"])
  156. try:
  157. dataset_equal(data1, data2, 0)
  158. except Exception as e:
  159. # two datasets both use numbers from the generated sequence "a"
  160. logger.info("Got an exception in DE: {}".format(str(e)))
  161. assert "Array" in str(e)
  162. # Restore original configuration values
  163. ds.config.set_num_parallel_workers(num_parallel_workers_original)
  164. ds.config.set_seed(seed_original)
  165. def test_seed_deterministic():
  166. """
  167. Test deterministic run with setting the seed, only works with num_parallel worker = 1
  168. """
  169. logger.info("test_seed_deterministic")
  170. # Save original configuration values
  171. num_parallel_workers_original = ds.config.get_num_parallel_workers()
  172. seed_original = ds.config.get_seed()
  173. ds.config.set_seed(0)
  174. ds.config.set_num_parallel_workers(1)
  175. # First dataset
  176. data1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False)
  177. # seed will be read in during constructor call
  178. random_crop_op = c_vision.RandomCrop([512, 512], [200, 200, 200, 200])
  179. decode_op = c_vision.Decode()
  180. data1 = data1.map(operations=decode_op, input_columns=["image"])
  181. data1 = data1.map(operations=random_crop_op, input_columns=["image"])
  182. # Second dataset
  183. data2 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False)
  184. data2 = data2.map(operations=decode_op, input_columns=["image"])
  185. # If seed is set up on constructor, so the two ops output deterministic sequence
  186. random_crop_op2 = c_vision.RandomCrop([512, 512], [200, 200, 200, 200])
  187. data2 = data2.map(operations=random_crop_op2, input_columns=["image"])
  188. dataset_equal(data1, data2, 0)
  189. # Restore original configuration values
  190. ds.config.set_num_parallel_workers(num_parallel_workers_original)
  191. ds.config.set_seed(seed_original)
  192. def test_deterministic_run_distribution():
  193. """
  194. Test deterministic run with with setting the seed being used in a distribution
  195. """
  196. logger.info("test_deterministic_run_distribution")
  197. # Save original configuration values
  198. num_parallel_workers_original = ds.config.get_num_parallel_workers()
  199. seed_original = ds.config.get_seed()
  200. # when we set the seed all operations within our dataset should be deterministic
  201. ds.config.set_seed(0)
  202. ds.config.set_num_parallel_workers(1)
  203. # First dataset
  204. data1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False)
  205. random_horizontal_flip_op = c_vision.RandomHorizontalFlip(0.1)
  206. decode_op = c_vision.Decode()
  207. data1 = data1.map(operations=decode_op, input_columns=["image"])
  208. data1 = data1.map(operations=random_horizontal_flip_op, input_columns=["image"])
  209. # Second dataset
  210. data2 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False)
  211. data2 = data2.map(operations=decode_op, input_columns=["image"])
  212. # If seed is set up on constructor, so the two ops output deterministic sequence
  213. random_horizontal_flip_op2 = c_vision.RandomHorizontalFlip(0.1)
  214. data2 = data2.map(operations=random_horizontal_flip_op2, input_columns=["image"])
  215. dataset_equal(data1, data2, 0)
  216. # Restore original configuration values
  217. ds.config.set_num_parallel_workers(num_parallel_workers_original)
  218. ds.config.set_seed(seed_original)
  219. def test_deterministic_python_seed():
  220. """
  221. Test deterministic execution with seed in python
  222. """
  223. logger.info("test_deterministic_python_seed")
  224. # Save original configuration values
  225. num_parallel_workers_original = ds.config.get_num_parallel_workers()
  226. seed_original = ds.config.get_seed()
  227. ds.config.set_seed(0)
  228. ds.config.set_num_parallel_workers(1)
  229. # First dataset
  230. data1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False)
  231. transforms = [
  232. py_vision.Decode(),
  233. py_vision.RandomCrop([512, 512], [200, 200, 200, 200]),
  234. py_vision.ToTensor(),
  235. ]
  236. transform = mindspore.dataset.transforms.py_transforms.Compose(transforms)
  237. data1 = data1.map(operations=transform, input_columns=["image"])
  238. data1_output = []
  239. # config.set_seed() calls random.seed()
  240. for data_one in data1.create_dict_iterator(num_epochs=1, output_numpy=True):
  241. data1_output.append(data_one["image"])
  242. # Second dataset
  243. data2 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False)
  244. data2 = data2.map(operations=transform, input_columns=["image"])
  245. # config.set_seed() calls random.seed(), resets seed for next dataset iterator
  246. ds.config.set_seed(0)
  247. data2_output = []
  248. for data_two in data2.create_dict_iterator(num_epochs=1, output_numpy=True):
  249. data2_output.append(data_two["image"])
  250. np.testing.assert_equal(data1_output, data2_output)
  251. # Restore original configuration values
  252. ds.config.set_num_parallel_workers(num_parallel_workers_original)
  253. ds.config.set_seed(seed_original)
  254. def test_deterministic_python_seed_multi_thread():
  255. """
  256. Test deterministic execution with seed in python, this fails with multi-thread pyfunc run
  257. """
  258. logger.info("test_deterministic_python_seed_multi_thread")
  259. # Sometimes there are some ITERATORS left in ITERATORS_LIST when run all UTs together,
  260. # and cause core dump and blocking in this UT. Add cleanup() here to fix it.
  261. it._cleanup() # pylint: disable=W0212
  262. # Save original configuration values
  263. num_parallel_workers_original = ds.config.get_num_parallel_workers()
  264. seed_original = ds.config.get_seed()
  265. mem_original = ds.config.get_enable_shared_mem()
  266. ds.config.set_num_parallel_workers(3)
  267. ds.config.set_seed(0)
  268. # Disable shared memory to save shm in CI
  269. ds.config.set_enable_shared_mem(False)
  270. # when we set the seed all operations within our dataset should be deterministic
  271. # First dataset
  272. data1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False)
  273. transforms = [
  274. py_vision.Decode(),
  275. py_vision.RandomCrop([512, 512], [200, 200, 200, 200]),
  276. py_vision.ToTensor(),
  277. ]
  278. transform = mindspore.dataset.transforms.py_transforms.Compose(transforms)
  279. data1 = data1.map(operations=transform, input_columns=["image"], python_multiprocessing=True)
  280. data1_output = []
  281. # config.set_seed() calls random.seed()
  282. for data_one in data1.create_dict_iterator(num_epochs=1, output_numpy=True):
  283. data1_output.append(data_one["image"])
  284. # Second dataset
  285. data2 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False)
  286. # If seed is set up on constructor
  287. data2 = data2.map(operations=transform, input_columns=["image"], python_multiprocessing=True)
  288. # config.set_seed() calls random.seed()
  289. ds.config.set_seed(0)
  290. data2_output = []
  291. for data_two in data2.create_dict_iterator(num_epochs=1, output_numpy=True):
  292. data2_output.append(data_two["image"])
  293. try:
  294. np.testing.assert_equal(data1_output, data2_output)
  295. except Exception as e:
  296. # expect output to not match during multi-threaded execution
  297. logger.info("Got an exception in DE: {}".format(str(e)))
  298. assert "Array" in str(e)
  299. # Restore original configuration values
  300. ds.config.set_num_parallel_workers(num_parallel_workers_original)
  301. ds.config.set_seed(seed_original)
  302. ds.config.set_enable_shared_mem(mem_original)
  303. def test_auto_num_workers_error():
  304. """
  305. Test auto_num_workers error
  306. """
  307. err_msg = ""
  308. try:
  309. ds.config.set_auto_num_workers([1, 2])
  310. except TypeError as e:
  311. err_msg = str(e)
  312. assert "must be of type bool" in err_msg
  313. def test_auto_num_workers():
  314. """
  315. Test auto_num_workers can be set.
  316. """
  317. saved_config = ds.config.get_auto_num_workers()
  318. assert isinstance(saved_config, bool)
  319. # change to a different config
  320. flipped_config = not saved_config
  321. ds.config.set_auto_num_workers(flipped_config)
  322. assert flipped_config == ds.config.get_auto_num_workers()
  323. # now flip this back
  324. ds.config.set_auto_num_workers(saved_config)
  325. assert saved_config == ds.config.get_auto_num_workers()
  326. def test_enable_watchdog():
  327. """
  328. Feature: Test the function of get_enable_watchdog and set_enable_watchdog.
  329. Description: We add this new interface so we can close the watchdog thread
  330. Expectation: The default state is True, when execute set_enable_watchdog, the state will update.
  331. """
  332. saved_config = ds.config.get_enable_watchdog()
  333. assert isinstance(saved_config, bool)
  334. assert saved_config is True
  335. # change to a different config
  336. flipped_config = not saved_config
  337. ds.config.set_enable_watchdog(flipped_config)
  338. assert flipped_config == ds.config.get_enable_watchdog()
  339. # now flip this back
  340. ds.config.set_enable_watchdog(saved_config)
  341. assert saved_config == ds.config.get_enable_watchdog()
  342. def test_multiprocessing_timeout_interval():
  343. """
  344. Feature: Test the function of get_multiprocessing_timeout_interval and set_multiprocessing_timeout_interval.
  345. Description: We add this new interface so we can adjust the timeout of multiprocessing get function.
  346. Expectation: The default state is 300s, when execute set_multiprocessing_timeout_interval, the state will update.
  347. """
  348. saved_config = ds.config.get_multiprocessing_timeout_interval()
  349. assert saved_config == 300
  350. # change to a different config
  351. flipped_config = 1000
  352. ds.config.set_multiprocessing_timeout_interval(flipped_config)
  353. assert flipped_config == ds.config.get_multiprocessing_timeout_interval()
  354. # now flip this back
  355. ds.config.set_multiprocessing_timeout_interval(saved_config)
  356. assert saved_config == ds.config.get_multiprocessing_timeout_interval()
  357. def test_config_bool_type_error():
  358. """
  359. Feature: Now many interfaces of config support bool input even its valid input is int.
  360. Description: We will raise a type error when input is a bool when it should be int.
  361. Expectation: TypeError will be raised when input is a bool.
  362. """
  363. # set_seed will raise TypeError if input is a boolean
  364. config_error_func(ds.config.set_seed, True, TypeError, "seed isn't of type int")
  365. # set_prefetch_size will raise TypeError if input is a boolean
  366. config_error_func(ds.config.set_prefetch_size, True, TypeError, "size isn't of type int")
  367. # set_num_parallel_workers will raise TypeError if input is a boolean
  368. config_error_func(ds.config.set_num_parallel_workers, True, TypeError, "num isn't of type int")
  369. # set_monitor_sampling_interval will raise TypeError if input is a boolean
  370. config_error_func(ds.config.set_monitor_sampling_interval, True, TypeError, "interval isn't of type int")
  371. # set_callback_timeout will raise TypeError if input is a boolean
  372. config_error_func(ds.config.set_callback_timeout, True, TypeError, "timeout isn't of type int")
  373. # set_autotune_interval will raise TypeError if input is a boolean
  374. config_error_func(ds.config.set_autotune_interval, True, TypeError, "interval must be of type int")
  375. # set_sending_batches will raise TypeError if input is a boolean
  376. config_error_func(ds.config.set_sending_batches, True, TypeError, "batch_num must be an int dtype")
  377. # set_multiprocessing_timeout_interval will raise TypeError if input is a boolean
  378. config_error_func(ds.config.set_multiprocessing_timeout_interval, True, TypeError, "interval isn't of type int")
  379. if __name__ == '__main__':
  380. test_basic()
  381. test_get_seed()
  382. test_pipeline()
  383. test_deterministic_run_fail()
  384. test_seed_undeterministic()
  385. test_seed_deterministic()
  386. test_deterministic_run_distribution()
  387. test_deterministic_python_seed()
  388. test_deterministic_python_seed_multi_thread()
  389. test_auto_num_workers_error()
  390. test_auto_num_workers()
  391. test_enable_watchdog()
  392. test_multiprocessing_timeout_interval()
  393. test_config_bool_type_error()