| @@ -23,46 +23,11 @@ from mindspore import log as logger | |||||
| from ..filewriter import FileWriter | from ..filewriter import FileWriter | ||||
| from ..shardutils import check_filename, ExceptionThread | from ..shardutils import check_filename, ExceptionThread | ||||
| try: | |||||
| tf = import_module("tensorflow") # just used to convert tfrecord to mindrecord | |||||
| except ModuleNotFoundError: | |||||
| tf = None | |||||
| __all__ = ['TFRecordToMR'] | __all__ = ['TFRecordToMR'] | ||||
| SupportedTensorFlowVersion = '1.13.0-rc1' | SupportedTensorFlowVersion = '1.13.0-rc1' | ||||
| def _cast_type(value): | |||||
| """ | |||||
| Cast complex data type to basic datatype for MindRecord to recognize. | |||||
| Args: | |||||
| value: the TFRecord data type | |||||
| Returns: | |||||
| str, which is MindRecord field type. | |||||
| """ | |||||
| tf_type_to_mr_type = {tf.string: "string", | |||||
| tf.int8: "int32", | |||||
| tf.int16: "int32", | |||||
| tf.int32: "int32", | |||||
| tf.int64: "int64", | |||||
| tf.uint8: "int32", | |||||
| tf.uint16: "int32", | |||||
| tf.uint32: "int64", | |||||
| tf.uint64: "int64", | |||||
| tf.float16: "float32", | |||||
| tf.float32: "float32", | |||||
| tf.float64: "float64", | |||||
| tf.double: "float64", | |||||
| tf.bool: "int32"} | |||||
| unsupport_tf_type_to_mr_type = {tf.complex64: "None", | |||||
| tf.complex128: "None"} | |||||
| if value in tf_type_to_mr_type: | |||||
| return tf_type_to_mr_type[value] | |||||
| raise ValueError("Type " + value + " is not supported in MindRecord.") | |||||
| def _cast_string_type_to_np_type(value): | def _cast_string_type_to_np_type(value): | ||||
| """Cast string type like: int32/int64/float32/float64 to np.int32/np.int64/np.float32/np.float64""" | """Cast string type like: int32/int64/float32/float64 to np.int32/np.int64/np.float32/np.float64""" | ||||
| @@ -76,6 +41,7 @@ def _cast_string_type_to_np_type(value): | |||||
| raise ValueError("Type " + value + " is not supported cast to numpy type in MindRecord.") | raise ValueError("Type " + value + " is not supported cast to numpy type in MindRecord.") | ||||
| def _cast_name(key): | def _cast_name(key): | ||||
| """ | """ | ||||
| Cast schema names which containing special characters to valid names. | Cast schema names which containing special characters to valid names. | ||||
| @@ -97,6 +63,7 @@ def _cast_name(key): | |||||
| casted_key = ''.join(new_key) | casted_key = ''.join(new_key) | ||||
| return casted_key | return casted_key | ||||
| class TFRecordToMR: | class TFRecordToMR: | ||||
| """ | """ | ||||
| A class to transform from TFRecord to MindRecord. | A class to transform from TFRecord to MindRecord. | ||||
| @@ -120,10 +87,14 @@ class TFRecordToMR: | |||||
| Exception: when tensorflow module is not found or version is not correct. | Exception: when tensorflow module is not found or version is not correct. | ||||
| """ | """ | ||||
| def __init__(self, source, destination, feature_dict, bytes_fields=None): | def __init__(self, source, destination, feature_dict, bytes_fields=None): | ||||
| if not tf: | |||||
| try: | |||||
| self.tf = import_module("tensorflow") # just used to convert tfrecord to mindrecord | |||||
| except ModuleNotFoundError: | |||||
| self.tf = None | |||||
| if not self.tf: | |||||
| raise Exception("Module tensorflow is not found, please use pip install it.") | raise Exception("Module tensorflow is not found, please use pip install it.") | ||||
| if tf.__version__ < SupportedTensorFlowVersion: | |||||
| if self.tf.__version__ < SupportedTensorFlowVersion: | |||||
| raise Exception("Module tensorflow version must be greater or equal {}.".format(SupportedTensorFlowVersion)) | raise Exception("Module tensorflow version must be greater or equal {}.".format(SupportedTensorFlowVersion)) | ||||
| if not isinstance(source, str): | if not isinstance(source, str): | ||||
| @@ -141,7 +112,7 @@ class TFRecordToMR: | |||||
| raise ValueError("Parameter feature_dict is None or not dict.") | raise ValueError("Parameter feature_dict is None or not dict.") | ||||
| for key, val in feature_dict.items(): | for key, val in feature_dict.items(): | ||||
| if not isinstance(val, tf.io.FixedLenFeature): | |||||
| if not isinstance(val, self.tf.io.FixedLenFeature): | |||||
| raise ValueError("Parameter feature_dict: {} only support FixedLenFeature.".format(feature_dict)) | raise ValueError("Parameter feature_dict: {} only support FixedLenFeature.".format(feature_dict)) | ||||
| self.feature_dict = feature_dict | self.feature_dict = feature_dict | ||||
| @@ -161,7 +132,7 @@ class TFRecordToMR: | |||||
| if not isinstance(self.feature_dict[item].shape, list): | if not isinstance(self.feature_dict[item].shape, list): | ||||
| raise ValueError("Parameter feature_dict[{}].shape should be a list.".format(item)) | raise ValueError("Parameter feature_dict[{}].shape should be a list.".format(item)) | ||||
| if self.feature_dict[item].dtype != tf.string: | |||||
| if self.feature_dict[item].dtype != self.tf.string: | |||||
| raise ValueError("Parameter bytes_field: {} should be tf.string in feature_dict.".format(item)) | raise ValueError("Parameter bytes_field: {} should be tf.string in feature_dict.".format(item)) | ||||
| casted_bytes_field = _cast_name(item) | casted_bytes_field = _cast_name(item) | ||||
| @@ -178,34 +149,34 @@ class TFRecordToMR: | |||||
| if _cast_name(key) in self.bytes_fields_list: | if _cast_name(key) in self.bytes_fields_list: | ||||
| mindrecord_schema[_cast_name(key)] = {"type": "bytes"} | mindrecord_schema[_cast_name(key)] = {"type": "bytes"} | ||||
| else: | else: | ||||
| mindrecord_schema[_cast_name(key)] = {"type": _cast_type(val.dtype)} | |||||
| mindrecord_schema[_cast_name(key)] = {"type": self._cast_type(val.dtype)} | |||||
| else: | else: | ||||
| if len(val.shape) != 1: | if len(val.shape) != 1: | ||||
| raise ValueError("Parameter len(feature_dict[{}].shape) should be 1.") | raise ValueError("Parameter len(feature_dict[{}].shape) should be 1.") | ||||
| if val.shape[0] < 1: | if val.shape[0] < 1: | ||||
| raise ValueError("Parameter feature_dict[{}].shape[0] should > 0".format(key)) | raise ValueError("Parameter feature_dict[{}].shape[0] should > 0".format(key)) | ||||
| if val.dtype == tf.string: | |||||
| raise ValueError("Parameter feautre_dict[{}].dtype is tf.string which shape[0] " \ | |||||
| "is not None. It is not supported.".format(key)) | |||||
| if val.dtype == self.tf.string: | |||||
| raise ValueError("Parameter feature_dict[{}].dtype is tf.string which shape[0] " | |||||
| "is not None. It is not supported.".format(key)) | |||||
| self.list_set.add(_cast_name(key)) | self.list_set.add(_cast_name(key)) | ||||
| mindrecord_schema[_cast_name(key)] = {"type": _cast_type(val.dtype), "shape": [val.shape[0]]} | |||||
| mindrecord_schema[_cast_name(key)] = {"type": self._cast_type(val.dtype), "shape": [val.shape[0]]} | |||||
| self.mindrecord_schema = mindrecord_schema | self.mindrecord_schema = mindrecord_schema | ||||
| def _parse_record(self, example): | def _parse_record(self, example): | ||||
| """Returns features for a single example""" | """Returns features for a single example""" | ||||
| features = tf.io.parse_single_example(example, features=self.feature_dict) | |||||
| features = self.tf.io.parse_single_example(example, features=self.feature_dict) | |||||
| return features | return features | ||||
| def _get_data_when_scalar_field(self, ms_dict, cast_key, key, val): | def _get_data_when_scalar_field(self, ms_dict, cast_key, key, val): | ||||
| """put data in ms_dict when field type is string""" | """put data in ms_dict when field type is string""" | ||||
| if isinstance(val.numpy(), (np.ndarray, list)): | if isinstance(val.numpy(), (np.ndarray, list)): | ||||
| raise ValueError("The response key: {}, value: {} from TFRecord should be a scalar.".format(key, val)) | raise ValueError("The response key: {}, value: {} from TFRecord should be a scalar.".format(key, val)) | ||||
| if self.feature_dict[key].dtype == tf.string: | |||||
| if self.feature_dict[key].dtype == self.tf.string: | |||||
| if cast_key in self.bytes_fields_list: | if cast_key in self.bytes_fields_list: | ||||
| ms_dict[cast_key] = val.numpy() | ms_dict[cast_key] = val.numpy() | ||||
| else: | else: | ||||
| ms_dict[cast_key] = str(val.numpy(), encoding="utf-8") | ms_dict[cast_key] = str(val.numpy(), encoding="utf-8") | ||||
| elif _cast_type(self.feature_dict[key].dtype).startswith("int"): | |||||
| elif self._cast_type(self.feature_dict[key].dtype).startswith("int"): | |||||
| ms_dict[cast_key] = int(val.numpy()) | ms_dict[cast_key] = int(val.numpy()) | ||||
| else: | else: | ||||
| ms_dict[cast_key] = float(val.numpy()) | ms_dict[cast_key] = float(val.numpy()) | ||||
| @@ -218,7 +189,7 @@ class TFRecordToMR: | |||||
| if isinstance(val, (bytes, str)): | if isinstance(val, (bytes, str)): | ||||
| if isinstance(val, (np.ndarray, list)): | if isinstance(val, (np.ndarray, list)): | ||||
| raise ValueError("The response key: {}, value: {} from TFRecord should be a scalar.".format(key, val)) | raise ValueError("The response key: {}, value: {} from TFRecord should be a scalar.".format(key, val)) | ||||
| if self.feature_dict[key].dtype == tf.string: | |||||
| if self.feature_dict[key].dtype == self.tf.string: | |||||
| if cast_key in self.bytes_fields_list: | if cast_key in self.bytes_fields_list: | ||||
| ms_dict[cast_key] = val | ms_dict[cast_key] = val | ||||
| else: | else: | ||||
| @@ -226,7 +197,7 @@ class TFRecordToMR: | |||||
| else: | else: | ||||
| ms_dict[cast_key] = val | ms_dict[cast_key] = val | ||||
| else: | else: | ||||
| if _cast_type(self.feature_dict[key].dtype).startswith("int"): | |||||
| if self._cast_type(self.feature_dict[key].dtype).startswith("int"): | |||||
| ms_dict[cast_key] = int(val) | ms_dict[cast_key] = int(val) | ||||
| else: | else: | ||||
| ms_dict[cast_key] = float(val) | ms_dict[cast_key] = float(val) | ||||
| @@ -236,10 +207,10 @@ class TFRecordToMR: | |||||
| Yield a dict with key to be fields in schema, and value to be data. | Yield a dict with key to be fields in schema, and value to be data. | ||||
| This function is for old version tensorflow whose version number < 2.1.0 | This function is for old version tensorflow whose version number < 2.1.0 | ||||
| """ | """ | ||||
| dataset = tf.data.TFRecordDataset(self.source) | |||||
| dataset = self.tf.data.TFRecordDataset(self.source) | |||||
| dataset = dataset.map(self._parse_record) | dataset = dataset.map(self._parse_record) | ||||
| iterator = dataset.make_one_shot_iterator() | iterator = dataset.make_one_shot_iterator() | ||||
| with tf.Session() as sess: | |||||
| with self.tf.Session() as sess: | |||||
| while True: | while True: | ||||
| try: | try: | ||||
| ms_dict = {} | ms_dict = {} | ||||
| @@ -258,14 +229,14 @@ class TFRecordToMR: | |||||
| ms_dict[cast_key] = \ | ms_dict[cast_key] = \ | ||||
| np.asarray(val, _cast_string_type_to_np_type(self.mindrecord_schema[cast_key]["type"])) | np.asarray(val, _cast_string_type_to_np_type(self.mindrecord_schema[cast_key]["type"])) | ||||
| yield ms_dict | yield ms_dict | ||||
| except tf.errors.OutOfRangeError: | |||||
| except self.tf.errors.OutOfRangeError: | |||||
| break | break | ||||
| except tf.errors.InvalidArgumentError: | |||||
| except self.tf.errors.InvalidArgumentError: | |||||
| raise ValueError("TFRecord feature_dict parameter error.") | raise ValueError("TFRecord feature_dict parameter error.") | ||||
| def tfrecord_iterator(self): | def tfrecord_iterator(self): | ||||
| """Yield a dictionary whose keys are fields in schema.""" | """Yield a dictionary whose keys are fields in schema.""" | ||||
| dataset = tf.data.TFRecordDataset(self.source) | |||||
| dataset = self.tf.data.TFRecordDataset(self.source) | |||||
| dataset = dataset.map(self._parse_record) | dataset = dataset.map(self._parse_record) | ||||
| iterator = dataset.__iter__() | iterator = dataset.__iter__() | ||||
| while True: | while True: | ||||
| @@ -278,15 +249,15 @@ class TFRecordToMR: | |||||
| self._get_data_when_scalar_field(ms_dict, cast_key, key, val) | self._get_data_when_scalar_field(ms_dict, cast_key, key, val) | ||||
| else: | else: | ||||
| if not isinstance(val.numpy(), np.ndarray) and not isinstance(val.numpy(), list): | if not isinstance(val.numpy(), np.ndarray) and not isinstance(val.numpy(), list): | ||||
| raise ValueError("The response key: {}, value: {} from TFRecord should be a ndarray or " \ | |||||
| raise ValueError("The response key: {}, value: {} from TFRecord should be a ndarray or " | |||||
| "list.".format(key, val)) | "list.".format(key, val)) | ||||
| # list set | # list set | ||||
| ms_dict[cast_key] = \ | ms_dict[cast_key] = \ | ||||
| np.asarray(val, _cast_string_type_to_np_type(self.mindrecord_schema[cast_key]["type"])) | np.asarray(val, _cast_string_type_to_np_type(self.mindrecord_schema[cast_key]["type"])) | ||||
| yield ms_dict | yield ms_dict | ||||
| except tf.errors.OutOfRangeError: | |||||
| except self.tf.errors.OutOfRangeError: | |||||
| break | break | ||||
| except tf.errors.InvalidArgumentError: | |||||
| except self.tf.errors.InvalidArgumentError: | |||||
| raise ValueError("TFRecord feature_dict parameter error.") | raise ValueError("TFRecord feature_dict parameter error.") | ||||
| def run(self): | def run(self): | ||||
| @@ -301,7 +272,7 @@ class TFRecordToMR: | |||||
| .format(self.mindrecord_schema, self.feature_dict)) | .format(self.mindrecord_schema, self.feature_dict)) | ||||
| writer.add_schema(self.mindrecord_schema, "TFRecord to MindRecord") | writer.add_schema(self.mindrecord_schema, "TFRecord to MindRecord") | ||||
| if tf.__version__ < '2.0.0': | |||||
| if self.tf.__version__ < '2.0.0': | |||||
| tf_iter = self.tfrecord_iterator_oldversion() | tf_iter = self.tfrecord_iterator_oldversion() | ||||
| else: | else: | ||||
| tf_iter = self.tfrecord_iterator() | tf_iter = self.tfrecord_iterator() | ||||
| @@ -331,3 +302,35 @@ class TFRecordToMR: | |||||
| if t.exitcode != 0: | if t.exitcode != 0: | ||||
| raise t.exception | raise t.exception | ||||
| return t.res | return t.res | ||||
| def _cast_type(self, value): | |||||
| """ | |||||
| Cast complex data type to basic datatype for MindRecord to recognize. | |||||
| Args: | |||||
| value: the TFRecord data type | |||||
| Returns: | |||||
| str, which is MindRecord field type. | |||||
| """ | |||||
| tf_type_to_mr_type = {self.tf.string: "string", | |||||
| self.tf.int8: "int32", | |||||
| self.tf.int16: "int32", | |||||
| self.tf.int32: "int32", | |||||
| self.tf.int64: "int64", | |||||
| self.tf.uint8: "int32", | |||||
| self.tf.uint16: "int32", | |||||
| self.tf.uint32: "int64", | |||||
| self.tf.uint64: "int64", | |||||
| self.tf.float16: "float32", | |||||
| self.tf.float32: "float32", | |||||
| self.tf.float64: "float64", | |||||
| self.tf.double: "float64", | |||||
| self.tf.bool: "int32"} | |||||
| unsupport_tf_type_to_mr_type = {self.tf.complex64: "None", | |||||
| self.tf.complex128: "None"} | |||||
| if value in tf_type_to_mr_type: | |||||
| return tf_type_to_mr_type[value] | |||||
| raise ValueError("Type " + value + " is not supported in MindRecord.") | |||||