|
- # Copyright 2019 Huawei Technologies Co., Ltd
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- # ==============================================================================
- """
- This module is to read data from mindrecord.
- """
- import mindspore._c_mindrecord as ms
- from mindspore import log as logger
- from .common.exceptions import MRMOpenError, MRMLaunchError, MRMFinishError
-
- __all__ = ['ShardReader']
-
- class ShardReader:
- """
- Wrapper class which is represent ShardReader class in c++ module.
-
- The class would read a batch of data from MindRecord File series.
- """
- def __init__(self):
- self._reader = ms.ShardReader()
-
- def open(self, file_name, num_consumer=4, columns=None, operator=None):
- """
- Open file and prepare to read MindRecord File.
-
- Args:
- file_name (str, list[str]): File names of MindRecord File.
- num_consumer (int): Number of worker threads which load data in parallel. Default: 4.
- columns (list[str]): List of fields which correspond data would be read.
- operator(int): Reserved parameter for operators. Default: None.
-
- Returns:
- MSRStatus, SUCCESS or FAILED.
-
- Raises:
- MRMOpenError: If failed to open MindRecord File.
- """
- columns = columns if columns else []
- operator = operator if operator else []
- if isinstance(file_name, list):
- load_dataset = False
- else:
- load_dataset = True
- file_name = [file_name]
- ret = self._reader.open(file_name, load_dataset, num_consumer, columns, operator)
- if ret != ms.MSRStatus.SUCCESS:
- logger.error("Failed to open {}.".format(file_name))
- raise MRMOpenError
- return ret
-
- def launch(self):
- """
- Launch the worker threads to load data.
-
- Returns:
- MSRStatus, SUCCESS or FAILED.
-
- Raises:
- MRMLaunchError: If failed to launch worker threads.
- """
- ret = self._reader.launch(False)
- if ret != ms.MSRStatus.SUCCESS:
- logger.error("Failed to launch worker threads.")
- raise MRMLaunchError
- return ret
-
- def get_next(self):
- """
- Return a batch of data including blob data and raw data.
-
- Returns:
- list of dict.
- """
- return self._reader.get_next()
-
- def get_blob_fields(self):
- """
- Return blob fields of MindRecord.
-
- Returns:
- list of str.
- """
- return self._reader.get_blob_fields()
-
- def get_header(self):
- """
- Return header of MindRecord.
-
- Returns:
- pointer object refer to header.
- """
- return self._reader.get_header()
-
- def finish(self):
- """
- stop the worker threads.
-
- Returns:
- MSRStatus, SUCCESS or FAILED.
-
- Raises:
- MRMFinishError: If failed to finish worker threads.
- """
- ret = self._reader.finish()
- if ret != ms.MSRStatus.SUCCESS:
- logger.error("Failed to finish worker threads.")
- raise MRMFinishError
- return ret
-
- def close(self):
- """close MindRecord File."""
- self._reader.close()
|