You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

shardutils.py 5.0 kB

5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144
  1. # Copyright 2019 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. """
  16. This module is to write data into mindrecord.
  17. """
  18. import os
  19. import sys
  20. import threading
  21. import traceback
  22. import numpy as np
  23. import mindspore._c_mindrecord as ms
  24. from .common.exceptions import ParamValueError, MRMUnsupportedSchemaError
  25. SUCCESS = ms.MSRStatus.SUCCESS
  26. FAILED = ms.MSRStatus.FAILED
  27. DATASET_NLP = ms.ShardType.NLP
  28. DATASET_CV = ms.ShardType.CV
  29. MIN_HEADER_SIZE = ms.MIN_HEADER_SIZE
  30. MAX_HEADER_SIZE = ms.MAX_HEADER_SIZE
  31. MIN_PAGE_SIZE = ms.MIN_PAGE_SIZE
  32. MAX_PAGE_SIZE = ms.MAX_PAGE_SIZE
  33. MIN_SHARD_COUNT = ms.MIN_SHARD_COUNT
  34. MAX_SHARD_COUNT = ms.MAX_SHARD_COUNT
  35. MIN_CONSUMER_COUNT = ms.MIN_CONSUMER_COUNT
  36. MAX_CONSUMER_COUNT = ms.get_max_thread_num
  37. VALUE_TYPE_MAP = {"int": ["int32", "int64"], "float": ["float32", "float64"], "str": "string", "bytes": "bytes",
  38. "int32": "int32", "int64": "int64", "float32": "float32", "float64": "float64",
  39. "ndarray": ["int32", "int64", "float32", "float64"]}
  40. VALID_ATTRIBUTES = ["int32", "int64", "float32", "float64", "string", "bytes"]
  41. VALID_ARRAY_ATTRIBUTES = ["int32", "int64", "float32", "float64"]
  42. class ExceptionThread(threading.Thread):
  43. """ class to pass exception"""
  44. def __init__(self, *args, **kwargs):
  45. threading.Thread.__init__(self, *args, **kwargs)
  46. self.res = SUCCESS
  47. self.exitcode = 0
  48. self.exception = None
  49. self.exc_traceback = ''
  50. def run(self):
  51. try:
  52. if self._target:
  53. self.res = self._target(*self._args, **self._kwargs)
  54. except Exception as e: # pylint: disable=W0703
  55. self.exitcode = 1
  56. self.exception = e
  57. self.exc_traceback = ''.join(traceback.format_exception(*sys.exc_info()))
  58. def check_filename(path):
  59. """
  60. check the filename in the path.
  61. Args:
  62. path (str): the path.
  63. Raises:
  64. ParamValueError: If path is not string.
  65. FileNameError: If path contains invalid character.
  66. Returns:
  67. Bool, whether filename is valid.
  68. """
  69. if not path:
  70. raise ParamValueError('File path is not allowed None or empty!')
  71. if not isinstance(path, str):
  72. raise ParamValueError("File path: {} is not string.".format(path))
  73. if path.endswith("/"):
  74. raise ParamValueError("File path can not end with '/'")
  75. file_name = os.path.basename(path)
  76. # '#', ':', '|', ' ', '}', '"', '+', '!', ']', '[', '\\', '`',
  77. # '&', '.', '/', '@', "'", '^', ',', '_', '<', ';', '~', '>',
  78. # '*', '(', '%', ')', '-', '=', '{', '?', '$'
  79. forbidden_symbols = set(r'\/:*?"<>|`&\';')
  80. if set(file_name) & forbidden_symbols:
  81. raise ParamValueError(r"File name should not contains \/:*?\"<>|`&;\'")
  82. if file_name.startswith(' ') or file_name.endswith(' '):
  83. raise ParamValueError("File name should not start/end with space.")
  84. return True
  85. def populate_data(raw, blob, columns, blob_fields, schema):
  86. """
  87. Reconstruct data form raw and blob data.
  88. Args:
  89. raw (Dict): Data contain primitive data like "int32", "int64", "float32", "float64", "string", "bytes".
  90. blob (Bytes): Data contain bytes and ndarray data.
  91. columns(List): List of column name which will be populated.
  92. blob_fields (List): Refer to the field which data stored in blob.
  93. schema(Dict): Dict of Schema
  94. Raises:
  95. MRMUnsupportedSchemaError: If schema is invalid.
  96. """
  97. if raw:
  98. # remove dummy fileds
  99. raw = {k: v for k, v in raw.items() if k in schema}
  100. else:
  101. raw = {}
  102. if not blob_fields:
  103. return raw
  104. loaded_columns = []
  105. if columns:
  106. for column in columns:
  107. if column in blob_fields:
  108. loaded_columns.append(column)
  109. else:
  110. loaded_columns = blob_fields
  111. def _render_raw(field, blob_data):
  112. data_type = schema[field]['type']
  113. data_shape = schema[field]['shape'] if 'shape' in schema[field] else []
  114. if data_shape:
  115. try:
  116. raw[field] = np.reshape(np.frombuffer(blob_data, dtype=data_type), data_shape)
  117. except ValueError:
  118. raise MRMUnsupportedSchemaError('Shape in schema is illegal.')
  119. else:
  120. raw[field] = blob_data
  121. for i, blob_field in enumerate(loaded_columns):
  122. _render_raw(blob_field, bytes(blob[i]))
  123. return raw