You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

shardsegment.py 5.3 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144
  1. # Copyright 2019 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. """
  16. This module is to read page from mindrecord.
  17. """
  18. import mindspore._c_mindrecord as ms
  19. from mindspore import log as logger
  20. from .shardutils import populate_data, SUCCESS
  21. from .shardheader import ShardHeader
  22. from .common.exceptions import MRMOpenError, MRMFetchCandidateFieldsError, MRMReadCategoryInfoError, MRMFetchDataError
  23. __all__ = ['ShardSegment']
  24. class ShardSegment:
  25. """
  26. Wrapper class which is represent ShardSegment class in c++ module.
  27. The class would query data from MindRecord File in pagination.
  28. """
  29. def __init__(self):
  30. self._segment = ms.ShardSegment()
  31. self._header = None
  32. self._columns = None
  33. def open(self, file_name, num_consumer=4, columns=None, operator=None):
  34. """
  35. Initialize the ShardSegment.
  36. Args:
  37. file_name (str): File name of MindRecord File.
  38. num_consumer (int): Number of worker threads which load data in parallel. Default: 4.
  39. columns (list[str]): List of fields which correspond data would be read.
  40. operator(int): Reserved parameter for operators. Default: None.
  41. Returns:
  42. MSRStatus, SUCCESS or FAILED.
  43. Raises:
  44. MRMOpenError: If failed to open MindRecord File.
  45. """
  46. self._columns = columns if columns else []
  47. operator = operator if operator else []
  48. ret = self._segment.open(file_name, num_consumer, self._columns, operator)
  49. if ret != SUCCESS:
  50. logger.error("Failed to open {}.".format(file_name))
  51. raise MRMOpenError
  52. self._header = ShardHeader(self._segment.get_header())
  53. return ret
  54. def get_category_fields(self):
  55. """
  56. Get candidate category fields.
  57. Returns:
  58. list[str], by which data could be grouped.
  59. Raises:
  60. MRMFetchCandidateFieldsError: If failed to get candidate category fields.
  61. """
  62. ret, fields = self._segment.get_category_fields()
  63. if ret != SUCCESS:
  64. logger.error("Failed to get candidate category fields.")
  65. raise MRMFetchCandidateFieldsError
  66. return fields
  67. def set_category_field(self, category_field):
  68. """Select one category field to use."""
  69. return self._segment.set_category_field(category_field)
  70. def read_category_info(self):
  71. """
  72. Get the group info by the current category field.
  73. Returns:
  74. str, description fo group information.
  75. Raises:
  76. MRMReadCategoryInfoError: If failed to read category information.
  77. """
  78. ret, category_info = self._segment.read_category_info()
  79. if ret != SUCCESS:
  80. logger.error("Failed to read category information.")
  81. raise MRMReadCategoryInfoError
  82. return category_info
  83. def read_at_page_by_id(self, category_id, page, num_row):
  84. """
  85. Get the data of some page by category id.
  86. Args:
  87. category_id (int): Category id, referred to the return of read_category_info.
  88. page (int): Index of page.
  89. num_row (int): Number of rows in a page.
  90. Returns:
  91. list[dict]
  92. Raises:
  93. MRMFetchDataError: If failed to read by category id.
  94. MRMUnsupportedSchemaError: If schema is invalid.
  95. """
  96. ret, data = self._segment.read_at_page_by_id(category_id, page, num_row)
  97. if ret != SUCCESS:
  98. logger.error("Failed to read by category id.")
  99. raise MRMFetchDataError
  100. return [populate_data(raw, blob, self._columns, self._header.blob_fields,
  101. self._header.schema) for blob, raw in data]
  102. def read_at_page_by_name(self, category_name, page, num_row):
  103. """
  104. Get the data of some page by category name.
  105. Args:
  106. category_name (str): Category name, referred to the return of read_category_info.
  107. page (int): Index of page.
  108. num_row (int): Number of rows in a page.
  109. Returns:
  110. list[dict]
  111. Raises:
  112. MRMFetchDataError: If failed to read by category name.
  113. MRMUnsupportedSchemaError: If schema is invalid.
  114. """
  115. ret, data = self._segment.read_at_page_by_name(category_name, page, num_row)
  116. if ret != SUCCESS:
  117. logger.error("Failed to read by category name.")
  118. raise MRMFetchDataError
  119. return [populate_data(raw, blob, self._columns, self._header.blob_fields,
  120. self._header.schema) for blob, raw in data]