# Copyright 2022 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """ Testing pipeline Reset """ import numpy as np import pytest import mindspore.dataset as ds def create_np_dataset(size): data = ds.NumpySlicesDataset(list(range(1, size + 1)), shuffle=False) return data def util(data, num_epochs, failure_point: int, reset_step): size = data.get_dataset_size() expected = [] expected_itr = data.create_tuple_iterator(num_epochs=num_epochs, output_numpy=True) for _ in range(num_epochs): for d in expected_itr: expected.append(d) del expected_itr actual_before_reset = [] itr = data.create_tuple_iterator(num_epochs=num_epochs, output_numpy=True) ds.engine.datasets._set_training_dataset(itr) # pylint: disable=W0212 cur_step: int = 0 failed = False for _ in range(num_epochs): for d in itr: actual_before_reset.append(d) if cur_step == failure_point: ds.engine.datasets._reset_training_dataset(reset_step) # pylint: disable=W0212 failed = True break cur_step += 1 if failed: break actual_after_reset = [] if failed: for _ in range(reset_step // size, num_epochs): for d in itr: actual_after_reset.append(d) with pytest.raises(RuntimeError, match="User tries to fetch data beyond the specified number of epochs."): for _ in itr: pass for x, y in zip(expected[:failure_point], actual_before_reset): np.testing.assert_array_equal(x, y) for x, y in zip(expected[reset_step:], actual_after_reset): np.testing.assert_array_equal(x, y) def test_reset(): """ Feature: dataset recovery Description: Simple test of data pipeline reset feature Expectation: same datasets after reset """ dataset_size = 5 num_epochs = 3 data = create_np_dataset(size=dataset_size) for failure_point in range(dataset_size * num_epochs): for reset_step in range(dataset_size * num_epochs): util(data, num_epochs=num_epochs, failure_point=failure_point, reset_step=reset_step) if __name__ == "__main__": test_reset()