Browse Source

!2363 fix TextFildDataset and CLUEDataset does not support to_device

Merge pull request !2363 from yanghaitao/yht_fix_textfiledataset
tags/v0.5.0-beta
mindspore-ci-bot Gitee 5 years ago
parent
commit
bbf69912be
3 changed files with 14 additions and 1 deletions
  1. +1
    -1
      mindspore/dataset/engine/datasets.py
  2. +9
    -0
      tests/ut/python/dataset/test_datasets_clue.py
  3. +4
    -0
      tests/ut/python/dataset/test_datasets_textfileop.py

+ 1
- 1
mindspore/dataset/engine/datasets.py View File

@@ -1001,7 +1001,7 @@ class Dataset:
if isinstance(sampler, samplers.DistributedSampler):
dev_id = sampler.shard_id
return "", dev_id
if isinstance(output_dataset, TFRecordDataset):
if isinstance(output_dataset, (TFRecordDataset, TextFileDataset, CLUEDataset)):
if output_dataset.shard_id is not None:
dev_id = output_dataset.shard_id
return "", dev_id


+ 9
- 0
tests/ut/python/dataset/test_datasets_clue.py View File

@@ -344,6 +344,15 @@ def test_clue_wsc():
})
assert len(buffer) == 3

def test_clue_to_device():
"""
Test CLUE with to_device
"""
TRAIN_FILE = '../data/dataset/testCLUE/afqmc/train.json'
data = ds.CLUEDataset(TRAIN_FILE, task='AFQMC', usage='train', shuffle=False)
data = data.to_device()
data.send()


if __name__ == "__main__":
test_clue()


+ 4
- 0
tests/ut/python/dataset/test_datasets_textfileop.py View File

@@ -89,6 +89,10 @@ def test_textline_dataset_get_datasetsize():
size = data.get_dataset_size()
assert size == 3

def test_textline_dataset_to_device():
data = ds.TextFileDataset(DATA_FILE, shuffle=False)
data = data.to_device()
data.send()

if __name__ == "__main__":
test_textline_dataset_one_file()


Loading…
Cancel
Save