|
- # Copyright 2022 Huawei Technologies Co., Ltd
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- # ==============================================================================
- import pytest
-
- import mindspore.dataset as ds
-
- DATASET_DIR_V1 = '../data/dataset/testSQuAD/SQuAD1'
- DATASET_DIR_V2 = '../data/dataset/testSQuAD/SQuAD2'
-
-
- def test_squad_basic():
- """
- Feature: SQuADDataset.
- Description: test SQuADDataset with repeat, skip and so on.
- Expectation: the data is processed successfully.
- """
- data = ds.SQuADDataset(DATASET_DIR_V1, usage='train', shuffle=False)
-
- data = data.repeat(2)
- data = data.skip(3)
- expected_result = ["Who is \"The Father of Modern Computers\"?",
- "When was John von Neumann's birth date?",
- "Where is John von Neumann's birthplace?"]
- count = 0
- for d in data.create_dict_iterator(num_epochs=1, output_numpy=True):
- assert d['question'].item().decode("utf8") == expected_result[count]
- count += 1
- assert count == 3
-
-
- def test_squad_num_shards():
- """
- Feature: SQuADDataset.
- Description: test num_shards param of SQuAD dataset.
- Expectation: the data is processed successfully.
- """
- data = ds.SQuADDataset(DATASET_DIR_V1, usage='train',
- num_shards=3, shard_id=2)
- expected_result = ["Where is John von Neumann's birthplace?"]
- count = 0
- for d in data.create_dict_iterator(num_epochs=1, output_numpy=True):
- assert d['question'].item().decode("utf8") == expected_result[count]
- count += 1
- assert count == 1
-
-
- def test_squad_num_samples():
- """
- Feature: SQuADDataset.
- Description: test num_samples param of SQuAD dataset.
- Expectation: the data is processed successfully.
- """
- data = ds.SQuADDataset(DATASET_DIR_V1, usage='train', num_samples=2)
- count = 0
- for _ in data.create_dict_iterator(num_epochs=1, output_numpy=True):
- count += 1
- assert count == 2
-
-
- def test_squad_dataset_get_datasetsize():
- """
- Feature: SQuADDataset.
- Description: test get_dataset_size of SQuAD dataset.
- Expectation: the data is processed successfully.
- """
- data = ds.SQuADDataset(DATASET_DIR_V1, usage='train')
- size = data.get_dataset_size()
- assert size == 3
-
-
- def test_squad_version1():
- """
- Feature: SQuADDataset.
- Description: test SQuAD 1.1 for train, dev and all.
- Expectation: the data is processed successfully.
- """
- # train
- data = ds.SQuADDataset(DATASET_DIR_V1, usage='train', shuffle=False)
- expected_result = ["Who is \"The Father of Modern Computers\"?",
- "When was John von Neumann's birth date?",
- "Where is John von Neumann's birthplace?"]
- count = 0
- for d in data.create_dict_iterator(num_epochs=1, output_numpy=True):
- assert d['question'].item().decode("utf8") == expected_result[count]
- count += 1
- assert count == 3
-
- # dev
- data = ds.SQuADDataset(DATASET_DIR_V1, usage='dev', shuffle=False)
- expected_result = ["\"The Mathematical Principles of Natural Philosophy\" is a philosophical philosophy " +
- "of physics created by British Cognitive Isaac Newton. It was first published in 1687.",
- "\"The Mathematical Principles of Natural Philosophy\" is a philosophical philosophy " +
- "of physics created by British Cognitive Isaac Newton. It was first published in 1687."]
- count = 0
- for d in data.create_dict_iterator(num_epochs=1, output_numpy=True):
- assert d['context'].item().decode("utf8") == expected_result[count]
- count += 1
- assert count == 2
-
- # all
- data = ds.SQuADDataset(DATASET_DIR_V1, usage='all', shuffle=False)
- expected_result = [[0], [122, 122, 122], [18], [162, 162, 162], [55]]
- count = 0
- for d in data.create_dict_iterator(num_epochs=1, output_numpy=True):
- assert [i.item() for i in d['answer_start']] == expected_result[count]
- count += 1
- assert count == 5
-
-
- def test_squad_version2():
- """
- Feature: SQuADDataset.
- Description: test SQuAD2.0 for train, dev and all.
- Expectation: the data is processed successfully.
- """
-
- # train
- data = ds.SQuADDataset(DATASET_DIR_V2, usage='train', shuffle=False)
- expected_result = ["Stephen William Hawking, born on January 8, 1942 in Oxford, England, " +
- "is one of the greatest modern physicists.",
- "Stephen William Hawking, born on January 8, 1942 in Oxford, England, " +
- "is one of the greatest modern physicists."]
- count = 0
- for d in data.create_dict_iterator(num_epochs=1, output_numpy=True):
- assert d['context'].item().decode("utf8") == expected_result[count]
- count += 1
- assert count == 2
-
- # dev
- data = ds.SQuADDataset(DATASET_DIR_V2, usage='dev', shuffle=False)
- expected_result = ["What is the lifestyle of dolphins?",
- "Who ate the squid?"]
- count = 0
- for d in data.create_dict_iterator(num_epochs=1, output_numpy=True):
- assert d['question'].item().decode("utf8") == expected_result[count]
- count += 1
- assert count == 2
-
- # all
- data = ds.SQuADDataset(DATASET_DIR_V2, usage='all', shuffle=False)
- expected_result = [["Oxford, England"],
- ["live in groups", "live in groups",
- "live in groups", "live in groups"],
- ["January 8, 1942"], [""]]
- count = 0
- for d in data.create_dict_iterator(num_epochs=1, output_numpy=True):
- result = [i.item().decode("utf8") for i in d['text']]
- assert result == expected_result[count]
- count += 1
- assert count == 4
-
-
- def test_squad_to_device():
- """
- Feature: SQuADDataset.
- Description: test SQuAD with to_device.
- Expectation: the data is processed successfully.
- """
- data = ds.SQuADDataset(DATASET_DIR_V1, usage='train', shuffle=False)
- data = data.to_device()
- data.send()
-
-
- def test_squad_invalid_dir():
- """
- Feature: SQuADDataset.
- Description: test SQuAD with invalid dir.
- Expectation: throw correct error and message.
- """
- invalid_dataset_dir = '../data/dataset/invalid_dir'
- with pytest.raises(ValueError) as info:
- _ = ds.SQuADDataset(invalid_dataset_dir, usage='train', shuffle=False)
- assert "The folder " + invalid_dataset_dir + " does not exist or is not a directory or permission denied!" \
- in str(info.value)
- assert invalid_dataset_dir in str(info.value)
-
-
- def test_squad_exception():
- """
- Feature: SQuADDataset.
- Description: test file info in err msg when exception occur of SQuAD dataset.
- Expectation: unable to read in data.
- """
-
- def exception_func(item):
- raise Exception("Error occur!")
-
- try:
- data = ds.SQuADDataset(DATASET_DIR_V1, usage='train')
- data = data.map(operations=exception_func, input_columns=["context"],
- num_parallel_workers=1)
- for _ in data.create_dict_iterator():
- pass
- assert False
- except RuntimeError as e:
- assert "map operation: [PyFunc] failed. The corresponding data files" \
- in str(e)
-
- try:
- data = ds.SQuADDataset(DATASET_DIR_V1, usage='train')
- data = data.map(operations=exception_func, input_columns=["question"],
- num_parallel_workers=1)
- for _ in data.create_dict_iterator():
- pass
- assert False
- except RuntimeError as e:
- assert "map operation: [PyFunc] failed. The corresponding data files" \
- in str(e)
-
- try:
- data = ds.SQuADDataset(DATASET_DIR_V1, usage='train')
- data = data.map(operations=exception_func, input_columns=["answer_start"],
- num_parallel_workers=1)
- for _ in data.create_dict_iterator():
- pass
- assert False
- except RuntimeError as e:
- assert "map operation: [PyFunc] failed. The corresponding data files" \
- in str(e)
-
- try:
- data = ds.SQuADDataset(DATASET_DIR_V1, usage='train')
- data = data.map(operations=exception_func, input_columns=["text"],
- num_parallel_workers=1)
- for _ in data.create_dict_iterator():
- pass
- assert False
- except RuntimeError as e:
- assert "map operation: [PyFunc] failed. The corresponding data files" \
- in str(e)
-
-
- if __name__ == "__main__":
- test_squad_basic()
- test_squad_num_shards()
- test_squad_num_samples()
- test_squad_dataset_get_datasetsize()
- test_squad_version1()
- test_squad_version2()
- test_squad_to_device()
- test_squad_invalid_dir()
- test_squad_exception()
|