|
|
|
@@ -17,6 +17,7 @@ This is the test module for mindrecord |
|
|
|
""" |
|
|
|
import collections |
|
|
|
import json |
|
|
|
import math |
|
|
|
import os |
|
|
|
import re |
|
|
|
import string |
|
|
|
@@ -1605,3 +1606,149 @@ def test_write_with_multi_array_and_MindDataset(): |
|
|
|
|
|
|
|
os.remove("{}".format(mindrecord_file_name)) |
|
|
|
os.remove("{}.db".format(mindrecord_file_name)) |
|
|
|
|
|
|
|
def test_write_with_float32_float64_float32_array_float64_array_and_MindDataset(): |
|
|
|
mindrecord_file_name = "test.mindrecord" |
|
|
|
data = [{"float32_array": np.array([1.2, 2.78, 3.1234, 4.9871, 5.12341], dtype=np.float32), |
|
|
|
"float64_array": np.array([48.1234556789, 49.3251241431, 50.13514312414, 51.8971298471, |
|
|
|
123414314.2141243, 87.1212122], dtype=np.float64), |
|
|
|
"float32": 3456.12345, |
|
|
|
"float64": 1987654321.123456785, |
|
|
|
"int32_array": np.array([1, 2, 3, 4, 5], dtype=np.int32), |
|
|
|
"int64_array": np.array([48, 49, 50, 51, 123414314, 87], dtype=np.int64), |
|
|
|
"int32": 3456, |
|
|
|
"int64": 947654321123}, |
|
|
|
{"float32_array": np.array([1.2, 2.78, 4.1234, 4.9871, 5.12341], dtype=np.float32), |
|
|
|
"float64_array": np.array([48.1234556789, 49.3251241431, 60.13514312414, 51.8971298471, |
|
|
|
123414314.2141243, 87.1212122], dtype=np.float64), |
|
|
|
"float32": 3456.12445, |
|
|
|
"float64": 1987654321.123456786, |
|
|
|
"int32_array": np.array([11, 21, 31, 41, 51], dtype=np.int32), |
|
|
|
"int64_array": np.array([481, 491, 501, 511, 1234143141, 871], dtype=np.int64), |
|
|
|
"int32": 3466, |
|
|
|
"int64": 957654321123}, |
|
|
|
{"float32_array": np.array([1.2, 2.78, 5.1234, 4.9871, 5.12341], dtype=np.float32), |
|
|
|
"float64_array": np.array([48.1234556789, 49.3251241431, 70.13514312414, 51.8971298471, |
|
|
|
123414314.2141243, 87.1212122], dtype=np.float64), |
|
|
|
"float32": 3456.12545, |
|
|
|
"float64": 1987654321.123456787, |
|
|
|
"int32_array": np.array([12, 22, 32, 42, 52], dtype=np.int32), |
|
|
|
"int64_array": np.array([482, 492, 502, 512, 1234143142, 872], dtype=np.int64), |
|
|
|
"int32": 3476, |
|
|
|
"int64": 967654321123}, |
|
|
|
{"float32_array": np.array([1.2, 2.78, 6.1234, 4.9871, 5.12341], dtype=np.float32), |
|
|
|
"float64_array": np.array([48.1234556789, 49.3251241431, 80.13514312414, 51.8971298471, |
|
|
|
123414314.2141243, 87.1212122], dtype=np.float64), |
|
|
|
"float32": 3456.12645, |
|
|
|
"float64": 1987654321.123456788, |
|
|
|
"int32_array": np.array([13, 23, 33, 43, 53], dtype=np.int32), |
|
|
|
"int64_array": np.array([483, 493, 503, 513, 1234143143, 873], dtype=np.int64), |
|
|
|
"int32": 3486, |
|
|
|
"int64": 977654321123}, |
|
|
|
{"float32_array": np.array([1.2, 2.78, 7.1234, 4.9871, 5.12341], dtype=np.float32), |
|
|
|
"float64_array": np.array([48.1234556789, 49.3251241431, 90.13514312414, 51.8971298471, |
|
|
|
123414314.2141243, 87.1212122], dtype=np.float64), |
|
|
|
"float32": 3456.12745, |
|
|
|
"float64": 1987654321.123456789, |
|
|
|
"int32_array": np.array([14, 24, 34, 44, 54], dtype=np.int32), |
|
|
|
"int64_array": np.array([484, 494, 504, 514, 1234143144, 874], dtype=np.int64), |
|
|
|
"int32": 3496, |
|
|
|
"int64": 987654321123}, |
|
|
|
] |
|
|
|
writer = FileWriter(mindrecord_file_name) |
|
|
|
schema = {"float32_array": {"type": "float32", "shape": [-1]}, |
|
|
|
"float64_array": {"type": "float64", "shape": [-1]}, |
|
|
|
"float32": {"type": "float32"}, |
|
|
|
"float64": {"type": "float64"}, |
|
|
|
"int32_array": {"type": "int32", "shape": [-1]}, |
|
|
|
"int64_array": {"type": "int64", "shape": [-1]}, |
|
|
|
"int32": {"type": "int32"}, |
|
|
|
"int64": {"type": "int64"}} |
|
|
|
writer.add_schema(schema, "data is so cool") |
|
|
|
writer.write_raw_data(data) |
|
|
|
writer.commit() |
|
|
|
|
|
|
|
# change data value to list - do none |
|
|
|
data_value_to_list = [] |
|
|
|
for item in data: |
|
|
|
new_data = {} |
|
|
|
new_data['float32_array'] = item["float32_array"] |
|
|
|
new_data['float64_array'] = item["float64_array"] |
|
|
|
new_data['float32'] = item["float32"] |
|
|
|
new_data['float64'] = item["float64"] |
|
|
|
new_data['int32_array'] = item["int32_array"] |
|
|
|
new_data['int64_array'] = item["int64_array"] |
|
|
|
new_data['int32'] = item["int32"] |
|
|
|
new_data['int64'] = item["int64"] |
|
|
|
data_value_to_list.append(new_data) |
|
|
|
|
|
|
|
num_readers = 2 |
|
|
|
data_set = ds.MindDataset(dataset_file=mindrecord_file_name, |
|
|
|
num_parallel_workers=num_readers, |
|
|
|
shuffle=False) |
|
|
|
assert data_set.get_dataset_size() == 5 |
|
|
|
num_iter = 0 |
|
|
|
for item in data_set.create_dict_iterator(): |
|
|
|
assert len(item) == 8 |
|
|
|
for field in item: |
|
|
|
if isinstance(item[field], np.ndarray): |
|
|
|
if item[field].dtype == np.float32: |
|
|
|
assert (item[field] == |
|
|
|
np.array(data_value_to_list[num_iter][field], np.float32)).all() |
|
|
|
else: |
|
|
|
assert (item[field] == |
|
|
|
data_value_to_list[num_iter][field]).all() |
|
|
|
else: |
|
|
|
assert item[field] == data_value_to_list[num_iter][field] |
|
|
|
num_iter += 1 |
|
|
|
assert num_iter == 5 |
|
|
|
|
|
|
|
num_readers = 2 |
|
|
|
data_set = ds.MindDataset(dataset_file=mindrecord_file_name, |
|
|
|
columns_list=["float32", "int32"], |
|
|
|
num_parallel_workers=num_readers, |
|
|
|
shuffle=False) |
|
|
|
assert data_set.get_dataset_size() == 5 |
|
|
|
num_iter = 0 |
|
|
|
for item in data_set.create_dict_iterator(): |
|
|
|
assert len(item) == 2 |
|
|
|
for field in item: |
|
|
|
if isinstance(item[field], np.ndarray): |
|
|
|
if item[field].dtype == np.float32: |
|
|
|
assert (item[field] == |
|
|
|
np.array(data_value_to_list[num_iter][field], np.float32)).all() |
|
|
|
else: |
|
|
|
assert (item[field] == |
|
|
|
data_value_to_list[num_iter][field]).all() |
|
|
|
else: |
|
|
|
assert item[field] == data_value_to_list[num_iter][field] |
|
|
|
num_iter += 1 |
|
|
|
assert num_iter == 5 |
|
|
|
|
|
|
|
num_readers = 2 |
|
|
|
data_set = ds.MindDataset(dataset_file=mindrecord_file_name, |
|
|
|
columns_list=["float64", "int64"], |
|
|
|
num_parallel_workers=num_readers, |
|
|
|
shuffle=False) |
|
|
|
assert data_set.get_dataset_size() == 5 |
|
|
|
num_iter = 0 |
|
|
|
for item in data_set.create_dict_iterator(): |
|
|
|
assert len(item) == 2 |
|
|
|
for field in item: |
|
|
|
if isinstance(item[field], np.ndarray): |
|
|
|
if item[field].dtype == np.float32: |
|
|
|
assert (item[field] == |
|
|
|
np.array(data_value_to_list[num_iter][field], np.float32)).all() |
|
|
|
elif item[field].dtype == np.float64: |
|
|
|
assert math.isclose(item[field], |
|
|
|
np.array(data_value_to_list[num_iter][field], np.float64), rel_tol=1e-14) |
|
|
|
else: |
|
|
|
assert (item[field] == |
|
|
|
data_value_to_list[num_iter][field]).all() |
|
|
|
else: |
|
|
|
assert item[field] == data_value_to_list[num_iter][field] |
|
|
|
num_iter += 1 |
|
|
|
assert num_iter == 5 |
|
|
|
|
|
|
|
os.remove("{}".format(mindrecord_file_name)) |
|
|
|
os.remove("{}.db".format(mindrecord_file_name)) |