|
|
|
@@ -26,8 +26,7 @@ from .shardheader import ShardHeader |
|
|
|
from .shardindexgenerator import ShardIndexGenerator |
|
|
|
from .shardutils import MIN_SHARD_COUNT, MAX_SHARD_COUNT, VALID_ATTRIBUTES, VALID_ARRAY_ATTRIBUTES, \ |
|
|
|
check_filename, VALUE_TYPE_MAP |
|
|
|
from .common.exceptions import ParamValueError, ParamTypeError, MRMInvalidSchemaError, MRMDefineIndexError, \ |
|
|
|
MRMValidateDataError |
|
|
|
from .common.exceptions import ParamValueError, ParamTypeError, MRMInvalidSchemaError, MRMDefineIndexError |
|
|
|
|
|
|
|
__all__ = ['FileWriter'] |
|
|
|
|
|
|
|
@@ -201,52 +200,13 @@ class FileWriter: |
|
|
|
raw_data.pop(i) |
|
|
|
logger.warning(v) |
|
|
|
|
|
|
|
def _verify_based_on_blob_fields(self, raw_data): |
|
|
|
def write_raw_data(self, raw_data): |
|
|
|
""" |
|
|
|
Verify data according to blob fields which is sub set of schema's fields. |
|
|
|
|
|
|
|
Raise exception if validation failed. |
|
|
|
1) allowed data type contains: "int32", "int64", "float32", "float64", "string", "bytes". |
|
|
|
|
|
|
|
Args: |
|
|
|
raw_data (list[dict]): List of raw data. |
|
|
|
|
|
|
|
Raises: |
|
|
|
MRMValidateDataError: If data does not match blob fields. |
|
|
|
""" |
|
|
|
schema_content = self._header.schema |
|
|
|
for field in schema_content: |
|
|
|
for i, v in enumerate(raw_data): |
|
|
|
if field not in v: |
|
|
|
raise MRMValidateDataError("for schema, {} th data is wrong: "\ |
|
|
|
"there is not '{}' object in the raw data.".format(i, field)) |
|
|
|
if field in self._header.blob_fields: |
|
|
|
field_type = type(v[field]).__name__ |
|
|
|
if field_type not in VALUE_TYPE_MAP: |
|
|
|
raise MRMValidateDataError("for schema, {} th data is wrong: "\ |
|
|
|
"data type for '{}' is not matched.".format(i, field)) |
|
|
|
if schema_content[field]["type"] not in VALUE_TYPE_MAP[field_type]: |
|
|
|
raise MRMValidateDataError("for schema, {} th data is wrong: "\ |
|
|
|
"data type for '{}' is not matched.".format(i, field)) |
|
|
|
if field_type == 'ndarray': |
|
|
|
if 'shape' not in schema_content[field]: |
|
|
|
raise MRMValidateDataError("for schema, {} th data is wrong: " \ |
|
|
|
"data type for '{}' is not matched.".format(i, field)) |
|
|
|
try: |
|
|
|
# tuple or list |
|
|
|
np.reshape(v[field], schema_content[field]['shape']) |
|
|
|
except ValueError: |
|
|
|
raise MRMValidateDataError("for schema, {} th data is wrong: " \ |
|
|
|
"data type for '{}' is not matched.".format(i, field)) |
|
|
|
|
|
|
|
def write_raw_data(self, raw_data, validate=True): |
|
|
|
""" |
|
|
|
Write raw data and generate sequential pair of MindRecord File. |
|
|
|
Write raw data and generate sequential pair of MindRecord File and \ |
|
|
|
validate data based on predefined schema by default. |
|
|
|
|
|
|
|
Args: |
|
|
|
raw_data (list[dict]): List of raw data. |
|
|
|
validate (bool, optional): Validate data according schema if it equals to True, |
|
|
|
or validate data according to blob fields (default=True). |
|
|
|
|
|
|
|
Raises: |
|
|
|
ParamTypeError: If index field is invalid. |
|
|
|
@@ -264,11 +224,8 @@ class FileWriter: |
|
|
|
for each_raw in raw_data: |
|
|
|
if not isinstance(each_raw, dict): |
|
|
|
raise ParamTypeError('raw_data item', 'dict') |
|
|
|
if validate is True: |
|
|
|
self._verify_based_on_schema(raw_data) |
|
|
|
elif validate is False: |
|
|
|
self._verify_based_on_blob_fields(raw_data) |
|
|
|
return self._writer.write_raw_data(raw_data, validate) |
|
|
|
self._verify_based_on_schema(raw_data) |
|
|
|
return self._writer.write_raw_data(raw_data, True) |
|
|
|
|
|
|
|
def set_header_size(self, header_size): |
|
|
|
""" |
|
|
|
|