浏览代码

!4868 fix: concat with none sample dataset

Merge pull request !4868 from guozhijian/fix_concat_with_zero_dataset
tags/v0.7.0-beta
mindspore-ci-bot Gitee 5 年前
父节点
当前提交
bccb92adf7
共有 2 个文件被更改,包括 54 次插入8 次删除
  1. +13
    -8
      mindspore/dataset/engine/datasets.py
  2. +41
    -0
      tests/ut/python/dataset/test_paddeddataset.py

+ 13
- 8
mindspore/dataset/engine/datasets.py 查看文件

@@ -2310,6 +2310,7 @@ class ConcatDataset(DatasetOp):

Raises:
TypeError: If dataset is not an instance of Dataset.
ValueError: If there is no samples in the one of the datasets.
"""

def __init__(self, datasets):
@@ -2324,15 +2325,19 @@ class ConcatDataset(DatasetOp):
data.parent.append(self)

self.children_sizes_ = [c.get_dataset_size() for c in self.children]
"""
_children_flag_and_nums: A list of pair<int ,int>.The first element of pair is flag that characterizes
whether the data set is mappable. The second element of pair is length of the dataset
"""
child_index = 0
for item in self.children_sizes_:
if item == 0:
raise ValueError("There is no samples in the %dth dataset. Please make sure there are "
"valid samples in the dataset" % child_index)
child_index += 1

# _children_flag_and_nums: A list of pair<int ,int>.The first element of pair is flag that characterizes
# whether the data set is mappable. The second element of pair is length of the dataset
self._children_flag_and_nums = []
"""
_children_start_end_index_: A list of pair<int ,int>.The elements of pair are used to characterize
the valid position of the dataset corresponding to the subscript when sampling
"""

# _children_start_end_index_: A list of pair<int ,int>.The elements of pair are used to characterize
# the valid position of the dataset corresponding to the subscript when sampling
self._children_start_end_index_ = []
for index, child in enumerate(self.children):
tem_list = [-1, -1]


+ 41
- 0
tests/ut/python/dataset/test_paddeddataset.py 查看文件

@@ -1,4 +1,5 @@
from io import BytesIO
import copy
import os
import numpy as np
import pytest
@@ -412,6 +413,46 @@ def test_Mindrecord_Padded(remove_mindrecord_file):
result_list.append(tem_list)
assert result_list == verify_list

def test_clue_padded_and_skip_with_0_samples():
"""
Test num_samples param of CLUE dataset
"""
TRAIN_FILE = '../data/dataset/testCLUE/afqmc/train.json'

data = ds.CLUEDataset(TRAIN_FILE, task='AFQMC', usage='train')
count = 0
for _ in data.create_dict_iterator():
count += 1
assert count == 3

data_copy1 = copy.deepcopy(data)

sample = {"label": np.array(1, np.string_),
"sentence1": np.array(1, np.string_),
"sentence2": np.array(1, np.string_)}
samples = [sample]
padded_ds = ds.PaddedDataset(samples)
dataset = data + padded_ds
testsampler = ds.DistributedSampler(num_shards=2, shard_id=1, shuffle=False, num_samples=None)
dataset.use_sampler(testsampler)
assert dataset.get_dataset_size() == 2
count = 0
for data in dataset.create_dict_iterator():
count += 1
assert count == 2

dataset = dataset.skip(count=2) # dataset2 has none samples
count = 0
for data in dataset.create_dict_iterator():
count += 1
assert count == 0

with pytest.raises(ValueError, match="There is no samples in the "):
dataset = dataset.concat(data_copy1)
count = 0
for data in dataset.create_dict_iterator():
count += 1
assert count == 2

if __name__ == '__main__':
test_TFRecord_Padded()


正在加载...
取消
保存