| @@ -23,46 +23,11 @@ from mindspore import log as logger | |||
| from ..filewriter import FileWriter | |||
| from ..shardutils import check_filename, ExceptionThread | |||
| try: | |||
| tf = import_module("tensorflow") # just used to convert tfrecord to mindrecord | |||
| except ModuleNotFoundError: | |||
| tf = None | |||
| __all__ = ['TFRecordToMR'] | |||
| SupportedTensorFlowVersion = '1.13.0-rc1' | |||
| def _cast_type(value): | |||
| """ | |||
| Cast complex data type to basic datatype for MindRecord to recognize. | |||
| Args: | |||
| value: the TFRecord data type | |||
| Returns: | |||
| str, which is MindRecord field type. | |||
| """ | |||
| tf_type_to_mr_type = {tf.string: "string", | |||
| tf.int8: "int32", | |||
| tf.int16: "int32", | |||
| tf.int32: "int32", | |||
| tf.int64: "int64", | |||
| tf.uint8: "int32", | |||
| tf.uint16: "int32", | |||
| tf.uint32: "int64", | |||
| tf.uint64: "int64", | |||
| tf.float16: "float32", | |||
| tf.float32: "float32", | |||
| tf.float64: "float64", | |||
| tf.double: "float64", | |||
| tf.bool: "int32"} | |||
| unsupport_tf_type_to_mr_type = {tf.complex64: "None", | |||
| tf.complex128: "None"} | |||
| if value in tf_type_to_mr_type: | |||
| return tf_type_to_mr_type[value] | |||
| raise ValueError("Type " + value + " is not supported in MindRecord.") | |||
| def _cast_string_type_to_np_type(value): | |||
| """Cast string type like: int32/int64/float32/float64 to np.int32/np.int64/np.float32/np.float64""" | |||
| @@ -76,6 +41,7 @@ def _cast_string_type_to_np_type(value): | |||
| raise ValueError("Type " + value + " is not supported cast to numpy type in MindRecord.") | |||
| def _cast_name(key): | |||
| """ | |||
| Cast schema names which containing special characters to valid names. | |||
| @@ -97,6 +63,7 @@ def _cast_name(key): | |||
| casted_key = ''.join(new_key) | |||
| return casted_key | |||
| class TFRecordToMR: | |||
| """ | |||
| A class to transform from TFRecord to MindRecord. | |||
| @@ -120,10 +87,14 @@ class TFRecordToMR: | |||
| Exception: when tensorflow module is not found or version is not correct. | |||
| """ | |||
| def __init__(self, source, destination, feature_dict, bytes_fields=None): | |||
| if not tf: | |||
| try: | |||
| self.tf = import_module("tensorflow") # just used to convert tfrecord to mindrecord | |||
| except ModuleNotFoundError: | |||
| self.tf = None | |||
| if not self.tf: | |||
| raise Exception("Module tensorflow is not found, please use pip install it.") | |||
| if tf.__version__ < SupportedTensorFlowVersion: | |||
| if self.tf.__version__ < SupportedTensorFlowVersion: | |||
| raise Exception("Module tensorflow version must be greater or equal {}.".format(SupportedTensorFlowVersion)) | |||
| if not isinstance(source, str): | |||
| @@ -141,7 +112,7 @@ class TFRecordToMR: | |||
| raise ValueError("Parameter feature_dict is None or not dict.") | |||
| for key, val in feature_dict.items(): | |||
| if not isinstance(val, tf.io.FixedLenFeature): | |||
| if not isinstance(val, self.tf.io.FixedLenFeature): | |||
| raise ValueError("Parameter feature_dict: {} only support FixedLenFeature.".format(feature_dict)) | |||
| self.feature_dict = feature_dict | |||
| @@ -161,7 +132,7 @@ class TFRecordToMR: | |||
| if not isinstance(self.feature_dict[item].shape, list): | |||
| raise ValueError("Parameter feature_dict[{}].shape should be a list.".format(item)) | |||
| if self.feature_dict[item].dtype != tf.string: | |||
| if self.feature_dict[item].dtype != self.tf.string: | |||
| raise ValueError("Parameter bytes_field: {} should be tf.string in feature_dict.".format(item)) | |||
| casted_bytes_field = _cast_name(item) | |||
| @@ -178,34 +149,34 @@ class TFRecordToMR: | |||
| if _cast_name(key) in self.bytes_fields_list: | |||
| mindrecord_schema[_cast_name(key)] = {"type": "bytes"} | |||
| else: | |||
| mindrecord_schema[_cast_name(key)] = {"type": _cast_type(val.dtype)} | |||
| mindrecord_schema[_cast_name(key)] = {"type": self._cast_type(val.dtype)} | |||
| else: | |||
| if len(val.shape) != 1: | |||
| raise ValueError("Parameter len(feature_dict[{}].shape) should be 1.") | |||
| if val.shape[0] < 1: | |||
| raise ValueError("Parameter feature_dict[{}].shape[0] should > 0".format(key)) | |||
| if val.dtype == tf.string: | |||
| raise ValueError("Parameter feautre_dict[{}].dtype is tf.string which shape[0] " \ | |||
| "is not None. It is not supported.".format(key)) | |||
| if val.dtype == self.tf.string: | |||
| raise ValueError("Parameter feature_dict[{}].dtype is tf.string which shape[0] " | |||
| "is not None. It is not supported.".format(key)) | |||
| self.list_set.add(_cast_name(key)) | |||
| mindrecord_schema[_cast_name(key)] = {"type": _cast_type(val.dtype), "shape": [val.shape[0]]} | |||
| mindrecord_schema[_cast_name(key)] = {"type": self._cast_type(val.dtype), "shape": [val.shape[0]]} | |||
| self.mindrecord_schema = mindrecord_schema | |||
| def _parse_record(self, example): | |||
| """Returns features for a single example""" | |||
| features = tf.io.parse_single_example(example, features=self.feature_dict) | |||
| features = self.tf.io.parse_single_example(example, features=self.feature_dict) | |||
| return features | |||
| def _get_data_when_scalar_field(self, ms_dict, cast_key, key, val): | |||
| """put data in ms_dict when field type is string""" | |||
| if isinstance(val.numpy(), (np.ndarray, list)): | |||
| raise ValueError("The response key: {}, value: {} from TFRecord should be a scalar.".format(key, val)) | |||
| if self.feature_dict[key].dtype == tf.string: | |||
| if self.feature_dict[key].dtype == self.tf.string: | |||
| if cast_key in self.bytes_fields_list: | |||
| ms_dict[cast_key] = val.numpy() | |||
| else: | |||
| ms_dict[cast_key] = str(val.numpy(), encoding="utf-8") | |||
| elif _cast_type(self.feature_dict[key].dtype).startswith("int"): | |||
| elif self._cast_type(self.feature_dict[key].dtype).startswith("int"): | |||
| ms_dict[cast_key] = int(val.numpy()) | |||
| else: | |||
| ms_dict[cast_key] = float(val.numpy()) | |||
| @@ -218,7 +189,7 @@ class TFRecordToMR: | |||
| if isinstance(val, (bytes, str)): | |||
| if isinstance(val, (np.ndarray, list)): | |||
| raise ValueError("The response key: {}, value: {} from TFRecord should be a scalar.".format(key, val)) | |||
| if self.feature_dict[key].dtype == tf.string: | |||
| if self.feature_dict[key].dtype == self.tf.string: | |||
| if cast_key in self.bytes_fields_list: | |||
| ms_dict[cast_key] = val | |||
| else: | |||
| @@ -226,7 +197,7 @@ class TFRecordToMR: | |||
| else: | |||
| ms_dict[cast_key] = val | |||
| else: | |||
| if _cast_type(self.feature_dict[key].dtype).startswith("int"): | |||
| if self._cast_type(self.feature_dict[key].dtype).startswith("int"): | |||
| ms_dict[cast_key] = int(val) | |||
| else: | |||
| ms_dict[cast_key] = float(val) | |||
| @@ -236,10 +207,10 @@ class TFRecordToMR: | |||
| Yield a dict with key to be fields in schema, and value to be data. | |||
| This function is for old version tensorflow whose version number < 2.1.0 | |||
| """ | |||
| dataset = tf.data.TFRecordDataset(self.source) | |||
| dataset = self.tf.data.TFRecordDataset(self.source) | |||
| dataset = dataset.map(self._parse_record) | |||
| iterator = dataset.make_one_shot_iterator() | |||
| with tf.Session() as sess: | |||
| with self.tf.Session() as sess: | |||
| while True: | |||
| try: | |||
| ms_dict = {} | |||
| @@ -258,14 +229,14 @@ class TFRecordToMR: | |||
| ms_dict[cast_key] = \ | |||
| np.asarray(val, _cast_string_type_to_np_type(self.mindrecord_schema[cast_key]["type"])) | |||
| yield ms_dict | |||
| except tf.errors.OutOfRangeError: | |||
| except self.tf.errors.OutOfRangeError: | |||
| break | |||
| except tf.errors.InvalidArgumentError: | |||
| except self.tf.errors.InvalidArgumentError: | |||
| raise ValueError("TFRecord feature_dict parameter error.") | |||
| def tfrecord_iterator(self): | |||
| """Yield a dictionary whose keys are fields in schema.""" | |||
| dataset = tf.data.TFRecordDataset(self.source) | |||
| dataset = self.tf.data.TFRecordDataset(self.source) | |||
| dataset = dataset.map(self._parse_record) | |||
| iterator = dataset.__iter__() | |||
| while True: | |||
| @@ -278,15 +249,15 @@ class TFRecordToMR: | |||
| self._get_data_when_scalar_field(ms_dict, cast_key, key, val) | |||
| else: | |||
| if not isinstance(val.numpy(), np.ndarray) and not isinstance(val.numpy(), list): | |||
| raise ValueError("The response key: {}, value: {} from TFRecord should be a ndarray or " \ | |||
| raise ValueError("The response key: {}, value: {} from TFRecord should be a ndarray or " | |||
| "list.".format(key, val)) | |||
| # list set | |||
| ms_dict[cast_key] = \ | |||
| np.asarray(val, _cast_string_type_to_np_type(self.mindrecord_schema[cast_key]["type"])) | |||
| yield ms_dict | |||
| except tf.errors.OutOfRangeError: | |||
| except self.tf.errors.OutOfRangeError: | |||
| break | |||
| except tf.errors.InvalidArgumentError: | |||
| except self.tf.errors.InvalidArgumentError: | |||
| raise ValueError("TFRecord feature_dict parameter error.") | |||
| def run(self): | |||
| @@ -301,7 +272,7 @@ class TFRecordToMR: | |||
| .format(self.mindrecord_schema, self.feature_dict)) | |||
| writer.add_schema(self.mindrecord_schema, "TFRecord to MindRecord") | |||
| if tf.__version__ < '2.0.0': | |||
| if self.tf.__version__ < '2.0.0': | |||
| tf_iter = self.tfrecord_iterator_oldversion() | |||
| else: | |||
| tf_iter = self.tfrecord_iterator() | |||
| @@ -331,3 +302,35 @@ class TFRecordToMR: | |||
| if t.exitcode != 0: | |||
| raise t.exception | |||
| return t.res | |||
| def _cast_type(self, value): | |||
| """ | |||
| Cast complex data type to basic datatype for MindRecord to recognize. | |||
| Args: | |||
| value: the TFRecord data type | |||
| Returns: | |||
| str, which is MindRecord field type. | |||
| """ | |||
| tf_type_to_mr_type = {self.tf.string: "string", | |||
| self.tf.int8: "int32", | |||
| self.tf.int16: "int32", | |||
| self.tf.int32: "int32", | |||
| self.tf.int64: "int64", | |||
| self.tf.uint8: "int32", | |||
| self.tf.uint16: "int32", | |||
| self.tf.uint32: "int64", | |||
| self.tf.uint64: "int64", | |||
| self.tf.float16: "float32", | |||
| self.tf.float32: "float32", | |||
| self.tf.float64: "float64", | |||
| self.tf.double: "float64", | |||
| self.tf.bool: "int32"} | |||
| unsupport_tf_type_to_mr_type = {self.tf.complex64: "None", | |||
| self.tf.complex128: "None"} | |||
| if value in tf_type_to_mr_type: | |||
| return tf_type_to_mr_type[value] | |||
| raise ValueError("Type " + value + " is not supported in MindRecord.") | |||