Browse Source

move import tensorflow into class tfrecord_to_mr

tags/v1.1.0
ms_yan 5 years ago
parent
commit
cb19781672
1 changed files with 61 additions and 58 deletions
  1. +61
    -58
      mindspore/mindrecord/tools/tfrecord_to_mr.py

+ 61
- 58
mindspore/mindrecord/tools/tfrecord_to_mr.py View File

@@ -23,46 +23,11 @@ from mindspore import log as logger
from ..filewriter import FileWriter
from ..shardutils import check_filename, ExceptionThread

try:
tf = import_module("tensorflow") # just used to convert tfrecord to mindrecord
except ModuleNotFoundError:
tf = None

__all__ = ['TFRecordToMR']

SupportedTensorFlowVersion = '1.13.0-rc1'

def _cast_type(value):
"""
Cast complex data type to basic datatype for MindRecord to recognize.

Args:
value: the TFRecord data type

Returns:
str, which is MindRecord field type.
"""
tf_type_to_mr_type = {tf.string: "string",
tf.int8: "int32",
tf.int16: "int32",
tf.int32: "int32",
tf.int64: "int64",
tf.uint8: "int32",
tf.uint16: "int32",
tf.uint32: "int64",
tf.uint64: "int64",
tf.float16: "float32",
tf.float32: "float32",
tf.float64: "float64",
tf.double: "float64",
tf.bool: "int32"}
unsupport_tf_type_to_mr_type = {tf.complex64: "None",
tf.complex128: "None"}

if value in tf_type_to_mr_type:
return tf_type_to_mr_type[value]

raise ValueError("Type " + value + " is not supported in MindRecord.")

def _cast_string_type_to_np_type(value):
"""Cast string type like: int32/int64/float32/float64 to np.int32/np.int64/np.float32/np.float64"""
@@ -76,6 +41,7 @@ def _cast_string_type_to_np_type(value):

raise ValueError("Type " + value + " is not supported cast to numpy type in MindRecord.")


def _cast_name(key):
"""
Cast schema names which containing special characters to valid names.
@@ -97,6 +63,7 @@ def _cast_name(key):
casted_key = ''.join(new_key)
return casted_key


class TFRecordToMR:
"""
A class to transform from TFRecord to MindRecord.
@@ -120,10 +87,14 @@ class TFRecordToMR:
Exception: when tensorflow module is not found or version is not correct.
"""
def __init__(self, source, destination, feature_dict, bytes_fields=None):
if not tf:
try:
self.tf = import_module("tensorflow") # just used to convert tfrecord to mindrecord
except ModuleNotFoundError:
self.tf = None
if not self.tf:
raise Exception("Module tensorflow is not found, please use pip install it.")

if tf.__version__ < SupportedTensorFlowVersion:
if self.tf.__version__ < SupportedTensorFlowVersion:
raise Exception("Module tensorflow version must be greater or equal {}.".format(SupportedTensorFlowVersion))

if not isinstance(source, str):
@@ -141,7 +112,7 @@ class TFRecordToMR:
raise ValueError("Parameter feature_dict is None or not dict.")

for key, val in feature_dict.items():
if not isinstance(val, tf.io.FixedLenFeature):
if not isinstance(val, self.tf.io.FixedLenFeature):
raise ValueError("Parameter feature_dict: {} only support FixedLenFeature.".format(feature_dict))

self.feature_dict = feature_dict
@@ -161,7 +132,7 @@ class TFRecordToMR:
if not isinstance(self.feature_dict[item].shape, list):
raise ValueError("Parameter feature_dict[{}].shape should be a list.".format(item))

if self.feature_dict[item].dtype != tf.string:
if self.feature_dict[item].dtype != self.tf.string:
raise ValueError("Parameter bytes_field: {} should be tf.string in feature_dict.".format(item))

casted_bytes_field = _cast_name(item)
@@ -178,34 +149,34 @@ class TFRecordToMR:
if _cast_name(key) in self.bytes_fields_list:
mindrecord_schema[_cast_name(key)] = {"type": "bytes"}
else:
mindrecord_schema[_cast_name(key)] = {"type": _cast_type(val.dtype)}
mindrecord_schema[_cast_name(key)] = {"type": self._cast_type(val.dtype)}
else:
if len(val.shape) != 1:
raise ValueError("Parameter len(feature_dict[{}].shape) should be 1.")
if val.shape[0] < 1:
raise ValueError("Parameter feature_dict[{}].shape[0] should > 0".format(key))
if val.dtype == tf.string:
raise ValueError("Parameter feautre_dict[{}].dtype is tf.string which shape[0] " \
"is not None. It is not supported.".format(key))
if val.dtype == self.tf.string:
raise ValueError("Parameter feature_dict[{}].dtype is tf.string which shape[0] "
"is not None. It is not supported.".format(key))
self.list_set.add(_cast_name(key))
mindrecord_schema[_cast_name(key)] = {"type": _cast_type(val.dtype), "shape": [val.shape[0]]}
mindrecord_schema[_cast_name(key)] = {"type": self._cast_type(val.dtype), "shape": [val.shape[0]]}
self.mindrecord_schema = mindrecord_schema

def _parse_record(self, example):
"""Returns features for a single example"""
features = tf.io.parse_single_example(example, features=self.feature_dict)
features = self.tf.io.parse_single_example(example, features=self.feature_dict)
return features

def _get_data_when_scalar_field(self, ms_dict, cast_key, key, val):
"""put data in ms_dict when field type is string"""
if isinstance(val.numpy(), (np.ndarray, list)):
raise ValueError("The response key: {}, value: {} from TFRecord should be a scalar.".format(key, val))
if self.feature_dict[key].dtype == tf.string:
if self.feature_dict[key].dtype == self.tf.string:
if cast_key in self.bytes_fields_list:
ms_dict[cast_key] = val.numpy()
else:
ms_dict[cast_key] = str(val.numpy(), encoding="utf-8")
elif _cast_type(self.feature_dict[key].dtype).startswith("int"):
elif self._cast_type(self.feature_dict[key].dtype).startswith("int"):
ms_dict[cast_key] = int(val.numpy())
else:
ms_dict[cast_key] = float(val.numpy())
@@ -218,7 +189,7 @@ class TFRecordToMR:
if isinstance(val, (bytes, str)):
if isinstance(val, (np.ndarray, list)):
raise ValueError("The response key: {}, value: {} from TFRecord should be a scalar.".format(key, val))
if self.feature_dict[key].dtype == tf.string:
if self.feature_dict[key].dtype == self.tf.string:
if cast_key in self.bytes_fields_list:
ms_dict[cast_key] = val
else:
@@ -226,7 +197,7 @@ class TFRecordToMR:
else:
ms_dict[cast_key] = val
else:
if _cast_type(self.feature_dict[key].dtype).startswith("int"):
if self._cast_type(self.feature_dict[key].dtype).startswith("int"):
ms_dict[cast_key] = int(val)
else:
ms_dict[cast_key] = float(val)
@@ -236,10 +207,10 @@ class TFRecordToMR:
Yield a dict with key to be fields in schema, and value to be data.
This function is for old version tensorflow whose version number < 2.1.0
"""
dataset = tf.data.TFRecordDataset(self.source)
dataset = self.tf.data.TFRecordDataset(self.source)
dataset = dataset.map(self._parse_record)
iterator = dataset.make_one_shot_iterator()
with tf.Session() as sess:
with self.tf.Session() as sess:
while True:
try:
ms_dict = {}
@@ -258,14 +229,14 @@ class TFRecordToMR:
ms_dict[cast_key] = \
np.asarray(val, _cast_string_type_to_np_type(self.mindrecord_schema[cast_key]["type"]))
yield ms_dict
except tf.errors.OutOfRangeError:
except self.tf.errors.OutOfRangeError:
break
except tf.errors.InvalidArgumentError:
except self.tf.errors.InvalidArgumentError:
raise ValueError("TFRecord feature_dict parameter error.")

