You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

shardsegment.py 5.4 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149
  1. # Copyright 2019 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. """
  16. This module is to read page from mindrecord.
  17. """
  18. import mindspore._c_mindrecord as ms
  19. from mindspore import log as logger
  20. from .shardutils import populate_data, SUCCESS
  21. from .shardheader import ShardHeader
  22. from .common.exceptions import MRMOpenError, MRMFetchCandidateFieldsError, MRMReadCategoryInfoError, MRMFetchDataError
  23. __all__ = ['ShardSegment']
  24. class ShardSegment:
  25. """
  26. Wrapper class which is represent ShardSegment class in c++ module.
  27. The class would query data from MindRecord File in pagination.
  28. """
  29. def __init__(self):
  30. self._segment = ms.ShardSegment()
  31. self._header = None
  32. self._columns = None
  33. def open(self, file_name, num_consumer=4, columns=None, operator=None):
  34. """
  35. Initialize the ShardSegment.
  36. Args:
  37. file_name (str, list[str]): File names of MindRecord File.
  38. num_consumer (int): Number of worker threads which load data in parallel. Default: 4.
  39. columns (list[str]): List of fields which correspond data would be read.
  40. operator(int): Reserved parameter for operators. Default: None.
  41. Returns:
  42. MSRStatus, SUCCESS or FAILED.
  43. Raises:
  44. MRMOpenError: If failed to open MindRecord File.
  45. """
  46. self._columns = columns if columns else []
  47. operator = operator if operator else []
  48. if isinstance(file_name, list):
  49. load_dataset = False
  50. else:
  51. load_dataset = True
  52. file_name = [file_name]
  53. ret = self._segment.open(file_name, load_dataset, num_consumer, self._columns, operator)
  54. if ret != SUCCESS:
  55. logger.error("Failed to open {}.".format(file_name))
  56. raise MRMOpenError
  57. self._header = ShardHeader(self._segment.get_header())
  58. return ret
  59. def get_category_fields(self):
  60. """
  61. Get candidate category fields.
  62. Returns:
  63. list[str], by which data could be grouped.
  64. Raises:
  65. MRMFetchCandidateFieldsError: If failed to get candidate category fields.
  66. """
  67. ret, fields = self._segment.get_category_fields()
  68. if ret != SUCCESS:
  69. logger.error("Failed to get candidate category fields.")
  70. raise MRMFetchCandidateFieldsError
  71. return fields
  72. def set_category_field(self, category_field):
  73. """Select one category field to use."""
  74. return self._segment.set_category_field(category_field)
  75. def read_category_info(self):
  76. """
  77. Get the group info by the current category field.
  78. Returns:
  79. str, description fo group information.
  80. Raises:
  81. MRMReadCategoryInfoError: If failed to read category information.
  82. """
  83. ret, category_info = self._segment.read_category_info()
  84. if ret != SUCCESS:
  85. logger.error("Failed to read category information.")
  86. raise MRMReadCategoryInfoError
  87. return category_info
  88. def read_at_page_by_id(self, category_id, page, num_row):
  89. """
  90. Get the data of some page by category id.
  91. Args:
  92. category_id (int): Category id, referred to the return of read_category_info.
  93. page (int): Index of page.
  94. num_row (int): Number of rows in a page.
  95. Returns:
  96. list[dict]
  97. Raises:
  98. MRMFetchDataError: If failed to read by category id.
  99. MRMUnsupportedSchemaError: If schema is invalid.
  100. """
  101. ret, data = self._segment.read_at_page_by_id(category_id, page, num_row)
  102. if ret != SUCCESS:
  103. logger.error("Failed to read by category id.")
  104. raise MRMFetchDataError
  105. return [populate_data(raw, blob, self._columns, self._header.blob_fields,
  106. self._header.schema) for blob, raw in data]
  107. def read_at_page_by_name(self, category_name, page, num_row):
  108. """
  109. Get the data of some page by category name.
  110. Args:
  111. category_name (str): Category name, referred to the return of read_category_info.
  112. page (int): Index of page.
  113. num_row (int): Number of rows in a page.
  114. Returns:
  115. list[dict]
  116. Raises:
  117. MRMFetchDataError: If failed to read by category name.
  118. MRMUnsupportedSchemaError: If schema is invalid.
  119. """
  120. ret, data = self._segment.read_at_page_by_name(category_name, page, num_row)
  121. if ret != SUCCESS:
  122. logger.error("Failed to read by category name.")
  123. raise MRMFetchDataError
  124. return [populate_data(raw, blob, self._columns, self._header.blob_fields,
  125. self._header.schema) for blob, raw in data]