You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

shardreader.py 3.8 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123
  1. # Copyright 2019 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. """
  16. This module is to read data from mindrecord.
  17. """
  18. import mindspore._c_mindrecord as ms
  19. from mindspore import log as logger
  20. from .common.exceptions import MRMOpenError, MRMLaunchError, MRMFinishError
  21. __all__ = ['ShardReader']
  22. class ShardReader:
  23. """
  24. Wrapper class which is represent ShardReader class in c++ module.
  25. The class would read a batch of data from MindRecord File series.
  26. """
  27. def __init__(self):
  28. self._reader = ms.ShardReader()
  29. def open(self, file_name, num_consumer=4, columns=None, operator=None):
  30. """
  31. Open file and prepare to read MindRecord File.
  32. Args:
  33. file_name (str, list[str]): File names of MindRecord File.
  34. num_consumer (int): Number of worker threads which load data in parallel. Default: 4.
  35. columns (list[str]): List of fields which correspond data would be read.
  36. operator(int): Reserved parameter for operators. Default: None.
  37. Returns:
  38. MSRStatus, SUCCESS or FAILED.
  39. Raises:
  40. MRMOpenError: If failed to open MindRecord File.
  41. """
  42. columns = columns if columns else []
  43. operator = operator if operator else []
  44. if isinstance(file_name, list):
  45. load_dataset = False
  46. else:
  47. load_dataset = True
  48. file_name = [file_name]
  49. ret = self._reader.open(file_name, load_dataset, num_consumer, columns, operator)
  50. if ret != ms.MSRStatus.SUCCESS:
  51. logger.error("Failed to open {}.".format(file_name))
  52. raise MRMOpenError
  53. return ret
  54. def launch(self):
  55. """
  56. Launch the worker threads to load data.
  57. Returns:
  58. MSRStatus, SUCCESS or FAILED.
  59. Raises:
  60. MRMLaunchError: If failed to launch worker threads.
  61. """
  62. ret = self._reader.launch(False)
  63. if ret != ms.MSRStatus.SUCCESS:
  64. logger.error("Failed to launch worker threads.")
  65. raise MRMLaunchError
  66. return ret
  67. def get_next(self):
  68. """
  69. Return a batch of data including blob data and raw data.
  70. Returns:
  71. list of dict.
  72. """
  73. return self._reader.get_next()
  74. def get_blob_fields(self):
  75. """
  76. Return blob fields of MindRecord.
  77. Returns:
  78. list of str.
  79. """
  80. return self._reader.get_blob_fields()
  81. def get_header(self):
  82. """
  83. Return header of MindRecord.
  84. Returns:
  85. pointer object refer to header.
  86. """
  87. return self._reader.get_header()
  88. def finish(self):
  89. """
  90. stop the worker threads.
  91. Returns:
  92. MSRStatus, SUCCESS or FAILED.
  93. Raises:
  94. MRMFinishError: If failed to finish worker threads.
  95. """
  96. ret = self._reader.finish()
  97. if ret != ms.MSRStatus.SUCCESS:
  98. logger.error("Failed to finish worker threads.")
  99. raise MRMFinishError
  100. return ret
  101. def close(self):
  102. """close MindRecord File."""
  103. self._reader.close()