You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_two_level_pipeline.py 4.6 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125
  1. # Copyright 2022 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. """
  16. This is the test module for two level pipeline.
  17. """
  18. import os
  19. import pytest
  20. import mindspore.dataset as ds
  21. from util_minddataset import add_and_remove_cv_file # pylint: disable=unused-import
  22. # pylint: disable=redefined-outer-name
  23. def test_minddtaset_generatordataset_01(add_and_remove_cv_file):
  24. """
  25. Feature: Test basic two level pipeline.
  26. Description: MindDataset + GeneratorDataset
  27. Expectation: Data Iteration Successfully.
  28. """
  29. columns_list = ["data", "file_name", "label"]
  30. num_readers = 1
  31. file_name = os.environ.get('PYTEST_CURRENT_TEST').split(':')[-1].split(' ')[0]
  32. data_set = ds.MindDataset(file_name + "0", columns_list, num_parallel_workers=num_readers, shuffle=None)
  33. dataset_size = data_set.get_dataset_size()
  34. class MyIterable:
  35. """ custom iteration """
  36. def __init__(self, dataset, dataset_size):
  37. self._iter = None
  38. self._index = 0
  39. self._dataset = dataset
  40. self._dataset_size = dataset_size
  41. def __next__(self):
  42. if self._index >= self._dataset_size:
  43. raise StopIteration
  44. if self._iter:
  45. item = next(self._iter)
  46. self._index += 1
  47. return item
  48. self._iter = self._dataset.create_tuple_iterator(num_epochs=1, output_numpy=True)
  49. return next(self)
  50. def __iter__(self):
  51. self._index = 0
  52. self._iter = None
  53. return self
  54. def __len__(self):
  55. return self._dataset_size
  56. dataset = ds.GeneratorDataset(source=MyIterable(data_set, dataset_size),
  57. column_names=["data", "file_name", "label"], num_parallel_workers=1)
  58. num_epoches = 3
  59. iter_ = dataset.create_dict_iterator(num_epochs=3, output_numpy=True)
  60. num_iter = 0
  61. for _ in range(num_epoches):
  62. for _ in iter_:
  63. num_iter += 1
  64. assert num_iter == num_epoches * dataset_size
  65. # pylint: disable=redefined-outer-name
  66. def test_minddtaset_generatordataset_exception_01(add_and_remove_cv_file):
  67. """
  68. Feature: Test basic two level pipeline.
  69. Description: invalid column name in MindDataset
  70. Expectation: throw expected exception.
  71. """
  72. err_columns_list = ["data", "filename", "label"]
  73. num_readers = 1
  74. file_name = os.environ.get('PYTEST_CURRENT_TEST').split(':')[-1].split(' ')[0]
  75. data_set = ds.MindDataset(file_name + "0", err_columns_list, num_parallel_workers=num_readers, shuffle=None)
  76. dataset_size = data_set.get_dataset_size()
  77. class MyIterable:
  78. """ custom iteration """
  79. def __init__(self, dataset, dataset_size):
  80. self._iter = None
  81. self._index = 0
  82. self._dataset = dataset
  83. self._dataset_size = dataset_size
  84. def __next__(self):
  85. if self._index >= self._dataset_size:
  86. raise StopIteration
  87. if self._iter:
  88. item = next(self._iter)
  89. self._index += 1
  90. return item
  91. self._iter = self._dataset.create_tuple_iterator(num_epochs=1, output_numpy=True)
  92. return next(self)
  93. def __iter__(self):
  94. self._index = 0
  95. self._iter = None
  96. return self
  97. def __len__(self):
  98. return self._dataset_size
  99. dataset = ds.GeneratorDataset(source=MyIterable(data_set, dataset_size),
  100. column_names=["data", "file_name", "label"], num_parallel_workers=1)
  101. num_epoches = 3
  102. iter_ = dataset.create_dict_iterator(num_epochs=3, output_numpy=True)
  103. num_iter = 0
  104. with pytest.raises(RuntimeError) as error_info:
  105. for _ in range(num_epoches):
  106. for _ in iter_:
  107. num_iter += 1
  108. assert 'Unexpected error. Invalid data, column name:' in str(error_info.value)