|
|
|
@@ -1,4 +1,5 @@ |
|
|
|
from io import BytesIO |
|
|
|
import copy |
|
|
|
import os |
|
|
|
import numpy as np |
|
|
|
import pytest |
|
|
|
@@ -412,6 +413,46 @@ def test_Mindrecord_Padded(remove_mindrecord_file): |
|
|
|
result_list.append(tem_list) |
|
|
|
assert result_list == verify_list |
|
|
|
|
|
|
|
def test_clue_padded_and_skip_with_0_samples(): |
|
|
|
""" |
|
|
|
Test num_samples param of CLUE dataset |
|
|
|
""" |
|
|
|
TRAIN_FILE = '../data/dataset/testCLUE/afqmc/train.json' |
|
|
|
|
|
|
|
data = ds.CLUEDataset(TRAIN_FILE, task='AFQMC', usage='train') |
|
|
|
count = 0 |
|
|
|
for _ in data.create_dict_iterator(): |
|
|
|
count += 1 |
|
|
|
assert count == 3 |
|
|
|
|
|
|
|
data_copy1 = copy.deepcopy(data) |
|
|
|
|
|
|
|
sample = {"label": np.array(1, np.string_), |
|
|
|
"sentence1": np.array(1, np.string_), |
|
|
|
"sentence2": np.array(1, np.string_)} |
|
|
|
samples = [sample] |
|
|
|
padded_ds = ds.PaddedDataset(samples) |
|
|
|
dataset = data + padded_ds |
|
|
|
testsampler = ds.DistributedSampler(num_shards=2, shard_id=1, shuffle=False, num_samples=None) |
|
|
|
dataset.use_sampler(testsampler) |
|
|
|
assert dataset.get_dataset_size() == 2 |
|
|
|
count = 0 |
|
|
|
for data in dataset.create_dict_iterator(): |
|
|
|
count += 1 |
|
|
|
assert count == 2 |
|
|
|
|
|
|
|
dataset = dataset.skip(count=2) # dataset2 has none samples |
|
|
|
count = 0 |
|
|
|
for data in dataset.create_dict_iterator(): |
|
|
|
count += 1 |
|
|
|
assert count == 0 |
|
|
|
|
|
|
|
with pytest.raises(ValueError, match="There is no samples in the "): |
|
|
|
dataset = dataset.concat(data_copy1) |
|
|
|
count = 0 |
|
|
|
for data in dataset.create_dict_iterator(): |
|
|
|
count += 1 |
|
|
|
assert count == 2 |
|
|
|
|
|
|
|
if __name__ == '__main__': |
|
|
|
test_TFRecord_Padded() |
|
|
|
|