|
|
|
@@ -1,8 +1,12 @@ |
|
|
|
from io import BytesIO |
|
|
|
import os |
|
|
|
import numpy as np |
|
|
|
import pytest |
|
|
|
import mindspore.dataset as ds |
|
|
|
from mindspore.mindrecord import FileWriter |
|
|
|
import mindspore.dataset.transforms.vision.c_transforms as V_C |
|
|
|
from PIL import Image |
|
|
|
|
|
|
|
FILES_NUM = 4 |
|
|
|
CV_FILE_NAME = "../data/mindrecord/imagenet.mindrecord" |
|
|
|
CV_DIR_NAME = "../data/mindrecord/testImageNetData" |
|
|
|
@@ -197,7 +201,7 @@ def test_raise_error(): |
|
|
|
ds3.use_sampler(testsampler) |
|
|
|
assert excinfo.type == 'ValueError' |
|
|
|
|
|
|
|
def test_imagefolden_padded(): |
|
|
|
def test_imagefolder_padded(): |
|
|
|
DATA_DIR = "../data/dataset/testPK/data" |
|
|
|
data = ds.ImageFolderDatasetV2(DATA_DIR) |
|
|
|
|
|
|
|
@@ -220,6 +224,32 @@ def test_imagefolden_padded(): |
|
|
|
assert verify_list[8] == 1 |
|
|
|
assert verify_list[9] == 6 |
|
|
|
|
|
|
|
def test_imagefolder_padded_with_decode(): |
|
|
|
DATA_DIR = "../data/dataset/testPK/data" |
|
|
|
data = ds.ImageFolderDatasetV2(DATA_DIR) |
|
|
|
|
|
|
|
white_io = BytesIO() |
|
|
|
Image.new('RGB', (224, 224), (255, 255, 255)).save(white_io, 'JPEG') |
|
|
|
padded_sample = {} |
|
|
|
padded_sample['image'] = np.array(bytearray(white_io), dtype='uint8') |
|
|
|
padded_sample['label'] = np.array(-1, np.int32) |
|
|
|
|
|
|
|
white_samples = [padded_sample, padded_sample, padded_sample, padded_sample] |
|
|
|
data2 = ds.PaddedDataset(white_samples) |
|
|
|
data3 = data + data2 |
|
|
|
|
|
|
|
num_shards = 5 |
|
|
|
count = 0 |
|
|
|
for shard_id in range(num_shards): |
|
|
|
testsampler = ds.DistributedSampler(num_shards=num_shards, shard_id=shard_id, shuffle=False, num_samples=None) |
|
|
|
data3.use_sampler(testsampler) |
|
|
|
data3.map(input_columns="image", operations=V_C.Decode()) |
|
|
|
for ele in data3.create_dict_iterator(): |
|
|
|
print("label: {}".format(ele['label'])) |
|
|
|
count += 1 |
|
|
|
assert count == 48 |
|
|
|
|
|
|
|
|
|
|
|
def test_more_shard_padded(): |
|
|
|
result_list = [] |
|
|
|
for i in range(8): |
|
|
|
|