|
- # Copyright 2019 Huawei Technologies Co., Ltd
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- # ==============================================================================
- """
- This module is to read page from mindrecord.
- """
- import mindspore._c_mindrecord as ms
- from mindspore import log as logger
- from .shardutils import populate_data, SUCCESS
- from .shardheader import ShardHeader
- from .common.exceptions import MRMOpenError, MRMFetchCandidateFieldsError, MRMReadCategoryInfoError, MRMFetchDataError
-
- __all__ = ['ShardSegment']
-
- class ShardSegment:
- """
- Wrapper class which is represent ShardSegment class in c++ module.
-
- The class would query data from MindRecord File in pagination.
-
- """
- def __init__(self):
- self._segment = ms.ShardSegment()
- self._header = None
- self._columns = None
-
- def open(self, file_name, num_consumer=4, columns=None, operator=None):
- """
- Initialize the ShardSegment.
-
- Args:
- file_name (str, list[str]): File names of MindRecord File.
- num_consumer (int): Number of worker threads which load data in parallel. Default: 4.
- columns (list[str]): List of fields which correspond data would be read.
- operator(int): Reserved parameter for operators. Default: None.
-
- Returns:
- MSRStatus, SUCCESS or FAILED.
-
- Raises:
- MRMOpenError: If failed to open MindRecord File.
- """
- self._columns = columns if columns else []
- operator = operator if operator else []
- if isinstance(file_name, list):
- load_dataset = False
- else:
- load_dataset = True
- file_name = [file_name]
- ret = self._segment.open(file_name, load_dataset, num_consumer, self._columns, operator)
- if ret != SUCCESS:
- logger.error("Failed to open {}.".format(file_name))
- raise MRMOpenError
- self._header = ShardHeader(self._segment.get_header())
- return ret
-
- def get_category_fields(self):
- """
- Get candidate category fields.
-
- Returns:
- list[str], by which data could be grouped.
-
- Raises:
- MRMFetchCandidateFieldsError: If failed to get candidate category fields.
- """
- ret, fields = self._segment.get_category_fields()
- if ret != SUCCESS:
- logger.error("Failed to get candidate category fields.")
- raise MRMFetchCandidateFieldsError
- return fields
-
-
- def set_category_field(self, category_field):
- """Select one category field to use."""
- return self._segment.set_category_field(category_field)
-
- def read_category_info(self):
- """
- Get the group info by the current category field.
-
- Returns:
- str, description fo group information.
-
- Raises:
- MRMReadCategoryInfoError: If failed to read category information.
- """
- ret, category_info = self._segment.read_category_info()
- if ret != SUCCESS:
- logger.error("Failed to read category information.")
- raise MRMReadCategoryInfoError
- return category_info
-
- def read_at_page_by_id(self, category_id, page, num_row):
- """
- Get the data of some page by category id.
-
- Args:
- category_id (int): Category id, referred to the return of read_category_info.
- page (int): Index of page.
- num_row (int): Number of rows in a page.
-
- Returns:
- list[dict]
-
- Raises:
- MRMFetchDataError: If failed to read by category id.
- MRMUnsupportedSchemaError: If schema is invalid.
- """
- ret, data = self._segment.read_at_page_by_id(category_id, page, num_row)
- if ret != SUCCESS:
- logger.error("Failed to read by category id.")
- raise MRMFetchDataError
- return [populate_data(raw, blob, self._columns, self._header.blob_fields,
- self._header.schema) for blob, raw in data]
-
- def read_at_page_by_name(self, category_name, page, num_row):
- """
- Get the data of some page by category name.
-
- Args:
- category_name (str): Category name, referred to the return of read_category_info.
- page (int): Index of page.
- num_row (int): Number of rows in a page.
-
- Returns:
- list[dict]
-
- Raises:
- MRMFetchDataError: If failed to read by category name.
- MRMUnsupportedSchemaError: If schema is invalid.
- """
- ret, data = self._segment.read_at_page_by_name(category_name, page, num_row)
- if ret != SUCCESS:
- logger.error("Failed to read by category name.")
- raise MRMFetchDataError
- return [populate_data(raw, blob, self._columns, self._header.blob_fields,
- self._header.schema) for blob, raw in data]
|