| @@ -26,8 +26,7 @@ from .shardheader import ShardHeader | |||
| from .shardindexgenerator import ShardIndexGenerator | |||
| from .shardutils import MIN_SHARD_COUNT, MAX_SHARD_COUNT, VALID_ATTRIBUTES, VALID_ARRAY_ATTRIBUTES, \ | |||
| check_filename, VALUE_TYPE_MAP | |||
| from .common.exceptions import ParamValueError, ParamTypeError, MRMInvalidSchemaError, MRMDefineIndexError, \ | |||
| MRMValidateDataError | |||
| from .common.exceptions import ParamValueError, ParamTypeError, MRMInvalidSchemaError, MRMDefineIndexError | |||
| __all__ = ['FileWriter'] | |||
| @@ -201,52 +200,13 @@ class FileWriter: | |||
| raw_data.pop(i) | |||
| logger.warning(v) | |||
| def _verify_based_on_blob_fields(self, raw_data): | |||
| def write_raw_data(self, raw_data): | |||
| """ | |||
| Verify data according to blob fields which is sub set of schema's fields. | |||
| Raise exception if validation failed. | |||
| 1) allowed data type contains: "int32", "int64", "float32", "float64", "string", "bytes". | |||
| Args: | |||
| raw_data (list[dict]): List of raw data. | |||
| Raises: | |||
| MRMValidateDataError: If data does not match blob fields. | |||
| """ | |||
| schema_content = self._header.schema | |||
| for field in schema_content: | |||
| for i, v in enumerate(raw_data): | |||
| if field not in v: | |||
| raise MRMValidateDataError("for schema, {} th data is wrong: "\ | |||
| "there is not '{}' object in the raw data.".format(i, field)) | |||
| if field in self._header.blob_fields: | |||
| field_type = type(v[field]).__name__ | |||
| if field_type not in VALUE_TYPE_MAP: | |||
| raise MRMValidateDataError("for schema, {} th data is wrong: "\ | |||
| "data type for '{}' is not matched.".format(i, field)) | |||
| if schema_content[field]["type"] not in VALUE_TYPE_MAP[field_type]: | |||
| raise MRMValidateDataError("for schema, {} th data is wrong: "\ | |||
| "data type for '{}' is not matched.".format(i, field)) | |||
| if field_type == 'ndarray': | |||
| if 'shape' not in schema_content[field]: | |||
| raise MRMValidateDataError("for schema, {} th data is wrong: " \ | |||
| "data type for '{}' is not matched.".format(i, field)) | |||
| try: | |||
| # tuple or list | |||
| np.reshape(v[field], schema_content[field]['shape']) | |||
| except ValueError: | |||
| raise MRMValidateDataError("for schema, {} th data is wrong: " \ | |||
| "data type for '{}' is not matched.".format(i, field)) | |||
| def write_raw_data(self, raw_data, validate=True): | |||
| """ | |||
| Write raw data and generate sequential pair of MindRecord File. | |||
| Write raw data and generate sequential pair of MindRecord File and \ | |||
| validate data based on predefined schema by default. | |||
| Args: | |||
| raw_data (list[dict]): List of raw data. | |||
| validate (bool, optional): Validate data according schema if it equals to True, | |||
| or validate data according to blob fields (default=True). | |||
| Raises: | |||
| ParamTypeError: If index field is invalid. | |||
| @@ -264,11 +224,8 @@ class FileWriter: | |||
| for each_raw in raw_data: | |||
| if not isinstance(each_raw, dict): | |||
| raise ParamTypeError('raw_data item', 'dict') | |||
| if validate is True: | |||
| self._verify_based_on_schema(raw_data) | |||
| elif validate is False: | |||
| self._verify_based_on_blob_fields(raw_data) | |||
| return self._writer.write_raw_data(raw_data, validate) | |||
| self._verify_based_on_schema(raw_data) | |||
| return self._writer.write_raw_data(raw_data, True) | |||
| def set_header_size(self, header_size): | |||
| """ | |||