|
|
|
@@ -23,46 +23,11 @@ from mindspore import log as logger |
|
|
|
from ..filewriter import FileWriter |
|
|
|
from ..shardutils import check_filename, ExceptionThread |
|
|
|
|
|
|
|
try: |
|
|
|
tf = import_module("tensorflow") # just used to convert tfrecord to mindrecord |
|
|
|
except ModuleNotFoundError: |
|
|
|
tf = None |
|
|
|
|
|
|
|
__all__ = ['TFRecordToMR'] |
|
|
|
|
|
|
|
SupportedTensorFlowVersion = '1.13.0-rc1' |
|
|
|
|
|
|
|
def _cast_type(value): |
|
|
|
""" |
|
|
|
Cast complex data type to basic datatype for MindRecord to recognize. |
|
|
|
|
|
|
|
Args: |
|
|
|
value: the TFRecord data type |
|
|
|
|
|
|
|
Returns: |
|
|
|
str, which is MindRecord field type. |
|
|
|
""" |
|
|
|
tf_type_to_mr_type = {tf.string: "string", |
|
|
|
tf.int8: "int32", |
|
|
|
tf.int16: "int32", |
|
|
|
tf.int32: "int32", |
|
|
|
tf.int64: "int64", |
|
|
|
tf.uint8: "int32", |
|
|
|
tf.uint16: "int32", |
|
|
|
tf.uint32: "int64", |
|
|
|
tf.uint64: "int64", |
|
|
|
tf.float16: "float32", |
|
|
|
tf.float32: "float32", |
|
|
|
tf.float64: "float64", |
|
|
|
tf.double: "float64", |
|
|
|
tf.bool: "int32"} |
|
|
|
unsupport_tf_type_to_mr_type = {tf.complex64: "None", |
|
|
|
tf.complex128: "None"} |
|
|
|
|
|
|
|
if value in tf_type_to_mr_type: |
|
|
|
return tf_type_to_mr_type[value] |
|
|
|
|
|
|
|
raise ValueError("Type " + value + " is not supported in MindRecord.") |
|
|
|
|
|
|
|
def _cast_string_type_to_np_type(value): |
|
|
|
"""Cast string type like: int32/int64/float32/float64 to np.int32/np.int64/np.float32/np.float64""" |
|
|
|
@@ -76,6 +41,7 @@ def _cast_string_type_to_np_type(value): |
|
|
|
|
|
|
|
raise ValueError("Type " + value + " is not supported cast to numpy type in MindRecord.") |
|
|
|
|
|
|
|
|
|
|
|
def _cast_name(key): |
|
|
|
""" |
|
|
|
Cast schema names which containing special characters to valid names. |
|
|
|
@@ -97,6 +63,7 @@ def _cast_name(key): |
|
|
|
casted_key = ''.join(new_key) |
|
|
|
return casted_key |
|
|
|
|
|
|
|
|
|
|
|
class TFRecordToMR: |
|
|
|
""" |
|
|
|
A class to transform from TFRecord to MindRecord. |
|
|
|
@@ -120,10 +87,14 @@ class TFRecordToMR: |
|
|
|
Exception: when tensorflow module is not found or version is not correct. |
|
|
|
""" |
|
|
|
def __init__(self, source, destination, feature_dict, bytes_fields=None): |
|
|
|
if not tf: |
|
|
|
try: |
|
|
|
self.tf = import_module("tensorflow") # just used to convert tfrecord to mindrecord |
|
|
|
except ModuleNotFoundError: |
|
|
|
self.tf = None |
|
|
|
if not self.tf: |
|
|
|
raise Exception("Module tensorflow is not found, please use pip install it.") |
|
|
|
|
|
|
|
if tf.__version__ < SupportedTensorFlowVersion: |
|
|
|
if self.tf.__version__ < SupportedTensorFlowVersion: |
|
|
|
raise Exception("Module tensorflow version must be greater or equal {}.".format(SupportedTensorFlowVersion)) |
|
|
|
|
|
|
|
if not isinstance(source, str): |
|
|
|
@@ -141,7 +112,7 @@ class TFRecordToMR: |
|
|
|
raise ValueError("Parameter feature_dict is None or not dict.") |
|
|
|
|
|
|
|
for key, val in feature_dict.items(): |
|
|
|
if not isinstance(val, tf.io.FixedLenFeature): |
|
|
|
if not isinstance(val, self.tf.io.FixedLenFeature): |
|
|
|
raise ValueError("Parameter feature_dict: {} only support FixedLenFeature.".format(feature_dict)) |
|
|
|
|
|
|
|
self.feature_dict = feature_dict |
|
|
|
@@ -161,7 +132,7 @@ class TFRecordToMR: |
|
|
|
if not isinstance(self.feature_dict[item].shape, list): |
|
|
|
raise ValueError("Parameter feature_dict[{}].shape should be a list.".format(item)) |
|
|
|
|
|
|
|
if self.feature_dict[item].dtype != tf.string: |
|
|
|
if self.feature_dict[item].dtype != self.tf.string: |
|
|
|
raise ValueError("Parameter bytes_field: {} should be tf.string in feature_dict.".format(item)) |
|
|
|
|
|
|
|
casted_bytes_field = _cast_name(item) |
|
|
|
@@ -178,34 +149,34 @@ class TFRecordToMR: |
|
|
|
if _cast_name(key) in self.bytes_fields_list: |
|
|
|
mindrecord_schema[_cast_name(key)] = {"type": "bytes"} |
|
|
|
else: |
|
|
|
mindrecord_schema[_cast_name(key)] = {"type": _cast_type(val.dtype)} |
|
|
|
mindrecord_schema[_cast_name(key)] = {"type": self._cast_type(val.dtype)} |
|
|
|
else: |
|
|
|
if len(val.shape) != 1: |
|
|
|
raise ValueError("Parameter len(feature_dict[{}].shape) should be 1.") |
|
|
|
if val.shape[0] < 1: |
|
|
|
raise ValueError("Parameter feature_dict[{}].shape[0] should > 0".format(key)) |
|
|
|
if val.dtype == tf.string: |
|
|
|
raise ValueError("Parameter feautre_dict[{}].dtype is tf.string which shape[0] " \ |
|
|
|
"is not None. It is not supported.".format(key)) |
|
|
|
if val.dtype == self.tf.string: |
|
|
|
raise ValueError("Parameter feature_dict[{}].dtype is tf.string which shape[0] " |
|
|
|
"is not None. It is not supported.".format(key)) |
|
|
|
self.list_set.add(_cast_name(key)) |
|
|
|
mindrecord_schema[_cast_name(key)] = {"type": _cast_type(val.dtype), "shape": [val.shape[0]]} |
|
|
|
mindrecord_schema[_cast_name(key)] = {"type": self._cast_type(val.dtype), "shape": [val.shape[0]]} |
|
|
|
self.mindrecord_schema = mindrecord_schema |
|
|
|
|
|
|
|
def _parse_record(self, example): |
|
|
|
"""Returns features for a single example""" |
|
|
|
features = tf.io.parse_single_example(example, features=self.feature_dict) |
|
|
|
features = self.tf.io.parse_single_example(example, features=self.feature_dict) |
|
|
|
return features |
|
|
|
|
|
|
|
def _get_data_when_scalar_field(self, ms_dict, cast_key, key, val): |
|
|
|
"""put data in ms_dict when field type is string""" |
|
|
|
if isinstance(val.numpy(), (np.ndarray, list)): |
|
|
|
raise ValueError("The response key: {}, value: {} from TFRecord should be a scalar.".format(key, val)) |
|
|
|
if self.feature_dict[key].dtype == tf.string: |
|
|
|
if self.feature_dict[key].dtype == self.tf.string: |
|
|
|
if cast_key in self.bytes_fields_list: |
|
|
|
ms_dict[cast_key] = val.numpy() |
|
|
|
else: |
|
|
|
ms_dict[cast_key] = str(val.numpy(), encoding="utf-8") |
|
|
|
elif _cast_type(self.feature_dict[key].dtype).startswith("int"): |
|
|
|
elif self._cast_type(self.feature_dict[key].dtype).startswith("int"): |
|
|
|
ms_dict[cast_key] = int(val.numpy()) |
|
|
|
else: |
|
|
|
ms_dict[cast_key] = float(val.numpy()) |
|
|
|
@@ -218,7 +189,7 @@ class TFRecordToMR: |
|
|
|
if isinstance(val, (bytes, str)): |
|
|
|
if isinstance(val, (np.ndarray, list)): |
|
|
|
raise ValueError("The response key: {}, value: {} from TFRecord should be a scalar.".format(key, val)) |
|
|
|
if self.feature_dict[key].dtype == tf.string: |
|
|
|
if self.feature_dict[key].dtype == self.tf.string: |
|
|
|
if cast_key in self.bytes_fields_list: |
|
|
|
ms_dict[cast_key] = val |
|
|
|
else: |
|
|
|
@@ -226,7 +197,7 @@ class TFRecordToMR: |
|
|
|
else: |
|
|
|
ms_dict[cast_key] = val |
|
|
|
else: |
|
|
|
if _cast_type(self.feature_dict[key].dtype).startswith("int"): |
|
|
|
if self._cast_type(self.feature_dict[key].dtype).startswith("int"): |
|
|
|
ms_dict[cast_key] = int(val) |
|
|
|
else: |
|
|
|
ms_dict[cast_key] = float(val) |
|
|
|
@@ -236,10 +207,10 @@ class TFRecordToMR: |
|
|
|
Yield a dict with key to be fields in schema, and value to be data. |
|
|
|
This function is for old version tensorflow whose version number < 2.1.0 |
|
|
|
""" |
|
|
|
dataset = tf.data.TFRecordDataset(self.source) |
|
|
|
dataset = self.tf.data.TFRecordDataset(self.source) |
|
|
|
dataset = dataset.map(self._parse_record) |
|
|
|
iterator = dataset.make_one_shot_iterator() |
|
|
|
with tf.Session() as sess: |
|
|
|
with self.tf.Session() as sess: |
|
|
|
while True: |
|
|
|
try: |
|
|
|
ms_dict = {} |
|
|
|
@@ -258,14 +229,14 @@ class TFRecordToMR: |
|
|
|
ms_dict[cast_key] = \ |
|
|
|
np.asarray(val, _cast_string_type_to_np_type(self.mindrecord_schema[cast_key]["type"])) |
|
|
|
yield ms_dict |
|
|
|
except tf.errors.OutOfRangeError: |
|
|
|
except self.tf.errors.OutOfRangeError: |
|
|
|
break |
|
|
|
except tf.errors.InvalidArgumentError: |
|
|
|
except self.tf.errors.InvalidArgumentError: |
|
|
|
raise ValueError("TFRecord feature_dict parameter error.") |
|
|
|
|
|
|
|
def tfrecord_iterator(self): |
|
|
|
"""Yield a dictionary whose keys are fields in schema.""" |
|
|
|
dataset = tf.data.TFRecordDataset(self.source) |
|
|
|
dataset = self.tf.data.TFRecordDataset(self.source) |
|
|
|
dataset = dataset.map(self._parse_record) |
|
|
|
iterator = dataset.__iter__() |
|
|
|
while True: |
|
|
|
@@ -278,15 +249,15 @@ class TFRecordToMR: |
|
|
|
self._get_data_when_scalar_field(ms_dict, cast_key, key, val) |
|
|
|
else: |
|
|
|
if not isinstance(val.numpy(), np.ndarray) and not isinstance(val.numpy(), list): |
|
|
|
raise ValueError("The response key: {}, value: {} from TFRecord should be a ndarray or " \ |
|
|
|
raise ValueError("The response key: {}, value: {} from TFRecord should be a ndarray or " |
|
|
|
"list.".format(key, val)) |
|
|
|
# list set |
|
|
|
ms_dict[cast_key] = \ |
|
|
|
np.asarray(val, _cast_string_type_to_np_type(self.mindrecord_schema[cast_key]["type"])) |
|
|
|
yield ms_dict |
|
|
|
except tf.errors.OutOfRangeError: |
|
|
|
except self.tf.errors.OutOfRangeError: |
|
|
|
break |
|
|
|
except tf.errors.InvalidArgumentError: |
|
|
|
except self.tf.errors.InvalidArgumentError: |
|
|
|
raise ValueError("TFRecord feature_dict parameter error.") |
|
|
|
|
|
|
|
def run(self): |
|
|
|
@@ -301,7 +272,7 @@ class TFRecordToMR: |
|
|
|
.format(self.mindrecord_schema, self.feature_dict)) |
|
|
|
|
|
|
|
writer.add_schema(self.mindrecord_schema, "TFRecord to MindRecord") |
|
|
|
if tf.__version__ < '2.0.0': |
|
|
|
if self.tf.__version__ < '2.0.0': |
|
|
|
tf_iter = self.tfrecord_iterator_oldversion() |
|
|
|
else: |
|
|
|
tf_iter = self.tfrecord_iterator() |
|
|
|
@@ -331,3 +302,35 @@ class TFRecordToMR: |
|
|
|
if t.exitcode != 0: |
|
|
|
raise t.exception |
|
|
|
return t.res |
|
|
|
|
|
|
|
def _cast_type(self, value): |
|
|
|
""" |
|
|
|
Cast complex data type to basic datatype for MindRecord to recognize. |
|
|
|
|
|
|
|
Args: |
|
|
|
value: the TFRecord data type |
|
|
|
|
|
|
|
Returns: |
|
|
|
str, which is MindRecord field type. |
|
|
|
""" |
|
|
|
tf_type_to_mr_type = {self.tf.string: "string", |
|
|
|
self.tf.int8: "int32", |
|
|
|
self.tf.int16: "int32", |
|
|
|
self.tf.int32: "int32", |
|
|
|
self.tf.int64: "int64", |
|
|
|
self.tf.uint8: "int32", |
|
|
|
self.tf.uint16: "int32", |
|
|
|
self.tf.uint32: "int64", |
|
|
|
self.tf.uint64: "int64", |
|
|
|
self.tf.float16: "float32", |
|
|
|
self.tf.float32: "float32", |
|
|
|
self.tf.float64: "float64", |
|
|
|
self.tf.double: "float64", |
|
|
|
self.tf.bool: "int32"} |
|
|
|
unsupport_tf_type_to_mr_type = {self.tf.complex64: "None", |
|
|
|
self.tf.complex128: "None"} |
|
|
|
|
|
|
|
if value in tf_type_to_mr_type: |
|
|
|
return tf_type_to_mr_type[value] |
|
|
|
|
|
|
|
raise ValueError("Type " + value + " is not supported in MindRecord.") |