|
|
|
@@ -69,7 +69,7 @@ class TFRecordToMR: |
|
|
|
|
|
|
|
Args: |
|
|
|
source (str): the TFRecord file to be transformed. |
|
|
|
destination (str): the MindRecord file path to tranform into. |
|
|
|
destination (str): the MindRecord file path to transform into. |
|
|
|
feature_dict (dict): a dictionary that states the feature type, e.g. |
|
|
|
feature_dict = {"xxxx": tf.io.FixedLenFeature([], tf.string), \ |
|
|
|
"yyyy": tf.io.FixedLenFeature([], tf.int64)} |
|
|
|
@@ -90,31 +90,14 @@ class TFRecordToMR: |
|
|
|
try: |
|
|
|
self.tf = import_module("tensorflow") # just used to convert tfrecord to mindrecord |
|
|
|
except ModuleNotFoundError: |
|
|
|
self.tf = None |
|
|
|
if not self.tf: |
|
|
|
raise Exception("Module tensorflow is not found, please use pip install it.") |
|
|
|
|
|
|
|
if self.tf.__version__ < SupportedTensorFlowVersion: |
|
|
|
raise Exception("Module tensorflow version must be greater or equal {}.".format(SupportedTensorFlowVersion)) |
|
|
|
|
|
|
|
if not isinstance(source, str): |
|
|
|
raise ValueError("Parameter source must be string.") |
|
|
|
check_filename(source) |
|
|
|
|
|
|
|
if not isinstance(destination, str): |
|
|
|
raise ValueError("Parameter destination must be string.") |
|
|
|
check_filename(destination) |
|
|
|
|
|
|
|
self._check_input(source, destination, feature_dict) |
|
|
|
self.source = source |
|
|
|
self.destination = destination |
|
|
|
|
|
|
|
if feature_dict is None or not isinstance(feature_dict, dict): |
|
|
|
raise ValueError("Parameter feature_dict is None or not dict.") |
|
|
|
|
|
|
|
for key, val in feature_dict.items(): |
|
|
|
if not isinstance(val, self.tf.io.FixedLenFeature): |
|
|
|
raise ValueError("Parameter feature_dict: {} only support FixedLenFeature.".format(feature_dict)) |
|
|
|
|
|
|
|
self.feature_dict = feature_dict |
|
|
|
|
|
|
|
bytes_fields_list = [] |
|
|
|
@@ -162,6 +145,23 @@ class TFRecordToMR: |
|
|
|
mindrecord_schema[_cast_name(key)] = {"type": self._cast_type(val.dtype), "shape": [val.shape[0]]} |
|
|
|
self.mindrecord_schema = mindrecord_schema |
|
|
|
|
|
|
|
def _check_input(self, source, destination, feature_dict): |
|
|
|
"""Validation check for inputs of init method""" |
|
|
|
if not isinstance(source, str): |
|
|
|
raise ValueError("Parameter source must be string.") |
|
|
|
check_filename(source) |
|
|
|
|
|
|
|
if not isinstance(destination, str): |
|
|
|
raise ValueError("Parameter destination must be string.") |
|
|
|
check_filename(destination) |
|
|
|
|
|
|
|
if not isinstance(feature_dict, dict): |
|
|
|
raise ValueError("Parameter feature_dict is None or not dict.") |
|
|
|
|
|
|
|
for _, val in feature_dict.items(): |
|
|
|
if not isinstance(val, self.tf.io.FixedLenFeature): |
|
|
|
raise ValueError("Parameter feature_dict: {} only support FixedLenFeature.".format(feature_dict)) |
|
|
|
|
|
|
|
def _parse_record(self, example): |
|
|
|
"""Returns features for a single example""" |
|
|
|
features = self.tf.io.parse_single_example(example, features=self.feature_dict) |
|
|
|
@@ -206,6 +206,9 @@ class TFRecordToMR: |
|
|
|
""" |
|
|
|
Yield a dict with key to be fields in schema, and value to be data. |
|
|
|
This function is for old version tensorflow whose version number < 2.1.0 |
|
|
|
|
|
|
|
Yields: |
|
|
|
dict, data dictionary whose keys are the same as columns. |
|
|
|
""" |
|
|
|
dataset = self.tf.data.TFRecordDataset(self.source) |
|
|
|
dataset = dataset.map(self._parse_record) |
|
|
|
@@ -235,7 +238,12 @@ class TFRecordToMR: |
|
|
|
raise ValueError("TFRecord feature_dict parameter error.") |
|
|
|
|
|
|
|
def tfrecord_iterator(self): |
|
|
|
"""Yield a dictionary whose keys are fields in schema.""" |
|
|
|
""" |
|
|
|
Yield a dictionary whose keys are fields in schema. |
|
|
|
|
|
|
|
Yields: |
|
|
|
dict, data dictionary whose keys are the same as columns. |
|
|
|
""" |
|
|
|
dataset = self.tf.data.TFRecordDataset(self.source) |
|
|
|
dataset = dataset.map(self._parse_record) |
|
|
|
iterator = dataset.__iter__() |
|
|
|
@@ -265,7 +273,7 @@ class TFRecordToMR: |
|
|
|
Execute transformation from TFRecord to MindRecord. |
|
|
|
|
|
|
|
Returns: |
|
|
|
MSRStatus, whether TFRecord is successfuly transformed to MindRecord. |
|
|
|
MSRStatus, whether TFRecord is successfully transformed to MindRecord. |
|
|
|
""" |
|
|
|
writer = FileWriter(self.destination) |
|
|
|
logger.info("Transformed MindRecord schema is: {}, TFRecord feature dict is: {}" |
|
|
|
|