|
|
@@ -30,7 +30,7 @@ except ModuleNotFoundError: |
|
|
|
|
|
|
|
|
__all__ = ['TFRecordToMR'] |
|
|
__all__ = ['TFRecordToMR'] |
|
|
|
|
|
|
|
|
SupportedTensorFlowVersion = '2.1.0' |
|
|
|
|
|
|
|
|
SupportedTensorFlowVersion = '1.13.0-rc1' |
|
|
|
|
|
|
|
|
def _cast_type(value): |
|
|
def _cast_type(value): |
|
|
""" |
|
|
""" |
|
|
@@ -210,30 +210,84 @@ class TFRecordToMR: |
|
|
else: |
|
|
else: |
|
|
ms_dict[cast_key] = float(val.numpy()) |
|
|
ms_dict[cast_key] = float(val.numpy()) |
|
|
|
|
|
|
|
|
|
|
|
def _get_data_when_scalar_field_oldversion(self, ms_dict, cast_key, key, val): |
|
|
|
|
|
""" |
|
|
|
|
|
put data in ms_dict when field type is string |
|
|
|
|
|
However, we have to make change due to the different structure of old version |
|
|
|
|
|
""" |
|
|
|
|
|
if isinstance(val, (bytes, str)): |
|
|
|
|
|
if isinstance(val, (np.ndarray, list)): |
|
|
|
|
|
raise ValueError("The response key: {}, value: {} from TFRecord should be a scalar.".format(key, val)) |
|
|
|
|
|
if self.feature_dict[key].dtype == tf.string: |
|
|
|
|
|
if cast_key in self.bytes_fields_list: |
|
|
|
|
|
ms_dict[cast_key] = val |
|
|
|
|
|
else: |
|
|
|
|
|
ms_dict[cast_key] = val.decode("utf-8") |
|
|
|
|
|
else: |
|
|
|
|
|
ms_dict[cast_key] = val |
|
|
|
|
|
else: |
|
|
|
|
|
if _cast_type(self.feature_dict[key].dtype).startswith("int"): |
|
|
|
|
|
ms_dict[cast_key] = int(val) |
|
|
|
|
|
else: |
|
|
|
|
|
ms_dict[cast_key] = float(val) |
|
|
|
|
|
|
|
|
|
|
|
def tfrecord_iterator_oldversion(self): |
|
|
|
|
|
""" |
|
|
|
|
|
Yield a dict with key to be fields in schema, and value to be data. |
|
|
|
|
|
This function is for old version tensorflow whose version number < 2.1.0 |
|
|
|
|
|
""" |
|
|
|
|
|
dataset = tf.data.TFRecordDataset(self.source) |
|
|
|
|
|
dataset = dataset.map(self._parse_record) |
|
|
|
|
|
iterator = dataset.make_one_shot_iterator() |
|
|
|
|
|
with tf.Session() as sess: |
|
|
|
|
|
while True: |
|
|
|
|
|
try: |
|
|
|
|
|
ms_dict = {} |
|
|
|
|
|
sample = iterator.get_next() |
|
|
|
|
|
sample = sess.run(sample) |
|
|
|
|
|
for key, val in sample.items(): |
|
|
|
|
|
cast_key = _cast_name(key) |
|
|
|
|
|
if cast_key in self.scalar_set: |
|
|
|
|
|
self._get_data_when_scalar_field_oldversion(ms_dict, cast_key, key, val) |
|
|
|
|
|
else: |
|
|
|
|
|
if not isinstance(val, np.ndarray) and not isinstance(val, list): |
|
|
|
|
|
raise ValueError("The response key: {}, value: {} from " |
|
|
|
|
|
"TFRecord should be a ndarray or " |
|
|
|
|
|
"list.".format(key, val)) |
|
|
|
|
|
# list set |
|
|
|
|
|
ms_dict[cast_key] = \ |
|
|
|
|
|
np.asarray(val, _cast_string_type_to_np_type(self.mindrecord_schema[cast_key]["type"])) |
|
|
|
|
|
yield ms_dict |
|
|
|
|
|
except tf.errors.OutOfRangeError: |
|
|
|
|
|
break |
|
|
|
|
|
except tf.errors.InvalidArgumentError: |
|
|
|
|
|
raise ValueError("TFRecord feature_dict parameter error.") |
|
|
|
|
|
|
|
|
def tfrecord_iterator(self): |
|
|
def tfrecord_iterator(self): |
|
|
"""Yield a dict with key to be fields in schema, and value to be data.""" |
|
|
"""Yield a dict with key to be fields in schema, and value to be data.""" |
|
|
dataset = tf.data.TFRecordDataset(self.source) |
|
|
dataset = tf.data.TFRecordDataset(self.source) |
|
|
dataset = dataset.map(self._parse_record) |
|
|
dataset = dataset.map(self._parse_record) |
|
|
iterator = dataset.__iter__() |
|
|
iterator = dataset.__iter__() |
|
|
index_id = 0 |
|
|
|
|
|
try: |
|
|
|
|
|
for features in iterator: |
|
|
|
|
|
|
|
|
while True: |
|
|
|
|
|
try: |
|
|
ms_dict = {} |
|
|
ms_dict = {} |
|
|
index_id = index_id + 1 |
|
|
|
|
|
for key, val in features.items(): |
|
|
|
|
|
|
|
|
sample = iterator.get_next() |
|
|
|
|
|
for key, val in sample.items(): |
|
|
cast_key = _cast_name(key) |
|
|
cast_key = _cast_name(key) |
|
|
if cast_key in self.scalar_set: |
|
|
if cast_key in self.scalar_set: |
|
|
self._get_data_when_scalar_field(ms_dict, cast_key, key, val) |
|
|
self._get_data_when_scalar_field(ms_dict, cast_key, key, val) |
|
|
else: |
|
|
else: |
|
|
if not isinstance(val.numpy(), np.ndarray) and not isinstance(val.numpy(), list): |
|
|
if not isinstance(val.numpy(), np.ndarray) and not isinstance(val.numpy(), list): |
|
|
raise ValueError("The response key: {}, value: {} from TFRecord should be a ndarray or " \ |
|
|
raise ValueError("The response key: {}, value: {} from TFRecord should be a ndarray or " \ |
|
|
"list.".format(key, val)) |
|
|
|
|
|
|
|
|
"list.".format(key, val)) |
|
|
# list set |
|
|
# list set |
|
|
ms_dict[cast_key] = \ |
|
|
ms_dict[cast_key] = \ |
|
|
np.asarray(val, _cast_string_type_to_np_type(self.mindrecord_schema[cast_key]["type"])) |
|
|
np.asarray(val, _cast_string_type_to_np_type(self.mindrecord_schema[cast_key]["type"])) |
|
|
yield ms_dict |
|
|
yield ms_dict |
|
|
except tf.errors.InvalidArgumentError: |
|
|
|
|
|
raise ValueError("TFRecord feature_dict parameter error.") |
|
|
|
|
|
|
|
|
except tf.errors.OutOfRangeError: |
|
|
|
|
|
break |
|
|
|
|
|
except tf.errors.InvalidArgumentError: |
|
|
|
|
|
raise ValueError("TFRecord feature_dict parameter error.") |
|
|
|
|
|
|
|
|
def run(self): |
|
|
def run(self): |
|
|
""" |
|
|
""" |
|
|
@@ -247,10 +301,11 @@ class TFRecordToMR: |
|
|
.format(self.mindrecord_schema, self.feature_dict)) |
|
|
.format(self.mindrecord_schema, self.feature_dict)) |
|
|
|
|
|
|
|
|
writer.add_schema(self.mindrecord_schema, "TFRecord to MindRecord") |
|
|
writer.add_schema(self.mindrecord_schema, "TFRecord to MindRecord") |
|
|
|
|
|
|
|
|
tf_iter = self.tfrecord_iterator() |
|
|
|
|
|
|
|
|
if tf.__version__ < '2.0.0': |
|
|
|
|
|
tf_iter = self.tfrecord_iterator_oldversion() |
|
|
|
|
|
else: |
|
|
|
|
|
tf_iter = self.tfrecord_iterator() |
|
|
batch_size = 256 |
|
|
batch_size = 256 |
|
|
|
|
|
|
|
|
transform_count = 0 |
|
|
transform_count = 0 |
|
|
while True: |
|
|
while True: |
|
|
data_list = [] |
|
|
data_list = [] |
|
|
|