You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_reset.py 2.8 kB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485
  1. # Copyright 2022 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. """
  16. Testing pipeline Reset
  17. """
  18. import numpy as np
  19. import pytest
  20. import mindspore.dataset as ds
  21. def create_np_dataset(size):
  22. data = ds.NumpySlicesDataset(list(range(1, size + 1)), shuffle=False)
  23. return data
  24. def util(data, num_epochs, failure_point: int, reset_step):
  25. size = data.get_dataset_size()
  26. expected = []
  27. expected_itr = data.create_tuple_iterator(num_epochs=num_epochs, output_numpy=True)
  28. for _ in range(num_epochs):
  29. for d in expected_itr:
  30. expected.append(d)
  31. del expected_itr
  32. actual_before_reset = []
  33. itr = data.create_tuple_iterator(num_epochs=num_epochs, output_numpy=True)
  34. ds.engine.datasets._set_training_dataset(itr) # pylint: disable=W0212
  35. cur_step: int = 0
  36. failed = False
  37. for _ in range(num_epochs):
  38. for d in itr:
  39. actual_before_reset.append(d)
  40. if cur_step == failure_point:
  41. ds.engine.datasets._reset_training_dataset(reset_step) # pylint: disable=W0212
  42. failed = True
  43. break
  44. cur_step += 1
  45. if failed:
  46. break
  47. actual_after_reset = []
  48. if failed:
  49. for _ in range(reset_step // size, num_epochs):
  50. for d in itr:
  51. actual_after_reset.append(d)
  52. with pytest.raises(RuntimeError, match="User tries to fetch data beyond the specified number of epochs."):
  53. for _ in itr:
  54. pass
  55. for x, y in zip(expected[:failure_point], actual_before_reset):
  56. np.testing.assert_array_equal(x, y)
  57. for x, y in zip(expected[reset_step:], actual_after_reset):
  58. np.testing.assert_array_equal(x, y)
  59. def test_reset():
  60. """
  61. Feature: dataset recovery
  62. Description: Simple test of data pipeline reset feature
  63. Expectation: same datasets after reset
  64. """
  65. dataset_size = 5
  66. num_epochs = 3
  67. data = create_np_dataset(size=dataset_size)
  68. for failure_point in range(dataset_size * num_epochs):
  69. for reset_step in range(dataset_size * num_epochs):
  70. util(data, num_epochs=num_epochs, failure_point=failure_point, reset_step=reset_step)
  71. if __name__ == "__main__":
  72. test_reset()