You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

mindpage.py 6.7 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166
  1. # Copyright 2019 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. """
  16. This module is to support reading page from mindrecord.
  17. """
  18. from mindspore import log as logger
  19. from .shardsegment import ShardSegment
  20. from .shardutils import MIN_CONSUMER_COUNT, MAX_CONSUMER_COUNT, check_filename
  21. from .common.exceptions import ParamValueError, ParamTypeError, MRMDefineCategoryError
  22. __all__ = ['MindPage']
  23. class MindPage:
  24. """
  25. Class to read MindRecord File series in pagination.
  26. Args:
  27. file_name (str): File name of MindRecord File.
  28. num_consumer(int, optional): Number of consumer threads which load data to memory (default=4).
  29. It should not be smaller than 1 or larger than the number of CPU.
  30. Raises:
  31. ParamValueError: If file_name, num_consumer or columns is invalid.
  32. MRMInitSegmentError: If failed to initialize ShardSegment.
  33. """
  34. def __init__(self, file_name, num_consumer=4):
  35. check_filename(file_name)
  36. self._file_name = file_name
  37. if num_consumer is not None:
  38. if isinstance(num_consumer, int):
  39. if num_consumer < MIN_CONSUMER_COUNT or num_consumer > MAX_CONSUMER_COUNT():
  40. raise ParamValueError("Consumer number should between {} and {}."
  41. .format(MIN_CONSUMER_COUNT, MAX_CONSUMER_COUNT()))
  42. else:
  43. raise ParamValueError("Consumer number is illegal.")
  44. else:
  45. raise ParamValueError("Consumer number is illegal.")
  46. self._segment = ShardSegment()
  47. self._segment.open(file_name, num_consumer)
  48. self._category_field = None
  49. self._candidate_fields = [field[:field.rfind('_')] for field in self._segment.get_category_fields()]
  50. @property
  51. def candidate_fields(self):
  52. """
  53. Return candidate category fields.
  54. Returns:
  55. list[str], by which data could be grouped.
  56. """
  57. return self._candidate_fields
  58. def get_category_fields(self):
  59. """Return candidate category fields."""
  60. logger.warning("WARN_DEPRECATED: The usage of get_category_fields is deprecated."
  61. " Please use candidate_fields")
  62. return self.candidate_fields
  63. def set_category_field(self, category_field):
  64. """
  65. Set category field for reading.
  66. Note:
  67. Should be a candidate category field.
  68. Args:
  69. category_field (str): String of category field name.
  70. Returns:
  71. MSRStatus, SUCCESS or FAILED.
  72. """
  73. logger.warning("WARN_DEPRECATED: The usage of set_category_field is deprecated."
  74. " Please use category_field")
  75. if not category_field or not isinstance(category_field, str):
  76. raise ParamTypeError('category_fields', 'str')
  77. if category_field not in self._candidate_fields:
  78. raise MRMDefineCategoryError("Field '{}' is not a candidate category field.".format(category_field))
  79. return self._segment.set_category_field(category_field)
  80. @property
  81. def category_field(self):
  82. """Getter function for category field"""
  83. return self._category_field
  84. @category_field.setter
  85. def category_field(self, category_field):
  86. """Setter function for category field"""
  87. if not category_field or not isinstance(category_field, str):
  88. raise ParamTypeError('category_fields', 'str')
  89. if category_field not in self._candidate_fields:
  90. raise MRMDefineCategoryError("Field '{}' is not a candidate category field.".format(category_field))
  91. self._category_field = category_field
  92. return self._segment.set_category_field(self._category_field)
  93. def read_category_info(self):
  94. """
  95. Return category information when data is grouped by indicated category field.
  96. Returns:
  97. str, description of group information.
  98. Raises:
  99. MRMReadCategoryInfoError: If failed to read category information.
  100. """
  101. return self._segment.read_category_info()
  102. def read_at_page_by_id(self, category_id, page, num_row):
  103. """
  104. Query by category id in pagination.
  105. Args:
  106. category_id (int): Category id, referred to the return of read_category_info.
  107. page (int): Index of page.
  108. num_row (int): Number of rows in a page.
  109. Returns:
  110. List, list[dict].
  111. Raises:
  112. ParamValueError: If any parameter is invalid.
  113. MRMFetchDataError: If failed to fetch data by category.
  114. MRMUnsupportedSchemaError: If schema is invalid.
  115. """
  116. if not isinstance(category_id, int) or category_id < 0:
  117. raise ParamValueError("Category id should be int and greater than or equal to 0.")
  118. if not isinstance(page, int) or page < 0:
  119. raise ParamValueError("Page should be int and greater than or equal to 0.")
  120. if not isinstance(num_row, int) or num_row <= 0:
  121. raise ParamValueError("num_row should be int and greater than 0.")
  122. return self._segment.read_at_page_by_id(category_id, page, num_row)
  123. def read_at_page_by_name(self, category_name, page, num_row):
  124. """
  125. Query by category name in pagination.
  126. Args:
  127. category_name (str): String of category field's value,
  128. referred to the return of read_category_info.
  129. page (int): Index of page.
  130. num_row (int): Number of row in a page.
  131. Returns:
  132. str, read at page.
  133. """
  134. if not isinstance(category_name, str):
  135. raise ParamValueError("Category name should be str.")
  136. if not isinstance(page, int) or page < 0:
  137. raise ParamValueError("Page should be int and greater than or equal to 0.")
  138. if not isinstance(num_row, int) or num_row <= 0:
  139. raise ParamValueError("num_row should be int and greater than 0.")
  140. return self._segment.read_at_page_by_name(category_name, page, num_row)