|
|
|
@@ -105,21 +105,21 @@ class TFRecordToMR: |
|
|
|
source (str): the TFRecord file to be transformed. |
|
|
|
destination (str): the MindRecord file path to tranform into. |
|
|
|
feature_dict (dict): a dictionary than states the feature type, i.e. |
|
|
|
feature_dict = {"xxxx": tf.io.FixedLenFeature([], tf.string), |
|
|
|
feature_dict = {"xxxx": tf.io.FixedLenFeature([], tf.string), \ |
|
|
|
"yyyy": tf.io.FixedLenFeature([], tf.int64)} |
|
|
|
****** follow case which uses VarLenFeature not support ****** |
|
|
|
feature_dict = {"context": {"xxxx": tf.io.FixedLenFeature([], tf.string), |
|
|
|
"yyyy": tf.io.VarLenFeature(tf.int64)}, |
|
|
|
|
|
|
|
**Follow case which uses VarLenFeature not support** |
|
|
|
|
|
|
|
feature_dict = {"context": {"xxxx": tf.io.FixedLenFeature([], tf.string), \ |
|
|
|
"yyyy": tf.io.VarLenFeature(tf.int64)}, \ |
|
|
|
"sequence": {"zzzz": tf.io.FixedLenSequenceFeature([], tf.float32)}} |
|
|
|
bytes_fields (list): the bytes fields which are in feature_dict. |
|
|
|
|
|
|
|
Rasies: |
|
|
|
ValueError, when: |
|
|
|
1) parameter TFRecord is not string. |
|
|
|
2) parameter MindRecord is not string. |
|
|
|
3) feature_dict is not FixedLenFeature. |
|
|
|
4) parameter bytes_field is not list(str) or not in feature_dict |
|
|
|
Exception, when tensorflow module not found or version is not correct. |
|
|
|
ValueError: the following condition will cause ValueError, 1) parameter TFRecord is not string, 2) parameter |
|
|
|
MindRecord is not string, 3) feature_dict is not FixedLenFeature, 4) parameter bytes_field is not list(str) |
|
|
|
or not in feature_dict. |
|
|
|
Exception: when tensorflow module not found or version is not correct. |
|
|
|
""" |
|
|
|
def __init__(self, source, destination, feature_dict, bytes_fields=None): |
|
|
|
if not tf: |
|
|
|
|