|
|
|
@@ -12,6 +12,10 @@ |
|
|
|
# See the License for the specific language governing permissions and |
|
|
|
# limitations under the License. |
|
|
|
# ============================================================================== |
|
|
|
import os |
|
|
|
|
|
|
|
import numpy as np |
|
|
|
|
|
|
|
import mindspore.dataset as ds |
|
|
|
from mindspore import log as logger |
|
|
|
|
|
|
|
@@ -26,6 +30,20 @@ DATA_DIR_10 = "../data/dataset/testCifar10Data" |
|
|
|
DATA_DIR_100 = "../data/dataset/testCifar100Data" |
|
|
|
|
|
|
|
|
|
|
|
def load_cifar(path): |
|
|
|
raw = np.empty(0, dtype=np.uint8) |
|
|
|
for file_name in os.listdir(path): |
|
|
|
if file_name.endswith(".bin"): |
|
|
|
with open(os.path.join(path, file_name), mode='rb') as file: |
|
|
|
raw = np.append(raw, np.fromfile(file, dtype=np.uint8), axis=0) |
|
|
|
raw = raw.reshape(-1, 3073) |
|
|
|
labels = raw[:, 0] |
|
|
|
images = raw[:, 1:] |
|
|
|
images = images.reshape(-1, 3, 32, 32) |
|
|
|
images = images.transpose(0, 2, 3, 1) |
|
|
|
return images, labels |
|
|
|
|
|
|
|
|
|
|
|
def test_case_dataset_cifar10(): |
|
|
|
""" |
|
|
|
dataset parameter |
|
|
|
@@ -56,6 +74,18 @@ def test_case_dataset_cifar100(): |
|
|
|
assert num_iter == 100 |
|
|
|
|
|
|
|
|
|
|
|
def test_reading_cifar10(): |
|
|
|
""" |
|
|
|
Validate CIFAR10 image readings |
|
|
|
""" |
|
|
|
data1 = ds.Cifar10Dataset(DATA_DIR_10, 100, shuffle=False) |
|
|
|
images, labels = load_cifar(DATA_DIR_10) |
|
|
|
for i, d in enumerate(data1.create_dict_iterator()): |
|
|
|
np.testing.assert_array_equal(d["image"], images[i]) |
|
|
|
np.testing.assert_array_equal(d["label"], labels[i]) |
|
|
|
|
|
|
|
|
|
|
|
if __name__ == '__main__': |
|
|
|
test_case_dataset_cifar10() |
|
|
|
test_case_dataset_cifar100() |
|
|
|
test_reading_cifar10() |