def tfrecord_iterator(self):
"""Yield a dictionary whose keys are fields in schema."""
dataset = tf.data.TFRecordDataset(self.source)
dataset = self.tf.data.TFRecordDataset(self.source)
dataset = dataset.map(self._parse_record)
iterator = dataset.__iter__()
while True:
@@ -278,15 +249,15 @@ class TFRecordToMR:
self._get_data_when_scalar_field(ms_dict, cast_key, key, val)
else:
if not isinstance(val.numpy(), np.ndarray) and not isinstance(val.numpy(), list):
raise ValueError("The response key: {}, value: {} from TFRecord should be a ndarray or " \
raise ValueError("The response key: {}, value: {} from TFRecord should be a ndarray or "
"list.".format(key, val))
# list set
ms_dict[cast_key] = \
np.asarray(val, _cast_string_type_to_np_type(self.mindrecord_schema[cast_key]["type"]))
yield ms_dict
except tf.errors.OutOfRangeError:
except self.tf.errors.OutOfRangeError:
break
except tf.errors.InvalidArgumentError:
except self.tf.errors.InvalidArgumentError:
raise ValueError("TFRecord feature_dict parameter error.")

def run(self):
@@ -301,7 +272,7 @@ class TFRecordToMR:
.format(self.mindrecord_schema, self.feature_dict))

writer.add_schema(self.mindrecord_schema, "TFRecord to MindRecord")
if tf.__version__ < '2.0.0':
if self.tf.__version__ < '2.0.0':
tf_iter = self.tfrecord_iterator_oldversion()
else:
tf_iter = self.tfrecord_iterator()
@@ -331,3 +302,35 @@ class TFRecordToMR:
if t.exitcode != 0:
raise t.exception
return t.res

def _cast_type(self, value):
"""
Cast complex data type to basic datatype for MindRecord to recognize.

Args:
value: the TFRecord data type

Returns:
str, which is MindRecord field type.
"""
tf_type_to_mr_type = {self.tf.string: "string",
self.tf.int8: "int32",
self.tf.int16: "int32",
self.tf.int32: "int32",
self.tf.int64: "int64",
self.tf.uint8: "int32",
self.tf.uint16: "int32",
self.tf.uint32: "int64",
self.tf.uint64: "int64",
self.tf.float16: "float32",
self.tf.float32: "float32",
self.tf.float64: "float64",
self.tf.double: "float64",
self.tf.bool: "int32"}
unsupport_tf_type_to_mr_type = {self.tf.complex64: "None",
self.tf.complex128: "None"}

if value in tf_type_to_mr_type:
return tf_type_to_mr_type[value]

raise ValueError("Type " + value + " is not supported in MindRecord.")

Loading…
Cancel
Save