You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

shardreader.py 3.4 kB

4 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108
  1. # Copyright 2019 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. """
  16. This module is to read data from mindrecord.
  17. """
  18. import mindspore._c_mindrecord as ms
  19. from mindspore import log as logger
  20. from .common.exceptions import MRMOpenError, MRMLaunchError
  21. __all__ = ['ShardReader']
  22. class ShardReader:
  23. """
  24. Wrapper class which is represent ShardReader class in c++ module.
  25. The class would read a batch of data from MindRecord File series.
  26. """
  27. def __init__(self):
  28. self._reader = ms.ShardReader()
  29. def open(self, file_name, num_consumer=4, columns=None, operator=None):
  30. """
  31. Open file and prepare to read MindRecord File.
  32. Args:
  33. file_name (str, list[str]): File names of MindRecord File.
  34. num_consumer (int): Number of worker threads which load data in parallel. Default: 4.
  35. columns (list[str]): List of fields which correspond data would be read.
  36. operator(int): Reserved parameter for operators. Default: None.
  37. Returns:
  38. MSRStatus, SUCCESS or FAILED.
  39. Raises:
  40. MRMOpenError: If failed to open MindRecord File.
  41. """
  42. columns = columns if columns else []
  43. operator = operator if operator else []
  44. if isinstance(file_name, list):
  45. load_dataset = False
  46. else:
  47. load_dataset = True
  48. file_name = [file_name]
  49. ret = self._reader.open(file_name, load_dataset, num_consumer, columns, operator)
  50. if ret != ms.MSRStatus.SUCCESS:
  51. logger.critical("Failed to open {}.".format(file_name))
  52. self.close()
  53. raise MRMOpenError
  54. return ret
  55. def launch(self):
  56. """
  57. Launch the worker threads to load data.
  58. Returns:
  59. MSRStatus, SUCCESS or FAILED.
  60. Raises:
  61. MRMLaunchError: If failed to launch worker threads.
  62. """
  63. ret = self._reader.launch()
  64. if ret != ms.MSRStatus.SUCCESS:
  65. logger.critical("Failed to launch worker threads.")
  66. raise MRMLaunchError
  67. return ret
  68. def get_next(self):
  69. """
  70. Return a batch of data including blob data and raw data.
  71. Returns:
  72. list of dict.
  73. """
  74. return self._reader.get_next()
  75. def get_blob_fields(self):
  76. """
  77. Return blob fields of MindRecord.
  78. Returns:
  79. list of str.
  80. """
  81. return self._reader.get_blob_fields()
  82. def get_header(self):
  83. """
  84. Return header of MindRecord.
  85. Returns:
  86. pointer object refer to header.
  87. """
  88. return self._reader.get_header()
  89. def close(self):
  90. """close MindRecord File."""
  91. self._reader.close()