You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

mindpage.py 6.8 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169
  1. # Copyright 2019 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. """
  16. This module is to support reading page from mindrecord.
  17. """
  18. from mindspore import log as logger
  19. from .shardsegment import ShardSegment
  20. from .shardutils import MIN_CONSUMER_COUNT, MAX_CONSUMER_COUNT, check_filename
  21. from .common.exceptions import ParamValueError, ParamTypeError, MRMDefineCategoryError
  22. __all__ = ['MindPage']
  23. class MindPage:
  24. """
  25. Class to read MindRecord File series in pagination.
  26. Args:
  27. file_name (str): One of MindRecord File or a file list.
  28. num_consumer(int, optional): The number of consumer threads which load data to memory (default=4).
  29. It should not be smaller than 1 or larger than the number of CPUs.
  30. Raises:
  31. ParamValueError: If `file_name`, `num_consumer` or columns is invalid.
  32. MRMInitSegmentError: If failed to initialize ShardSegment.
  33. """
  34. def __init__(self, file_name, num_consumer=4):
  35. if isinstance(file_name, list):
  36. for f in file_name:
  37. check_filename(f)
  38. else:
  39. check_filename(file_name)
  40. if num_consumer is not None:
  41. if isinstance(num_consumer, int):
  42. if num_consumer < MIN_CONSUMER_COUNT or num_consumer > MAX_CONSUMER_COUNT():
  43. raise ParamValueError("Consumer number should between {} and {}."
  44. .format(MIN_CONSUMER_COUNT, MAX_CONSUMER_COUNT()))
  45. else:
  46. raise ParamValueError("Consumer number is illegal.")
  47. else:
  48. raise ParamValueError("Consumer number is illegal.")
  49. self._segment = ShardSegment()
  50. self._segment.open(file_name, num_consumer)
  51. self._category_field = None
  52. self._candidate_fields = [field[:field.rfind('_')] for field in self._segment.get_category_fields()]
  53. @property
  54. def candidate_fields(self):
  55. """
  56. Return candidate category fields.
  57. Returns:
  58. list[str], by which data could be grouped.
  59. """
  60. return self._candidate_fields
  61. def get_category_fields(self):
  62. """Return candidate category fields."""
  63. logger.warning("WARN_DEPRECATED: The usage of get_category_fields is deprecated."
  64. " Please use candidate_fields")
  65. return self.candidate_fields
  66. def set_category_field(self, category_field):
  67. """
  68. Set category field for reading.
  69. Note:
  70. Should be a candidate category field.
  71. Args:
  72. category_field (str): String of category field name.
  73. Returns:
  74. MSRStatus, SUCCESS or FAILED.
  75. """
  76. logger.warning("WARN_DEPRECATED: The usage of set_category_field is deprecated."
  77. " Please use category_field")
  78. if not category_field or not isinstance(category_field, str):
  79. raise ParamTypeError('category_fields', 'str')
  80. if category_field not in self._candidate_fields:
  81. raise MRMDefineCategoryError("Field '{}' is not a candidate category field.".format(category_field))
  82. return self._segment.set_category_field(category_field)
  83. @property
  84. def category_field(self):
  85. """Getter function for category fields."""
  86. return self._category_field
  87. @category_field.setter
  88. def category_field(self, category_field):
  89. """Setter function for category field"""
  90. if not category_field or not isinstance(category_field, str):
  91. raise ParamTypeError('category_fields', 'str')
  92. if category_field not in self._candidate_fields:
  93. raise MRMDefineCategoryError("Field '{}' is not a candidate category field.".format(category_field))
  94. self._category_field = category_field
  95. return self._segment.set_category_field(self._category_field)
  96. def read_category_info(self):
  97. """
  98. Return category information when data is grouped by indicated category field.
  99. Returns:
  100. str, description of group information.
  101. Raises:
  102. MRMReadCategoryInfoError: If failed to read category information.
  103. """
  104. return self._segment.read_category_info()
  105. def read_at_page_by_id(self, category_id, page, num_row):
  106. """
  107. Query by category id in pagination.
  108. Args:
  109. category_id (int): Category id, referred to the return of `read_category_info`.
  110. page (int): Index of page.
  111. num_row (int): Number of rows in a page.
  112. Returns:
  113. List, list[dict].
  114. Raises:
  115. ParamValueError: If any parameter is invalid.
  116. MRMFetchDataError: If failed to fetch data by category.
  117. MRMUnsupportedSchemaError: If schema is invalid.
  118. """
  119. if not isinstance(category_id, int) or category_id < 0:
  120. raise ParamValueError("Category id should be int and greater than or equal to 0.")
  121. if not isinstance(page, int) or page < 0:
  122. raise ParamValueError("Page should be int and greater than or equal to 0.")
  123. if not isinstance(num_row, int) or num_row <= 0:
  124. raise ParamValueError("num_row should be int and greater than 0.")
  125. return self._segment.read_at_page_by_id(category_id, page, num_row)
  126. def read_at_page_by_name(self, category_name, page, num_row):
  127. """
  128. Query by category name in pagination.
  129. Args:
  130. category_name (str): String of category field's value,
  131. referred to the return of `read_category_info`.
  132. page (int): Index of page.
  133. num_row (int): Number of row in a page.
  134. Returns:
  135. str, read at page.
  136. """
  137. if not isinstance(category_name, str):
  138. raise ParamValueError("Category name should be str.")
  139. if not isinstance(page, int) or page < 0:
  140. raise ParamValueError("Page should be int and greater than or equal to 0.")
  141. if not isinstance(num_row, int) or num_row <= 0:
  142. raise ParamValueError("num_row should be int and greater than 0.")
  143. return self._segment.read_at_page_by_name(category_name, page, num_row)