From 9f336bb73554c9e438016dafe01e38a7d8b271ec Mon Sep 17 00:00:00 2001 From: yanghaitao Date: Fri, 19 Jun 2020 16:17:54 +0800 Subject: [PATCH] fix TextFileDataset and CLUEDataset failed with to_device --- mindspore/dataset/engine/datasets.py | 2 +- tests/ut/python/dataset/test_datasets_clue.py | 9 +++++++++ tests/ut/python/dataset/test_datasets_textfileop.py | 4 ++++ 3 files changed, 14 insertions(+), 1 deletion(-) diff --git a/mindspore/dataset/engine/datasets.py b/mindspore/dataset/engine/datasets.py index 70e9b763f6..d268cf25da 100644 --- a/mindspore/dataset/engine/datasets.py +++ b/mindspore/dataset/engine/datasets.py @@ -1002,7 +1002,7 @@ class Dataset: if isinstance(sampler, samplers.DistributedSampler): dev_id = sampler.shard_id return "", dev_id - if isinstance(output_dataset, TFRecordDataset): + if isinstance(output_dataset, (TFRecordDataset, TextFileDataset, CLUEDataset)): if output_dataset.shard_id is not None: dev_id = output_dataset.shard_id return "", dev_id diff --git a/tests/ut/python/dataset/test_datasets_clue.py b/tests/ut/python/dataset/test_datasets_clue.py index c49db45abe..e1959acb42 100644 --- a/tests/ut/python/dataset/test_datasets_clue.py +++ b/tests/ut/python/dataset/test_datasets_clue.py @@ -344,6 +344,15 @@ def test_clue_wsc(): }) assert len(buffer) == 3 +def test_clue_to_device(): + """ + Test CLUE with to_device + """ + TRAIN_FILE = '../data/dataset/testCLUE/afqmc/train.json' + data = ds.CLUEDataset(TRAIN_FILE, task='AFQMC', usage='train', shuffle=False) + data = data.to_device() + data.send() + if __name__ == "__main__": test_clue() diff --git a/tests/ut/python/dataset/test_datasets_textfileop.py b/tests/ut/python/dataset/test_datasets_textfileop.py index a1d19d88e4..1732c1817d 100644 --- a/tests/ut/python/dataset/test_datasets_textfileop.py +++ b/tests/ut/python/dataset/test_datasets_textfileop.py @@ -89,6 +89,10 @@ def test_textline_dataset_get_datasetsize(): size = data.get_dataset_size() assert size == 3 +def test_textline_dataset_to_device(): + data = ds.TextFileDataset(DATA_FILE, shuffle=False) + data = data.to_device() + data.send() if __name__ == "__main__": test_textline_dataset_one_file()