# Copyright 2019 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """ This module is to write data into mindrecord. """ import os import sys import threading import traceback import numpy as np import mindspore._c_mindrecord as ms from .common.exceptions import ParamValueError, MRMUnsupportedSchemaError SUCCESS = ms.MSRStatus.SUCCESS FAILED = ms.MSRStatus.FAILED DATASET_NLP = ms.ShardType.NLP DATASET_CV = ms.ShardType.CV MIN_HEADER_SIZE = ms.MIN_HEADER_SIZE MAX_HEADER_SIZE = ms.MAX_HEADER_SIZE MIN_PAGE_SIZE = ms.MIN_PAGE_SIZE MAX_PAGE_SIZE = ms.MAX_PAGE_SIZE MIN_SHARD_COUNT = ms.MIN_SHARD_COUNT MAX_SHARD_COUNT = ms.MAX_SHARD_COUNT MIN_CONSUMER_COUNT = ms.MIN_CONSUMER_COUNT MAX_CONSUMER_COUNT = ms.get_max_thread_num VALUE_TYPE_MAP = {"int": ["int32", "int64"], "float": ["float32", "float64"], "str": "string", "bytes": "bytes", "int32": "int32", "int64": "int64", "float32": "float32", "float64": "float64", "ndarray": ["int32", "int64", "float32", "float64"]} VALID_ATTRIBUTES = ["int32", "int64", "float32", "float64", "string", "bytes"] VALID_ARRAY_ATTRIBUTES = ["int32", "int64", "float32", "float64"] class ExceptionThread(threading.Thread): """ class to pass exception""" def __init__(self, *args, **kwargs): threading.Thread.__init__(self, *args, **kwargs) self.res = SUCCESS self.exitcode = 0 self.exception = None self.exc_traceback = '' def run(self): try: if self._target: self.res = self._target(*self._args, **self._kwargs) except Exception as e: # pylint: disable=W0703 self.exitcode = 1 self.exception = e self.exc_traceback = ''.join(traceback.format_exception(*sys.exc_info())) def check_filename(path): """ check the filename in the path. Args: path (str): the path. Raises: ParamValueError: If path is not string. FileNameError: If path contains invalid character. Returns: Bool, whether filename is valid. """ if not path: raise ParamValueError('File path is not allowed None or empty!') if not isinstance(path, str): raise ParamValueError("File path: {} is not string.".format(path)) file_name = os.path.basename(path) # '#', ':', '|', ' ', '}', '"', '+', '!', ']', '[', '\\', '`', # '&', '.', '/', '@', "'", '^', ',', '_', '<', ';', '~', '>', # '*', '(', '%', ')', '-', '=', '{', '?', '$' forbidden_symbols = set(r'\/:*?"<>|`&\';') if set(file_name) & forbidden_symbols: raise ParamValueError(r"File name should not contains \/:*?\"<>|`&;\'") if file_name.startswith(' ') or file_name.endswith(' '): raise ParamValueError("File name should not start/end with space.") return True def populate_data(raw, blob, columns, blob_fields, schema): """ Reconstruct data form raw and blob data. Args: raw (Dict): Data contain primitive data like "int32", "int64", "float32", "float64", "string", "bytes". blob (Bytes): Data contain bytes and ndarray data. columns(List): List of column name which will be populated. blob_fields (List): Refer to the field which data stored in blob. schema(Dict): Dict of Schema Raises: MRMUnsupportedSchemaError: If schema is invalid. """ if raw: # remove dummy fileds raw = {k: v for k, v in raw.items() if k in schema} else: raw = {} if not blob_fields: return raw loaded_columns = [] if columns: for column in columns: if column in blob_fields: loaded_columns.append(column) else: loaded_columns = blob_fields def _render_raw(field, blob_data): data_type = schema[field]['type'] data_shape = schema[field]['shape'] if 'shape' in schema[field] else [] if data_shape: try: raw[field] = np.reshape(np.frombuffer(blob_data, dtype=data_type), data_shape) except ValueError: raise MRMUnsupportedSchemaError('Shape in schema is illegal.') else: raw[field] = blob_data for i, blob_field in enumerate(loaded_columns): _render_raw(blob_field, bytes(blob[i])) return raw