|
|
|
@@ -15,6 +15,7 @@ |
|
|
|
import collections |
|
|
|
from importlib import import_module |
|
|
|
import os |
|
|
|
from string import punctuation |
|
|
|
|
|
|
|
import numpy as np |
|
|
|
import pytest |
|
|
|
@@ -35,6 +36,27 @@ TFRECORD_FILE_NAME = "test.tfrecord" |
|
|
|
MINDRECORD_FILE_NAME = "test.mindrecord" |
|
|
|
PARTITION_NUM = 1 |
|
|
|
|
|
|
|
def cast_name(key): |
|
|
|
""" |
|
|
|
Cast schema names which containing special characters to valid names. |
|
|
|
|
|
|
|
Here special characters means any characters in |
|
|
|
'!"#$%&\'()*+,./:;<=>?@[\\]^`{|}~ |
|
|
|
Valid names can only contain a-z, A-Z, and 0-9 and _ |
|
|
|
|
|
|
|
Args: |
|
|
|
key (str): original key that might contains special characters. |
|
|
|
|
|
|
|
Returns: |
|
|
|
str, casted key that replace the special characters with "_". i.e. if |
|
|
|
key is "a b" then returns "a_b". |
|
|
|
""" |
|
|
|
special_symbols = set('{}{}'.format(punctuation, ' ')) |
|
|
|
special_symbols.remove('_') |
|
|
|
new_key = ['_' if x in special_symbols else x for x in key] |
|
|
|
casted_key = ''.join(new_key) |
|
|
|
return casted_key |
|
|
|
|
|
|
|
def verify_data(transformer, reader): |
|
|
|
"""Verify the data by read from mindrecord""" |
|
|
|
tf_iter = transformer.tfrecord_iterator() |
|
|
|
@@ -43,14 +65,14 @@ def verify_data(transformer, reader): |
|
|
|
count = 0 |
|
|
|
for tf_item, mr_item in zip(tf_iter, mr_iter): |
|
|
|
count = count + 1 |
|
|
|
assert len(tf_item) == 6 |
|
|
|
assert len(mr_item) == 6 |
|
|
|
assert len(tf_item) == len(mr_item) |
|
|
|
for key, value in tf_item.items(): |
|
|
|
logger.info("key: {}, tfrecord: value: {}, mindrecord: value: {}".format(key, value, mr_item[key])) |
|
|
|
logger.info("key: {}, tfrecord: value: {}, mindrecord: value: {}".format(key, value, |
|
|
|
mr_item[cast_name(key)])) |
|
|
|
if isinstance(value, np.ndarray): |
|
|
|
assert (value == mr_item[key]).all() |
|
|
|
assert (value == mr_item[cast_name(key)]).all() |
|
|
|
else: |
|
|
|
assert value == mr_item[key] |
|
|
|
assert value == mr_item[cast_name(key)] |
|
|
|
assert count == 10 |
|
|
|
|
|
|
|
def generate_tfrecord(): |
|
|
|
@@ -102,6 +124,39 @@ def generate_tfrecord(): |
|
|
|
writer.close() |
|
|
|
logger.info("Write {} rows in tfrecord.".format(example_count)) |
|
|
|
|
|
|
|
def generate_tfrecord_with_special_field_name(): |
|
|
|
def create_int_feature(values): |
|
|
|
if isinstance(values, list): |
|
|
|
feature = tf.train.Feature(int64_list=tf.train.Int64List(value=list(values))) # values: [int, int, int] |
|
|
|
else: |
|
|
|
feature = tf.train.Feature(int64_list=tf.train.Int64List(value=[values])) # values: int |
|
|
|
return feature |
|
|
|
|
|
|
|
def create_bytes_feature(values): |
|
|
|
if isinstance(values, bytes): |
|
|
|
feature = tf.train.Feature(bytes_list=tf.train.BytesList(value=[values])) # values: bytes |
|
|
|
else: |
|
|
|
# values: string |
|
|
|
feature = tf.train.Feature(bytes_list=tf.train.BytesList(value=[bytes(values, encoding='utf-8')])) |
|
|
|
return feature |
|
|
|
|
|
|
|
writer = tf.io.TFRecordWriter(os.path.join(TFRECORD_DATA_DIR, TFRECORD_FILE_NAME)) |
|
|
|
|
|
|
|
example_count = 0 |
|
|
|
for i in range(10): |
|
|
|
label = i |
|
|
|
image_bytes = bytes(str("aaaabbbbcccc" + str(i)), encoding="utf-8") |
|
|
|
|
|
|
|
features = collections.OrderedDict() |
|
|
|
features["image/class/label"] = create_int_feature(label) |
|
|
|
features["image/encoded"] = create_bytes_feature(image_bytes) |
|
|
|
|
|
|
|
tf_example = tf.train.Example(features=tf.train.Features(feature=features)) |
|
|
|
writer.write(tf_example.SerializeToString()) |
|
|
|
example_count += 1 |
|
|
|
writer.close() |
|
|
|
logger.info("Write {} rows in tfrecord.".format(example_count)) |
|
|
|
|
|
|
|
def test_tfrecord_to_mindrecord(): |
|
|
|
"""test transform tfrecord to mindrecord.""" |
|
|
|
if not tf or tf.__version__ < SupportedTensorFlowVersion: |
|
|
|
@@ -398,3 +453,110 @@ def test_tfrecord_to_mindrecord_scalar_bytes_with_10_exception(): |
|
|
|
os.remove(MINDRECORD_FILE_NAME + ".db") |
|
|
|
|
|
|
|
os.remove(os.path.join(TFRECORD_DATA_DIR, TFRECORD_FILE_NAME)) |
|
|
|
|
|
|
|
def test_tfrecord_to_mindrecord_exception_bytes_fields_is_not_string_type(): |
|
|
|
"""test transform tfrecord to mindrecord.""" |
|
|
|
if not tf or tf.__version__ < SupportedTensorFlowVersion: |
|
|
|
# skip the test |
|
|
|
logger.warning("Module tensorflow is not found or version wrong, \ |
|
|
|
please use pip install it / reinstall version >= {}.".format(SupportedTensorFlowVersion)) |
|
|
|
return |
|
|
|
|
|
|
|
generate_tfrecord() |
|
|
|
assert os.path.exists(os.path.join(TFRECORD_DATA_DIR, TFRECORD_FILE_NAME)) |
|
|
|
|
|
|
|
feature_dict = {"file_name": tf.io.FixedLenFeature([], tf.string), |
|
|
|
"image_bytes": tf.io.FixedLenFeature([], tf.string), |
|
|
|
"int64_scalar": tf.io.FixedLenFeature([], tf.int64), |
|
|
|
"float_scalar": tf.io.FixedLenFeature([], tf.float32), |
|
|
|
"int64_list": tf.io.FixedLenFeature([6], tf.int64), |
|
|
|
"float_list": tf.io.FixedLenFeature([7], tf.float32), |
|
|
|
} |
|
|
|
|
|
|
|
if os.path.exists(MINDRECORD_FILE_NAME): |
|
|
|
os.remove(MINDRECORD_FILE_NAME) |
|
|
|
if os.path.exists(MINDRECORD_FILE_NAME + ".db"): |
|
|
|
os.remove(MINDRECORD_FILE_NAME + ".db") |
|
|
|
|
|
|
|
with pytest.raises(ValueError): |
|
|
|
tfrecord_transformer = TFRecordToMR(os.path.join(TFRECORD_DATA_DIR, TFRECORD_FILE_NAME), |
|
|
|
MINDRECORD_FILE_NAME, feature_dict, ["int64_list"]) |
|
|
|
tfrecord_transformer.transform() |
|
|
|
|
|
|
|
if os.path.exists(MINDRECORD_FILE_NAME): |
|
|
|
os.remove(MINDRECORD_FILE_NAME) |
|
|
|
if os.path.exists(MINDRECORD_FILE_NAME + ".db"): |
|
|
|
os.remove(MINDRECORD_FILE_NAME + ".db") |
|
|
|
|
|
|
|
os.remove(os.path.join(TFRECORD_DATA_DIR, TFRECORD_FILE_NAME)) |
|
|
|
|
|
|
|
def test_tfrecord_to_mindrecord_exception_bytes_fields_is_not_list(): |
|
|
|
"""test transform tfrecord to mindrecord.""" |
|
|
|
if not tf or tf.__version__ < SupportedTensorFlowVersion: |
|
|
|
# skip the test |
|
|
|
logger.warning("Module tensorflow is not found or version wrong, \ |
|
|
|
please use pip install it / reinstall version >= {}.".format(SupportedTensorFlowVersion)) |
|
|
|
return |
|
|
|
|
|
|
|
generate_tfrecord() |
|
|
|
assert os.path.exists(os.path.join(TFRECORD_DATA_DIR, TFRECORD_FILE_NAME)) |
|
|
|
|
|
|
|
feature_dict = {"file_name": tf.io.FixedLenFeature([], tf.string), |
|
|
|
"image_bytes": tf.io.FixedLenFeature([], tf.string), |
|
|
|
"int64_scalar": tf.io.FixedLenFeature([], tf.int64), |
|
|
|
"float_scalar": tf.io.FixedLenFeature([], tf.float32), |
|
|
|
"int64_list": tf.io.FixedLenFeature([6], tf.int64), |
|
|
|
"float_list": tf.io.FixedLenFeature([7], tf.float32), |
|
|
|
} |
|
|
|
|
|
|
|
if os.path.exists(MINDRECORD_FILE_NAME): |
|
|
|
os.remove(MINDRECORD_FILE_NAME) |
|
|
|
if os.path.exists(MINDRECORD_FILE_NAME + ".db"): |
|
|
|
os.remove(MINDRECORD_FILE_NAME + ".db") |
|
|
|
|
|
|
|
with pytest.raises(ValueError): |
|
|
|
tfrecord_transformer = TFRecordToMR(os.path.join(TFRECORD_DATA_DIR, TFRECORD_FILE_NAME), |
|
|
|
MINDRECORD_FILE_NAME, feature_dict, "") |
|
|
|
tfrecord_transformer.transform() |
|
|
|
|
|
|
|
if os.path.exists(MINDRECORD_FILE_NAME): |
|
|
|
os.remove(MINDRECORD_FILE_NAME) |
|
|
|
if os.path.exists(MINDRECORD_FILE_NAME + ".db"): |
|
|
|
os.remove(MINDRECORD_FILE_NAME + ".db") |
|
|
|
|
|
|
|
os.remove(os.path.join(TFRECORD_DATA_DIR, TFRECORD_FILE_NAME)) |
|
|
|
|
|
|
|
def test_tfrecord_to_mindrecord_with_special_field_name(): |
|
|
|
"""test transform tfrecord to mindrecord.""" |
|
|
|
if not tf or tf.__version__ < SupportedTensorFlowVersion: |
|
|
|
# skip the test |
|
|
|
logger.warning("Module tensorflow is not found or version wrong, \ |
|
|
|
please use pip install it / reinstall version >= {}.".format(SupportedTensorFlowVersion)) |
|
|
|
return |
|
|
|
|
|
|
|
generate_tfrecord_with_special_field_name() |
|
|
|
assert os.path.exists(os.path.join(TFRECORD_DATA_DIR, TFRECORD_FILE_NAME)) |
|
|
|
|
|
|
|
feature_dict = {"image/class/label": tf.io.FixedLenFeature([], tf.int64), |
|
|
|
"image/encoded": tf.io.FixedLenFeature([], tf.string), |
|
|
|
} |
|
|
|
|
|
|
|
if os.path.exists(MINDRECORD_FILE_NAME): |
|
|
|
os.remove(MINDRECORD_FILE_NAME) |
|
|
|
if os.path.exists(MINDRECORD_FILE_NAME + ".db"): |
|
|
|
os.remove(MINDRECORD_FILE_NAME + ".db") |
|
|
|
|
|
|
|
tfrecord_transformer = TFRecordToMR(os.path.join(TFRECORD_DATA_DIR, TFRECORD_FILE_NAME), |
|
|
|
MINDRECORD_FILE_NAME, feature_dict, ["image/encoded"]) |
|
|
|
tfrecord_transformer.transform() |
|
|
|
|
|
|
|
assert os.path.exists(MINDRECORD_FILE_NAME) |
|
|
|
assert os.path.exists(MINDRECORD_FILE_NAME + ".db") |
|
|
|
|
|
|
|
fr_mindrecord = FileReader(MINDRECORD_FILE_NAME) |
|
|
|
verify_data(tfrecord_transformer, fr_mindrecord) |
|
|
|
|
|
|
|
os.remove(MINDRECORD_FILE_NAME) |
|
|
|
os.remove(MINDRECORD_FILE_NAME + ".db") |
|
|
|
|
|
|
|
os.remove(os.path.join(TFRECORD_DATA_DIR, TFRECORD_FILE_NAME)) |