You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_config.py 1.5 kB

1234567891011121314151617181920212223242526272829303132333435363738394041
  1. # Copyright 2019 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. import mindspore.dataset as ds
  16. def test_basic():
  17. ds.config.load('../data/dataset/declient.cfg')
  18. # assert ds.config.get_rows_per_buffer() == 32
  19. assert ds.config.get_num_parallel_workers() == 4
  20. # assert ds.config.get_worker_connector_size() == 16
  21. assert ds.config.get_prefetch_size() == 16
  22. assert ds.config.get_seed() == 5489
  23. # ds.config.set_rows_per_buffer(1)
  24. ds.config.set_num_parallel_workers(2)
  25. # ds.config.set_worker_connector_size(3)
  26. ds.config.set_prefetch_size(4)
  27. ds.config.set_seed(5)
  28. # assert ds.config.get_rows_per_buffer() == 1
  29. assert ds.config.get_num_parallel_workers() == 2
  30. # assert ds.config.get_worker_connector_size() == 3
  31. assert ds.config.get_prefetch_size() == 4
  32. assert ds.config.get_seed() == 5
  33. if __name__ == '__main__':
  34. test_basic()