# Copyright 2022 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== import pytest import mindspore.dataset as ds DATASET_DIR_V1 = '../data/dataset/testSQuAD/SQuAD1' DATASET_DIR_V2 = '../data/dataset/testSQuAD/SQuAD2' def test_squad_basic(): """ Feature: SQuADDataset. Description: test SQuADDataset with repeat, skip and so on. Expectation: the data is processed successfully. """ data = ds.SQuADDataset(DATASET_DIR_V1, usage='train', shuffle=False) data = data.repeat(2) data = data.skip(3) expected_result = ["Who is \"The Father of Modern Computers\"?", "When was John von Neumann's birth date?", "Where is John von Neumann's birthplace?"] count = 0 for d in data.create_dict_iterator(num_epochs=1, output_numpy=True): assert d['question'].item().decode("utf8") == expected_result[count] count += 1 assert count == 3 def test_squad_num_shards(): """ Feature: SQuADDataset. Description: test num_shards param of SQuAD dataset. Expectation: the data is processed successfully. """ data = ds.SQuADDataset(DATASET_DIR_V1, usage='train', num_shards=3, shard_id=2) expected_result = ["Where is John von Neumann's birthplace?"] count = 0 for d in data.create_dict_iterator(num_epochs=1, output_numpy=True): assert d['question'].item().decode("utf8") == expected_result[count] count += 1 assert count == 1 def test_squad_num_samples(): """ Feature: SQuADDataset. Description: test num_samples param of SQuAD dataset. Expectation: the data is processed successfully. """ data = ds.SQuADDataset(DATASET_DIR_V1, usage='train', num_samples=2) count = 0 for _ in data.create_dict_iterator(num_epochs=1, output_numpy=True): count += 1 assert count == 2 def test_squad_dataset_get_datasetsize(): """ Feature: SQuADDataset. Description: test get_dataset_size of SQuAD dataset. Expectation: the data is processed successfully. """ data = ds.SQuADDataset(DATASET_DIR_V1, usage='train') size = data.get_dataset_size() assert size == 3 def test_squad_version1(): """ Feature: SQuADDataset. Description: test SQuAD 1.1 for train, dev and all. Expectation: the data is processed successfully. """ # train data = ds.SQuADDataset(DATASET_DIR_V1, usage='train', shuffle=False) expected_result = ["Who is \"The Father of Modern Computers\"?", "When was John von Neumann's birth date?", "Where is John von Neumann's birthplace?"] count = 0 for d in data.create_dict_iterator(num_epochs=1, output_numpy=True): assert d['question'].item().decode("utf8") == expected_result[count] count += 1 assert count == 3 # dev data = ds.SQuADDataset(DATASET_DIR_V1, usage='dev', shuffle=False) expected_result = ["\"The Mathematical Principles of Natural Philosophy\" is a philosophical philosophy " + "of physics created by British Cognitive Isaac Newton. It was first published in 1687.", "\"The Mathematical Principles of Natural Philosophy\" is a philosophical philosophy " + "of physics created by British Cognitive Isaac Newton. It was first published in 1687."] count = 0 for d in data.create_dict_iterator(num_epochs=1, output_numpy=True): assert d['context'].item().decode("utf8") == expected_result[count] count += 1 assert count == 2 # all data = ds.SQuADDataset(DATASET_DIR_V1, usage='all', shuffle=False) expected_result = [[0], [122, 122, 122], [18], [162, 162, 162], [55]] count = 0 for d in data.create_dict_iterator(num_epochs=1, output_numpy=True): assert [i.item() for i in d['answer_start']] == expected_result[count] count += 1 assert count == 5 def test_squad_version2(): """ Feature: SQuADDataset. Description: test SQuAD2.0 for train, dev and all. Expectation: the data is processed successfully. """ # train data = ds.SQuADDataset(DATASET_DIR_V2, usage='train', shuffle=False) expected_result = ["Stephen William Hawking, born on January 8, 1942 in Oxford, England, " + "is one of the greatest modern physicists.", "Stephen William Hawking, born on January 8, 1942 in Oxford, England, " + "is one of the greatest modern physicists."] count = 0 for d in data.create_dict_iterator(num_epochs=1, output_numpy=True): assert d['context'].item().decode("utf8") == expected_result[count] count += 1 assert count == 2 # dev data = ds.SQuADDataset(DATASET_DIR_V2, usage='dev', shuffle=False) expected_result = ["What is the lifestyle of dolphins?", "Who ate the squid?"] count = 0 for d in data.create_dict_iterator(num_epochs=1, output_numpy=True): assert d['question'].item().decode("utf8") == expected_result[count] count += 1 assert count == 2 # all data = ds.SQuADDataset(DATASET_DIR_V2, usage='all', shuffle=False) expected_result = [["Oxford, England"], ["live in groups", "live in groups", "live in groups", "live in groups"], ["January 8, 1942"], [""]] count = 0 for d in data.create_dict_iterator(num_epochs=1, output_numpy=True): result = [i.item().decode("utf8") for i in d['text']] assert result == expected_result[count] count += 1 assert count == 4 def test_squad_to_device(): """ Feature: SQuADDataset. Description: test SQuAD with to_device. Expectation: the data is processed successfully. """ data = ds.SQuADDataset(DATASET_DIR_V1, usage='train', shuffle=False) data = data.to_device() data.send() def test_squad_invalid_dir(): """ Feature: SQuADDataset. Description: test SQuAD with invalid dir. Expectation: throw correct error and message. """ invalid_dataset_dir = '../data/dataset/invalid_dir' with pytest.raises(ValueError) as info: _ = ds.SQuADDataset(invalid_dataset_dir, usage='train', shuffle=False) assert "The folder " + invalid_dataset_dir + " does not exist or is not a directory or permission denied!" \ in str(info.value) assert invalid_dataset_dir in str(info.value) def test_squad_exception(): """ Feature: SQuADDataset. Description: test file info in err msg when exception occur of SQuAD dataset. Expectation: unable to read in data. """ def exception_func(item): raise Exception("Error occur!") try: data = ds.SQuADDataset(DATASET_DIR_V1, usage='train') data = data.map(operations=exception_func, input_columns=["context"], num_parallel_workers=1) for _ in data.create_dict_iterator(): pass assert False except RuntimeError as e: assert "map operation: [PyFunc] failed. The corresponding data files" \ in str(e) try: data = ds.SQuADDataset(DATASET_DIR_V1, usage='train') data = data.map(operations=exception_func, input_columns=["question"], num_parallel_workers=1) for _ in data.create_dict_iterator(): pass assert False except RuntimeError as e: assert "map operation: [PyFunc] failed. The corresponding data files" \ in str(e) try: data = ds.SQuADDataset(DATASET_DIR_V1, usage='train') data = data.map(operations=exception_func, input_columns=["answer_start"], num_parallel_workers=1) for _ in data.create_dict_iterator(): pass assert False except RuntimeError as e: assert "map operation: [PyFunc] failed. The corresponding data files" \ in str(e) try: data = ds.SQuADDataset(DATASET_DIR_V1, usage='train') data = data.map(operations=exception_func, input_columns=["text"], num_parallel_workers=1) for _ in data.create_dict_iterator(): pass assert False except RuntimeError as e: assert "map operation: [PyFunc] failed. The corresponding data files" \ in str(e) if __name__ == "__main__": test_squad_basic() test_squad_num_shards() test_squad_num_samples() test_squad_dataset_get_datasetsize() test_squad_version1() test_squad_version2() test_squad_to_device() test_squad_invalid_dir() test_squad_exception()