From c257cf36e2ff8b2c19ccdd39ca6aff6f66cd7e3b Mon Sep 17 00:00:00 2001 From: liyong Date: Wed, 25 Aug 2021 16:59:16 +0800 Subject: [PATCH] refactor mindrecord --- .../dataset/engine/consumers/tree_consumer.cc | 30 +- .../engine/datasetops/source/mindrecord_op.cc | 59 +- .../engine/gnn/graph_feature_parser.cc | 15 +- .../dataset/engine/gnn/graph_loader.cc | 8 +- .../minddata/mindrecord/common/shard_error.cc | 83 -- .../mindrecord/common/shard_pybind.cc | 211 +++- .../minddata/mindrecord/common/shard_utils.cc | 65 +- .../mindrecord/include/common/shard_utils.h | 21 +- .../mindrecord/include/shard_category.h | 2 +- .../mindrecord/include/shard_column.h | 39 +- .../include/shard_distributed_sample.h | 2 +- .../minddata/mindrecord/include/shard_error.h | 92 +- .../mindrecord/include/shard_header.h | 66 +- .../include/shard_index_generator.h | 56 +- .../mindrecord/include/shard_operator.h | 37 +- .../mindrecord/include/shard_pk_sample.h | 2 +- .../mindrecord/include/shard_reader.h | 142 ++- .../mindrecord/include/shard_sample.h | 6 +- .../mindrecord/include/shard_schema.h | 9 - .../mindrecord/include/shard_segment.h | 39 +- .../include/shard_sequential_sample.h | 2 +- .../mindrecord/include/shard_shuffle.h | 8 +- .../mindrecord/include/shard_statistics.h | 9 - .../mindrecord/include/shard_writer.h | 132 ++- .../mindrecord/io/shard_index_generator.cc | 533 ++++------ .../minddata/mindrecord/io/shard_reader.cc | 923 +++++++----------- .../minddata/mindrecord/io/shard_segment.cc | 303 +++--- .../minddata/mindrecord/io/shard_writer.cc | 768 ++++++--------- .../mindrecord/meta/shard_category.cc | 2 +- .../minddata/mindrecord/meta/shard_column.cc | 174 ++-- .../meta/shard_distributed_sample.cc | 16 +- .../minddata/mindrecord/meta/shard_header.cc | 482 ++++----- .../mindrecord/meta/shard_pk_sample.cc | 8 +- .../minddata/mindrecord/meta/shard_sample.cc | 27 +- .../minddata/mindrecord/meta/shard_schema.cc | 12 - .../meta/shard_sequential_sample.cc | 9 +- .../minddata/mindrecord/meta/shard_shuffle.cc | 43 +- .../mindrecord/meta/shard_statistics.cc | 18 - .../mindrecord/meta/shard_task_list.cc | 10 +- mindspore/mindrecord/shardreader.py | 2 +- mindspore/mindrecord/shardsegment.py | 30 +- tests/ut/cpp/mindrecord/ut_common.cc | 4 +- tests/ut/cpp/mindrecord/ut_shard.cc | 11 +- .../ut/cpp/mindrecord/ut_shard_header_test.cc | 37 +- .../ut/cpp/mindrecord/ut_shard_reader_test.cc | 16 +- .../cpp/mindrecord/ut_shard_segment_test.cc | 127 ++- .../ut/cpp/mindrecord/ut_shard_writer_test.cc | 78 +- .../dataset/test_minddataset_exception.py | 10 +- .../mindrecord/test_cifar100_to_mindrecord.py | 9 +- .../mindrecord/test_cifar10_to_mindrecord.py | 10 +- .../mindrecord/test_mindrecord_exception.py | 87 +- 51 files changed, 2020 insertions(+), 2864 deletions(-) delete mode 100644 mindspore/ccsrc/minddata/mindrecord/common/shard_error.cc diff --git a/mindspore/ccsrc/minddata/dataset/engine/consumers/tree_consumer.cc b/mindspore/ccsrc/minddata/dataset/engine/consumers/tree_consumer.cc index 01ae379c2a..99aa0eb747 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/consumers/tree_consumer.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/consumers/tree_consumer.cc @@ -256,9 +256,7 @@ Status SaveToDisk::Save() { auto mr_header = std::make_shared(); auto mr_writer = std::make_unique(); std::vector blob_fields; - if (mindrecord::SUCCESS != mindrecord::ShardWriter::Initialize(&mr_writer, file_names)) { - RETURN_STATUS_UNEXPECTED("Error: failed to initialize ShardWriter, please check above `ERROR` level message."); - } + RETURN_IF_NOT_OK(mindrecord::ShardWriter::Initialize(&mr_writer, file_names)); std::unordered_map column_name_id_map; for (auto el : tree_adapter_->GetColumnNameMap()) { @@ -286,22 +284,16 @@ Status SaveToDisk::Save() { std::vector index_fields; RETURN_IF_NOT_OK(FetchMetaFromTensorRow(column_name_id_map, row, &mr_json, &index_fields)); MS_LOG(INFO) << "Schema of saved mindrecord: " << mr_json.dump(); - if (mindrecord::SUCCESS != - mindrecord::ShardHeader::Initialize(&mr_header, mr_json, index_fields, blob_fields, mr_schema_id)) { - RETURN_STATUS_UNEXPECTED("Error: failed to initialize ShardHeader."); - } - if (mindrecord::SUCCESS != mr_writer->SetShardHeader(mr_header)) { - RETURN_STATUS_UNEXPECTED("Error: failed to set header of ShardWriter."); - } + RETURN_IF_NOT_OK( + mindrecord::ShardHeader::Initialize(&mr_header, mr_json, index_fields, blob_fields, mr_schema_id)); + RETURN_IF_NOT_OK(mr_writer->SetShardHeader(mr_header)); first_loop = false; } // construct data if (!row.empty()) { // write data RETURN_IF_NOT_OK(FetchDataFromTensorRow(row, column_name_id_map, &row_raw_data, &row_bin_data)); std::shared_ptr> output_bin_data; - if (mindrecord::SUCCESS != mr_writer->MergeBlobData(blob_fields, row_bin_data, &output_bin_data)) { - RETURN_STATUS_UNEXPECTED("Error: failed to merge blob data of ShardWriter."); - } + RETURN_IF_NOT_OK(mr_writer->MergeBlobData(blob_fields, row_bin_data, &output_bin_data)); std::map> raw_data; raw_data.insert( std::pair>(mr_schema_id, std::vector{row_raw_data})); @@ -309,18 +301,12 @@ Status SaveToDisk::Save() { if (output_bin_data != nullptr) { bin_data.emplace_back(*output_bin_data); } - if (mindrecord::SUCCESS != mr_writer->WriteRawData(raw_data, bin_data)) { - RETURN_STATUS_UNEXPECTED("Error: failed to write raw data to ShardWriter."); - } + RETURN_IF_NOT_OK(mr_writer->WriteRawData(raw_data, bin_data)); } } while (!row.empty()); - if (mindrecord::SUCCESS != mr_writer->Commit()) { - RETURN_STATUS_UNEXPECTED("Error: failed to commit ShardWriter."); - } - if (mindrecord::SUCCESS != mindrecord::ShardIndexGenerator::Finalize(file_names)) { - RETURN_STATUS_UNEXPECTED("Error: failed to finalize ShardIndexGenerator."); - } + RETURN_IF_NOT_OK(mr_writer->Commit()); + RETURN_IF_NOT_OK(mindrecord::ShardIndexGenerator::Finalize(file_names)); return Status::OK(); } diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/mindrecord_op.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/mindrecord_op.cc index beb23ec80e..2cef0f1a52 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/mindrecord_op.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/mindrecord_op.cc @@ -23,6 +23,7 @@ #include "minddata/dataset/core/config_manager.h" #include "minddata/dataset/core/global_context.h" #include "minddata/dataset/engine/datasetops/source/sampler/mind_record_sampler.h" +#include "minddata/mindrecord/include/shard_column.h" #include "minddata/dataset/engine/db_connector.h" #include "minddata/dataset/engine/execution_tree.h" #include "minddata/dataset/include/dataset/constants.h" @@ -63,10 +64,8 @@ MindRecordOp::MindRecordOp(int32_t num_mind_record_workers, std::vectorOpen(dataset_file_, load_dataset_, num_mind_record_workers_, columns_to_load_, operators_, - num_padded_); - - CHECK_FAIL_RETURN_UNEXPECTED(rc == MSRStatus::SUCCESS, "MindRecordOp init failed, " + ErrnoToMessage(rc)); + RETURN_IF_NOT_OK(shard_reader_->Open(dataset_file_, load_dataset_, num_mind_record_workers_, columns_to_load_, + operators_, num_padded_)); data_schema_ = std::make_unique(); @@ -206,7 +205,9 @@ Status MindRecordOp::GetRowFromReader(TensorRow *fetched_row, uint64_t row_id, i fetched_row->setPath(file_path); fetched_row->setId(row_id); } - if (tupled_buffer.empty()) return Status::OK(); + if (tupled_buffer.empty()) { + return Status::OK(); + } if (task_type == mindrecord::TaskType::kCommonTask) { for (const auto &tupled_row : tupled_buffer) { std::vector columns_blob = std::get<0>(tupled_row); @@ -237,20 +238,15 @@ Status MindRecordOp::LoadTensorRow(TensorRow *tensor_row, const std::vectorGetShardColumn(); if (num_padded_ > 0 && task_type == mindrecord::TaskType::kPaddedTask) { - auto rc = - shard_column->GetColumnTypeByName(column_name, &column_data_type, &column_data_type_size, &column_shape); - if (rc.first != MSRStatus::SUCCESS) { - RETURN_STATUS_UNEXPECTED("Invalid parameter, column_name: " + column_name + "does not exist in dataset."); - } - if (rc.second == mindrecord::ColumnInRaw) { - auto column_in_raw = shard_column->GetColumnFromJson(column_name, sample_json_, &data_ptr, &n_bytes); - if (column_in_raw == MSRStatus::FAILED) { - RETURN_STATUS_UNEXPECTED("Invalid data, failed to retrieve raw data from padding sample."); - } - } else if (rc.second == mindrecord::ColumnInBlob) { - if (sample_bytes_.find(column_name) == sample_bytes_.end()) { - RETURN_STATUS_UNEXPECTED("Invalid data, failed to retrieve blob data from padding sample."); - } + mindrecord::ColumnCategory category; + RETURN_IF_NOT_OK(shard_column->GetColumnTypeByName(column_name, &column_data_type, &column_data_type_size, + &column_shape, &category)); + if (category == mindrecord::ColumnInRaw) { + RETURN_IF_NOT_OK(shard_column->GetColumnFromJson(column_name, sample_json_, &data_ptr, &n_bytes)); + } else if (category == mindrecord::ColumnInBlob) { + CHECK_FAIL_RETURN_UNEXPECTED(sample_bytes_.find(column_name) != sample_bytes_.end(), + "Invalid data, failed to retrieve blob data from padding sample."); + std::string ss(sample_bytes_[column_name]); n_bytes = ss.size(); data_ptr = std::make_unique(n_bytes); @@ -262,12 +258,9 @@ Status MindRecordOp::LoadTensorRow(TensorRow *tensor_row, const std::vector(data_ptr.get()); } } else { - auto has_column = - shard_column->GetColumnValueByName(column_name, columns_blob, columns_json, &data, &data_ptr, &n_bytes, - &column_data_type, &column_data_type_size, &column_shape); - if (has_column == MSRStatus::FAILED) { - RETURN_STATUS_UNEXPECTED("Invalid data, failed to retrieve data from mindrecord reader."); - } + RETURN_IF_NOT_OK(shard_column->GetColumnValueByName(column_name, columns_blob, columns_json, &data, &data_ptr, + &n_bytes, &column_data_type, &column_data_type_size, + &column_shape)); } std::shared_ptr tensor; @@ -309,15 +302,10 @@ Status MindRecordOp::Reset() { } Status MindRecordOp::LaunchThreadsAndInitOp() { - if (tree_ == nullptr) { - RETURN_STATUS_UNEXPECTED("Pipeline init failed, Execution tree not set."); - } - + RETURN_UNEXPECTED_IF_NULL(tree_); RETURN_IF_NOT_OK(io_block_queues_.Register(tree_->AllTasks())); RETURN_IF_NOT_OK(wait_for_workers_post_.Register(tree_->AllTasks())); - if (shard_reader_->Launch(true) == MSRStatus::FAILED) { - RETURN_STATUS_UNEXPECTED("MindRecordOp launch failed."); - } + RETURN_IF_NOT_OK(shard_reader_->Launch(true)); // Launch main workers that load TensorRows by reading all images RETURN_IF_NOT_OK( tree_->LaunchWorkers(num_workers_, std::bind(&MindRecordOp::WorkerEntry, this, std::placeholders::_1), "", id())); @@ -330,12 +318,7 @@ Status MindRecordOp::LaunchThreadsAndInitOp() { Status MindRecordOp::CountTotalRows(const std::vector dataset_path, bool load_dataset, const std::shared_ptr &op, int64_t *count, int64_t num_padded) { std::unique_ptr shard_reader = std::make_unique(); - MSRStatus rc = shard_reader->CountTotalRows(dataset_path, load_dataset, op, count, num_padded); - if (rc == MSRStatus::FAILED) { - RETURN_STATUS_UNEXPECTED( - "Invalid data, MindRecordOp failed to count total rows. Check whether there are corresponding .db files " - "and the value of dataset_file parameter is given correctly."); - } + RETURN_IF_NOT_OK(shard_reader->CountTotalRows(dataset_path, load_dataset, op, count, num_padded)); return Status::OK(); } diff --git a/mindspore/ccsrc/minddata/dataset/engine/gnn/graph_feature_parser.cc b/mindspore/ccsrc/minddata/dataset/engine/gnn/graph_feature_parser.cc index f09bf8abe8..dd354636aa 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/gnn/graph_feature_parser.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/gnn/graph_feature_parser.cc @@ -38,9 +38,8 @@ Status GraphFeatureParser::LoadFeatureTensor(const std::string &key, const std:: uint64_t n_bytes = 0, col_type_size = 1; mindrecord::ColumnDataType col_type = mindrecord::ColumnNoDataType; std::vector column_shape; - MSRStatus rs = shard_column_->GetColumnValueByName(key, col_blob, {}, &data, &data_ptr, &n_bytes, &col_type, - &col_type_size, &column_shape); - CHECK_FAIL_RETURN_UNEXPECTED(rs == mindrecord::SUCCESS, "fail to load column" + key); + RETURN_IF_NOT_OK(shard_column_->GetColumnValueByName(key, col_blob, {}, &data, &data_ptr, &n_bytes, &col_type, + &col_type_size, &column_shape)); if (data == nullptr) data = reinterpret_cast(&data_ptr[0]); RETURN_IF_NOT_OK(Tensor::CreateFromMemory(std::move(TensorShape({static_cast(n_bytes / col_type_size)})), std::move(DataType(mindrecord::ColumnDataTypeNameNormalized[col_type])), @@ -57,9 +56,8 @@ Status GraphFeatureParser::LoadFeatureToSharedMemory(const std::string &key, con uint64_t n_bytes = 0, col_type_size = 1; mindrecord::ColumnDataType col_type = mindrecord::ColumnNoDataType; std::vector column_shape; - MSRStatus rs = shard_column_->GetColumnValueByName(key, col_blob, {}, &data, &data_ptr, &n_bytes, &col_type, - &col_type_size, &column_shape); - CHECK_FAIL_RETURN_UNEXPECTED(rs == mindrecord::SUCCESS, "fail to load column" + key); + RETURN_IF_NOT_OK(shard_column_->GetColumnValueByName(key, col_blob, {}, &data, &data_ptr, &n_bytes, &col_type, + &col_type_size, &column_shape)); if (data == nullptr) data = reinterpret_cast(&data_ptr[0]); std::shared_ptr tensor; RETURN_IF_NOT_OK(Tensor::CreateEmpty(std::move(TensorShape({2})), std::move(DataType(DataType::DE_INT64)), &tensor)); @@ -81,9 +79,8 @@ Status GraphFeatureParser::LoadFeatureIndex(const std::string &key, const std::v uint64_t n_bytes = 0, col_type_size = 1; mindrecord::ColumnDataType col_type = mindrecord::ColumnNoDataType; std::vector column_shape; - MSRStatus rs = shard_column_->GetColumnValueByName(key, col_blob, {}, &data, &data_ptr, &n_bytes, &col_type, - &col_type_size, &column_shape); - CHECK_FAIL_RETURN_UNEXPECTED(rs == mindrecord::SUCCESS, "fail to load column:" + key); + RETURN_IF_NOT_OK(shard_column_->GetColumnValueByName(key, col_blob, {}, &data, &data_ptr, &n_bytes, &col_type, + &col_type_size, &column_shape)); if (data == nullptr) data = reinterpret_cast(&data_ptr[0]); diff --git a/mindspore/ccsrc/minddata/dataset/engine/gnn/graph_loader.cc b/mindspore/ccsrc/minddata/dataset/engine/gnn/graph_loader.cc index 1a9c3699f3..df5277bfe5 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/gnn/graph_loader.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/gnn/graph_loader.cc @@ -94,10 +94,9 @@ Status GraphLoader::InitAndLoad() { TaskGroup vg; shard_reader_ = std::make_unique(); - CHECK_FAIL_RETURN_UNEXPECTED(shard_reader_->Open({mr_path_}, true, num_workers_) == MSRStatus::SUCCESS, - "Fail to open" + mr_path_); + RETURN_IF_NOT_OK(shard_reader_->Open({mr_path_}, true, num_workers_)); CHECK_FAIL_RETURN_UNEXPECTED(shard_reader_->GetShardHeader()->GetSchemaCount() > 0, "No schema found!"); - CHECK_FAIL_RETURN_UNEXPECTED(shard_reader_->Launch(true) == MSRStatus::SUCCESS, "fail to launch mr"); + RETURN_IF_NOT_OK(shard_reader_->Launch(true)); graph_impl_->data_schema_ = (shard_reader_->GetShardHeader()->GetSchemas()[0]->GetSchema()); mindrecord::json schema = graph_impl_->data_schema_["schema"]; @@ -116,8 +115,7 @@ Status GraphLoader::InitAndLoad() { if (graph_impl_->server_mode_) { #if !defined(_WIN32) && !defined(_WIN64) int64_t total_blob_size = 0; - CHECK_FAIL_RETURN_UNEXPECTED(shard_reader_->GetTotalBlobSize(&total_blob_size) == MSRStatus::SUCCESS, - "failed to get total blob size"); + RETURN_IF_NOT_OK(shard_reader_->GetTotalBlobSize(&total_blob_size)); graph_impl_->graph_shared_memory_ = std::make_unique(total_blob_size, mr_path_); RETURN_IF_NOT_OK(graph_impl_->graph_shared_memory_->CreateSharedMemory()); #endif diff --git a/mindspore/ccsrc/minddata/mindrecord/common/shard_error.cc b/mindspore/ccsrc/minddata/mindrecord/common/shard_error.cc deleted file mode 100644 index 247c566152..0000000000 --- a/mindspore/ccsrc/minddata/mindrecord/common/shard_error.cc +++ /dev/null @@ -1,83 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "minddata/mindrecord/include/shard_error.h" - -namespace mindspore { -namespace mindrecord { -static const std::map kErrnoToMessage = { - {FAILED, "operator failed"}, - {SUCCESS, "operator success"}, - {OPEN_FILE_FAILED, "open file failed"}, - {CLOSE_FILE_FAILED, "close file failed"}, - {WRITE_METADATA_FAILED, "write metadata failed"}, - {WRITE_RAWDATA_FAILED, "write rawdata failed"}, - {GET_SCHEMA_FAILED, "get schema failed"}, - {ILLEGAL_RAWDATA, "illegal raw data"}, - {PYTHON_TO_JSON_FAILED, "pybind: python object to json failed"}, - {DIR_CREATE_FAILED, "directory create failed"}, - {OPEN_DIR_FAILED, "open directory failed"}, - {INVALID_STATISTICS, "invalid statistics object"}, - {OPEN_DATABASE_FAILED, "open database failed"}, - {CLOSE_DATABASE_FAILED, "close database failed"}, - {DATABASE_OPERATE_FAILED, "database operate failed"}, - {BUILD_SCHEMA_FAILED, "build schema failed"}, - {DIVISOR_IS_ILLEGAL, "divisor is illegal"}, - {INVALID_FILE_PATH, "file path is invalid"}, - {SECURE_FUNC_FAILED, "secure function failed"}, - {ALLOCATE_MEM_FAILED, "allocate memory failed"}, - {ILLEGAL_FIELD_NAME, "illegal field name"}, - {ILLEGAL_FIELD_TYPE, "illegal field type"}, - {SET_METADATA_FAILED, "set metadata failed"}, - {ILLEGAL_SCHEMA_DEFINITION, "illegal schema definition"}, - {ILLEGAL_COLUMN_LIST, "illegal column list"}, - {SQL_ERROR, "sql error"}, - {ILLEGAL_SHARD_COUNT, "illegal shard count"}, - {ILLEGAL_SCHEMA_COUNT, "illegal schema count"}, - {VERSION_ERROR, "data version is not matched"}, - {ADD_SCHEMA_FAILED, "add schema failed"}, - {ILLEGAL_Header_SIZE, "illegal header size"}, - {ILLEGAL_Page_SIZE, "illegal page size"}, - {ILLEGAL_SIZE_VALUE, "illegal size value"}, - {INDEX_FIELD_ERROR, "add index fields failed"}, - {GET_CANDIDATE_CATEGORYFIELDS_FAILED, "get candidate category fields failed"}, - {GET_CATEGORY_INFO_FAILED, "get category information failed"}, - {ILLEGAL_CATEGORY_ID, "illegal category id"}, - {ILLEGAL_ROWNUMBER_OF_PAGE, "illegal row number of page"}, - {ILLEGAL_SCHEMA_ID, "illegal schema id"}, - {DESERIALIZE_SCHEMA_FAILED, "deserialize schema failed"}, - {DESERIALIZE_STATISTICS_FAILED, "deserialize statistics failed"}, - {ILLEGAL_DB_FILE, "illegal db file"}, - {OVERWRITE_DB_FILE, "overwrite db file"}, - {OVERWRITE_MINDRECORD_FILE, "overwrite mindrecord file"}, - {ILLEGAL_MINDRECORD_FILE, "illegal mindrecord file"}, - {PARSE_JSON_FAILED, "parse json failed"}, - {ILLEGAL_PARAMETERS, "illegal parameters"}, - {GET_PAGE_BY_GROUP_ID_FAILED, "get page by group id failed"}, - {GET_SYSTEM_STATE_FAILED, "get system state failed"}, - {IO_FAILED, "io operate failed"}, - {MATCH_HEADER_FAILED, "match header failed"}}; - -std::string ErrnoToMessage(MSRStatus status) { - auto iter = kErrnoToMessage.find(status); - if (iter != kErrnoToMessage.end()) { - return kErrnoToMessage.at(status); - } else { - return "invalid error no"; - } -} -} // namespace mindrecord -} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/mindrecord/common/shard_pybind.cc b/mindspore/ccsrc/minddata/mindrecord/common/shard_pybind.cc index 72492c0857..908ec39449 100644 --- a/mindspore/ccsrc/minddata/mindrecord/common/shard_pybind.cc +++ b/mindspore/ccsrc/minddata/mindrecord/common/shard_pybind.cc @@ -36,20 +36,42 @@ using mindspore::MsLogLevel::ERROR; namespace mindspore { namespace mindrecord { +#define THROW_IF_ERROR(s) \ + do { \ + Status rc = std::move(s); \ + if (rc.IsError()) throw std::runtime_error(rc.ToString()); \ + } while (false) + void BindSchema(py::module *m) { (void)py::class_>(*m, "Schema", py::module_local()) - .def_static("build", (std::shared_ptr(*)(std::string, py::handle)) & Schema::Build) + .def_static("build", + [](const std::string &desc, const pybind11::handle &schema) { + json schema_json = nlohmann::detail::ToJsonImpl(schema); + return Schema::Build(std::move(desc), schema_json); + }) .def("get_desc", &Schema::GetDesc) - .def("get_schema_content", (py::object(Schema::*)()) & Schema::GetSchemaForPython) + .def("get_schema_content", + [](Schema &s) { + json schema_json = s.GetSchema(); + return nlohmann::detail::FromJsonImpl(schema_json); + }) .def("get_blob_fields", &Schema::GetBlobFields) .def("get_schema_id", &Schema::GetSchemaID); } void BindStatistics(const py::module *m) { (void)py::class_>(*m, "Statistics", py::module_local()) - .def_static("build", (std::shared_ptr(*)(std::string, py::handle)) & Statistics::Build) + .def_static("build", + [](const std::string desc, const pybind11::handle &statistics) { + json statistics_json = nlohmann::detail::ToJsonImpl(statistics); + return Statistics::Build(std::move(desc), statistics_json); + }) .def("get_desc", &Statistics::GetDesc) - .def("get_statistics", (py::object(Statistics::*)()) & Statistics::GetStatisticsForPython) + .def("get_statistics", + [](Statistics &s) { + json statistics_json = s.GetStatistics(); + return nlohmann::detail::FromJsonImpl(statistics_json); + }) .def("get_statistics_id", &Statistics::GetStatisticsID); } @@ -59,70 +81,179 @@ void BindShardHeader(const py::module *m) { .def("add_schema", &ShardHeader::AddSchema) .def("add_statistics", &ShardHeader::AddStatistic) .def("add_index_fields", - (MSRStatus(ShardHeader::*)(const std::vector &)) & ShardHeader::AddIndexFields) + [](ShardHeader &s, const std::vector &fields) { + THROW_IF_ERROR(s.AddIndexFields(fields)); + return SUCCESS; + }) .def("get_meta", &ShardHeader::GetSchemas) .def("get_statistics", &ShardHeader::GetStatistics) .def("get_fields", &ShardHeader::GetFields) - .def("get_schema_by_id", &ShardHeader::GetSchemaByID) - .def("get_statistic_by_id", &ShardHeader::GetStatisticByID); + .def("get_schema_by_id", + [](ShardHeader &s, int64_t schema_id) { + std::shared_ptr schema_ptr; + THROW_IF_ERROR(s.GetSchemaByID(schema_id, &schema_ptr)); + return schema_ptr; + }) + .def("get_statistic_by_id", [](ShardHeader &s, int64_t statistic_id) { + std::shared_ptr statistics_ptr; + THROW_IF_ERROR(s.GetStatisticByID(statistic_id, &statistics_ptr)); + return statistics_ptr; + }); } void BindShardWriter(py::module *m) { (void)py::class_(*m, "ShardWriter", py::module_local()) .def(py::init<>()) - .def("open", &ShardWriter::Open) - .def("open_for_append", &ShardWriter::OpenForAppend) - .def("set_header_size", &ShardWriter::SetHeaderSize) - .def("set_page_size", &ShardWriter::SetPageSize) - .def("set_shard_header", &ShardWriter::SetShardHeader) - .def("write_raw_data", (MSRStatus(ShardWriter::*)(std::map> &, - vector> &, bool, bool)) & - ShardWriter::WriteRawData) - .def("commit", &ShardWriter::Commit); + .def("open", + [](ShardWriter &s, const std::vector &paths, bool append) { + THROW_IF_ERROR(s.Open(paths, append)); + return SUCCESS; + }) + .def("open_for_append", + [](ShardWriter &s, const std::string &path) { + THROW_IF_ERROR(s.OpenForAppend(path)); + return SUCCESS; + }) + .def("set_header_size", + [](ShardWriter &s, const uint64_t &header_size) { + THROW_IF_ERROR(s.SetHeaderSize(header_size)); + return SUCCESS; + }) + .def("set_page_size", + [](ShardWriter &s, const uint64_t &page_size) { + THROW_IF_ERROR(s.SetPageSize(page_size)); + return SUCCESS; + }) + .def("set_shard_header", + [](ShardWriter &s, std::shared_ptr header_data) { + THROW_IF_ERROR(s.SetShardHeader(header_data)); + return SUCCESS; + }) + .def("write_raw_data", + [](ShardWriter &s, std::map> &raw_data, vector> &blob_data, + bool sign, bool parallel_writer) { + std::map> raw_data_json; + (void)std::transform(raw_data.begin(), raw_data.end(), std::inserter(raw_data_json, raw_data_json.end()), + [](const std::pair> &p) { + auto &py_raw_data = p.second; + std::vector json_raw_data; + (void)std::transform( + py_raw_data.begin(), py_raw_data.end(), std::back_inserter(json_raw_data), + [](const py::handle &obj) { return nlohmann::detail::ToJsonImpl(obj); }); + return std::make_pair(p.first, std::move(json_raw_data)); + }); + THROW_IF_ERROR(s.WriteRawData(raw_data_json, blob_data, sign, parallel_writer)); + return SUCCESS; + }) + .def("commit", [](ShardWriter &s) { + THROW_IF_ERROR(s.Commit()); + return SUCCESS; + }); } void BindShardReader(const py::module *m) { (void)py::class_>(*m, "ShardReader", py::module_local()) .def(py::init<>()) - .def("open", (MSRStatus(ShardReader::*)(const std::vector &, bool, const int &, - const std::vector &, - const std::vector> &)) & - ShardReader::OpenPy) - .def("launch", &ShardReader::Launch) + .def("open", + [](ShardReader &s, const std::vector &file_paths, bool load_dataset, const int &n_consumer, + const std::vector &selected_columns, + const std::vector> &operators) { + THROW_IF_ERROR(s.Open(file_paths, load_dataset, n_consumer, selected_columns, operators)); + return SUCCESS; + }) + .def("launch", + [](ShardReader &s) { + THROW_IF_ERROR(s.Launch(false)); + return SUCCESS; + }) .def("get_header", &ShardReader::GetShardHeader) .def("get_blob_fields", &ShardReader::GetBlobFields) - .def("get_next", (std::vector>, pybind11::object>>(ShardReader::*)()) & - ShardReader::GetNextPy) + .def("get_next", + [](ShardReader &s) { + auto data = s.GetNext(); + vector>, pybind11::object>> res; + std::transform(data.begin(), data.end(), std::back_inserter(res), + [&s](const std::tuple, json> &item) { + auto &j = std::get<1>(item); + pybind11::object obj = nlohmann::detail::FromJsonImpl(j); + auto blob_data_ptr = std::make_shared>>(); + (void)s.UnCompressBlob(std::get<0>(item), &blob_data_ptr); + return std::make_tuple(*blob_data_ptr, std::move(obj)); + }); + return res; + }) .def("close", &ShardReader::Close); } void BindShardIndexGenerator(const py::module *m) { (void)py::class_(*m, "ShardIndexGenerator", py::module_local()) .def(py::init()) - .def("build", &ShardIndexGenerator::Build) - .def("write_to_db", &ShardIndexGenerator::WriteToDatabase); + .def("build", + [](ShardIndexGenerator &s) { + THROW_IF_ERROR(s.Build()); + return SUCCESS; + }) + .def("write_to_db", [](ShardIndexGenerator &s) { + THROW_IF_ERROR(s.WriteToDatabase()); + return SUCCESS; + }); } void BindShardSegment(py::module *m) { (void)py::class_(*m, "ShardSegment", py::module_local()) .def(py::init<>()) - .def("open", (MSRStatus(ShardSegment::*)(const std::vector &, bool, const int &, - const std::vector &, - const std::vector> &)) & - ShardSegment::OpenPy) + .def("open", + [](ShardSegment &s, const std::vector &file_paths, bool load_dataset, const int &n_consumer, + const std::vector &selected_columns, + const std::vector> &operators) { + THROW_IF_ERROR(s.Open(file_paths, load_dataset, n_consumer, selected_columns, operators)); + return SUCCESS; + }) .def("get_category_fields", - (std::pair>(ShardSegment::*)()) & ShardSegment::GetCategoryFields) - .def("set_category_field", (MSRStatus(ShardSegment::*)(std::string)) & ShardSegment::SetCategoryField) - .def("read_category_info", (std::pair(ShardSegment::*)()) & ShardSegment::ReadCategoryInfo) - .def("read_at_page_by_id", (std::pair, pybind11::object>>>( - ShardSegment::*)(int64_t, int64_t, int64_t)) & - ShardSegment::ReadAtPageByIdPy) - .def("read_at_page_by_name", (std::pair, pybind11::object>>>( - ShardSegment::*)(std::string, int64_t, int64_t)) & - ShardSegment::ReadAtPageByNamePy) + [](ShardSegment &s) { + auto fields_ptr = std::make_shared>(); + THROW_IF_ERROR(s.GetCategoryFields(&fields_ptr)); + return *fields_ptr; + }) + .def("set_category_field", + [](ShardSegment &s, const std::string &category_field) { + THROW_IF_ERROR(s.SetCategoryField(category_field)); + return SUCCESS; + }) + .def("read_category_info", + [](ShardSegment &s) { + std::shared_ptr category_ptr; + THROW_IF_ERROR(s.ReadCategoryInfo(&category_ptr)); + return *category_ptr; + }) + .def("read_at_page_by_id", + [](ShardSegment &s, int64_t category_id, int64_t page_no, int64_t n_rows_of_page) { + auto pages_load_ptr = std::make_shared(); + auto pages_ptr = std::make_shared(); + THROW_IF_ERROR(s.ReadAllAtPageById(category_id, page_no, n_rows_of_page, &pages_ptr)); + (void)std::transform(pages_ptr->begin(), pages_ptr->end(), std::back_inserter(*pages_load_ptr), + [](const std::tuple, json> &item) { + auto &j = std::get<1>(item); + pybind11::object obj = nlohmann::detail::FromJsonImpl(j); + return std::make_tuple(std::get<0>(item), std::move(obj)); + }); + return *pages_load_ptr; + }) + .def("read_at_page_by_name", + [](ShardSegment &s, std::string category_name, int64_t page_no, int64_t n_rows_of_page) { + auto pages_load_ptr = std::make_shared(); + auto pages_ptr = std::make_shared(); + THROW_IF_ERROR(s.ReadAllAtPageByName(category_name, page_no, n_rows_of_page, &pages_ptr)); + (void)std::transform(pages_ptr->begin(), pages_ptr->end(), std::back_inserter(*pages_load_ptr), + [](const std::tuple, json> &item) { + auto &j = std::get<1>(item); + pybind11::object obj = nlohmann::detail::FromJsonImpl(j); + return std::make_tuple(std::get<0>(item), std::move(obj)); + }); + return *pages_load_ptr; + }) .def("get_header", &ShardSegment::GetShardHeader) - .def("get_blob_fields", - (std::pair>(ShardSegment::*)()) & ShardSegment::GetBlobFields); + .def("get_blob_fields", [](ShardSegment &s) { return s.GetBlobFields(); }); } void BindGlobalParams(py::module *m) { diff --git a/mindspore/ccsrc/minddata/mindrecord/common/shard_utils.cc b/mindspore/ccsrc/minddata/mindrecord/common/shard_utils.cc index cb27640812..f9c27cbffb 100644 --- a/mindspore/ccsrc/minddata/mindrecord/common/shard_utils.cc +++ b/mindspore/ccsrc/minddata/mindrecord/common/shard_utils.cc @@ -57,26 +57,24 @@ bool ValidateFieldName(const std::string &str) { return true; } -std::pair GetFileName(const std::string &path) { +Status GetFileName(const std::string &path, std::shared_ptr *fn_ptr) { + RETURN_UNEXPECTED_IF_NULL(fn_ptr); char real_path[PATH_MAX] = {0}; char buf[PATH_MAX] = {0}; if (strncpy_s(buf, PATH_MAX, common::SafeCStr(path), path.length()) != EOK) { - MS_LOG(ERROR) << "Securec func [strncpy_s] failed, path: " << path; - return {FAILED, ""}; + RETURN_STATUS_UNEXPECTED("Securec func [strncpy_s] failed, path: " + path); } char tmp[PATH_MAX] = {0}; #if defined(_WIN32) || defined(_WIN64) if (_fullpath(tmp, dirname(&(buf[0])), PATH_MAX) == nullptr) { - MS_LOG(ERROR) << "Invalid file path, path: " << buf; - return {FAILED, ""}; + RETURN_STATUS_UNEXPECTED("Invalid file path, path: " + std::string(buf)); } if (_fullpath(real_path, common::SafeCStr(path), PATH_MAX) == nullptr) { MS_LOG(DEBUG) << "Path: " << common::SafeCStr(path) << "check successfully"; } #else if (realpath(dirname(&(buf[0])), tmp) == nullptr) { - MS_LOG(ERROR) << "Invalid file path, path: " << buf; - return {FAILED, ""}; + RETURN_STATUS_UNEXPECTED(std::string("Invalid file path, path: ") + buf); } if (realpath(common::SafeCStr(path), real_path) == nullptr) { MS_LOG(DEBUG) << "Path: " << path << "check successfully"; @@ -87,32 +85,32 @@ std::pair GetFileName(const std::string &path) { size_t i = s.rfind(sep, s.length()); if (i != std::string::npos) { if (i + 1 < s.size()) { - return {SUCCESS, s.substr(i + 1)}; + *fn_ptr = std::make_shared(s.substr(i + 1)); + return Status::OK(); } } - return {SUCCESS, s}; + *fn_ptr = std::make_shared(s); + return Status::OK(); } -std::pair GetParentDir(const std::string &path) { +Status GetParentDir(const std::string &path, std::shared_ptr *pd_ptr) { + RETURN_UNEXPECTED_IF_NULL(pd_ptr); char real_path[PATH_MAX] = {0}; char buf[PATH_MAX] = {0}; if (strncpy_s(buf, PATH_MAX, common::SafeCStr(path), path.length()) != EOK) { - MS_LOG(ERROR) << "Securec func [strncpy_s] failed, path: " << path; - return {FAILED, ""}; + RETURN_STATUS_UNEXPECTED("Securec func [strncpy_s] failed, path: " + path); } char tmp[PATH_MAX] = {0}; #if defined(_WIN32) || defined(_WIN64) if (_fullpath(tmp, dirname(&(buf[0])), PATH_MAX) == nullptr) { - MS_LOG(ERROR) << "Invalid file path, path: " << buf; - return {FAILED, ""}; + RETURN_STATUS_UNEXPECTED("Invalid file path, path: " + std::string(buf)); } if (_fullpath(real_path, common::SafeCStr(path), PATH_MAX) == nullptr) { MS_LOG(DEBUG) << "Path: " << common::SafeCStr(path) << "check successfully"; } #else if (realpath(dirname(&(buf[0])), tmp) == nullptr) { - MS_LOG(ERROR) << "Invalid file path, path: " << buf; - return {FAILED, ""}; + RETURN_STATUS_UNEXPECTED(std::string("Invalid file path, path: ") + buf); } if (realpath(common::SafeCStr(path), real_path) == nullptr) { MS_LOG(DEBUG) << "Path: " << path << "check successfully"; @@ -120,9 +118,11 @@ std::pair GetParentDir(const std::string &path) { #endif std::string s = real_path; if (s.rfind('/') + 1 <= s.size()) { - return {SUCCESS, s.substr(0, s.rfind('/') + 1)}; + *pd_ptr = std::make_shared(s.substr(0, s.rfind('/') + 1)); + return Status::OK(); } - return {SUCCESS, "/"}; + *pd_ptr = std::make_shared("/"); + return Status::OK(); } bool CheckIsValidUtf8(const std::string &str) { @@ -163,15 +163,16 @@ bool IsLegalFile(const std::string &path) { return false; } -std::pair GetDiskSize(const std::string &str_dir, const DiskSizeType &disk_type) { +Status GetDiskSize(const std::string &str_dir, const DiskSizeType &disk_type, std::shared_ptr *size_ptr) { + RETURN_UNEXPECTED_IF_NULL(size_ptr); #if defined(_WIN32) || defined(_WIN64) || defined(__APPLE__) - return {SUCCESS, 100}; + *size_ptr = std::make_shared(100); + return Status::OK(); #else uint64_t ll_count = 0; struct statfs disk_info; if (statfs(common::SafeCStr(str_dir), &disk_info) == -1) { - MS_LOG(ERROR) << "Get disk size error"; - return {FAILED, 0}; + RETURN_STATUS_UNEXPECTED("Get disk size error."); } switch (disk_type) { @@ -187,8 +188,8 @@ std::pair GetDiskSize(const std::string &str_dir, const Dis ll_count = 0; break; } - - return {SUCCESS, ll_count}; + *size_ptr = std::make_shared(ll_count); + return Status::OK(); #endif } @@ -201,17 +202,15 @@ uint32_t GetMaxThreadNum() { return thread_num; } -std::pair> GetDatasetFiles(const std::string &path, const json &addresses) { - auto ret = GetParentDir(path); - if (SUCCESS != ret.first) { - return {FAILED, {}}; - } - std::vector abs_addresses; +Status GetDatasetFiles(const std::string &path, const json &addresses, std::shared_ptr> *ds) { + RETURN_UNEXPECTED_IF_NULL(ds); + std::shared_ptr parent_dir; + RETURN_IF_NOT_OK(GetParentDir(path, &parent_dir)); for (const auto &p : addresses) { - std::string abs_path = ret.second + std::string(p); - abs_addresses.emplace_back(abs_path); + std::string abs_path = *parent_dir + std::string(p); + (*ds)->emplace_back(abs_path); } - return {SUCCESS, abs_addresses}; + return Status::OK(); } } // namespace mindrecord } // namespace mindspore diff --git a/mindspore/ccsrc/minddata/mindrecord/include/common/shard_utils.h b/mindspore/ccsrc/minddata/mindrecord/include/common/shard_utils.h index 794c6c4b16..baacd60687 100644 --- a/mindspore/ccsrc/minddata/mindrecord/include/common/shard_utils.h +++ b/mindspore/ccsrc/minddata/mindrecord/include/common/shard_utils.h @@ -33,6 +33,7 @@ #include #include #include +#include #include #include #include @@ -159,13 +160,15 @@ bool ValidateFieldName(const std::string &str); /// \brief get the filename by the path /// \param s file path -/// \return -std::pair GetFileName(const std::string &s); +/// \param fn_ptr shared ptr of file name +/// \return Status +Status GetFileName(const std::string &path, std::shared_ptr *fn_ptr); /// \brief get parent dir /// \param path file path -/// \return parent path -std::pair GetParentDir(const std::string &path); +/// \param pd_ptr shared ptr of parent path +/// \return Status +Status GetParentDir(const std::string &path, std::shared_ptr *pd_ptr); bool CheckIsValidUtf8(const std::string &str); @@ -179,8 +182,9 @@ enum DiskSizeType { kTotalSize = 0, kFreeSize }; /// \brief get the free space about the disk /// \param str_dir file path /// \param disk_type: kTotalSize / kFreeSize -/// \return size in Megabytes -std::pair GetDiskSize(const std::string &str_dir, const DiskSizeType &disk_type); +/// \param size: shared ptr of size in Megabytes +/// \return Status +Status GetDiskSize(const std::string &str_dir, const DiskSizeType &disk_type, std::shared_ptr *size); /// \brief get the max hardware concurrency /// \return max concurrency @@ -189,8 +193,9 @@ uint32_t GetMaxThreadNum(); /// \brief get absolute path of all mindrecord files /// \param path path to one fo mindrecord files /// \param addresses relative path of all mindrecord files -/// \return vector of absolute path -std::pair> GetDatasetFiles(const std::string &path, const json &addresses); +/// \param ds shared ptr of vector of absolute path +/// \return Status +Status GetDatasetFiles(const std::string &path, const json &addresses, std::shared_ptr> *ds); } // namespace mindrecord } // namespace mindspore diff --git a/mindspore/ccsrc/minddata/mindrecord/include/shard_category.h b/mindspore/ccsrc/minddata/mindrecord/include/shard_category.h index beee2b928f..e6c5385b10 100644 --- a/mindspore/ccsrc/minddata/mindrecord/include/shard_category.h +++ b/mindspore/ccsrc/minddata/mindrecord/include/shard_category.h @@ -46,7 +46,7 @@ class __attribute__((visibility("default"))) ShardCategory : public ShardOperato bool GetReplacement() const { return replacement_; } - MSRStatus Execute(ShardTaskList &tasks) override; + Status Execute(ShardTaskList &tasks) override; int64_t GetNumSamples(int64_t dataset_size, int64_t num_classes) override; diff --git a/mindspore/ccsrc/minddata/mindrecord/include/shard_column.h b/mindspore/ccsrc/minddata/mindrecord/include/shard_column.h index 50054d01bf..f2978dbcb4 100644 --- a/mindspore/ccsrc/minddata/mindrecord/include/shard_column.h +++ b/mindspore/ccsrc/minddata/mindrecord/include/shard_column.h @@ -65,11 +65,11 @@ class __attribute__((visibility("default"))) ShardColumn { ~ShardColumn() = default; /// \brief get column value by column name - MSRStatus GetColumnValueByName(const std::string &column_name, const std::vector &columns_blob, - const json &columns_json, const unsigned char **data, - std::unique_ptr *data_ptr, uint64_t *const n_bytes, - ColumnDataType *column_data_type, uint64_t *column_data_type_size, - std::vector *column_shape); + Status GetColumnValueByName(const std::string &column_name, const std::vector &columns_blob, + const json &columns_json, const unsigned char **data, + std::unique_ptr *data_ptr, uint64_t *const n_bytes, + ColumnDataType *column_data_type, uint64_t *column_data_type_size, + std::vector *column_shape); /// \brief compress blob std::vector CompressBlob(const std::vector &blob, int64_t *compression_size); @@ -90,19 +90,18 @@ class __attribute__((visibility("default"))) ShardColumn { std::vector> GetColumnShape() { return column_shape_; } /// \brief get column value from blob - MSRStatus GetColumnFromBlob(const std::string &column_name, const std::vector &columns_blob, - const unsigned char **data, std::unique_ptr *data_ptr, - uint64_t *const n_bytes); + Status GetColumnFromBlob(const std::string &column_name, const std::vector &columns_blob, + const unsigned char **data, std::unique_ptr *data_ptr, + uint64_t *const n_bytes); /// \brief get column type - std::pair GetColumnTypeByName(const std::string &column_name, - ColumnDataType *column_data_type, - uint64_t *column_data_type_size, - std::vector *column_shape); + Status GetColumnTypeByName(const std::string &column_name, ColumnDataType *column_data_type, + uint64_t *column_data_type_size, std::vector *column_shape, + ColumnCategory *column_category); /// \brief get column value from json - MSRStatus GetColumnFromJson(const std::string &column_name, const json &columns_json, - std::unique_ptr *data_ptr, uint64_t *n_bytes); + Status GetColumnFromJson(const std::string &column_name, const json &columns_json, + std::unique_ptr *data_ptr, uint64_t *n_bytes); private: /// \brief initialization @@ -110,15 +109,15 @@ class __attribute__((visibility("default"))) ShardColumn { /// \brief get float value from json template - MSRStatus GetFloat(std::unique_ptr *data_ptr, const json &json_column_value, bool use_double); + Status GetFloat(std::unique_ptr *data_ptr, const json &json_column_value, bool use_double); /// \brief get integer value from json template - MSRStatus GetInt(std::unique_ptr *data_ptr, const json &json_column_value); + Status GetInt(std::unique_ptr *data_ptr, const json &json_column_value); /// \brief get column offset address and size from blob - MSRStatus GetColumnAddressInBlock(const uint64_t &column_id, const std::vector &columns_blob, - uint64_t *num_bytes, uint64_t *shift_idx); + Status GetColumnAddressInBlock(const uint64_t &column_id, const std::vector &columns_blob, + uint64_t *num_bytes, uint64_t *shift_idx); /// \brief check if column name is available ColumnCategory CheckColumnName(const std::string &column_name); @@ -128,8 +127,8 @@ class __attribute__((visibility("default"))) ShardColumn { /// \brief uncompress integer array column template - static MSRStatus UncompressInt(const uint64_t &column_id, std::unique_ptr *const data_ptr, - const std::vector &columns_blob, uint64_t *num_bytes, uint64_t shift_idx); + static Status UncompressInt(const uint64_t &column_id, std::unique_ptr *const data_ptr, + const std::vector &columns_blob, uint64_t *num_bytes, uint64_t shift_idx); /// \brief convert big-endian bytes to unsigned int /// \param bytes_array bytes array diff --git a/mindspore/ccsrc/minddata/mindrecord/include/shard_distributed_sample.h b/mindspore/ccsrc/minddata/mindrecord/include/shard_distributed_sample.h index 9790e00018..cc67930665 100644 --- a/mindspore/ccsrc/minddata/mindrecord/include/shard_distributed_sample.h +++ b/mindspore/ccsrc/minddata/mindrecord/include/shard_distributed_sample.h @@ -39,7 +39,7 @@ class __attribute__((visibility("default"))) ShardDistributedSample : public Sha ~ShardDistributedSample() override{}; - MSRStatus PreExecute(ShardTaskList &tasks) override; + Status PreExecute(ShardTaskList &tasks) override; int64_t GetNumSamples(int64_t dataset_size, int64_t num_classes) override; diff --git a/mindspore/ccsrc/minddata/mindrecord/include/shard_error.h b/mindspore/ccsrc/minddata/mindrecord/include/shard_error.h index 92a9d7d2d6..ffa207fbb8 100644 --- a/mindspore/ccsrc/minddata/mindrecord/include/shard_error.h +++ b/mindspore/ccsrc/minddata/mindrecord/include/shard_error.h @@ -19,65 +19,55 @@ #include #include +#include "include/api/status.h" namespace mindspore { namespace mindrecord { +#define RETURN_IF_NOT_OK(_s) \ + do { \ + Status __rc = (_s); \ + if (__rc.IsError()) { \ + return __rc; \ + } \ + } while (false) + +#define RELEASE_AND_RETURN_IF_NOT_OK(_s, _db, _in) \ + do { \ + Status __rc = (_s); \ + if (__rc.IsError()) { \ + if ((_db) != nullptr) { \ + sqlite3_close(_db); \ + } \ + (_in).close(); \ + return __rc; \ + } \ + } while (false) + +#define CHECK_FAIL_RETURN_UNEXPECTED(_condition, _e) \ + do { \ + if (!(_condition)) { \ + return Status(StatusCode::kMDUnexpectedError, __LINE__, __FILE__, _e); \ + } \ + } while (false) + +#define RETURN_UNEXPECTED_IF_NULL(_ptr) \ + do { \ + if ((_ptr) == nullptr) { \ + std::string err_msg = "The pointer[" + std::string(#_ptr) + "] is null."; \ + RETURN_STATUS_UNEXPECTED(err_msg); \ + } \ + } while (false) + +#define RETURN_STATUS_UNEXPECTED(_e) \ + do { \ + return Status(StatusCode::kMDUnexpectedError, __LINE__, __FILE__, _e); \ + } while (false) + enum MSRStatus { SUCCESS = 0, FAILED = 1, - OPEN_FILE_FAILED, - CLOSE_FILE_FAILED, - WRITE_METADATA_FAILED, - WRITE_RAWDATA_FAILED, - GET_SCHEMA_FAILED, - ILLEGAL_RAWDATA, - PYTHON_TO_JSON_FAILED, - DIR_CREATE_FAILED, - OPEN_DIR_FAILED, - INVALID_STATISTICS, - OPEN_DATABASE_FAILED, - CLOSE_DATABASE_FAILED, - DATABASE_OPERATE_FAILED, - BUILD_SCHEMA_FAILED, - DIVISOR_IS_ILLEGAL, - INVALID_FILE_PATH, - SECURE_FUNC_FAILED, - ALLOCATE_MEM_FAILED, - ILLEGAL_FIELD_NAME, - ILLEGAL_FIELD_TYPE, - SET_METADATA_FAILED, - ILLEGAL_SCHEMA_DEFINITION, - ILLEGAL_COLUMN_LIST, - SQL_ERROR, - ILLEGAL_SHARD_COUNT, - ILLEGAL_SCHEMA_COUNT, - VERSION_ERROR, - ADD_SCHEMA_FAILED, - ILLEGAL_Header_SIZE, - ILLEGAL_Page_SIZE, - ILLEGAL_SIZE_VALUE, - INDEX_FIELD_ERROR, - GET_CANDIDATE_CATEGORYFIELDS_FAILED, - GET_CATEGORY_INFO_FAILED, - ILLEGAL_CATEGORY_ID, - ILLEGAL_ROWNUMBER_OF_PAGE, - ILLEGAL_SCHEMA_ID, - DESERIALIZE_SCHEMA_FAILED, - DESERIALIZE_STATISTICS_FAILED, - ILLEGAL_DB_FILE, - OVERWRITE_DB_FILE, - OVERWRITE_MINDRECORD_FILE, - ILLEGAL_MINDRECORD_FILE, - PARSE_JSON_FAILED, - ILLEGAL_PARAMETERS, - GET_PAGE_BY_GROUP_ID_FAILED, - GET_SYSTEM_STATE_FAILED, - IO_FAILED, - MATCH_HEADER_FAILED }; -// convert error no to string message -std::string __attribute__((visibility("default"))) ErrnoToMessage(MSRStatus status); } // namespace mindrecord } // namespace mindspore diff --git a/mindspore/ccsrc/minddata/mindrecord/include/shard_header.h b/mindspore/ccsrc/minddata/mindrecord/include/shard_header.h index e2bff12c46..9e5c34d0de 100644 --- a/mindspore/ccsrc/minddata/mindrecord/include/shard_header.h +++ b/mindspore/ccsrc/minddata/mindrecord/include/shard_header.h @@ -37,9 +37,9 @@ class __attribute__((visibility("default"))) ShardHeader { ~ShardHeader() = default; - MSRStatus BuildDataset(const std::vector &file_paths, bool load_dataset = true); + Status BuildDataset(const std::vector &file_paths, bool load_dataset = true); - static std::pair BuildSingleHeader(const std::string &file_path); + static Status BuildSingleHeader(const std::string &file_path, std::shared_ptr *header_ptr); /// \brief add the schema and save it /// \param[in] schema the schema needs to be added /// \return the last schema's id @@ -53,9 +53,9 @@ class __attribute__((visibility("default"))) ShardHeader { /// \brief create index and add fields which from schema for each schema /// \param[in] fields the index fields needs to be added /// \return SUCCESS if add successfully, FAILED if not - MSRStatus AddIndexFields(std::vector> fields); + Status AddIndexFields(std::vector> fields); - MSRStatus AddIndexFields(const std::vector &fields); + Status AddIndexFields(const std::vector &fields); /// \brief get the schema /// \return the schema @@ -79,9 +79,10 @@ class __attribute__((visibility("default"))) ShardHeader { std::shared_ptr GetIndex(); /// \brief get the schema by schemaid - /// \param[in] schemaId the id of schema needs to be got - /// \return the schema obtained by schemaId - std::pair, MSRStatus> GetSchemaByID(int64_t schema_id); + /// \param[in] schema_id the id of schema needs to be got + /// \param[in] schema_ptr the schema obtained by schemaId + /// \return Status + Status GetSchemaByID(int64_t schema_id, std::shared_ptr *schema_ptr); /// \brief get the filepath to shard by shardID /// \param[in] shardID the id of shard which filepath needs to be obtained @@ -89,25 +90,26 @@ class __attribute__((visibility("default"))) ShardHeader { std::string GetShardAddressByID(int64_t shard_id); /// \brief get the statistic by statistic id - /// \param[in] statisticId the id of statistic needs to be get - /// \return the statistics obtained by statistic id - std::pair, MSRStatus> GetStatisticByID(int64_t statistic_id); + /// \param[in] statistic_id the id of statistic needs to be get + /// \param[in] statistics_ptr the statistics obtained by statistic id + /// \return Status + Status GetStatisticByID(int64_t statistic_id, std::shared_ptr *statistics_ptr); - MSRStatus InitByFiles(const std::vector &file_paths); + Status InitByFiles(const std::vector &file_paths); void SetIndex(Index index) { index_ = std::make_shared(index); } - std::pair, MSRStatus> GetPage(const int &shard_id, const int &page_id); + Status GetPage(const int &shard_id, const int &page_id, std::shared_ptr *page_ptr); - MSRStatus SetPage(const std::shared_ptr &new_page); + Status SetPage(const std::shared_ptr &new_page); - MSRStatus AddPage(const std::shared_ptr &new_page); + Status AddPage(const std::shared_ptr &new_page); int64_t GetLastPageId(const int &shard_id); int GetLastPageIdByType(const int &shard_id, const std::string &page_type); - const std::pair> GetPageByGroupId(const int &group_id, const int &shard_id); + Status GetPageByGroupId(const int &group_id, const int &shard_id, std::shared_ptr *page_ptr); std::vector GetShardAddresses() const { return shard_addresses_; } @@ -129,43 +131,41 @@ class __attribute__((visibility("default"))) ShardHeader { std::vector SerializeHeader(); - MSRStatus PagesToFile(const std::string dump_file_name); + Status PagesToFile(const std::string dump_file_name); - MSRStatus FileToPages(const std::string dump_file_name); + Status FileToPages(const std::string dump_file_name); - static MSRStatus Initialize(const std::shared_ptr *header_ptr, const json &schema, - const std::vector &index_fields, std::vector &blob_fields, - uint64_t &schema_id); + static Status Initialize(const std::shared_ptr *header_ptr, const json &schema, + const std::vector &index_fields, std::vector &blob_fields, + uint64_t &schema_id); private: - MSRStatus InitializeHeader(const std::vector &headers, bool load_dataset); + Status InitializeHeader(const std::vector &headers, bool load_dataset); /// \brief get the headers from all the shard data /// \param[in] the shard data real path /// \param[in] the headers which read from the shard data /// \return SUCCESS/FAILED - MSRStatus GetHeaders(const vector &real_addresses, std::vector &headers); + Status GetHeaders(const vector &real_addresses, std::vector &headers); - MSRStatus ValidateField(const std::vector &field_name, json schema, const uint64_t &schema_id); + Status ValidateField(const std::vector &field_name, json schema, const uint64_t &schema_id); /// \brief check the binary file status - static MSRStatus CheckFileStatus(const std::string &path); + static Status CheckFileStatus(const std::string &path); - static std::pair ValidateHeader(const std::string &path); - - void ParseHeader(const json &header); + static Status ValidateHeader(const std::string &path, std::shared_ptr *header_ptr); void GetHeadersOneTask(int start, int end, std::vector &headers, const vector &realAddresses); - MSRStatus ParseIndexFields(const json &index_fields); + Status ParseIndexFields(const json &index_fields); - MSRStatus CheckIndexField(const std::string &field, const json &schema); + Status CheckIndexField(const std::string &field, const json &schema); - MSRStatus ParsePage(const json &page, int shard_index, bool load_dataset); + Status ParsePage(const json &page, int shard_index, bool load_dataset); - MSRStatus ParseStatistics(const json &statistics); + Status ParseStatistics(const json &statistics); - MSRStatus ParseSchema(const json &schema); + Status ParseSchema(const json &schema); void ParseShardAddress(const json &address); @@ -181,7 +181,7 @@ class __attribute__((visibility("default"))) ShardHeader { std::shared_ptr InitIndexPtr(); - MSRStatus GetAllSchemaID(std::set &bucket_count); + Status GetAllSchemaID(std::set &bucket_count); uint32_t shard_count_; uint64_t header_size_; diff --git a/mindspore/ccsrc/minddata/mindrecord/include/shard_index_generator.h b/mindspore/ccsrc/minddata/mindrecord/include/shard_index_generator.h index 474d6bb6d4..2942aa2fb8 100644 --- a/mindspore/ccsrc/minddata/mindrecord/include/shard_index_generator.h +++ b/mindspore/ccsrc/minddata/mindrecord/include/shard_index_generator.h @@ -30,23 +30,24 @@ namespace mindspore { namespace mindrecord { -using INDEX_FIELDS = std::pair>>; -using ROW_DATA = std::pair>>>; +using INDEX_FIELDS = std::vector>; +using ROW_DATA = std::vector>>; class __attribute__((visibility("default"))) ShardIndexGenerator { public: explicit ShardIndexGenerator(const std::string &file_path, bool append = false); - MSRStatus Build(); + Status Build(); - static std::pair GenerateFieldName(const std::pair &field); + static Status GenerateFieldName(const std::pair &field, std::shared_ptr *fn_ptr); ~ShardIndexGenerator() {} /// \brief fetch value in json by field name /// \param[in] field /// \param[in] input - /// \return pair - std::pair GetValueByField(const string &field, json input); + /// \param[in] value + /// \return Status + Status GetValueByField(const string &field, json input, std::shared_ptr *value); /// \brief fetch field type in schema n by field path /// \param[in] field_path @@ -55,55 +56,54 @@ class __attribute__((visibility("default"))) ShardIndexGenerator { static std::string TakeFieldType(const std::string &field_path, json schema); /// \brief create databases for indexes - MSRStatus WriteToDatabase(); + Status WriteToDatabase(); - static MSRStatus Finalize(const std::vector file_names); + static Status Finalize(const std::vector file_names); private: static int Callback(void *not_used, int argc, char **argv, char **az_col_name); - static MSRStatus ExecuteSQL(const std::string &statement, sqlite3 *db, const string &success_msg = ""); + static Status ExecuteSQL(const std::string &statement, sqlite3 *db, const string &success_msg = ""); static std::string ConvertJsonToSQL(const std::string &json); - std::pair CreateDatabase(int shard_no); + Status CreateDatabase(int shard_no, sqlite3 **db); - std::pair> GetSchemaDetails(const std::vector &schema_lens, std::fstream &in); + Status GetSchemaDetails(const std::vector &schema_lens, std::fstream &in, + std::shared_ptr> *detail_ptr); - static std::pair GenerateRawSQL(const std::vector> &fields); + static Status GenerateRawSQL(const std::vector> &fields, + std::shared_ptr *sql_ptr); - std::pair CheckDatabase(const std::string &shard_address); + Status CheckDatabase(const std::string &shard_address, sqlite3 **db); /// /// \param shard_no /// \param blob_id_to_page_id /// \param raw_page_id /// \param in - /// \return field name, db type, field value - ROW_DATA GenerateRowData(int shard_no, const std::map &blob_id_to_page_id, int raw_page_id, - std::fstream &in); + /// \return Status + Status GenerateRowData(int shard_no, const std::map &blob_id_to_page_id, int raw_page_id, std::fstream &in, + std::shared_ptr *row_data_ptr); /// /// \param db /// \param sql /// \param data /// \return - MSRStatus BindParameterExecuteSQL( - sqlite3 *db, const std::string &sql, - const std::vector>> &data); + Status BindParameterExecuteSQL(sqlite3 *db, const std::string &sql, const ROW_DATA &data); - INDEX_FIELDS GenerateIndexFields(const std::vector &schema_detail); + Status GenerateIndexFields(const std::vector &schema_detail, std::shared_ptr *index_fields_ptr); - MSRStatus ExecuteTransaction(const int &shard_no, std::pair &db, - const std::vector &raw_page_ids, const std::map &blob_id_to_page_id); + Status ExecuteTransaction(const int &shard_no, sqlite3 *db, const std::vector &raw_page_ids, + const std::map &blob_id_to_page_id); - MSRStatus CreateShardNameTable(sqlite3 *db, const std::string &shard_name); + Status CreateShardNameTable(sqlite3 *db, const std::string &shard_name); - MSRStatus AddBlobPageInfo(std::vector> &row_data, - const std::shared_ptr cur_blob_page, uint64_t &cur_blob_page_offset, - std::fstream &in); + Status AddBlobPageInfo(std::vector> &row_data, + const std::shared_ptr cur_blob_page, uint64_t &cur_blob_page_offset, std::fstream &in); - void AddIndexFieldByRawData(const std::vector &schema_detail, - std::vector> &row_data); + Status AddIndexFieldByRawData(const std::vector &schema_detail, + std::vector> &row_data); void DatabaseWriter(); // worker thread diff --git a/mindspore/ccsrc/minddata/mindrecord/include/shard_operator.h b/mindspore/ccsrc/minddata/mindrecord/include/shard_operator.h index f206ef77ca..fd6219f871 100644 --- a/mindspore/ccsrc/minddata/mindrecord/include/shard_operator.h +++ b/mindspore/ccsrc/minddata/mindrecord/include/shard_operator.h @@ -13,7 +13,6 @@ * See the License for the specific language governing permissions and * limitations under the License. */ - #ifndef MINDSPORE_CCSRC_MINDDATA_MINDRECORD_INCLUDE_SHARD_OPERATOR_H_ #define MINDSPORE_CCSRC_MINDDATA_MINDRECORD_INCLUDE_SHARD_OPERATOR_H_ @@ -28,33 +27,29 @@ class __attribute__((visibility("default"))) ShardOperator { public: virtual ~ShardOperator() = default; - MSRStatus operator()(ShardTaskList &tasks) { - if (SUCCESS != this->PreExecute(tasks)) { - return FAILED; - } - if (SUCCESS != this->Execute(tasks)) { - return FAILED; - } - if (SUCCESS != this->SufExecute(tasks)) { - return FAILED; - } - return SUCCESS; + Status operator()(ShardTaskList &tasks) { + RETURN_IF_NOT_OK(this->PreExecute(tasks)); + RETURN_IF_NOT_OK(this->Execute(tasks)); + RETURN_IF_NOT_OK(this->SufExecute(tasks)); + return Status::OK(); } virtual bool HasChildOp() { return child_op_ != nullptr; } - virtual MSRStatus SetChildOp(std::shared_ptr child_op) { - if (child_op != nullptr) child_op_ = child_op; - return SUCCESS; + virtual Status SetChildOp(std::shared_ptr child_op) { + if (child_op != nullptr) { + child_op_ = child_op; + } + return Status::OK(); } virtual std::shared_ptr GetChildOp() { return child_op_; } - virtual MSRStatus PreExecute(ShardTaskList &tasks) { return SUCCESS; } + virtual Status PreExecute(ShardTaskList &tasks) { return Status::OK(); } - virtual MSRStatus Execute(ShardTaskList &tasks) = 0; + virtual Status Execute(ShardTaskList &tasks) = 0; - virtual MSRStatus SufExecute(ShardTaskList &tasks) { return SUCCESS; } + virtual Status SufExecute(ShardTaskList &tasks) { return Status::OK(); } virtual int64_t GetNumSamples(int64_t dataset_size, int64_t num_classes) { return 0; } @@ -72,9 +67,9 @@ class __attribute__((visibility("default"))) ShardOperator { std::shared_ptr child_op_ = nullptr; // indicate shard_id : inc_count - // 0 : 15 - shard0 has 15 samples - // 1 : 41 - shard1 has 26 samples - // 2 : 58 - shard2 has 17 samples + // // 0 : 15 - shard0 has 15 samples + // // 1 : 41 - shard1 has 26 samples + // // 2 : 58 - shard2 has 17 samples std::vector shard_sample_count_; dataset::ShuffleMode shuffle_mode_ = dataset::ShuffleMode::kGlobal; diff --git a/mindspore/ccsrc/minddata/mindrecord/include/shard_pk_sample.h b/mindspore/ccsrc/minddata/mindrecord/include/shard_pk_sample.h index fecdb97905..090d36a1f0 100644 --- a/mindspore/ccsrc/minddata/mindrecord/include/shard_pk_sample.h +++ b/mindspore/ccsrc/minddata/mindrecord/include/shard_pk_sample.h @@ -38,7 +38,7 @@ class __attribute__((visibility("default"))) ShardPkSample : public ShardCategor ~ShardPkSample() override{}; - MSRStatus SufExecute(ShardTaskList &tasks) override; + Status SufExecute(ShardTaskList &tasks) override; int64_t GetNumSamples() const { return num_samples_; } diff --git a/mindspore/ccsrc/minddata/mindrecord/include/shard_reader.h b/mindspore/ccsrc/minddata/mindrecord/include/shard_reader.h index f30d5c8ebb..28adeca8c8 100644 --- a/mindspore/ccsrc/minddata/mindrecord/include/shard_reader.h +++ b/mindspore/ccsrc/minddata/mindrecord/include/shard_reader.h @@ -59,12 +59,9 @@ namespace mindspore { namespace mindrecord { -using ROW_GROUPS = - std::tuple>>, std::vector>>; -using ROW_GROUP_BRIEF = - std::tuple>, std::vector>; -using TASK_RETURN_CONTENT = - std::pair, json>>>>; +using ROW_GROUPS = std::pair>>, std::vector>>; +using ROW_GROUP_BRIEF = std::tuple>, std::vector>; +using TASK_CONTENT = std::pair, json>>>; const int kNumBatchInMap = 1000; // iterator buffer size in row-reader mode class API_PUBLIC ShardReader { @@ -82,21 +79,10 @@ class API_PUBLIC ShardReader { /// \param[in] num_padded the number of padded samples /// \param[in] lazy_load if the mindrecord dataset is too large, enable lazy load mode to speed up initialization /// \return MSRStatus the status of MSRStatus - MSRStatus Open(const std::vector &file_paths, bool load_dataset, int n_consumer = 4, - const std::vector &selected_columns = {}, - const std::vector> &operators = {}, const int num_padded = 0, - bool lazy_load = false); - - /// \brief open files and initialize reader, python API - /// \param[in] file_paths the path of ONE file, any file in dataset is fine or file list - /// \param[in] load_dataset load dataset from single file or not - /// \param[in] n_consumer number of threads when reading - /// \param[in] selected_columns column list to be populated - /// \param[in] operators operators applied to data, operator type is shuffle, sample or category - /// \return MSRStatus the status of MSRStatus - MSRStatus OpenPy(const std::vector &file_paths, bool load_dataset, const int &n_consumer = 4, - const std::vector &selected_columns = {}, - const std::vector> &operators = {}); + Status Open(const std::vector &file_paths, bool load_dataset, int n_consumer = 4, + const std::vector &selected_columns = {}, + const std::vector> &operators = {}, const int num_padded = 0, + bool lazy_load = false); /// \brief close reader /// \return null @@ -104,16 +90,16 @@ class API_PUBLIC ShardReader { /// \brief read the file, get schema meta,statistics and index, single-thread mode /// \return MSRStatus the status of MSRStatus - MSRStatus Open(); + Status Open(); /// \brief read the file, get schema meta,statistics and index, multiple-thread mode /// \return MSRStatus the status of MSRStatus - MSRStatus Open(int n_consumer); + Status Open(int n_consumer); /// \brief launch threads to get batches /// \param[in] is_simple_reader trigger threads if false; do nothing if true /// \return MSRStatus the status of MSRStatus - MSRStatus Launch(bool is_simple_reader = false); + Status Launch(bool is_simple_reader = false); /// \brief aim to get the meta data /// \return the metadata @@ -133,8 +119,8 @@ class API_PUBLIC ShardReader { /// \param[in] op smart pointer refer to ShardCategory or ShardSample object /// \param[out] count # of rows /// \return MSRStatus the status of MSRStatus - MSRStatus CountTotalRows(const std::vector &file_paths, bool load_dataset, - const std::shared_ptr &op, int64_t *count, const int num_padded); + Status CountTotalRows(const std::vector &file_paths, bool load_dataset, + const std::shared_ptr &op, int64_t *count, const int num_padded); /// \brief shuffle task with incremental seed /// \return void @@ -162,8 +148,8 @@ class API_PUBLIC ShardReader { /// 3. Offset address of row group in file /// 4. The list of image offset in page [startOffset, endOffset) /// 5. The list of columns data - ROW_GROUP_BRIEF ReadRowGroupBrief(int group_id, int shard_id, - const std::vector &columns = std::vector()); + Status ReadRowGroupBrief(int group_id, int shard_id, const std::vector &columns, + std::shared_ptr *row_group_brief_ptr); /// \brief Read 1 row group data, excluding images, following an index field criteria /// \param[in] groupID row group ID @@ -176,8 +162,9 @@ class API_PUBLIC ShardReader { /// 3. Offset address of row group in file /// 4. The list of image offset in page [startOffset, endOffset) /// 5. The list of columns data - ROW_GROUP_BRIEF ReadRowGroupCriteria(int group_id, int shard_id, const std::pair &criteria, - const std::vector &columns = std::vector()); + Status ReadRowGroupCriteria(int group_id, int shard_id, const std::pair &criteria, + const std::vector &columns, + std::shared_ptr *row_group_brief_ptr); /// \brief return a batch, given that one is ready /// \return a batch of images and image data @@ -185,13 +172,7 @@ class API_PUBLIC ShardReader { /// \brief return a row by id /// \return a batch of images and image data - std::pair, json>>> GetNextById(const int64_t &task_id, - const int32_t &consumer_id); - - /// \brief return a batch, given that one is ready, python API - /// \return a batch of images and image data - std::vector>, pybind11::object>> GetNextPy(); - + TASK_CONTENT GetNextById(const int64_t &task_id, const int32_t &consumer_id); /// \brief get blob filed list /// \return blob field list std::pair> GetBlobFields(); @@ -205,83 +186,86 @@ class API_PUBLIC ShardReader { void SetAllInIndex(bool all_in_index) { all_in_index_ = all_in_index; } /// \brief get all classes - MSRStatus GetAllClasses(const std::string &category_field, std::shared_ptr> category_ptr); - - /// \brief get the size of blob data - MSRStatus GetTotalBlobSize(int64_t *total_blob_size); + Status GetAllClasses(const std::string &category_field, std::shared_ptr> category_ptr); /// \brief get a read-only ptr to the sampled ids for this epoch const std::vector *GetSampleIds(); + /// \brief get the size of blob data + Status GetTotalBlobSize(int64_t *total_blob_size); + + /// \brief extract uncompressed data based on column list + Status UnCompressBlob(const std::vector &raw_blob_data, + std::shared_ptr>> *blob_data_ptr); + protected: /// \brief sqlite call back function static int SelectCallback(void *p_data, int num_fields, char **p_fields, char **p_col_names); private: /// \brief wrap up labels to json format - MSRStatus ConvertLabelToJson(const std::vector> &labels, std::shared_ptr fs, - std::shared_ptr>>> offset_ptr, - int shard_id, const std::vector &columns, - std::shared_ptr>> col_val_ptr); + Status ConvertLabelToJson(const std::vector> &labels, std::shared_ptr fs, + std::shared_ptr>>> offset_ptr, int shard_id, + const std::vector &columns, + std::shared_ptr>> col_val_ptr); /// \brief read all rows for specified columns - ROW_GROUPS ReadAllRowGroup(const std::vector &columns); + Status ReadAllRowGroup(const std::vector &columns, std::shared_ptr *row_group_ptr); /// \brief read row meta by shard_id and sample_id - ROW_GROUPS ReadRowGroupByShardIDAndSampleID(const std::vector &columns, const uint32_t &shard_id, - const uint32_t &sample_id); + Status ReadRowGroupByShardIDAndSampleID(const std::vector &columns, const uint32_t &shard_id, + const uint32_t &sample_id, std::shared_ptr *row_group_ptr); /// \brief read all rows in one shard - MSRStatus ReadAllRowsInShard(int shard_id, const std::string &sql, const std::vector &columns, - std::shared_ptr>>> offset_ptr, - std::shared_ptr>> col_val_ptr); + Status ReadAllRowsInShard(int shard_id, const std::string &sql, const std::vector &columns, + std::shared_ptr>>> offset_ptr, + std::shared_ptr>> col_val_ptr); /// \brief initialize reader - MSRStatus Init(const std::vector &file_paths, bool load_dataset); + Status Init(const std::vector &file_paths, bool load_dataset); /// \brief validate column list - MSRStatus CheckColumnList(const std::vector &selected_columns); + Status CheckColumnList(const std::vector &selected_columns); /// \brief populate one row by task list in row-reader mode - MSRStatus ConsumerByRow(int consumer_id); + void ConsumerByRow(int consumer_id); /// \brief get offset address of images within page std::vector> GetImageOffset(int group_id, int shard_id, const std::pair &criteria = {"", ""}); /// \brief get page id by category - std::pair> GetPagesByCategory(int shard_id, - const std::pair &criteria); + Status GetPagesByCategory(int shard_id, const std::pair &criteria, + std::shared_ptr> *pages_ptr); /// \brief execute sqlite query with prepare statement - MSRStatus QueryWithCriteria(sqlite3 *db, const string &sql, const string &criteria, - std::shared_ptr>> labels_ptr); + Status QueryWithCriteria(sqlite3 *db, const string &sql, const string &criteria, + std::shared_ptr>> labels_ptr); /// \brief verify the validity of dataset - MSRStatus VerifyDataset(sqlite3 **db, const string &file); + Status VerifyDataset(sqlite3 **db, const string &file); /// \brief get column values - std::pair> GetLabels(int group_id, int shard_id, const std::vector &columns, - const std::pair &criteria = {"", ""}); + Status GetLabels(int page_id, int shard_id, const std::vector &columns, + const std::pair &criteria, std::shared_ptr> *labels_ptr); /// \brief get column values from raw data page - std::pair> GetLabelsFromPage(int group_id, int shard_id, - const std::vector &columns, - const std::pair &criteria = {"", - ""}); + Status GetLabelsFromPage(int page_id, int shard_id, const std::vector &columns, + const std::pair &criteria, + std::shared_ptr> *labels_ptr); /// \brief create category-applied task list - MSRStatus CreateTasksByCategory(const std::shared_ptr &op); + Status CreateTasksByCategory(const std::shared_ptr &op); /// \brief create task list in row-reader mode - MSRStatus CreateTasksByRow(const std::vector> &row_group_summary, - const std::vector> &operators); + Status CreateTasksByRow(const std::vector> &row_group_summary, + const std::vector> &operators); /// \brief create task list in row-reader mode and lazy mode - MSRStatus CreateLazyTasksByRow(const std::vector> &row_group_summary, - const std::vector> &operators); + Status CreateLazyTasksByRow(const std::vector> &row_group_summary, + const std::vector> &operators); /// \brief crate task list - MSRStatus CreateTasks(const std::vector> &row_group_summary, - const std::vector> &operators); + Status CreateTasks(const std::vector> &row_group_summary, + const std::vector> &operators); /// \brief check if all specified columns are in index table void CheckIfColumnInIndex(const std::vector &columns); @@ -290,11 +274,12 @@ class API_PUBLIC ShardReader { void FileStreamsOperator(); /// \brief read one row by one task - TASK_RETURN_CONTENT ConsumerOneTask(int task_id, uint32_t consumer_id); + Status ConsumerOneTask(int task_id, uint32_t consumer_id, std::shared_ptr *task_content_pt); /// \brief get labels from binary file - std::pair> GetLabelsFromBinaryFile( - int shard_id, const std::vector &columns, const std::vector> &label_offsets); + Status GetLabelsFromBinaryFile(int shard_id, const std::vector &columns, + const std::vector> &label_offsets, + std::shared_ptr> *labels_ptr); /// \brief get classes in one shard void GetClassesInShard(sqlite3 *db, int shard_id, const std::string &sql, @@ -304,11 +289,8 @@ class API_PUBLIC ShardReader { int64_t GetNumClasses(const std::string &category_field); /// \brief get meta of header - std::pair> GetMeta(const std::string &file_path, - std::shared_ptr meta_data_ptr); - - /// \brief extract uncompressed data based on column list - std::pair>> UnCompressBlob(const std::vector &raw_blob_data); + Status GetMeta(const std::string &file_path, std::shared_ptr meta_data_ptr, + std::shared_ptr> *addresses_ptr); protected: uint64_t header_size_; // header size diff --git a/mindspore/ccsrc/minddata/mindrecord/include/shard_sample.h b/mindspore/ccsrc/minddata/mindrecord/include/shard_sample.h index 6f469625df..a6cdc08e24 100644 --- a/mindspore/ccsrc/minddata/mindrecord/include/shard_sample.h +++ b/mindspore/ccsrc/minddata/mindrecord/include/shard_sample.h @@ -40,11 +40,11 @@ class __attribute__((visibility("default"))) ShardSample : public ShardOperator ~ShardSample() override{}; - MSRStatus Execute(ShardTaskList &tasks) override; + Status Execute(ShardTaskList &tasks) override; - MSRStatus UpdateTasks(ShardTaskList &tasks, int taking); + Status UpdateTasks(ShardTaskList &tasks, int taking); - MSRStatus SufExecute(ShardTaskList &tasks) override; + Status SufExecute(ShardTaskList &tasks) override; int64_t GetNumSamples(int64_t dataset_size, int64_t num_classes) override; diff --git a/mindspore/ccsrc/minddata/mindrecord/include/shard_schema.h b/mindspore/ccsrc/minddata/mindrecord/include/shard_schema.h index dae40e14ef..9f632a375b 100644 --- a/mindspore/ccsrc/minddata/mindrecord/include/shard_schema.h +++ b/mindspore/ccsrc/minddata/mindrecord/include/shard_schema.h @@ -39,11 +39,6 @@ class __attribute__((visibility("default"))) Schema { /// \param[in] schema the schema's json static std::shared_ptr Build(std::string desc, const json &schema); - /// \brief obtain the json schema and its description for python - /// \param[in] desc the description of the schema - /// \param[in] schema the schema's json - static std::shared_ptr Build(std::string desc, pybind11::handle schema); - /// \brief compare two schema to judge if they are equal /// \param b another schema to be judged /// \return true if they are equal,false if not @@ -57,10 +52,6 @@ class __attribute__((visibility("default"))) Schema { /// \return the json format of the schema and its description json GetSchema() const; - /// \brief get the schema and its description for python method - /// \return the python object of the schema and its description - pybind11::object GetSchemaForPython() const; - /// set the schema id /// \param[in] id the id need to be set void SetSchemaID(int64_t id); diff --git a/mindspore/ccsrc/minddata/mindrecord/include/shard_segment.h b/mindspore/ccsrc/minddata/mindrecord/include/shard_segment.h index 53aed517c5..4fb1b30af6 100644 --- a/mindspore/ccsrc/minddata/mindrecord/include/shard_segment.h +++ b/mindspore/ccsrc/minddata/mindrecord/include/shard_segment.h @@ -17,6 +17,7 @@ #ifndef MINDSPORE_CCSRC_MINDDATA_MINDRECORD_INCLUDE_SHARD_SEGMENT_H_ #define MINDSPORE_CCSRC_MINDDATA_MINDRECORD_INCLUDE_SHARD_SEGMENT_H_ +#include #include #include #include @@ -25,6 +26,10 @@ namespace mindspore { namespace mindrecord { +using CATEGORY_INFO = std::vector>; +using PAGES = std::vector, json>>; +using PAGES_LOAD = std::vector, pybind11::object>>; + class __attribute__((visibility("default"))) ShardSegment : public ShardReader { public: ShardSegment(); @@ -33,12 +38,12 @@ class __attribute__((visibility("default"))) ShardSegment : public ShardReader { /// \brief Get candidate category fields /// \return a list of fields names which are the candidates of category - std::pair> GetCategoryFields(); + Status GetCategoryFields(std::shared_ptr> *fields_ptr); /// \brief Set category field /// \param[in] category_field category name /// \return true if category name is existed - MSRStatus SetCategoryField(std::string category_field); + Status SetCategoryField(std::string category_field); /// \brief Thread-safe implementation of ReadCategoryInfo /// \return statistics data in json format with 2 field: "key" and "categories". @@ -50,47 +55,41 @@ class __attribute__((visibility("default"))) ShardSegment : public ShardReader { /// { "key": "label", /// "categories": [ { "count": 3, "id": 0, "name": "sport", }, /// { "count": 3, "id": 1, "name": "finance", } ] } - std::pair ReadCategoryInfo(); + Status ReadCategoryInfo(std::shared_ptr *category_ptr); /// \brief Thread-safe implementation of ReadAtPageById /// \param[in] category_id category ID /// \param[in] page_no page number /// \param[in] n_rows_of_page rows number in one page /// \return images array, image is a vector of uint8_t - std::pair>> ReadAtPageById(int64_t category_id, int64_t page_no, - int64_t n_rows_of_page); + Status ReadAtPageById(int64_t category_id, int64_t page_no, int64_t n_rows_of_page, + std::shared_ptr>> *page_ptr); /// \brief Thread-safe implementation of ReadAtPageByName /// \param[in] category_name category Name /// \param[in] page_no page number /// \param[in] n_rows_of_page rows number in one page /// \return images array, image is a vector of uint8_t - std::pair>> ReadAtPageByName(std::string category_name, int64_t page_no, - int64_t n_rows_of_page); - - std::pair, json>>> ReadAllAtPageById(int64_t category_id, - int64_t page_no, - int64_t n_rows_of_page); - - std::pair, json>>> ReadAllAtPageByName( - std::string category_name, int64_t page_no, int64_t n_rows_of_page); + Status ReadAtPageByName(std::string category_name, int64_t page_no, int64_t n_rows_of_page, + std::shared_ptr>> *pages_ptr); - std::pair, pybind11::object>>> ReadAtPageByIdPy( - int64_t category_id, int64_t page_no, int64_t n_rows_of_page); + Status ReadAllAtPageById(int64_t category_id, int64_t page_no, int64_t n_rows_of_page, + std::shared_ptr *pages_ptr); - std::pair, pybind11::object>>> ReadAtPageByNamePy( - std::string category_name, int64_t page_no, int64_t n_rows_of_page); + Status ReadAllAtPageByName(std::string category_name, int64_t page_no, int64_t n_rows_of_page, + std::shared_ptr *pages_ptr); std::pair> GetBlobFields(); private: - std::pair>> WrapCategoryInfo(); + Status WrapCategoryInfo(std::shared_ptr *category_info_ptr); std::string ToJsonForCategory(const std::vector> &tri_vec); std::string CleanUp(std::string fieldName); - std::pair> PackImages(int group_id, int shard_id, std::vector offset); + Status PackImages(int group_id, int shard_id, std::vector offset, + std::shared_ptr> *images_ptr); std::vector candidate_category_fields_; std::string current_category_field_; diff --git a/mindspore/ccsrc/minddata/mindrecord/include/shard_sequential_sample.h b/mindspore/ccsrc/minddata/mindrecord/include/shard_sequential_sample.h index 6b17497d53..9f86621f98 100644 --- a/mindspore/ccsrc/minddata/mindrecord/include/shard_sequential_sample.h +++ b/mindspore/ccsrc/minddata/mindrecord/include/shard_sequential_sample.h @@ -33,7 +33,7 @@ class __attribute__((visibility("default"))) ShardSequentialSample : public Shar ~ShardSequentialSample() override{}; - MSRStatus Execute(ShardTaskList &tasks) override; + Status Execute(ShardTaskList &tasks) override; int64_t GetNumSamples(int64_t dataset_size, int64_t num_classes) override; diff --git a/mindspore/ccsrc/minddata/mindrecord/include/shard_shuffle.h b/mindspore/ccsrc/minddata/mindrecord/include/shard_shuffle.h index 1eda526a77..b14d975e58 100644 --- a/mindspore/ccsrc/minddata/mindrecord/include/shard_shuffle.h +++ b/mindspore/ccsrc/minddata/mindrecord/include/shard_shuffle.h @@ -31,19 +31,19 @@ class __attribute__((visibility("default"))) ShardShuffle : public ShardOperator ~ShardShuffle() override{}; - MSRStatus Execute(ShardTaskList &tasks) override; + Status Execute(ShardTaskList &tasks) override; int64_t GetNumSamples(int64_t dataset_size, int64_t num_classes) override; private: // Private helper function - MSRStatus CategoryShuffle(ShardTaskList &tasks); + Status CategoryShuffle(ShardTaskList &tasks); // Keep the file sequence the same but shuffle the data within each file - MSRStatus ShuffleInfile(ShardTaskList &tasks); + Status ShuffleInfile(ShardTaskList &tasks); // Shuffle the file sequence but keep the order of data within each file - MSRStatus ShuffleFiles(ShardTaskList &tasks); + Status ShuffleFiles(ShardTaskList &tasks); uint32_t shuffle_seed_; int64_t no_of_samples_; diff --git a/mindspore/ccsrc/minddata/mindrecord/include/shard_statistics.h b/mindspore/ccsrc/minddata/mindrecord/include/shard_statistics.h index 45dfa277f6..910ff9cff1 100644 --- a/mindspore/ccsrc/minddata/mindrecord/include/shard_statistics.h +++ b/mindspore/ccsrc/minddata/mindrecord/include/shard_statistics.h @@ -39,11 +39,6 @@ class __attribute__((visibility("default"))) Statistics { /// \param[in] statistics the statistic needs to be saved static std::shared_ptr Build(std::string desc, const json &statistics); - /// \brief save the statistic from python and its description - /// \param[in] desc the statistic's description - /// \param[in] statistics the statistic needs to be saved - static std::shared_ptr Build(std::string desc, pybind11::handle statistics); - ~Statistics() = default; /// \brief compare two statistics to judge if they are equal @@ -59,10 +54,6 @@ class __attribute__((visibility("default"))) Statistics { /// \return json format of the statistic json GetStatistics() const; - /// \brief get the statistic for python - /// \return the python object of statistics - pybind11::object GetStatisticsForPython() const; - /// \brief decode the bson statistics to json /// \param[in] encodedStatistics the bson type of statistics /// \return json type of statistic diff --git a/mindspore/ccsrc/minddata/mindrecord/include/shard_writer.h b/mindspore/ccsrc/minddata/mindrecord/include/shard_writer.h index d014536ff3..34af97505b 100644 --- a/mindspore/ccsrc/minddata/mindrecord/include/shard_writer.h +++ b/mindspore/ccsrc/minddata/mindrecord/include/shard_writer.h @@ -55,69 +55,60 @@ class __attribute__((visibility("default"))) ShardWriter { /// \brief Open file at the beginning /// \param[in] paths the file names list /// \param[in] append new data at the end of file if true, otherwise overwrite file - /// \return MSRStatus the status of MSRStatus - MSRStatus Open(const std::vector &paths, bool append = false); + /// \return Status + Status Open(const std::vector &paths, bool append = false); /// \brief Open file at the ending /// \param[in] paths the file names list /// \return MSRStatus the status of MSRStatus - MSRStatus OpenForAppend(const std::string &path); + Status OpenForAppend(const std::string &path); /// \brief Write header to disk /// \return MSRStatus the status of MSRStatus - MSRStatus Commit(); + Status Commit(); /// \brief Set file size /// \param[in] header_size the size of header, only (1< header_data); + Status SetShardHeader(std::shared_ptr header_data); /// \brief write raw data by group size /// \param[in] raw_data the vector of raw json data, vector format /// \param[in] blob_data the vector of image data /// \param[in] sign validate data or not /// \return MSRStatus the status of MSRStatus to judge if write successfully - MSRStatus WriteRawData(std::map> &raw_data, vector> &blob_data, - bool sign = true, bool parallel_writer = false); - - /// \brief write raw data by group size for call from python - /// \param[in] raw_data the vector of raw json data, python-handle format - /// \param[in] blob_data the vector of image data - /// \param[in] sign validate data or not - /// \return MSRStatus the status of MSRStatus to judge if write successfully - MSRStatus WriteRawData(std::map> &raw_data, vector> &blob_data, - bool sign = true, bool parallel_writer = false); + Status WriteRawData(std::map> &raw_data, vector> &blob_data, + bool sign = true, bool parallel_writer = false); /// \brief write raw data by group size for call from python /// \param[in] raw_data the vector of raw json data, python-handle format /// \param[in] blob_data the vector of blob json data, python-handle format /// \param[in] sign validate data or not /// \return MSRStatus the status of MSRStatus to judge if write successfully - MSRStatus WriteRawData(std::map> &raw_data, - std::map> &blob_data, bool sign = true, - bool parallel_writer = false); + Status WriteRawData(std::map> &raw_data, + std::map> &blob_data, bool sign = true, + bool parallel_writer = false); - MSRStatus MergeBlobData(const std::vector &blob_fields, - const std::map>> &row_bin_data, - std::shared_ptr> *output); + Status MergeBlobData(const std::vector &blob_fields, + const std::map>> &row_bin_data, + std::shared_ptr> *output); - static MSRStatus Initialize(const std::unique_ptr *writer_ptr, - const std::vector &file_names); + static Status Initialize(const std::unique_ptr *writer_ptr, const std::vector &file_names); private: /// \brief write shard header data to disk - MSRStatus WriteShardHeader(); + Status WriteShardHeader(); /// \brief erase error data void DeleteErrorData(std::map> &raw_data, std::vector> &blob_data); @@ -130,108 +121,107 @@ class __attribute__((visibility("default"))) ShardWriter { std::map &err_raw_data); /// \brief write shard header data to disk - std::tuple ValidateRawData(std::map> &raw_data, - std::vector> &blob_data, bool sign); + Status ValidateRawData(std::map> &raw_data, std::vector> &blob_data, + bool sign, std::shared_ptr> *count_ptr); /// \brief fill data array in multiple thread run void FillArray(int start, int end, std::map> &raw_data, std::vector> &bin_data); /// \brief serialized raw data - MSRStatus SerializeRawData(std::map> &raw_data, - std::vector> &bin_data, uint32_t row_count); + Status SerializeRawData(std::map> &raw_data, std::vector> &bin_data, + uint32_t row_count); /// \brief write all data parallel - MSRStatus ParallelWriteData(const std::vector> &blob_data, - const std::vector> &bin_raw_data); + Status ParallelWriteData(const std::vector> &blob_data, + const std::vector> &bin_raw_data); /// \brief write data shard by shard - MSRStatus WriteByShard(int shard_id, int start_row, int end_row, const std::vector> &blob_data, - const std::vector> &bin_raw_data); + Status WriteByShard(int shard_id, int start_row, int end_row, const std::vector> &blob_data, + const std::vector> &bin_raw_data); /// \brief break image data up into multiple row groups - MSRStatus CutRowGroup(int start_row, int end_row, const std::vector> &blob_data, - std::vector> &rows_in_group, const std::shared_ptr &last_raw_page, - const std::shared_ptr &last_blob_page); + Status CutRowGroup(int start_row, int end_row, const std::vector> &blob_data, + std::vector> &rows_in_group, const std::shared_ptr &last_raw_page, + const std::shared_ptr &last_blob_page); /// \brief append partial blob data to previous page - MSRStatus AppendBlobPage(const int &shard_id, const std::vector> &blob_data, - const std::vector> &rows_in_group, - const std::shared_ptr &last_blob_page); - - /// \brief write new blob data page to disk - MSRStatus NewBlobPage(const int &shard_id, const std::vector> &blob_data, + Status AppendBlobPage(const int &shard_id, const std::vector> &blob_data, const std::vector> &rows_in_group, const std::shared_ptr &last_blob_page); + /// \brief write new blob data page to disk + Status NewBlobPage(const int &shard_id, const std::vector> &blob_data, + const std::vector> &rows_in_group, + const std::shared_ptr &last_blob_page); + /// \brief shift last row group to next raw page for new appending - MSRStatus ShiftRawPage(const int &shard_id, const std::vector> &rows_in_group, - std::shared_ptr &last_raw_page); + Status ShiftRawPage(const int &shard_id, const std::vector> &rows_in_group, + std::shared_ptr &last_raw_page); /// \brief write raw data page to disk - MSRStatus WriteRawPage(const int &shard_id, const std::vector> &rows_in_group, - std::shared_ptr &last_raw_page, const std::vector> &bin_raw_data); + Status WriteRawPage(const int &shard_id, const std::vector> &rows_in_group, + std::shared_ptr &last_raw_page, const std::vector> &bin_raw_data); /// \brief generate empty raw data page - void EmptyRawPage(const int &shard_id, std::shared_ptr &last_raw_page); + Status EmptyRawPage(const int &shard_id, std::shared_ptr &last_raw_page); /// \brief append a row group at the end of raw page - MSRStatus AppendRawPage(const int &shard_id, const std::vector> &rows_in_group, - const int &chunk_id, int &last_row_groupId, std::shared_ptr last_raw_page, - const std::vector> &bin_raw_data); + Status AppendRawPage(const int &shard_id, const std::vector> &rows_in_group, const int &chunk_id, + int &last_row_groupId, std::shared_ptr last_raw_page, + const std::vector> &bin_raw_data); /// \brief write blob chunk to disk - MSRStatus FlushBlobChunk(const std::shared_ptr &out, const std::vector> &blob_data, - const std::pair &blob_row); + Status FlushBlobChunk(const std::shared_ptr &out, const std::vector> &blob_data, + const std::pair &blob_row); /// \brief write raw chunk to disk - MSRStatus FlushRawChunk(const std::shared_ptr &out, - const std::vector> &rows_in_group, const int &chunk_id, - const std::vector> &bin_raw_data); + Status FlushRawChunk(const std::shared_ptr &out, const std::vector> &rows_in_group, + const int &chunk_id, const std::vector> &bin_raw_data); /// \brief break up into tasks by shard std::vector> BreakIntoShards(); /// \brief calculate raw data size row by row - MSRStatus SetRawDataSize(const std::vector> &bin_raw_data); + Status SetRawDataSize(const std::vector> &bin_raw_data); /// \brief calculate blob data size row by row - MSRStatus SetBlobDataSize(const std::vector> &blob_data); + Status SetBlobDataSize(const std::vector> &blob_data); /// \brief populate last raw page pointer - void SetLastRawPage(const int &shard_id, std::shared_ptr &last_raw_page); + Status SetLastRawPage(const int &shard_id, std::shared_ptr &last_raw_page); /// \brief populate last blob page pointer - void SetLastBlobPage(const int &shard_id, std::shared_ptr &last_blob_page); + Status SetLastBlobPage(const int &shard_id, std::shared_ptr &last_blob_page); /// \brief check the data by schema - MSRStatus CheckData(const std::map> &raw_data); + Status CheckData(const std::map> &raw_data); /// \brief check the data and type - MSRStatus CheckDataTypeAndValue(const std::string &key, const json &value, const json &data, const int &i, - std::map &err_raw_data); + Status CheckDataTypeAndValue(const std::string &key, const json &value, const json &data, const int &i, + std::map &err_raw_data); /// \brief Lock writer and save pages info - int LockWriter(bool parallel_writer = false); + Status LockWriter(bool parallel_writer, std::unique_ptr *fd_ptr); /// \brief Unlock writer and save pages info - MSRStatus UnlockWriter(int fd, bool parallel_writer = false); + Status UnlockWriter(int fd, bool parallel_writer = false); /// \brief Check raw data before writing - MSRStatus WriteRawDataPreCheck(std::map> &raw_data, vector> &blob_data, - bool sign, int *schema_count, int *row_count); + Status WriteRawDataPreCheck(std::map> &raw_data, vector> &blob_data, + bool sign, int *schema_count, int *row_count); /// \brief Get full path from file name - MSRStatus GetFullPathFromFileName(const std::vector &paths); + Status GetFullPathFromFileName(const std::vector &paths); /// \brief Open files - MSRStatus OpenDataFiles(bool append); + Status OpenDataFiles(bool append); /// \brief Remove lock file - MSRStatus RemoveLockFile(); + Status RemoveLockFile(); /// \brief Remove lock file - MSRStatus InitLockFile(); + Status InitLockFile(); private: const std::string kLockFileSuffix = "_Locker"; diff --git a/mindspore/ccsrc/minddata/mindrecord/io/shard_index_generator.cc b/mindspore/ccsrc/minddata/mindrecord/io/shard_index_generator.cc index 4c6681e151..2b2ba09124 100644 --- a/mindspore/ccsrc/minddata/mindrecord/io/shard_index_generator.cc +++ b/mindspore/ccsrc/minddata/mindrecord/io/shard_index_generator.cc @@ -37,70 +37,48 @@ ShardIndexGenerator::ShardIndexGenerator(const std::string &file_path, bool appe task_(0), write_success_(true) {} -MSRStatus ShardIndexGenerator::Build() { - auto ret = ShardHeader::BuildSingleHeader(file_path_); - if (ret.first != SUCCESS) { - return FAILED; - } - auto json_header = ret.second; - - auto ret2 = GetDatasetFiles(file_path_, json_header["shard_addresses"]); - if (SUCCESS != ret2.first) { - return FAILED; - } +Status ShardIndexGenerator::Build() { + std::shared_ptr header_ptr; + RETURN_IF_NOT_OK(ShardHeader::BuildSingleHeader(file_path_, &header_ptr)); + auto ds = std::make_shared>(); + RETURN_IF_NOT_OK(GetDatasetFiles(file_path_, (*header_ptr)["shard_addresses"], &ds)); ShardHeader header = ShardHeader(); - auto addresses = ret2.second; - if (header.BuildDataset(addresses) == FAILED) { - return FAILED; - } + RETURN_IF_NOT_OK(header.BuildDataset(*ds)); shard_header_ = header; MS_LOG(INFO) << "Init header from mindrecord file for index successfully."; - return SUCCESS; + return Status::OK(); } -std::pair ShardIndexGenerator::GetValueByField(const string &field, json input) { - if (field.empty()) { - MS_LOG(ERROR) << "The input field is None."; - return {FAILED, ""}; - } - - if (input.empty()) { - MS_LOG(ERROR) << "The input json is None."; - return {FAILED, ""}; - } +Status ShardIndexGenerator::GetValueByField(const string &field, json input, std::shared_ptr *value) { + RETURN_UNEXPECTED_IF_NULL(value); + CHECK_FAIL_RETURN_UNEXPECTED(!field.empty(), "The input field is empty."); + CHECK_FAIL_RETURN_UNEXPECTED(!input.empty(), "The input json is empty."); // parameter input does not contain the field - if (input.find(field) == input.end()) { - MS_LOG(ERROR) << "The field " << field << " is not found in parameter " << input; - return {FAILED, ""}; - } + CHECK_FAIL_RETURN_UNEXPECTED(input.find(field) != input.end(), + "The field " + field + " is not found in json " + input.dump()); // schema does not contain the field auto schema = shard_header_.GetSchemas()[0]->GetSchema()["schema"]; - if (schema.find(field) == schema.end()) { - MS_LOG(ERROR) << "The field " << field << " is not found in schema " << schema; - return {FAILED, ""}; - } + CHECK_FAIL_RETURN_UNEXPECTED(schema.find(field) != schema.end(), + "The field " + field + " is not found in schema " + schema.dump()); // field should be scalar type - if (kScalarFieldTypeSet.find(schema[field]["type"]) == kScalarFieldTypeSet.end()) { - MS_LOG(ERROR) << "The field " << field << " type is " << schema[field]["type"] << ", it is not retrievable"; - return {FAILED, ""}; - } + CHECK_FAIL_RETURN_UNEXPECTED( + kScalarFieldTypeSet.find(schema[field]["type"]) != kScalarFieldTypeSet.end(), + "The field " + field + " type is " + schema[field]["type"].dump() + " which is not retrievable."); if (kNumberFieldTypeSet.find(schema[field]["type"]) != kNumberFieldTypeSet.end()) { auto schema_field_options = schema[field]; - if (schema_field_options.find("shape") == schema_field_options.end()) { - return {SUCCESS, input[field].dump()}; - } else { - // field with shape option - MS_LOG(ERROR) << "The field " << field << " shape is " << schema[field]["shape"] << " which is not retrievable"; - return {FAILED, ""}; - } + CHECK_FAIL_RETURN_UNEXPECTED( + schema_field_options.find("shape") == schema_field_options.end(), + "The field " + field + " shape is " + schema[field]["shape"].dump() + " which is not retrievable."); + *value = std::make_shared(input[field].dump()); + } else { + // the field type is string in here + *value = std::make_shared(input[field].get()); } - - // the field type is string in here - return {SUCCESS, input[field].get()}; + return Status::OK(); } std::string ShardIndexGenerator::TakeFieldType(const string &field_path, json schema) { @@ -150,24 +128,28 @@ int ShardIndexGenerator::Callback(void *not_used, int argc, char **argv, char ** return 0; } -MSRStatus ShardIndexGenerator::ExecuteSQL(const std::string &sql, sqlite3 *db, const std::string &success_msg) { +Status ShardIndexGenerator::ExecuteSQL(const std::string &sql, sqlite3 *db, const std::string &success_msg) { char *z_err_msg = nullptr; int rc = sqlite3_exec(db, common::SafeCStr(sql), Callback, nullptr, &z_err_msg); if (rc != SQLITE_OK) { - MS_LOG(ERROR) << "Sql error: " << z_err_msg; + std::ostringstream oss; + oss << "Failed to exec sqlite3_exec, msg is: " << z_err_msg; + MS_LOG(DEBUG) << oss.str(); sqlite3_free(z_err_msg); - return FAILED; + sqlite3_close(db); + RETURN_STATUS_UNEXPECTED(oss.str()); } else { if (!success_msg.empty()) { - MS_LOG(DEBUG) << "Sqlite3_exec exec success, msg is: " << success_msg; + MS_LOG(DEBUG) << "Suceess to exec sqlite3_exec, msg is: " << success_msg; } sqlite3_free(z_err_msg); - return SUCCESS; + return Status::OK(); } } -std::pair ShardIndexGenerator::GenerateFieldName( - const std::pair &field) { +Status ShardIndexGenerator::GenerateFieldName(const std::pair &field, + std::shared_ptr *fn_ptr) { + RETURN_UNEXPECTED_IF_NULL(fn_ptr); // Replaces dots and dashes with underscores for SQL use std::string field_name = field.second; // white list to avoid sql injection @@ -176,95 +158,71 @@ std::pair ShardIndexGenerator::GenerateFieldName( auto pos = std::find_if_not(field_name.begin(), field_name.end(), [](char x) { return (x >= 'A' && x <= 'Z') || (x >= 'a' && x <= 'z') || x == '_' || (x >= '0' && x <= '9'); }); - if (pos != field_name.end()) { - MS_LOG(ERROR) << "Field name must be composed of '0-9' or 'a-z' or 'A-Z' or '_', field_name: " << field_name; - return {FAILED, ""}; - } - return {SUCCESS, field_name + "_" + std::to_string(field.first)}; + CHECK_FAIL_RETURN_UNEXPECTED( + pos == field_name.end(), + "Field name must be composed of '0-9' or 'a-z' or 'A-Z' or '_', field_name: " + field_name); + *fn_ptr = std::make_shared(field_name + "_" + std::to_string(field.first)); + return Status::OK(); } -std::pair ShardIndexGenerator::CheckDatabase(const std::string &shard_address) { +Status ShardIndexGenerator::CheckDatabase(const std::string &shard_address, sqlite3 **db) { auto realpath = Common::GetRealPath(shard_address); - if (!realpath.has_value()) { - MS_LOG(ERROR) << "Get real path failed, path=" << shard_address; - return {FAILED, nullptr}; - } - - sqlite3 *db = nullptr; + CHECK_FAIL_RETURN_UNEXPECTED(realpath.has_value(), "Get real path failed, path=" + shard_address); std::ifstream fin(realpath.value()); if (!append_ && fin.good()) { - MS_LOG(ERROR) << "Invalid file, DB file already exist: " << shard_address; fin.close(); - return {FAILED, nullptr}; + RETURN_STATUS_UNEXPECTED("Invalid file, DB file already exist: " + shard_address); } fin.close(); - int rc = sqlite3_open_v2(common::SafeCStr(shard_address), &db, SQLITE_OPEN_READWRITE | SQLITE_OPEN_CREATE, nullptr); - if (rc) { - MS_LOG(ERROR) << "Invalid file, failed to open database: " << shard_address << ", error" << sqlite3_errmsg(db); - return {FAILED, nullptr}; - } else { - MS_LOG(DEBUG) << "Opened database successfully"; - return {SUCCESS, db}; + if (sqlite3_open_v2(common::SafeCStr(shard_address), db, SQLITE_OPEN_READWRITE | SQLITE_OPEN_CREATE, nullptr)) { + RETURN_STATUS_UNEXPECTED("Invalid file, failed to open database: " + shard_address + ", error" + + std::string(sqlite3_errmsg(*db))); } + MS_LOG(DEBUG) << "Opened database successfully"; + return Status::OK(); } -MSRStatus ShardIndexGenerator::CreateShardNameTable(sqlite3 *db, const std::string &shard_name) { +Status ShardIndexGenerator::CreateShardNameTable(sqlite3 *db, const std::string &shard_name) { // create shard_name table std::string sql = "DROP TABLE IF EXISTS SHARD_NAME;"; - if (ExecuteSQL(sql, db, "drop table successfully.") != SUCCESS) { - return FAILED; - } - + RETURN_IF_NOT_OK(ExecuteSQL(sql, db, "drop table successfully.")); sql = "CREATE TABLE SHARD_NAME(NAME TEXT NOT NULL);"; - if (ExecuteSQL(sql, db, "create table successfully.") != SUCCESS) { - return FAILED; - } - + RETURN_IF_NOT_OK(ExecuteSQL(sql, db, "create table successfully.")); sql = "INSERT INTO SHARD_NAME (NAME) VALUES (:SHARD_NAME);"; sqlite3_stmt *stmt = nullptr; if (sqlite3_prepare_v2(db, common::SafeCStr(sql), -1, &stmt, 0) != SQLITE_OK) { if (stmt != nullptr) { (void)sqlite3_finalize(stmt); } - MS_LOG(ERROR) << "SQL error: could not prepare statement, sql: " << sql; - return FAILED; + sqlite3_close(db); + RETURN_STATUS_UNEXPECTED("SQL error: could not prepare statement, sql: " + sql); } int index = sqlite3_bind_parameter_index(stmt, ":SHARD_NAME"); if (sqlite3_bind_text(stmt, index, shard_name.data(), -1, SQLITE_STATIC) != SQLITE_OK) { (void)sqlite3_finalize(stmt); - MS_LOG(ERROR) << "SQL error: could not bind parameter, index: " << index << ", field value: " << shard_name; - return FAILED; + sqlite3_close(db); + RETURN_STATUS_UNEXPECTED("SQL error: could not bind parameter, index: " + std::to_string(index) + + ", field value: " + std::string(shard_name)); } if (sqlite3_step(stmt) != SQLITE_DONE) { (void)sqlite3_finalize(stmt); - MS_LOG(ERROR) << "SQL error: Could not step (execute) stmt."; - return FAILED; + RETURN_STATUS_UNEXPECTED("SQL error: Could not step (execute) stmt."); } - (void)sqlite3_finalize(stmt); - return SUCCESS; + return Status::OK(); } -std::pair ShardIndexGenerator::CreateDatabase(int shard_no) { +Status ShardIndexGenerator::CreateDatabase(int shard_no, sqlite3 **db) { std::string shard_address = shard_header_.GetShardAddressByID(shard_no); - if (shard_address.empty()) { - MS_LOG(ERROR) << "Shard address is null, shard no: " << shard_no; - return {FAILED, nullptr}; - } - - string shard_name = GetFileName(shard_address).second; + CHECK_FAIL_RETURN_UNEXPECTED(!shard_address.empty(), "Shard address is empty, shard No: " + shard_no); + std::shared_ptr fn_ptr; + RETURN_IF_NOT_OK(GetFileName(shard_address, &fn_ptr)); shard_address += ".db"; - auto ret1 = CheckDatabase(shard_address); - if (ret1.first != SUCCESS) { - return {FAILED, nullptr}; - } - sqlite3 *db = ret1.second; + RETURN_IF_NOT_OK(CheckDatabase(shard_address, db)); std::string sql = "DROP TABLE IF EXISTS INDEXES;"; - if (ExecuteSQL(sql, db, "drop table successfully.") != SUCCESS) { - return {FAILED, nullptr}; - } + RETURN_IF_NOT_OK(ExecuteSQL(sql, *db, "drop table successfully.")); sql = "CREATE TABLE INDEXES(" " ROW_ID INT NOT NULL, PAGE_ID_RAW INT NOT NULL" @@ -273,95 +231,79 @@ std::pair ShardIndexGenerator::CreateDatabase(int shard_no ", PAGE_OFFSET_BLOB INT NOT NULL, PAGE_OFFSET_BLOB_END INT NOT NULL"; int field_no = 0; + std::shared_ptr field_ptr; for (const auto &field : fields_) { uint64_t schema_id = field.first; - auto result = shard_header_.GetSchemaByID(schema_id); - if (result.second != SUCCESS) { - return {FAILED, nullptr}; - } - json json_schema = (result.first->GetSchema())["schema"]; + std::shared_ptr schema_ptr; + RETURN_IF_NOT_OK(shard_header_.GetSchemaByID(schema_id, &schema_ptr)); + json json_schema = (schema_ptr->GetSchema())["schema"]; std::string type = ConvertJsonToSQL(TakeFieldType(field.second, json_schema)); - auto ret = GenerateFieldName(field); - if (ret.first != SUCCESS) { - return {FAILED, nullptr}; - } - sql += ",INC_" + std::to_string(field_no++) + " INT, " + ret.second + " " + type; + RETURN_IF_NOT_OK(GenerateFieldName(field, &field_ptr)); + sql += ",INC_" + std::to_string(field_no++) + " INT, " + *field_ptr + " " + type; } sql += ", PRIMARY KEY(ROW_ID"; for (uint64_t i = 0; i < fields_.size(); ++i) { sql += ",INC_" + std::to_string(i); } sql += "));"; - if (ExecuteSQL(sql, db, "create table successfully.") != SUCCESS) { - return {FAILED, nullptr}; - } - - if (CreateShardNameTable(db, shard_name) != SUCCESS) { - return {FAILED, nullptr}; - } - return {SUCCESS, db}; + RETURN_IF_NOT_OK(ExecuteSQL(sql, *db, "create table successfully.")); + RETURN_IF_NOT_OK(CreateShardNameTable(*db, *fn_ptr)); + return Status::OK(); } -std::pair> ShardIndexGenerator::GetSchemaDetails(const std::vector &schema_lens, - std::fstream &in) { - std::vector schema_details; +Status ShardIndexGenerator::GetSchemaDetails(const std::vector &schema_lens, std::fstream &in, + std::shared_ptr> *detail_ptr) { + RETURN_UNEXPECTED_IF_NULL(detail_ptr); if (schema_count_ <= kMaxSchemaCount) { for (int sc = 0; sc < schema_count_; ++sc) { std::vector schema_detail(schema_lens[sc]); - auto &io_read = in.read(&schema_detail[0], schema_lens[sc]); if (!io_read.good() || io_read.fail() || io_read.bad()) { - MS_LOG(ERROR) << "File read failed"; in.close(); - return {FAILED, {}}; + RETURN_STATUS_UNEXPECTED("Failed to read file."); } - - schema_details.emplace_back(json::from_msgpack(std::string(schema_detail.begin(), schema_detail.end()))); + auto j = json::from_msgpack(std::string(schema_detail.begin(), schema_detail.end())); + (*detail_ptr)->emplace_back(j); } } - - return {SUCCESS, schema_details}; + return Status::OK(); } -std::pair ShardIndexGenerator::GenerateRawSQL( - const std::vector> &fields) { +Status ShardIndexGenerator::GenerateRawSQL(const std::vector> &fields, + std::shared_ptr *sql_ptr) { std::string sql = "INSERT INTO INDEXES (ROW_ID,ROW_GROUP_ID,PAGE_ID_RAW,PAGE_OFFSET_RAW,PAGE_OFFSET_RAW_END," "PAGE_ID_BLOB,PAGE_OFFSET_BLOB,PAGE_OFFSET_BLOB_END"; int field_no = 0; for (const auto &field : fields) { - auto ret = GenerateFieldName(field); - if (ret.first != SUCCESS) { - return {FAILED, ""}; - } - sql += ",INC_" + std::to_string(field_no++) + "," + ret.second; + std::shared_ptr fn_ptr; + RETURN_IF_NOT_OK(GenerateFieldName(field, &fn_ptr)); + sql += ",INC_" + std::to_string(field_no++) + "," + *fn_ptr; } sql += ") VALUES( :ROW_ID,:ROW_GROUP_ID,:PAGE_ID_RAW,:PAGE_OFFSET_RAW,:PAGE_OFFSET_RAW_END,:PAGE_ID_BLOB," ":PAGE_OFFSET_BLOB,:PAGE_OFFSET_BLOB_END"; field_no = 0; for (const auto &field : fields) { - auto ret = GenerateFieldName(field); - if (ret.first != SUCCESS) { - return {FAILED, ""}; - } - sql += ",:INC_" + std::to_string(field_no++) + ",:" + ret.second; + std::shared_ptr fn_ptr; + RETURN_IF_NOT_OK(GenerateFieldName(field, &fn_ptr)); + sql += ",:INC_" + std::to_string(field_no++) + ",:" + *fn_ptr; } sql += " )"; - return {SUCCESS, sql}; + + *sql_ptr = std::make_shared(sql); + return Status::OK(); } -MSRStatus ShardIndexGenerator::BindParameterExecuteSQL( - sqlite3 *db, const std::string &sql, - const std::vector>> &data) { +Status ShardIndexGenerator::BindParameterExecuteSQL(sqlite3 *db, const std::string &sql, const ROW_DATA &data) { sqlite3_stmt *stmt = nullptr; if (sqlite3_prepare_v2(db, common::SafeCStr(sql), -1, &stmt, 0) != SQLITE_OK) { if (stmt != nullptr) { (void)sqlite3_finalize(stmt); } - MS_LOG(ERROR) << "SQL error: could not prepare statement, sql: " << sql; - return FAILED; + sqlite3_close(db); + RETURN_STATUS_UNEXPECTED("SQL error: could not prepare statement, sql: " + sql); } for (auto &row : data) { for (auto &field : row) { @@ -373,45 +315,47 @@ MSRStatus ShardIndexGenerator::BindParameterExecuteSQL( if (field_type == "INTEGER") { if (sqlite3_bind_int64(stmt, index, std::stoll(field_value)) != SQLITE_OK) { (void)sqlite3_finalize(stmt); - MS_LOG(ERROR) << "SQL error: could not bind parameter, index: " << index - << ", field value: " << std::stoll(field_value); - return FAILED; + sqlite3_close(db); + RETURN_STATUS_UNEXPECTED("SQL error: could not bind parameter, index: " + std::to_string(index) + + ", field value: " + std::string(field_value)); } } else if (field_type == "NUMERIC") { if (sqlite3_bind_double(stmt, index, std::stold(field_value)) != SQLITE_OK) { (void)sqlite3_finalize(stmt); - MS_LOG(ERROR) << "SQL error: could not bind parameter, index: " << index - << ", field value: " << std::stold(field_value); - return FAILED; + sqlite3_close(db); + RETURN_STATUS_UNEXPECTED("SQL error: could not bind parameter, index: " + std::to_string(index) + + ", field value: " + std::string(field_value)); } } else if (field_type == "NULL") { if (sqlite3_bind_null(stmt, index) != SQLITE_OK) { (void)sqlite3_finalize(stmt); - MS_LOG(ERROR) << "SQL error: could not bind parameter, index: " << index << ", field value: NULL"; - return FAILED; + + sqlite3_close(db); + RETURN_STATUS_UNEXPECTED("SQL error: could not bind parameter, index: " + std::to_string(index) + + ", field value: NULL"); } } else { if (sqlite3_bind_text(stmt, index, common::SafeCStr(field_value), -1, SQLITE_STATIC) != SQLITE_OK) { (void)sqlite3_finalize(stmt); - MS_LOG(ERROR) << "SQL error: could not bind parameter, index: " << index << ", field value: " << field_value; - return FAILED; + sqlite3_close(db); + RETURN_STATUS_UNEXPECTED("SQL error: could not bind parameter, index: " + std::to_string(index) + + ", field value: " + std::string(field_value)); } } } if (sqlite3_step(stmt) != SQLITE_DONE) { (void)sqlite3_finalize(stmt); - MS_LOG(ERROR) << "SQL error: Could not step (execute) stmt."; - return FAILED; + RETURN_STATUS_UNEXPECTED("SQL error: Could not step (execute) stmt."); } (void)sqlite3_reset(stmt); } (void)sqlite3_finalize(stmt); - return SUCCESS; + return Status::OK(); } -MSRStatus ShardIndexGenerator::AddBlobPageInfo(std::vector> &row_data, - const std::shared_ptr cur_blob_page, - uint64_t &cur_blob_page_offset, std::fstream &in) { +Status ShardIndexGenerator::AddBlobPageInfo(std::vector> &row_data, + const std::shared_ptr cur_blob_page, uint64_t &cur_blob_page_offset, + std::fstream &in) { row_data.emplace_back(":PAGE_ID_BLOB", "INTEGER", std::to_string(cur_blob_page->GetPageID())); // blob data start @@ -419,89 +363,71 @@ MSRStatus ShardIndexGenerator::AddBlobPageInfo(std::vectorGetPageID() + header_size_ + cur_blob_page_offset, std::ios::beg); if (!io_seekg_blob.good() || io_seekg_blob.fail() || io_seekg_blob.bad()) { - MS_LOG(ERROR) << "File seekg failed"; in.close(); - return FAILED; + RETURN_STATUS_UNEXPECTED("Failed to seekg file."); } - uint64_t image_size = 0; - auto &io_read = in.read(reinterpret_cast(&image_size), kInt64Len); if (!io_read.good() || io_read.fail() || io_read.bad()) { MS_LOG(ERROR) << "File read failed"; in.close(); - return FAILED; + RETURN_STATUS_UNEXPECTED("Failed to read file."); } cur_blob_page_offset += (kInt64Len + image_size); row_data.emplace_back(":PAGE_OFFSET_BLOB_END", "INTEGER", std::to_string(cur_blob_page_offset)); - return SUCCESS; + return Status::OK(); } -void ShardIndexGenerator::AddIndexFieldByRawData( +Status ShardIndexGenerator::AddIndexFieldByRawData( const std::vector &schema_detail, std::vector> &row_data) { - auto result = GenerateIndexFields(schema_detail); - if (result.first == SUCCESS) { - int index = 0; - for (const auto &field : result.second) { - // assume simple field: string , number etc. - row_data.emplace_back(":INC_" + std::to_string(index++), "INTEGER", "0"); - row_data.emplace_back(":" + std::get<0>(field), std::get<1>(field), std::get<2>(field)); - } - } + auto index_fields_ptr = std::make_shared(); + RETURN_IF_NOT_OK(GenerateIndexFields(schema_detail, &index_fields_ptr)); + int index = 0; + for (const auto &field : *index_fields_ptr) { + // assume simple field: string , number etc. + row_data.emplace_back(":INC_" + std::to_string(index++), "INTEGER", "0"); + row_data.emplace_back(":" + std::get<0>(field), std::get<1>(field), std::get<2>(field)); + } + return Status::OK(); } -ROW_DATA ShardIndexGenerator::GenerateRowData(int shard_no, const std::map &blob_id_to_page_id, - int raw_page_id, std::fstream &in) { - std::vector>> full_data; - +Status ShardIndexGenerator::GenerateRowData(int shard_no, const std::map &blob_id_to_page_id, int raw_page_id, + std::fstream &in, std::shared_ptr *row_data_ptr) { + RETURN_UNEXPECTED_IF_NULL(row_data_ptr); // current raw data page - auto ret1 = shard_header_.GetPage(shard_no, raw_page_id); - if (ret1.second != SUCCESS) { - MS_LOG(ERROR) << "Get page failed"; - return {FAILED, {}}; - } - std::shared_ptr cur_raw_page = ret1.first; - + std::shared_ptr page_ptr; + RETURN_IF_NOT_OK(shard_header_.GetPage(shard_no, raw_page_id, &page_ptr)); // related blob page - vector> row_group_list = cur_raw_page->GetRowGroupIds(); + vector> row_group_list = page_ptr->GetRowGroupIds(); // pair: row_group id, offset in raw data page for (pair blob_ids : row_group_list) { // get blob data page according to row_group id auto iter = blob_id_to_page_id.find(blob_ids.first); - if (iter == blob_id_to_page_id.end()) { - MS_LOG(ERROR) << "Convert blob id failed"; - return {FAILED, {}}; - } - auto ret2 = shard_header_.GetPage(shard_no, iter->second); - if (ret2.second != SUCCESS) { - MS_LOG(ERROR) << "Get page failed"; - return {FAILED, {}}; - } - std::shared_ptr cur_blob_page = ret2.first; - + CHECK_FAIL_RETURN_UNEXPECTED(iter != blob_id_to_page_id.end(), "Failed to get page id from blob id."); + std::shared_ptr blob_page_ptr; + RETURN_IF_NOT_OK(shard_header_.GetPage(shard_no, iter->second, &blob_page_ptr)); // offset in current raw data page auto cur_raw_page_offset = static_cast(blob_ids.second); uint64_t cur_blob_page_offset = 0; - for (unsigned int i = cur_blob_page->GetStartRowID(); i < cur_blob_page->GetEndRowID(); ++i) { + for (unsigned int i = blob_page_ptr->GetStartRowID(); i < blob_page_ptr->GetEndRowID(); ++i) { std::vector> row_data; row_data.emplace_back(":ROW_ID", "INTEGER", std::to_string(i)); - row_data.emplace_back(":ROW_GROUP_ID", "INTEGER", std::to_string(cur_blob_page->GetPageTypeID())); - row_data.emplace_back(":PAGE_ID_RAW", "INTEGER", std::to_string(cur_raw_page->GetPageID())); + row_data.emplace_back(":ROW_GROUP_ID", "INTEGER", std::to_string(blob_page_ptr->GetPageTypeID())); + row_data.emplace_back(":PAGE_ID_RAW", "INTEGER", std::to_string(page_ptr->GetPageID())); // raw data start row_data.emplace_back(":PAGE_OFFSET_RAW", "INTEGER", std::to_string(cur_raw_page_offset)); // calculate raw data end auto &io_seekg = - in.seekg(page_size_ * (cur_raw_page->GetPageID()) + header_size_ + cur_raw_page_offset, std::ios::beg); + in.seekg(page_size_ * (page_ptr->GetPageID()) + header_size_ + cur_raw_page_offset, std::ios::beg); if (!io_seekg.good() || io_seekg.fail() || io_seekg.bad()) { - MS_LOG(ERROR) << "File seekg failed"; - return {FAILED, {}}; + in.close(); + RETURN_STATUS_UNEXPECTED("Failed to seekg file."); } - std::vector schema_lens; if (schema_count_ <= kMaxSchemaCount) { for (int sc = 0; sc < schema_count_; sc++) { @@ -509,8 +435,8 @@ ROW_DATA ShardIndexGenerator::GenerateRowData(int shard_no, const std::map(&schema_size), kInt64Len); if (!io_read.good() || io_read.fail() || io_read.bad()) { - MS_LOG(ERROR) << "File read failed"; - return {FAILED, {}}; + in.close(); + RETURN_STATUS_UNEXPECTED("Failed to read file."); } cur_raw_page_offset += (kInt64Len + schema_size); @@ -520,122 +446,79 @@ ROW_DATA ShardIndexGenerator::GenerateRowData(int shard_no, const std::map>(); + RETURN_IF_NOT_OK(GetSchemaDetails(schema_lens, in, &detail_ptr)); // start blob page info - if (AddBlobPageInfo(row_data, cur_blob_page, cur_blob_page_offset, in) != SUCCESS) { - return {FAILED, {}}; - } + RETURN_IF_NOT_OK(AddBlobPageInfo(row_data, blob_page_ptr, cur_blob_page_offset, in)); // start index field - AddIndexFieldByRawData(st_schema_detail.second, row_data); - full_data.push_back(std::move(row_data)); + AddIndexFieldByRawData(*detail_ptr, row_data); + (*row_data_ptr)->push_back(std::move(row_data)); } } - return {SUCCESS, full_data}; + return Status::OK(); } -INDEX_FIELDS ShardIndexGenerator::GenerateIndexFields(const std::vector &schema_detail) { - std::vector> fields; +Status ShardIndexGenerator::GenerateIndexFields(const std::vector &schema_detail, + std::shared_ptr *index_fields_ptr) { + RETURN_UNEXPECTED_IF_NULL(index_fields_ptr); // index fields std::vector> index_fields = shard_header_.GetFields(); for (const auto &field : index_fields) { - if (field.first >= schema_detail.size()) { - return {FAILED, {}}; - } - auto field_value = GetValueByField(field.second, schema_detail[field.first]); - if (field_value.first != SUCCESS) { - MS_LOG(ERROR) << "Get value from json by field name failed"; - return {FAILED, {}}; - } - - auto result = shard_header_.GetSchemaByID(field.first); - if (result.second != SUCCESS) { - return {FAILED, {}}; - } - - std::string field_type = ConvertJsonToSQL(TakeFieldType(field.second, result.first->GetSchema()["schema"])); - auto ret = GenerateFieldName(field); - if (ret.first != SUCCESS) { - return {FAILED, {}}; - } - - fields.emplace_back(ret.second, field_type, field_value.second); - } - return {SUCCESS, std::move(fields)}; + CHECK_FAIL_RETURN_UNEXPECTED(field.first < schema_detail.size(), "Index field id is out of range."); + std::shared_ptr field_val_ptr; + RETURN_IF_NOT_OK(GetValueByField(field.second, schema_detail[field.first], &field_val_ptr)); + std::shared_ptr schema_ptr; + RETURN_IF_NOT_OK(shard_header_.GetSchemaByID(field.first, &schema_ptr)); + std::string field_type = ConvertJsonToSQL(TakeFieldType(field.second, schema_ptr->GetSchema()["schema"])); + std::shared_ptr fn_ptr; + RETURN_IF_NOT_OK(GenerateFieldName(field, &fn_ptr)); + (*index_fields_ptr)->emplace_back(*fn_ptr, field_type, *field_val_ptr); + } + return Status::OK(); } -MSRStatus ShardIndexGenerator::ExecuteTransaction(const int &shard_no, std::pair &db, - const std::vector &raw_page_ids, - const std::map &blob_id_to_page_id) { +Status ShardIndexGenerator::ExecuteTransaction(const int &shard_no, sqlite3 *db, const std::vector &raw_page_ids, + const std::map &blob_id_to_page_id) { // Add index data to database std::string shard_address = shard_header_.GetShardAddressByID(shard_no); - if (shard_address.empty()) { - MS_LOG(ERROR) << "Invalid data, shard address is null"; - return FAILED; - } + CHECK_FAIL_RETURN_UNEXPECTED(!shard_address.empty(), "shard address is empty."); auto realpath = Common::GetRealPath(shard_address); - if (!realpath.has_value()) { - MS_LOG(ERROR) << "Get real path failed, path=" << shard_address; - return FAILED; - } - + CHECK_FAIL_RETURN_UNEXPECTED(realpath.has_value(), "Get real path failed, path=" + shard_address); std::fstream in; in.open(realpath.value(), std::ios::in | std::ios::binary); if (!in.good()) { - MS_LOG(ERROR) << "Invalid file, failed to open file: " << shard_address; in.close(); - return FAILED; + RETURN_STATUS_UNEXPECTED("Failed to open file: " + shard_address); } - (void)sqlite3_exec(db.second, "BEGIN TRANSACTION;", nullptr, nullptr, nullptr); + (void)sqlite3_exec(db, "BEGIN TRANSACTION;", nullptr, nullptr, nullptr); for (int raw_page_id : raw_page_ids) { - auto sql = GenerateRawSQL(fields_); - if (sql.first != SUCCESS) { - MS_LOG(ERROR) << "Generate raw SQL failed"; - in.close(); - sqlite3_close(db.second); - return FAILED; - } - auto data = GenerateRowData(shard_no, blob_id_to_page_id, raw_page_id, in); - if (data.first != SUCCESS) { - MS_LOG(ERROR) << "Generate raw data failed"; - in.close(); - sqlite3_close(db.second); - return FAILED; - } - if (BindParameterExecuteSQL(db.second, sql.second, data.second) == FAILED) { - MS_LOG(ERROR) << "Execute SQL failed"; - in.close(); - sqlite3_close(db.second); - return FAILED; - } - MS_LOG(INFO) << "Insert " << data.second.size() << " rows to index db."; - } - (void)sqlite3_exec(db.second, "END TRANSACTION;", nullptr, nullptr, nullptr); + std::shared_ptr sql_ptr; + RELEASE_AND_RETURN_IF_NOT_OK(GenerateRawSQL(fields_, &sql_ptr), db, in); + auto row_data_ptr = std::make_shared(); + RELEASE_AND_RETURN_IF_NOT_OK(GenerateRowData(shard_no, blob_id_to_page_id, raw_page_id, in, &row_data_ptr), db, in); + RELEASE_AND_RETURN_IF_NOT_OK(BindParameterExecuteSQL(db, *sql_ptr, *row_data_ptr), db, in); + MS_LOG(INFO) << "Insert " << row_data_ptr->size() << " rows to index db."; + } + (void)sqlite3_exec(db, "END TRANSACTION;", nullptr, nullptr, nullptr); in.close(); // Close database - if (sqlite3_close(db.second) != SQLITE_OK) { - MS_LOG(ERROR) << "Close database failed"; - return FAILED; - } - db.second = nullptr; - return SUCCESS; + sqlite3_close(db); + db = nullptr; + return Status::OK(); } -MSRStatus ShardIndexGenerator::WriteToDatabase() { +Status ShardIndexGenerator::WriteToDatabase() { fields_ = shard_header_.GetFields(); page_size_ = shard_header_.GetPageSize(); header_size_ = shard_header_.GetHeaderSize(); schema_count_ = shard_header_.GetSchemaCount(); - if (shard_header_.GetShardCount() > kMaxShardCount) { - MS_LOG(ERROR) << "num shards: " << shard_header_.GetShardCount() << " exceeds max count:" << kMaxSchemaCount; - return FAILED; - } + CHECK_FAIL_RETURN_UNEXPECTED(shard_header_.GetShardCount() <= kMaxShardCount, + "num shards: " + std::to_string(shard_header_.GetShardCount()) + + " exceeds max count:" + std::to_string(kMaxSchemaCount)); + task_ = 0; // set two atomic vars to initial value write_success_ = true; @@ -653,40 +536,41 @@ MSRStatus ShardIndexGenerator::WriteToDatabase() { for (size_t t = 0; t < threads.capacity(); t++) { threads[t].join(); } - return write_success_ ? SUCCESS : FAILED; + CHECK_FAIL_RETURN_UNEXPECTED(write_success_, "Failed to write data to db."); + return Status::OK(); } void ShardIndexGenerator::DatabaseWriter() { int shard_no = task_++; while (shard_no < shard_header_.GetShardCount()) { - auto db = CreateDatabase(shard_no); - if (db.first != SUCCESS || db.second == nullptr || write_success_ == false) { + sqlite3 *db = nullptr; + if (CreateDatabase(shard_no, &db).IsError()) { + MS_LOG(ERROR) << "Failed to create Generate database."; write_success_ = false; return; } - MS_LOG(INFO) << "Init index db for shard: " << shard_no << " successfully."; - // Pre-processing page information auto total_pages = shard_header_.GetLastPageId(shard_no) + 1; std::map blob_id_to_page_id; std::vector raw_page_ids; for (uint64_t i = 0; i < total_pages; ++i) { - auto ret = shard_header_.GetPage(shard_no, i); - if (ret.second != SUCCESS) { + std::shared_ptr page_ptr; + if (shard_header_.GetPage(shard_no, i, &page_ptr).IsError()) { + MS_LOG(ERROR) << "Failed to get page."; write_success_ = false; return; } - std::shared_ptr cur_page = ret.first; - if (cur_page->GetPageType() == "RAW_DATA") { + if (page_ptr->GetPageType() == "RAW_DATA") { raw_page_ids.push_back(i); - } else if (cur_page->GetPageType() == "BLOB_DATA") { - blob_id_to_page_id[cur_page->GetPageTypeID()] = i; + } else if (page_ptr->GetPageType() == "BLOB_DATA") { + blob_id_to_page_id[page_ptr->GetPageTypeID()] = i; } } - if (ExecuteTransaction(shard_no, db, raw_page_ids, blob_id_to_page_id) != SUCCESS) { + if (ExecuteTransaction(shard_no, db, raw_page_ids, blob_id_to_page_id).IsError()) { + MS_LOG(ERROR) << "Failed to execute transaction."; write_success_ = false; return; } @@ -694,21 +578,12 @@ void ShardIndexGenerator::DatabaseWriter() { shard_no = task_++; } } -MSRStatus ShardIndexGenerator::Finalize(const std::vector file_names) { - if (file_names.empty()) { - MS_LOG(ERROR) << "Mindrecord files is empty."; - return FAILED; - } +Status ShardIndexGenerator::Finalize(const std::vector file_names) { + CHECK_FAIL_RETURN_UNEXPECTED(!file_names.empty(), "Mindrecord files is empty."); ShardIndexGenerator sg{file_names[0]}; - if (SUCCESS != sg.Build()) { - MS_LOG(ERROR) << "Failed to build index generator."; - return FAILED; - } - if (SUCCESS != sg.WriteToDatabase()) { - MS_LOG(ERROR) << "Failed to write to database."; - return FAILED; - } - return SUCCESS; + RETURN_IF_NOT_OK(sg.Build()); + RETURN_IF_NOT_OK(sg.WriteToDatabase()); + return Status::OK(); } } // namespace mindrecord } // namespace mindspore diff --git a/mindspore/ccsrc/minddata/mindrecord/io/shard_reader.cc b/mindspore/ccsrc/minddata/mindrecord/io/shard_reader.cc index ec5bd0436d..a9fd12d7ac 100644 --- a/mindspore/ccsrc/minddata/mindrecord/io/shard_reader.cc +++ b/mindspore/ccsrc/minddata/mindrecord/io/shard_reader.cc @@ -53,64 +53,45 @@ ShardReader::ShardReader() lazy_load_(false), shard_sample_count_() {} -std::pair> ShardReader::GetMeta(const std::string &file_path, - std::shared_ptr meta_data_ptr) { - if (!IsLegalFile(file_path)) { - return {FAILED, {}}; - } - auto ret = ShardHeader::BuildSingleHeader(file_path); - if (ret.first != SUCCESS) { - return {FAILED, {}}; - } - auto header = ret.second; - *meta_data_ptr = {{"header_size", header["header_size"]}, {"page_size", header["page_size"]}, - {"version", header["version"]}, {"index_fields", header["index_fields"]}, - {"schema", header["schema"]}, {"blob_fields", header["blob_fields"]}}; - return {SUCCESS, header["shard_addresses"]}; +Status ShardReader::GetMeta(const std::string &file_path, std::shared_ptr meta_data_ptr, + std::shared_ptr> *addresses_ptr) { + RETURN_UNEXPECTED_IF_NULL(addresses_ptr); + CHECK_FAIL_RETURN_UNEXPECTED(IsLegalFile(file_path), "Invalid file path: " + file_path); + std::shared_ptr header_ptr; + RETURN_IF_NOT_OK(ShardHeader::BuildSingleHeader(file_path, &header_ptr)); + + *meta_data_ptr = {{"header_size", (*header_ptr)["header_size"]}, {"page_size", (*header_ptr)["page_size"]}, + {"version", (*header_ptr)["version"]}, {"index_fields", (*header_ptr)["index_fields"]}, + {"schema", (*header_ptr)["schema"]}, {"blob_fields", (*header_ptr)["blob_fields"]}}; + *addresses_ptr = std::make_shared>((*header_ptr)["shard_addresses"]); + return Status::OK(); } -MSRStatus ShardReader::Init(const std::vector &file_paths, bool load_dataset) { +Status ShardReader::Init(const std::vector &file_paths, bool load_dataset) { std::string file_path = file_paths[0]; auto first_meta_data_ptr = std::make_shared(); - auto ret = GetMeta(file_path, first_meta_data_ptr); - if (ret.first != SUCCESS) { - return FAILED; - } + std::shared_ptr> addresses_ptr; + RETURN_IF_NOT_OK(GetMeta(file_path, first_meta_data_ptr, &addresses_ptr)); if (file_paths.size() == 1 && load_dataset == true) { - auto ret2 = GetDatasetFiles(file_path, ret.second); - if (SUCCESS != ret2.first) { - return FAILED; - } - file_paths_ = ret2.second; + auto ds = std::make_shared>(); + RETURN_IF_NOT_OK(GetDatasetFiles(file_path, *addresses_ptr, &ds)); + file_paths_ = *ds; } else if (file_paths.size() >= 1 && load_dataset == false) { file_paths_ = file_paths; } else { - MS_LOG(ERROR) << "Error in parameter file_path or load_dataset."; - return FAILED; + RETURN_STATUS_UNEXPECTED("Error in parameter file_path or load_dataset."); } for (const auto &file : file_paths_) { auto meta_data_ptr = std::make_shared(); - auto ret1 = GetMeta(file, meta_data_ptr); - if (ret1.first != SUCCESS) { - return FAILED; - } - if (*meta_data_ptr != *first_meta_data_ptr) { - MS_LOG(ERROR) << "Mindrecord files meta information is different."; - return FAILED; - } + RETURN_IF_NOT_OK(GetMeta(file, meta_data_ptr, &addresses_ptr)); + CHECK_FAIL_RETURN_UNEXPECTED(*meta_data_ptr == *first_meta_data_ptr, + "Mindrecord files meta information is different."); sqlite3 *db = nullptr; - auto ret3 = VerifyDataset(&db, file); - if (ret3 != SUCCESS) { - sqlite3_close(db); - return FAILED; - } - + RETURN_IF_NOT_OK(VerifyDataset(&db, file)); database_paths_.push_back(db); } ShardHeader sh = ShardHeader(); - if (sh.BuildDataset(file_paths_, load_dataset) == FAILED) { - return FAILED; - } + RETURN_IF_NOT_OK(sh.BuildDataset(file_paths_, load_dataset)); shard_header_ = std::make_shared(sh); header_size_ = shard_header_->GetHeaderSize(); page_size_ = shard_header_->GetPageSize(); @@ -147,43 +128,39 @@ MSRStatus ShardReader::Init(const std::vector &file_paths, bool loa MS_LOG(INFO) << "Get meta from mindrecord file & index file successfully."; - return SUCCESS; + return Status::OK(); } -MSRStatus ShardReader::VerifyDataset(sqlite3 **db, const string &file) { +Status ShardReader::VerifyDataset(sqlite3 **db, const string &file) { // sqlite3_open create a database if not found, use sqlite3_open_v2 instead of it - auto rc = sqlite3_open_v2(common::SafeCStr(file + ".db"), db, SQLITE_OPEN_READONLY, nullptr); - if (rc != SQLITE_OK) { - MS_LOG(ERROR) << "Invalid file, failed to open database: " << file + ".db, error: " << sqlite3_errmsg(*db); - sqlite3_close(*db); - return FAILED; - } + CHECK_FAIL_RETURN_UNEXPECTED( + sqlite3_open_v2(common::SafeCStr(file + ".db"), db, SQLITE_OPEN_READONLY, nullptr) == SQLITE_OK, + "Invalid database file: " + file + ".db, error: " + sqlite3_errmsg(*db)); MS_LOG(DEBUG) << "Opened database successfully"; string sql = "SELECT NAME from SHARD_NAME;"; std::vector> name; char *errmsg = nullptr; - rc = sqlite3_exec(*db, common::SafeCStr(sql), SelectCallback, &name, &errmsg); - if (rc != SQLITE_OK) { - MS_LOG(ERROR) << "Error in select statement, sql: " << sql << ", error: " << errmsg; + if (sqlite3_exec(*db, common::SafeCStr(sql), SelectCallback, &name, &errmsg) != SQLITE_OK) { + std::ostringstream oss; + oss << "Error in execute sql: [ " << sql + " ], error: " << errmsg; sqlite3_free(errmsg); sqlite3_close(*db); - return FAILED; + RETURN_STATUS_UNEXPECTED(oss.str()); } else { MS_LOG(DEBUG) << "Get " << static_cast(name.size()) << " records from index."; - string shardName = GetFileName(file).second; - if (name.empty() || name[0][0] != shardName) { - MS_LOG(ERROR) << "Invalid file, DB file can not match file: " << file; + std::shared_ptr fn_ptr; + RETURN_IF_NOT_OK(GetFileName(file, &fn_ptr)); + if (name.empty() || name[0][0] != *fn_ptr) { sqlite3_free(errmsg); sqlite3_close(*db); - return FAILED; + RETURN_STATUS_UNEXPECTED("Invalid file, DB file can not match file: " + file); } } - sqlite3_free(errmsg); - return SUCCESS; + return Status::OK(); } -MSRStatus ShardReader::CheckColumnList(const std::vector &selected_columns) { +Status ShardReader::CheckColumnList(const std::vector &selected_columns) { vector inSchema(selected_columns.size(), 0); for (auto &p : GetShardHeader()->GetSchemas()) { auto schema = p->GetSchema()["schema"]; @@ -193,59 +170,40 @@ MSRStatus ShardReader::CheckColumnList(const std::vector &selected_ } } } - if (std::any_of(std::begin(inSchema), std::end(inSchema), [](int x) { return x == 0; })) { - return FAILED; - } - - return SUCCESS; + CHECK_FAIL_RETURN_UNEXPECTED(!std::any_of(std::begin(inSchema), std::end(inSchema), [](int x) { return x == 0; }), + "Column not found in schema."); + return Status::OK(); } -MSRStatus ShardReader::Open() { +Status ShardReader::Open() { file_streams_.clear(); - for (const auto &file : file_paths_) { auto realpath = Common::GetRealPath(file); - if (!realpath.has_value()) { - MS_LOG(ERROR) << "Get real path failed, path=" << file; - return FAILED; - } - + CHECK_FAIL_RETURN_UNEXPECTED(realpath.has_value(), "Get real path failed, path=" + file); std::shared_ptr fs = std::make_shared(); fs->open(realpath.value(), std::ios::in | std::ios::binary); - if (!fs->good()) { - MS_LOG(ERROR) << "Invalid file, failed to open file: " << file; - return FAILED; - } + CHECK_FAIL_RETURN_UNEXPECTED(fs->good(), "Failed to open file: " + file); MS_LOG(INFO) << "Open shard file successfully."; file_streams_.push_back(fs); } - - return SUCCESS; + return Status::OK(); } -MSRStatus ShardReader::Open(int n_consumer) { +Status ShardReader::Open(int n_consumer) { file_streams_random_ = std::vector>>(n_consumer, std::vector>()); for (const auto &file : file_paths_) { for (int j = 0; j < n_consumer; ++j) { auto realpath = Common::GetRealPath(file); - if (!realpath.has_value()) { - MS_LOG(ERROR) << "Get real path failed, path=" << file; - return FAILED; - } - + CHECK_FAIL_RETURN_UNEXPECTED(realpath.has_value(), "Get real path failed, path=" + file); std::shared_ptr fs = std::make_shared(); fs->open(realpath.value(), std::ios::in | std::ios::binary); - if (!fs->good()) { - MS_LOG(ERROR) << "Invalid file, failed to open file: " << file; - return FAILED; - } + CHECK_FAIL_RETURN_UNEXPECTED(fs->good(), "Failed to open file: " + file); file_streams_random_[j].push_back(fs); } MS_LOG(INFO) << "Open shard file successfully."; } - - return SUCCESS; + return Status::OK(); } void ShardReader::FileStreamsOperator() { @@ -314,16 +272,18 @@ std::vector> ShardReader::ReadRowGroupSummar continue; } for (uint64_t page_id = 0; page_id <= last_page_id; ++page_id) { - const auto &page_t = shard_header_->GetPage(shard_id, page_id); - const auto &page = page_t.first; - if (page->GetPageType() != kPageTypeBlob) continue; - uint64_t start_row_id = page->GetStartRowID(); - if (start_row_id > page->GetEndRowID()) { + std::shared_ptr page_ptr; + (void)shard_header_->GetPage(shard_id, page_id, &page_ptr); + if (page_ptr->GetPageType() != kPageTypeBlob) { + continue; + } + uint64_t start_row_id = page_ptr->GetStartRowID(); + if (start_row_id > page_ptr->GetEndRowID()) { return std::vector>(); } - uint64_t number_of_rows = page->GetEndRowID() - start_row_id; + uint64_t number_of_rows = page_ptr->GetEndRowID() - start_row_id; total_count += number_of_rows; - row_group_summary.emplace_back(shard_id, page->GetPageTypeID(), start_row_id, number_of_rows); + row_group_summary.emplace_back(shard_id, page_ptr->GetPageTypeID(), start_row_id, number_of_rows); } shard_sample_count_.push_back(total_count); } @@ -332,16 +292,11 @@ std::vector> ShardReader::ReadRowGroupSummar return row_group_summary; } -MSRStatus ShardReader::GetTotalBlobSize(int64_t *total_blob_size) { - *total_blob_size = total_blob_size_; - return SUCCESS; -} - -MSRStatus ShardReader::ConvertLabelToJson(const std::vector> &labels, - std::shared_ptr fs, - std::shared_ptr>>> offset_ptr, - int shard_id, const std::vector &columns, - std::shared_ptr>> col_val_ptr) { +Status ShardReader::ConvertLabelToJson(const std::vector> &labels, + std::shared_ptr fs, + std::shared_ptr>>> offset_ptr, + int shard_id, const std::vector &columns, + std::shared_ptr>> col_val_ptr) { for (int i = 0; i < static_cast(labels.size()); ++i) { try { uint64_t group_id = std::stoull(labels[i][0]); @@ -357,16 +312,13 @@ MSRStatus ShardReader::ConvertLabelToJson(const std::vector(len); auto &io_seekg = fs->seekg(page_size_ * raw_page_id + header_size_ + label_start, std::ios::beg); if (!io_seekg.good() || io_seekg.fail() || io_seekg.bad()) { - MS_LOG(ERROR) << "File seekg failed"; fs->close(); - return FAILED; + RETURN_STATUS_UNEXPECTED("Failed to seekg file."); } - auto &io_read = fs->read(reinterpret_cast(&label_raw[0]), len); if (!io_read.good() || io_read.fail() || io_read.bad()) { - MS_LOG(ERROR) << "File read failed"; fs->close(); - return FAILED; + RETURN_STATUS_UNEXPECTED("Failed to read file."); } json label_json = json::from_msgpack(label_raw); json tmp; @@ -402,78 +354,71 @@ MSRStatus ShardReader::ConvertLabelToJson(const std::vectorclose(); - return FAILED; + RETURN_STATUS_UNEXPECTED("Out of range: " + std::string(e.what())); } catch (std::invalid_argument &e) { - MS_LOG(ERROR) << "Invalid argument: " << e.what(); fs->close(); - return FAILED; + RETURN_STATUS_UNEXPECTED("Invalid argument: " + std::string(e.what())); } catch (...) { - MS_LOG(ERROR) << "Exception was caught while convert label to json."; fs->close(); - return FAILED; + RETURN_STATUS_UNEXPECTED("Exception was caught while convert label to json."); } } + fs->close(); - return SUCCESS; -} // namespace mindrecord + return Status::OK(); +} -MSRStatus ShardReader::ReadAllRowsInShard(int shard_id, const std::string &sql, const std::vector &columns, - std::shared_ptr>>> offset_ptr, - std::shared_ptr>> col_val_ptr) { +Status ShardReader::ReadAllRowsInShard(int shard_id, const std::string &sql, const std::vector &columns, + std::shared_ptr>>> offset_ptr, + std::shared_ptr>> col_val_ptr) { auto db = database_paths_[shard_id]; std::vector> labels; char *errmsg = nullptr; int rc = sqlite3_exec(db, common::SafeCStr(sql), SelectCallback, &labels, &errmsg); if (rc != SQLITE_OK) { - MS_LOG(ERROR) << "Error in select statement, sql: " << sql << ", error: " << errmsg; + std::ostringstream oss; + oss << "Error in execute sql: [ " << sql + " ], error: " << errmsg; sqlite3_free(errmsg); sqlite3_close(db); db = nullptr; - return FAILED; + RETURN_STATUS_UNEXPECTED(oss.str()); } MS_LOG(INFO) << "Get " << static_cast(labels.size()) << " records from shard " << shard_id << " index."; std::string file_name = file_paths_[shard_id]; - auto realpath = Common::GetRealPath(file_name); if (!realpath.has_value()) { - MS_LOG(ERROR) << "Get real path failed, path=" << file_name; sqlite3_free(errmsg); sqlite3_close(db); - return FAILED; + RETURN_STATUS_UNEXPECTED("Get real path failed, path=" + file_name); } std::shared_ptr fs = std::make_shared(); if (!all_in_index_) { fs->open(realpath.value(), std::ios::in | std::ios::binary); if (!fs->good()) { - MS_LOG(ERROR) << "Invalid file, failed to open file: " << file_name; sqlite3_free(errmsg); sqlite3_close(db); - return FAILED; + RETURN_STATUS_UNEXPECTED("Invalid file, failed to open file: " + file_name); } } sqlite3_free(errmsg); return ConvertLabelToJson(labels, fs, offset_ptr, shard_id, columns, col_val_ptr); } -MSRStatus ShardReader::GetAllClasses(const std::string &category_field, - std::shared_ptr> category_ptr) { +Status ShardReader::GetAllClasses(const std::string &category_field, + std::shared_ptr> category_ptr) { std::map index_columns; for (auto &field : GetShardHeader()->GetFields()) { index_columns[field.second] = field.first; } - if (index_columns.find(category_field) == index_columns.end()) { - MS_LOG(ERROR) << "Index field " << category_field << " does not exist."; - return FAILED; - } - auto ret = ShardIndexGenerator::GenerateFieldName(std::make_pair(index_columns[category_field], category_field)); - if (SUCCESS != ret.first) { - return FAILED; - } - std::string sql = "SELECT DISTINCT " + ret.second + " FROM INDEXES"; + CHECK_FAIL_RETURN_UNEXPECTED(index_columns.find(category_field) != index_columns.end(), + "Index field " + category_field + " does not exist."); + std::shared_ptr fn_ptr; + RETURN_IF_NOT_OK( + ShardIndexGenerator::GenerateFieldName(std::make_pair(index_columns[category_field], category_field), &fn_ptr)); + std::string sql = "SELECT DISTINCT " + *fn_ptr + " FROM INDEXES"; std::vector threads = std::vector(shard_count_); for (int x = 0; x < shard_count_; x++) { threads[x] = std::thread(&ShardReader::GetClassesInShard, this, database_paths_[x], x, sql, category_ptr); @@ -482,7 +427,7 @@ MSRStatus ShardReader::GetAllClasses(const std::string &category_field, for (int x = 0; x < shard_count_; x++) { threads[x].join(); } - return SUCCESS; + return Status::OK(); } void ShardReader::GetClassesInShard(sqlite3 *db, int shard_id, const std::string &sql, @@ -508,7 +453,9 @@ void ShardReader::GetClassesInShard(sqlite3 *db, int shard_id, const std::string sqlite3_free(errmsg); } -ROW_GROUPS ShardReader::ReadAllRowGroup(const std::vector &columns) { +Status ShardReader::ReadAllRowGroup(const std::vector &columns, + std::shared_ptr *row_group_ptr) { + RETURN_UNEXPECTED_IF_NULL(row_group_ptr); std::string fields = "ROW_GROUP_ID, PAGE_OFFSET_BLOB, PAGE_OFFSET_BLOB_END"; auto offset_ptr = std::make_shared>>>( shard_count_, std::vector>{}); @@ -517,11 +464,10 @@ ROW_GROUPS ShardReader::ReadAllRowGroup(const std::vector &columns) if (all_in_index_) { for (unsigned int i = 0; i < columns.size(); ++i) { fields += ','; - auto ret = ShardIndexGenerator::GenerateFieldName(std::make_pair(column_schema_id_[columns[i]], columns[i])); - if (ret.first != SUCCESS) { - return std::make_tuple(FAILED, std::move(*offset_ptr), std::move(*col_val_ptr)); - } - fields += ret.second; + std::shared_ptr fn_ptr; + RETURN_IF_NOT_OK( + ShardIndexGenerator::GenerateFieldName(std::make_pair(column_schema_id_[columns[i]], columns[i]), &fn_ptr)); + fields += *fn_ptr; } } else { // fetch raw data from Raw page while some field is not index. fields += ", PAGE_ID_RAW, PAGE_OFFSET_RAW, PAGE_OFFSET_RAW_END "; @@ -537,11 +483,14 @@ ROW_GROUPS ShardReader::ReadAllRowGroup(const std::vector &columns) for (int x = 0; x < shard_count_; x++) { thread_read_db[x].join(); } - return std::make_tuple(SUCCESS, std::move(*offset_ptr), std::move(*col_val_ptr)); + *row_group_ptr = std::make_shared(std::move(*offset_ptr), std::move(*col_val_ptr)); + return Status::OK(); } -ROW_GROUPS ShardReader::ReadRowGroupByShardIDAndSampleID(const std::vector &columns, - const uint32_t &shard_id, const uint32_t &sample_id) { +Status ShardReader::ReadRowGroupByShardIDAndSampleID(const std::vector &columns, const uint32_t &shard_id, + const uint32_t &sample_id, + std::shared_ptr *row_group_ptr) { + RETURN_UNEXPECTED_IF_NULL(row_group_ptr); std::string fields = "ROW_GROUP_ID, PAGE_OFFSET_BLOB, PAGE_OFFSET_BLOB_END"; auto offset_ptr = std::make_shared>>>( shard_count_, std::vector>{}); @@ -549,11 +498,10 @@ ROW_GROUPS ShardReader::ReadRowGroupByShardIDAndSampleID(const std::vector fn_ptr; + RETURN_IF_NOT_OK( + ShardIndexGenerator::GenerateFieldName(std::make_pair(column_schema_id_[columns[i]], columns[i]), &fn_ptr)); + fields += *fn_ptr; } } else { // fetch raw data from Raw page while some field is not index. fields += ", PAGE_ID_RAW, PAGE_OFFSET_RAW, PAGE_OFFSET_RAW_END "; @@ -561,60 +509,48 @@ ROW_GROUPS ShardReader::ReadRowGroupByShardIDAndSampleID(const std::vector(std::move(*offset_ptr), std::move(*col_val_ptr)); + return Status::OK(); } -ROW_GROUP_BRIEF ShardReader::ReadRowGroupBrief(int group_id, int shard_id, const std::vector &columns) { - const auto &ret = shard_header_->GetPageByGroupId(group_id, shard_id); - if (SUCCESS != ret.first) { - return std::make_tuple(FAILED, "", 0, 0, std::vector>(), std::vector()); - } - const std::shared_ptr &page = ret.second; +Status ShardReader::ReadRowGroupBrief(int group_id, int shard_id, const std::vector &columns, + std::shared_ptr *row_group_brief_ptr) { + RETURN_UNEXPECTED_IF_NULL(row_group_brief_ptr); + std::shared_ptr page_ptr; + RETURN_IF_NOT_OK(shard_header_->GetPageByGroupId(group_id, shard_id, &page_ptr)); std::string file_name = file_paths_[shard_id]; - uint64_t page_length = page->GetPageSize(); - uint64_t page_offset = page_size_ * page->GetPageID() + header_size_; - std::vector> image_offset = GetImageOffset(page->GetPageID(), shard_id); - - auto status_labels = GetLabels(page->GetPageID(), shard_id, columns); - if (status_labels.first != SUCCESS) { - return std::make_tuple(FAILED, "", 0, 0, std::vector>(), std::vector()); - } - return std::make_tuple(SUCCESS, file_name, page_length, page_offset, std::move(image_offset), - std::move(status_labels.second)); + uint64_t page_length = page_ptr->GetPageSize(); + uint64_t page_offset = page_size_ * page_ptr->GetPageID() + header_size_; + std::vector> image_offset = GetImageOffset(page_ptr->GetPageID(), shard_id); + auto labels_ptr = std::make_shared>(); + RETURN_IF_NOT_OK(GetLabels(page_ptr->GetPageID(), shard_id, columns, {"", ""}, &labels_ptr)); + *row_group_brief_ptr = std::make_shared(file_name, page_length, page_offset, std::move(image_offset), + std::move(*labels_ptr)); + return Status::OK(); } -ROW_GROUP_BRIEF ShardReader::ReadRowGroupCriteria(int group_id, int shard_id, - const std::pair &criteria, - const std::vector &columns) { - const auto &ret = shard_header_->GetPageByGroupId(group_id, shard_id); - if (SUCCESS != ret.first) { - return std::make_tuple(FAILED, "", 0, 0, std::vector>(), std::vector()); - } +Status ShardReader::ReadRowGroupCriteria(int group_id, int shard_id, + const std::pair &criteria, + const std::vector &columns, + std::shared_ptr *row_group_brief_ptr) { + RETURN_UNEXPECTED_IF_NULL(row_group_brief_ptr); + std::shared_ptr page_ptr; + RETURN_IF_NOT_OK(shard_header_->GetPageByGroupId(group_id, shard_id, &page_ptr)); vector criteria_list{criteria.first}; - if (CheckColumnList(criteria_list) == FAILED) { - return std::make_tuple(FAILED, "", 0, 0, std::vector>(), std::vector()); - } - const std::shared_ptr &page = ret.second; + RETURN_IF_NOT_OK(CheckColumnList(criteria_list)); std::string file_name = file_paths_[shard_id]; - uint64_t page_length = page->GetPageSize(); - uint64_t page_offset = page_size_ * page->GetPageID() + header_size_; - std::vector> image_offset = GetImageOffset(page->GetPageID(), shard_id, criteria); + uint64_t page_length = page_ptr->GetPageSize(); + uint64_t page_offset = page_size_ * page_ptr->GetPageID() + header_size_; + std::vector> image_offset = GetImageOffset(page_ptr->GetPageID(), shard_id, criteria); if (image_offset.empty()) { - return std::make_tuple(SUCCESS, file_name, page_length, page_offset, std::vector>(), - std::vector()); + *row_group_brief_ptr = std::make_shared(); } - auto status_labels = GetLabels(page->GetPageID(), shard_id, columns, criteria); - if (status_labels.first != SUCCESS) { - return std::make_tuple(FAILED, "", 0, 0, std::vector>(), std::vector()); - } - - return std::make_tuple(SUCCESS, file_name, page_length, page_offset, std::move(image_offset), - std::move(status_labels.second)); + auto labels_ptr = std::make_shared>(); + RETURN_IF_NOT_OK(GetLabels(page_ptr->GetPageID(), shard_id, columns, criteria, &labels_ptr)); + *row_group_brief_ptr = std::make_shared(file_name, page_length, page_offset, std::move(image_offset), + std::move(*labels_ptr)); + return Status::OK(); } int ShardReader::SelectCallback(void *p_data, int num_fields, char **p_fields, char **p_col_names) { @@ -671,8 +607,9 @@ std::vector> ShardReader::GetImageOffset(int page_id, int return res; } -std::pair> ShardReader::GetPagesByCategory( - int shard_id, const std::pair &criteria) { +Status ShardReader::GetPagesByCategory(int shard_id, const std::pair &criteria, + std::shared_ptr> *pages_ptr) { + RETURN_UNEXPECTED_IF_NULL(pages_ptr); auto db = database_paths_[shard_id]; std::string sql = "SELECT DISTINCT PAGE_ID_BLOB FROM INDEXES WHERE 1 = 1 "; @@ -692,20 +629,19 @@ std::pair> ShardReader::GetPagesByCategory( char *errmsg = nullptr; int rc = sqlite3_exec(db, common::SafeCStr(sql), SelectCallback, &page_ids, &errmsg); if (rc != SQLITE_OK) { - MS_LOG(ERROR) << "Error in select statement, sql: " << sql << ", error: " << errmsg; + string ss(errmsg); sqlite3_free(errmsg); sqlite3_close(db); db = nullptr; - return std::make_pair(FAILED, std::vector()); + RETURN_STATUS_UNEXPECTED("Error in select statement, sql: " + sql + ", error: " + ss); } else { MS_LOG(DEBUG) << "Get " << page_ids.size() << "pages from index."; } - std::vector res; for (int i = 0; i < static_cast(page_ids.size()); ++i) { - res.emplace_back(std::stoull(page_ids[i][0])); + (*pages_ptr)->emplace_back(std::stoull(page_ids[i][0])); } sqlite3_free(errmsg); - return std::make_pair(SUCCESS, res); + return Status::OK(); } std::pair> ShardReader::GetBlobFields() { @@ -736,17 +672,16 @@ void ShardReader::CheckIfColumnInIndex(const std::vector &columns) } } -MSRStatus ShardReader::QueryWithCriteria(sqlite3 *db, const string &sql, const string &criteria, - std::shared_ptr>> labels_ptr) { +Status ShardReader::QueryWithCriteria(sqlite3 *db, const string &sql, const string &criteria, + std::shared_ptr>> labels_ptr) { sqlite3_stmt *stmt = nullptr; if (sqlite3_prepare_v2(db, common::SafeCStr(sql), -1, &stmt, 0) != SQLITE_OK) { - MS_LOG(ERROR) << "SQL error: could not prepare statement, sql: " << sql; - return FAILED; + RETURN_STATUS_UNEXPECTED(std::string("SQL error: could not prepare statement, sql: ") + sql); } int index = sqlite3_bind_parameter_index(stmt, ":criteria"); if (sqlite3_bind_text(stmt, index, common::SafeCStr(criteria), -1, SQLITE_STATIC) != SQLITE_OK) { - MS_LOG(ERROR) << "SQL error: could not bind parameter, index: " << index << ", field value: " << criteria; - return FAILED; + RETURN_STATUS_UNEXPECTED("SQL error: could not bind parameter, index: " + std::to_string(index) + + ", field value: " + criteria); } int rc = sqlite3_step(stmt); while (rc != SQLITE_DONE) { @@ -759,30 +694,23 @@ MSRStatus ShardReader::QueryWithCriteria(sqlite3 *db, const string &sql, const s rc = sqlite3_step(stmt); } (void)sqlite3_finalize(stmt); - return SUCCESS; + return Status::OK(); } -std::pair> ShardReader::GetLabelsFromBinaryFile( - int shard_id, const std::vector &columns, const std::vector> &label_offsets) { +Status ShardReader::GetLabelsFromBinaryFile(int shard_id, const std::vector &columns, + const std::vector> &label_offsets, + std::shared_ptr> *labels_ptr) { + RETURN_UNEXPECTED_IF_NULL(labels_ptr); std::string file_name = file_paths_[shard_id]; - auto realpath = Common::GetRealPath(file_name); - if (!realpath.has_value()) { - MS_LOG(ERROR) << "Get real path failed, path=" << file_name; - return {FAILED, {}}; - } + CHECK_FAIL_RETURN_UNEXPECTED(realpath.has_value(), "Get real path failed, path=" + file_name); - std::vector res; std::shared_ptr fs = std::make_shared(); fs->open(realpath.value(), std::ios::in | std::ios::binary); - if (!fs->good()) { - MS_LOG(ERROR) << "Invalid file, failed to open file: " << file_name; - return {FAILED, {}}; - } - + CHECK_FAIL_RETURN_UNEXPECTED(fs->good(), "Invalid file, failed to open file: " + file_name); // init the return for (unsigned int i = 0; i < label_offsets.size(); ++i) { - res.emplace_back(json{}); + (*labels_ptr)->emplace_back(json{}); } for (unsigned int i = 0; i < label_offsets.size(); ++i) { @@ -794,16 +722,14 @@ std::pair> ShardReader::GetLabelsFromBinaryFile( auto label_raw = std::vector(len); auto &io_seekg = fs->seekg(page_size_ * raw_page_id + header_size_ + label_start, std::ios::beg); if (!io_seekg.good() || io_seekg.fail() || io_seekg.bad()) { - MS_LOG(ERROR) << "File seekg failed"; fs->close(); - return {FAILED, {}}; + RETURN_STATUS_UNEXPECTED("Failed to seekg file."); } auto &io_read = fs->read(reinterpret_cast(&label_raw[0]), len); if (!io_read.good() || io_read.fail() || io_read.bad()) { - MS_LOG(ERROR) << "File read failed"; fs->close(); - return {FAILED, {}}; + RETURN_STATUS_UNEXPECTED("Failed to read file."); } json label_json = json::from_msgpack(label_raw); @@ -813,14 +739,14 @@ std::pair> ShardReader::GetLabelsFromBinaryFile( tmp[col] = label_json[col]; } } - res[i] = tmp; + (*(*labels_ptr))[i] = tmp; } - return {SUCCESS, res}; + return Status::OK(); } - -std::pair> ShardReader::GetLabelsFromPage( - int page_id, int shard_id, const std::vector &columns, - const std::pair &criteria) { +Status ShardReader::GetLabelsFromPage(int page_id, int shard_id, const std::vector &columns, + const std::pair &criteria, + std::shared_ptr> *labels_ptr) { + RETURN_UNEXPECTED_IF_NULL(labels_ptr); // get page info from sqlite auto db = database_paths_[shard_id]; std::string sql = "SELECT PAGE_ID_RAW, PAGE_OFFSET_RAW,PAGE_OFFSET_RAW_END FROM INDEXES WHERE PAGE_ID_BLOB = " + @@ -828,30 +754,30 @@ std::pair> ShardReader::GetLabelsFromPage( auto label_offset_ptr = std::make_shared>>(); if (!criteria.first.empty()) { sql += " AND " + criteria.first + "_" + std::to_string(column_schema_id_[criteria.first]) + " = :criteria"; - if (QueryWithCriteria(db, sql, criteria.second, label_offset_ptr) == FAILED) { - return {FAILED, {}}; - } + RETURN_IF_NOT_OK(QueryWithCriteria(db, sql, criteria.second, label_offset_ptr)); } else { sql += ";"; char *errmsg = nullptr; int rc = sqlite3_exec(db, common::SafeCStr(sql), SelectCallback, label_offset_ptr.get(), &errmsg); if (rc != SQLITE_OK) { - MS_LOG(ERROR) << "Error in select statement, sql: " << sql << ", error: " << errmsg; + std::ostringstream oss; + oss << "Error in execute sql: [ " << sql + " ], error: " << errmsg; sqlite3_free(errmsg); sqlite3_close(db); db = nullptr; - return {FAILED, {}}; + RETURN_STATUS_UNEXPECTED(oss.str()); } MS_LOG(DEBUG) << "Get " << label_offset_ptr->size() << " records from index."; sqlite3_free(errmsg); } // get labels from binary file - return GetLabelsFromBinaryFile(shard_id, columns, *label_offset_ptr); + return GetLabelsFromBinaryFile(shard_id, columns, *label_offset_ptr, labels_ptr); } -std::pair> ShardReader::GetLabels(int page_id, int shard_id, - const std::vector &columns, - const std::pair &criteria) { +Status ShardReader::GetLabels(int page_id, int shard_id, const std::vector &columns, + const std::pair &criteria, + std::shared_ptr> *labels_ptr) { + RETURN_UNEXPECTED_IF_NULL(labels_ptr); if (all_in_index_) { auto db = database_paths_[shard_id]; std::string fields; @@ -860,34 +786,34 @@ std::pair> ShardReader::GetLabels(int page_id, int uint64_t schema_id = column_schema_id_[columns[i]]; fields += columns[i] + "_" + std::to_string(schema_id); } - if (fields.empty()) fields = "*"; - auto labels_ptr = std::make_shared>>(); + if (fields.empty()) { + fields = "*"; + } + auto labels = std::make_shared>>(); std::string sql = "SELECT " + fields + " FROM INDEXES WHERE PAGE_ID_BLOB = " + std::to_string(page_id); if (!criteria.first.empty()) { sql += " AND " + criteria.first + "_" + std::to_string(column_schema_id_[criteria.first]) + " = " + ":criteria"; - if (QueryWithCriteria(db, sql, criteria.second, labels_ptr) == FAILED) { - return {FAILED, {}}; - } + RETURN_IF_NOT_OK(QueryWithCriteria(db, sql, criteria.second, labels)); } else { sql += ";"; char *errmsg = nullptr; - int rc = sqlite3_exec(db, common::SafeCStr(sql), SelectCallback, labels_ptr.get(), &errmsg); + int rc = sqlite3_exec(db, common::SafeCStr(sql), SelectCallback, labels.get(), &errmsg); if (rc != SQLITE_OK) { - MS_LOG(ERROR) << "Error in select statement, sql: " << sql << ", error: " << errmsg; + std::ostringstream oss; + oss << "Error in execute sql: [ " << sql + " ], error: " << errmsg; sqlite3_free(errmsg); sqlite3_close(db); db = nullptr; - return {FAILED, {}}; + RETURN_STATUS_UNEXPECTED(oss.str()); } else { - MS_LOG(DEBUG) << "Get " << static_cast(labels_ptr->size()) << " records from index."; + MS_LOG(DEBUG) << "Get " << static_cast(labels->size()) << " records from index."; } sqlite3_free(errmsg); } - std::vector ret; - for (unsigned int i = 0; i < labels_ptr->size(); ++i) { - (void)ret.emplace_back(json{}); + for (unsigned int i = 0; i < labels->size(); ++i) { + (*labels_ptr)->emplace_back(json{}); } - for (unsigned int i = 0; i < labels_ptr->size(); ++i) { + for (unsigned int i = 0; i < labels->size(); ++i) { json construct_json; for (unsigned int j = 0; j < columns.size(); ++j) { // construct json "f1": value @@ -895,22 +821,22 @@ std::pair> ShardReader::GetLabels(int page_id, int // convert the string to base type by schema if (schema[columns[j]]["type"] == "int32") { - construct_json[columns[j]] = StringToNum((*labels_ptr)[i][j]); + construct_json[columns[j]] = StringToNum((*labels)[i][j]); } else if (schema[columns[j]]["type"] == "int64") { - construct_json[columns[j]] = StringToNum((*labels_ptr)[i][j]); + construct_json[columns[j]] = StringToNum((*labels)[i][j]); } else if (schema[columns[j]]["type"] == "float32") { - construct_json[columns[j]] = StringToNum((*labels_ptr)[i][j]); + construct_json[columns[j]] = StringToNum((*labels)[i][j]); } else if (schema[columns[j]]["type"] == "float64") { - construct_json[columns[j]] = StringToNum((*labels_ptr)[i][j]); + construct_json[columns[j]] = StringToNum((*labels)[i][j]); } else { - construct_json[columns[j]] = std::string((*labels_ptr)[i][j]); + construct_json[columns[j]] = std::string((*labels)[i][j]); } } - ret[i] = construct_json; + (*(*labels_ptr))[i] = construct_json; } - return {SUCCESS, ret}; + return Status::OK(); } - return GetLabelsFromPage(page_id, shard_id, columns, criteria); + return GetLabelsFromPage(page_id, shard_id, columns, criteria, labels_ptr); } bool ResortRowGroups(std::tuple a, std::tuple b) { @@ -930,12 +856,10 @@ int64_t ShardReader::GetNumClasses(const std::string &category_field) { MS_LOG(ERROR) << "Field " << category_field << " does not exist."; return -1; } - auto ret = - ShardIndexGenerator::GenerateFieldName(std::make_pair(map_schema_id_fields[category_field], category_field)); - if (SUCCESS != ret.first) { - return -1; - } - std::string sql = "SELECT DISTINCT " + ret.second + " FROM INDEXES"; + std::shared_ptr fn_ptr; + (void)ShardIndexGenerator::GenerateFieldName(std::make_pair(map_schema_id_fields[category_field], category_field), + &fn_ptr); + std::string sql = "SELECT DISTINCT " + *fn_ptr + " FROM INDEXES"; std::vector threads = std::vector(shard_count); auto category_ptr = std::make_shared>(); sqlite3 *db = nullptr; @@ -948,6 +872,7 @@ int64_t ShardReader::GetNumClasses(const std::string &category_field) { } threads[x] = std::thread(&ShardReader::GetClassesInShard, this, db, x, sql, category_ptr); } + for (int x = 0; x < shard_count; x++) { threads[x].join(); } @@ -955,11 +880,9 @@ int64_t ShardReader::GetNumClasses(const std::string &category_field) { return category_ptr->size(); } -MSRStatus ShardReader::CountTotalRows(const std::vector &file_paths, bool load_dataset, - const std::shared_ptr &ops, int64_t *count, const int num_padded) { - if (SUCCESS != Init(file_paths, load_dataset)) { - return FAILED; - } +Status ShardReader::CountTotalRows(const std::vector &file_paths, bool load_dataset, + const std::shared_ptr &ops, int64_t *count, const int num_padded) { + RETURN_IF_NOT_OK(Init(file_paths, load_dataset)); int64_t num_samples = num_rows_; bool root = true; std::stack> stack_ops; @@ -988,10 +911,8 @@ MSRStatus ShardReader::CountTotalRows(const std::vector &file_paths if (tmp != 0 && num_samples != -1) { num_samples = std::min(num_samples, tmp); } - if (-1 == num_samples) { - MS_LOG(ERROR) << "Number of samples exceeds the upper limit: " << std::numeric_limits::max(); - return FAILED; - } + CHECK_FAIL_RETURN_UNEXPECTED(num_samples != -1, "Number of samples exceeds the upper limit: " + + std::to_string(std::numeric_limits::max())); } } else if (std::dynamic_pointer_cast(op)) { if (std::dynamic_pointer_cast(op)) { @@ -999,34 +920,30 @@ MSRStatus ShardReader::CountTotalRows(const std::vector &file_paths if (root == true) { sampler_op->SetNumPaddedSamples(num_padded); num_samples = op->GetNumSamples(num_samples, 0); - if (-1 == num_samples) { - MS_LOG(ERROR) << "Dataset size plus number of padded samples is not divisible by number of shards."; - return FAILED; - } + CHECK_FAIL_RETURN_UNEXPECTED( + num_samples != -1, "Dataset size plus number of padded samples is not divisible by number of shards."); root = false; } } else { num_samples = op->GetNumSamples(num_samples, 0); } } else { - if (num_padded > 0) num_samples += num_padded; + if (num_padded > 0) { + num_samples += num_padded; + } } } *count = num_samples; - return SUCCESS; + return Status::OK(); } -MSRStatus ShardReader::Open(const std::vector &file_paths, bool load_dataset, int n_consumer, - const std::vector &selected_columns, - const std::vector> &operators, int num_padded, - bool lazy_load) { +Status ShardReader::Open(const std::vector &file_paths, bool load_dataset, int n_consumer, + const std::vector &selected_columns, + const std::vector> &operators, int num_padded, bool lazy_load) { lazy_load_ = lazy_load; // Open file and set header by ShardReader - auto ret = Init(file_paths, load_dataset); - if (SUCCESS != ret) { - return ret; - } + RETURN_IF_NOT_OK(Init(file_paths, load_dataset)); auto thread_limit = GetMaxThreadNum(); if (n_consumer > thread_limit) { n_consumer = thread_limit; @@ -1036,11 +953,7 @@ MSRStatus ShardReader::Open(const std::vector &file_paths, bool loa } selected_columns_ = selected_columns; - - if (CheckColumnList(selected_columns_) == FAILED) { - MS_LOG(ERROR) << "Illegal column list"; - return ILLEGAL_COLUMN_LIST; - } + RETURN_IF_NOT_OK(CheckColumnList(selected_columns_)); // Initialize argument shard_count_ = static_cast(file_paths_.size()); @@ -1048,74 +961,37 @@ MSRStatus ShardReader::Open(const std::vector &file_paths, bool loa num_padded_ = num_padded; operators_ = operators; - - if (Open(n_consumer) == FAILED) { - return FAILED; - } - return SUCCESS; -} - -MSRStatus ShardReader::OpenPy(const std::vector &file_paths, bool load_dataset, const int &n_consumer, - const std::vector &selected_columns, - const std::vector> &operators) { - // Open file and set header by ShardReader - if (SUCCESS != Init(file_paths, load_dataset)) { - return FAILED; - } - // should remove blob field from selected_columns when call from python - std::vector columns(selected_columns); - auto blob_fields = GetBlobFields().second; - for (auto &blob_field : blob_fields) { - auto it = std::find(selected_columns.begin(), selected_columns.end(), blob_field); - if (it != selected_columns.end()) { - columns.erase(columns.begin() + std::distance(selected_columns.begin(), it)); - } - } - if (CheckColumnList(columns) == FAILED) { - MS_LOG(ERROR) << "Illegal column list"; - return FAILED; - } - if (Open(n_consumer) == FAILED) { - return FAILED; - } - // Initialize argument - shard_count_ = static_cast(file_paths_.size()); - n_consumer_ = n_consumer; - - // Initialize columns which will be read - selected_columns_ = selected_columns; - operators_ = operators; - - return SUCCESS; + RETURN_IF_NOT_OK(Open(n_consumer)); + return Status::OK(); } -MSRStatus ShardReader::Launch(bool isSimpleReader) { +Status ShardReader::Launch(bool is_sample_read) { // Get all row groups' info auto row_group_summary = ReadRowGroupSummary(); // Sort row group by (group_id, shard_id), prepare for parallel reading std::sort(row_group_summary.begin(), row_group_summary.end(), ResortRowGroups); - if (CreateTasks(row_group_summary, operators_) != SUCCESS) { - MS_LOG(ERROR) << "Failed to launch read threads."; + if (CreateTasks(row_group_summary, operators_).IsError()) { interrupt_ = true; - return FAILED; + RETURN_STATUS_UNEXPECTED("Failed to launch read threads."); + } + if (is_sample_read) { + return Status::OK(); } - if (isSimpleReader) return SUCCESS; // Start provider consumer threads thread_set_ = std::vector(n_consumer_); - if (n_consumer_ <= 0 || n_consumer_ > kMaxConsumerCount) { - return FAILED; - } + CHECK_FAIL_RETURN_UNEXPECTED(n_consumer_ > 0 && n_consumer_ <= kMaxConsumerCount, + "Number of consumer is out of range."); for (int x = 0; x < n_consumer_; ++x) { thread_set_[x] = std::thread(&ShardReader::ConsumerByRow, this, x); } MS_LOG(INFO) << "Launch read thread successfully."; - return SUCCESS; + return Status::OK(); } -MSRStatus ShardReader::CreateTasksByCategory(const std::shared_ptr &op) { +Status ShardReader::CreateTasksByCategory(const std::shared_ptr &op) { CheckIfColumnInIndex(selected_columns_); auto category_op = std::dynamic_pointer_cast(op); auto categories = category_op->GetCategories(); @@ -1123,27 +999,18 @@ MSRStatus ShardReader::CreateTasksByCategory(const std::shared_ptr(op)) { num_samples = std::dynamic_pointer_cast(op)->GetNumSamples(); - if (num_samples < 0) { - MS_LOG(ERROR) << "Invalid parameter, num_samples must be greater than or equal to 0, but got " << num_samples; - return FAILED; - } - } - if (num_elements <= 0) { - MS_LOG(ERROR) << "Invalid parameter, num_elements must be greater than 0, but got " << num_elements; - return FAILED; + CHECK_FAIL_RETURN_UNEXPECTED( + num_samples >= 0, "Invalid parameter, num_samples must be greater than or equal to 0, but got " + num_samples); } + CHECK_FAIL_RETURN_UNEXPECTED(num_elements > 0, + "Invalid parameter, num_elements must be greater than 0, but got " + num_elements); if (categories.empty() == true) { std::string category_field = category_op->GetCategoryField(); int64_t num_categories = category_op->GetNumCategories(); - if (num_categories <= 0) { - MS_LOG(ERROR) << "Invalid parameter, num_categories must be greater than 0, but got " << num_elements; - return FAILED; - } + CHECK_FAIL_RETURN_UNEXPECTED(num_categories > 0, + "Invalid parameter, num_categories must be greater than 0, but got " + num_elements); auto category_ptr = std::make_shared>(); - auto ret = GetAllClasses(category_field, category_ptr); - if (SUCCESS != ret) { - return FAILED; - } + RETURN_IF_NOT_OK(GetAllClasses(category_field, category_ptr)); int i = 0; for (auto it = category_ptr->begin(); it != category_ptr->end() && i < num_categories; ++it) { categories.emplace_back(category_field, *it); @@ -1155,27 +1022,26 @@ MSRStatus ShardReader::CreateTasksByCategory(const std::shared_ptr= num_elements) break; - const auto &page_t = shard_header_->GetPage(shard_id, page_id); - const auto &page = page_t.first; - auto group_id = page->GetPageTypeID(); - auto details = ReadRowGroupCriteria(group_id, shard_id, categories[categoryNo], selected_columns_); - if (SUCCESS != std::get<0>(details)) { - return FAILED; + auto pages_ptr = std::make_shared>(); + RETURN_IF_NOT_OK(GetPagesByCategory(shard_id, categories[categoryNo], &pages_ptr)); + for (const auto &page_id : *pages_ptr) { + if (category_index >= num_elements) { + break; } - auto offsets = std::get<4>(details); + std::shared_ptr page_ptr; + RETURN_IF_NOT_OK(shard_header_->GetPage(shard_id, page_id, &page_ptr)); + auto group_id = page_ptr->GetPageTypeID(); + std::shared_ptr row_group_brief_ptr; + RETURN_IF_NOT_OK( + ReadRowGroupCriteria(group_id, shard_id, categories[categoryNo], selected_columns_, &row_group_brief_ptr)); + auto offsets = std::get<3>(*row_group_brief_ptr); auto number_of_rows = offsets.size(); for (uint32_t iStart = 0; iStart < number_of_rows; iStart += 1) { if (category_index < num_elements) { categoryTasks[categoryNo].InsertTask(TaskType::kCommonTask, shard_id, group_id, - std::get<4>(details)[iStart], std::get<5>(details)[iStart]); + std::get<3>(*row_group_brief_ptr)[iStart], + std::get<4>(*row_group_brief_ptr)[iStart]); category_index++; } } @@ -1186,98 +1052,86 @@ MSRStatus ShardReader::CreateTasksByCategory(const std::shared_ptrGetReplacement(), num_elements, num_samples); tasks_.InitSampleIds(); - if (SUCCESS != (*category_op)(tasks_)) { - return FAILED; - } - return SUCCESS; + RETURN_IF_NOT_OK((*category_op)(tasks_)); + return Status::OK(); } -MSRStatus ShardReader::CreateTasksByRow(const std::vector> &row_group_summary, - const std::vector> &operators) { +Status ShardReader::CreateTasksByRow(const std::vector> &row_group_summary, + const std::vector> &operators) { CheckIfColumnInIndex(selected_columns_); - - auto ret = ReadAllRowGroup(selected_columns_); - if (std::get<0>(ret) != SUCCESS) { - return FAILED; + std::shared_ptr row_group_ptr; + RETURN_IF_NOT_OK(ReadAllRowGroup(selected_columns_, &row_group_ptr)); + auto &offsets = std::get<0>(*row_group_ptr); + auto &local_columns = std::get<1>(*row_group_ptr); + CHECK_FAIL_RETURN_UNEXPECTED(shard_count_ <= kMaxFileCount, "shard count is out of range."); + int sample_count = 0; + for (int shard_id = 0; shard_id < shard_count_; shard_id++) { + sample_count += offsets[shard_id].size(); + } + MS_LOG(DEBUG) << "There are " << sample_count << " records in the dataset."; + + // Init the tasks_ size + tasks_.ResizeTask(sample_count); + + // Init the task threads, maybe use ThreadPool is better + std::vector init_tasks_thread(shard_count_); + + uint32_t current_offset = 0; + for (uint32_t shard_id = 0; shard_id < shard_count_; shard_id++) { + init_tasks_thread[shard_id] = std::thread([this, &offsets, &local_columns, shard_id, current_offset]() { + auto offset = current_offset; + for (uint32_t i = 0; i < offsets[shard_id].size(); i += 1) { + tasks_.InsertTask(offset, TaskType::kCommonTask, offsets[shard_id][i][0], offsets[shard_id][i][1], + std::vector{offsets[shard_id][i][2], offsets[shard_id][i][3]}, + local_columns[shard_id][i]); + offset++; + } + }); + current_offset += offsets[shard_id].size(); } - auto &offsets = std::get<1>(ret); - auto &local_columns = std::get<2>(ret); - if (shard_count_ <= kMaxFileCount) { - int sample_count = 0; - for (int shard_id = 0; shard_id < shard_count_; shard_id++) { - sample_count += offsets[shard_id].size(); - } - MS_LOG(DEBUG) << "There are " << sample_count << " records in the dataset."; - - // Init the tasks_ size - tasks_.ResizeTask(sample_count); - - // Init the task threads, maybe use ThreadPool is better - std::vector init_tasks_thread(shard_count_); - - uint32_t current_offset = 0; - for (uint32_t shard_id = 0; shard_id < shard_count_; shard_id++) { - init_tasks_thread[shard_id] = std::thread([this, &offsets, &local_columns, shard_id, current_offset]() { - auto offset = current_offset; - for (uint32_t i = 0; i < offsets[shard_id].size(); i += 1) { - tasks_.InsertTask(offset, TaskType::kCommonTask, offsets[shard_id][i][0], offsets[shard_id][i][1], - std::vector{offsets[shard_id][i][2], offsets[shard_id][i][3]}, - local_columns[shard_id][i]); - offset++; - } - }); - current_offset += offsets[shard_id].size(); - } - for (uint32_t shard_id = 0; shard_id < shard_count_; shard_id++) { - init_tasks_thread[shard_id].join(); - } - } else { - return FAILED; + for (uint32_t shard_id = 0; shard_id < shard_count_; shard_id++) { + init_tasks_thread[shard_id].join(); } - return SUCCESS; + return Status::OK(); } -MSRStatus ShardReader::CreateLazyTasksByRow(const std::vector> &row_group_summary, - const std::vector> &operators) { +Status ShardReader::CreateLazyTasksByRow(const std::vector> &row_group_summary, + const std::vector> &operators) { CheckIfColumnInIndex(selected_columns_); + CHECK_FAIL_RETURN_UNEXPECTED(shard_count_ <= kMaxFileCount, "shard count is out of range."); + uint32_t sample_count = shard_sample_count_[shard_sample_count_.size() - 1]; + MS_LOG(DEBUG) << "There are " << sample_count << " records in the dataset."; + + // Init the tasks_ size + tasks_.ResizeTask(sample_count); + + // Init the task threads, maybe use ThreadPool is better + std::vector init_tasks_thread(shard_count_); + + for (uint32_t shard_id = 0; shard_id < shard_count_; shard_id++) { + // the offset indicate the shard start + uint32_t current_offset = shard_id == 0 ? 0 : shard_sample_count_[shard_id - 1]; + + // the count indicate the number of samples in the shard + uint32_t shard_count = + shard_id == 0 ? shard_sample_count_[0] : shard_sample_count_[shard_id] - shard_sample_count_[shard_id - 1]; + init_tasks_thread[shard_id] = std::thread([this, shard_id, current_offset, shard_count]() { + for (uint32_t i = current_offset; i < shard_count + current_offset; ++i) { + // here "i - current_offset" indicate the sample id in the shard + tasks_.InsertTask(i, TaskType::kCommonTask, shard_id, i - current_offset, {}, json()); + } + }); + } - if (shard_count_ <= kMaxFileCount) { - uint32_t sample_count = shard_sample_count_[shard_sample_count_.size() - 1]; - MS_LOG(DEBUG) << "There are " << sample_count << " records in the dataset."; - - // Init the tasks_ size - tasks_.ResizeTask(sample_count); - - // Init the task threads, maybe use ThreadPool is better - std::vector init_tasks_thread(shard_count_); - - for (uint32_t shard_id = 0; shard_id < shard_count_; shard_id++) { - // the offset indicate the shard start - uint32_t current_offset = shard_id == 0 ? 0 : shard_sample_count_[shard_id - 1]; - - // the count indicate the number of samples in the shard - uint32_t shard_count = - shard_id == 0 ? shard_sample_count_[0] : shard_sample_count_[shard_id] - shard_sample_count_[shard_id - 1]; - init_tasks_thread[shard_id] = std::thread([this, shard_id, current_offset, shard_count]() { - for (uint32_t i = current_offset; i < shard_count + current_offset; ++i) { - // here "i - current_offset" indicate the sample id in the shard - tasks_.InsertTask(i, TaskType::kCommonTask, shard_id, i - current_offset, {}, json()); - } - }); - } - - for (uint32_t shard_id = 0; shard_id < shard_count_; shard_id++) { - init_tasks_thread[shard_id].join(); - } - } else { - return FAILED; + for (uint32_t shard_id = 0; shard_id < shard_count_; shard_id++) { + init_tasks_thread[shard_id].join(); } - return SUCCESS; + return Status::OK(); } -MSRStatus ShardReader::CreateTasks(const std::vector> &row_group_summary, - const std::vector> &operators) { +Status ShardReader::CreateTasks(const std::vector> &row_group_summary, + const std::vector> &operators) { int category_operator = -1; for (uint32_t i = 0; i < operators.size(); ++i) { const auto &op = operators[i]; @@ -1289,13 +1143,9 @@ MSRStatus ShardReader::CreateTasks(const std::vector(op)) continue; + if (std::dynamic_pointer_cast(op)) { + continue; + } if (std::dynamic_pointer_cast(op) || std::dynamic_pointer_cast(op)) { op->SetShardSampleCount(shard_sample_count_); } - - if (SUCCESS != (*op)(tasks_)) { - return FAILED; - } + RETURN_IF_NOT_OK((*op)(tasks_)); } if (tasks_.permutation_.empty()) tasks_.MakePerm(); @@ -1332,16 +1177,14 @@ MSRStatus ShardReader::CreateTasks(const std::vector *task_content_ptr) { + RETURN_UNEXPECTED_IF_NULL(task_content_ptr); // All tasks are done - if (task_id >= static_cast(tasks_.Size())) { - return std::make_pair(FAILED, - std::make_pair(TaskType::kCommonTask, std::vector, json>>())); - } - + CHECK_FAIL_RETURN_UNEXPECTED(task_id < static_cast(tasks_.Size()), "task id is out of range."); uint32_t shard_id = 0; uint32_t group_id = 0; uint32_t blob_start = 0; @@ -1353,8 +1196,9 @@ TASK_RETURN_CONTENT ShardReader::ConsumerOneTask(int task_id, uint32_t consumer_ // check task type auto task_type = std::get<0>(task); if (task_type == TaskType::kPaddedTask) { - return std::make_pair(SUCCESS, - std::make_pair(TaskType::kPaddedTask, std::vector, json>>())); + *task_content_ptr = + std::make_shared(TaskType::kPaddedTask, std::vector, json>>()); + return Status::OK(); } shard_id = std::get<0>(std::get<1>(task)); // shard id @@ -1369,13 +1213,10 @@ TASK_RETURN_CONTENT ShardReader::ConsumerOneTask(int task_id, uint32_t consumer_ uint32_t sample_id_in_shard = std::get<1>(std::get<1>(task)); // read the meta from index - auto row_meta = ReadRowGroupByShardIDAndSampleID(selected_columns_, shard_id, sample_id_in_shard); - if (std::get<0>(row_meta) != SUCCESS) { - return std::make_pair( - FAILED, std::make_pair(TaskType::kCommonTask, std::vector, json>>())); - } - auto &offsets = std::get<1>(row_meta); - auto &local_columns = std::get<2>(row_meta); + std::shared_ptr row_group_ptr; + RETURN_IF_NOT_OK(ReadRowGroupByShardIDAndSampleID(selected_columns_, shard_id, sample_id_in_shard, &row_group_ptr)); + auto &offsets = std::get<0>(*row_group_ptr); + auto &local_columns = std::get<1>(*row_group_ptr); group_id = offsets[shard_id][0][1]; // group_id blob_start = offsets[shard_id][0][2]; // blob start @@ -1384,42 +1225,35 @@ TASK_RETURN_CONTENT ShardReader::ConsumerOneTask(int task_id, uint32_t consumer_ } // read the blob from data file - const auto &ret = shard_header_->GetPageByGroupId(group_id, shard_id); - if (SUCCESS != ret.first) { - return std::make_pair(FAILED, - std::make_pair(TaskType::kCommonTask, std::vector, json>>())); - } - const std::shared_ptr &page = ret.second; + std::shared_ptr page_ptr; + RETURN_IF_NOT_OK(shard_header_->GetPageByGroupId(group_id, shard_id, &page_ptr)); + MS_LOG(DEBUG) << "Success to get page by group id."; // Pack image list std::vector images(blob_end - blob_start); - auto file_offset = header_size_ + page_size_ * (page->GetPageID()) + blob_start; + auto file_offset = header_size_ + page_size_ * (page_ptr->GetPageID()) + blob_start; auto &io_seekg = file_streams_random_[consumer_id][shard_id]->seekg(file_offset, std::ios::beg); if (!io_seekg.good() || io_seekg.fail() || io_seekg.bad()) { - MS_LOG(ERROR) << "File seekg failed"; file_streams_random_[consumer_id][shard_id]->close(); - return std::make_pair(FAILED, - std::make_pair(TaskType::kCommonTask, std::vector, json>>())); + RETURN_STATUS_UNEXPECTED("Failed to seekg file."); } - auto &io_read = file_streams_random_[consumer_id][shard_id]->read(reinterpret_cast(&images[0]), blob_end - blob_start); if (!io_read.good() || io_read.fail() || io_read.bad()) { - MS_LOG(ERROR) << "File read failed"; file_streams_random_[consumer_id][shard_id]->close(); - return std::make_pair(FAILED, - std::pair(TaskType::kCommonTask, std::vector, json>>())); + RETURN_STATUS_UNEXPECTED("Failed to read file."); } // Deliver batch data to output map std::vector, json>> batch; batch.emplace_back(std::move(images), std::move(var_fields)); - return std::make_pair(SUCCESS, std::make_pair(TaskType::kCommonTask, std::move(batch))); + *task_content_ptr = std::make_shared(TaskType::kCommonTask, std::move(batch)); + return Status::OK(); } -MSRStatus ShardReader::ConsumerByRow(int consumer_id) { +void ShardReader::ConsumerByRow(int consumer_id) { // Set thread name #if !defined(_WIN32) && !defined(_WIN64) && !defined(__APPLE__) auto thread_id = kThreadName + std::to_string(consumer_id); @@ -1435,13 +1269,15 @@ MSRStatus ShardReader::ConsumerByRow(int consumer_id) { // All tasks are done if (sample_id_pos >= static_cast(tasks_.sample_ids_.size())) { - return FAILED; + return; } - const auto &ret = ConsumerOneTask(tasks_.sample_ids_[sample_id_pos], consumer_id); - if (SUCCESS != ret.first) { - return FAILED; + auto task_content_ptr = + std::make_shared(TaskType::kCommonTask, std::vector, json>>()); + if (ConsumerOneTask(tasks_.sample_ids_[sample_id_pos], consumer_id, &task_content_ptr).IsError()) { + MS_LOG(ERROR) << "Error in ConsumerOneTask."; + return; } - const auto &batch = (ret.second).second; + const auto &batch = (*task_content_ptr).second; // Hanging if maximum map size exceeded // otherwise, set batch data in map { @@ -1449,7 +1285,7 @@ MSRStatus ShardReader::ConsumerByRow(int consumer_id) { cv_delivery_.wait(lck, [sample_id_pos, this] { return interrupt_ || sample_id_pos <= deliver_id_ + kNumBatchInMap; }); if (interrupt_) { - return SUCCESS; + return; } delivery_map_[sample_id_pos] = std::make_shared, json>>>(std::move(batch)); @@ -1482,53 +1318,40 @@ std::vector, json>> ShardReader::GetNext() { return *res; } -std::pair, json>>> ShardReader::GetNextById( - const int64_t &task_id, const int32_t &consumer_id) { +TASK_CONTENT ShardReader::GetNextById(const int64_t &task_id, const int32_t &consumer_id) { + auto task_content_ptr = + std::make_shared(TaskType::kCommonTask, std::vector, json>>()); if (interrupt_) { - return std::make_pair(TaskType::kCommonTask, std::vector, json>>()); - } - const auto &ret = ConsumerOneTask(task_id, consumer_id); - if (SUCCESS != ret.first) { - return std::make_pair(TaskType::kCommonTask, std::vector, json>>()); + return *task_content_ptr; } - return std::move(ret.second); + (void)ConsumerOneTask(task_id, consumer_id, &task_content_ptr); + return std::move(*task_content_ptr); } -std::pair>> ShardReader::UnCompressBlob( - const std::vector &raw_blob_data) { +Status ShardReader::UnCompressBlob(const std::vector &raw_blob_data, + std::shared_ptr>> *blob_data_ptr) { + RETURN_UNEXPECTED_IF_NULL(blob_data_ptr); auto loaded_columns = selected_columns_.size() == 0 ? shard_column_->GetColumnName() : selected_columns_; auto blob_fields = GetBlobFields().second; - std::vector> blob_data; for (uint32_t i_col = 0; i_col < loaded_columns.size(); ++i_col) { if (std::find(blob_fields.begin(), blob_fields.end(), loaded_columns[i_col]) == blob_fields.end()) continue; const unsigned char *data = nullptr; std::unique_ptr data_ptr; uint64_t n_bytes = 0; - auto ret = shard_column_->GetColumnFromBlob(loaded_columns[i_col], raw_blob_data, &data, &data_ptr, &n_bytes); - if (ret != SUCCESS) { - MS_LOG(ERROR) << "Error when get data from blob, column name is " << loaded_columns[i_col] << "."; - return {FAILED, std::vector>(blob_fields.size(), std::vector())}; - } + RETURN_IF_NOT_OK( + shard_column_->GetColumnFromBlob(loaded_columns[i_col], raw_blob_data, &data, &data_ptr, &n_bytes)); if (data == nullptr) { data = reinterpret_cast(data_ptr.get()); } std::vector column(data, data + (n_bytes / sizeof(unsigned char))); - blob_data.push_back(column); + (*blob_data_ptr)->push_back(column); } - return {SUCCESS, blob_data}; + return Status::OK(); } -std::vector>, pybind11::object>> ShardReader::GetNextPy() { - auto res = GetNext(); - vector>, pybind11::object>> data; - std::transform(res.begin(), res.end(), std::back_inserter(data), - [this](const std::tuple, json> &item) { - auto &j = std::get<1>(item); - pybind11::object obj = nlohmann::detail::FromJsonImpl(j); - auto ret = UnCompressBlob(std::get<0>(item)); - return std::make_tuple(ret.second, std::move(obj)); - }); - return data; +Status ShardReader::GetTotalBlobSize(int64_t *total_blob_size) { + *total_blob_size = total_blob_size_; + return Status::OK(); } void ShardReader::Reset() { @@ -1550,11 +1373,13 @@ void ShardReader::ShuffleTask() { } for (const auto &op : operators_) { if (std::dynamic_pointer_cast(op) && has_sharding == false) { - if (SUCCESS != (*op)(tasks_)) { + auto s = (*op)(tasks_); + if (s.IsError()) { MS_LOG(WARNING) << "Redo randomSampler failed."; } } else if (std::dynamic_pointer_cast(op)) { - if (SUCCESS != (*op)(tasks_)) { + auto s = (*op)(tasks_); + if (s.IsError()) { MS_LOG(WARNING) << "Redo distributeSampler failed."; } } diff --git a/mindspore/ccsrc/minddata/mindrecord/io/shard_segment.cc b/mindspore/ccsrc/minddata/mindrecord/io/shard_segment.cc index cf7c036c1a..c295319bca 100644 --- a/mindspore/ccsrc/minddata/mindrecord/io/shard_segment.cc +++ b/mindspore/ccsrc/minddata/mindrecord/io/shard_segment.cc @@ -30,9 +30,13 @@ namespace mindspore { namespace mindrecord { ShardSegment::ShardSegment() { SetAllInIndex(false); } -std::pair> ShardSegment::GetCategoryFields() { +Status ShardSegment::GetCategoryFields(std::shared_ptr> *fields_ptr) { + RETURN_UNEXPECTED_IF_NULL(fields_ptr); // Skip if already populated - if (!candidate_category_fields_.empty()) return {SUCCESS, candidate_category_fields_}; + if (!candidate_category_fields_.empty()) { + *fields_ptr = std::make_shared>(candidate_category_fields_); + return Status::OK(); + } std::string sql = "PRAGMA table_info(INDEXES);"; std::vector> field_names; @@ -40,11 +44,12 @@ std::pair> ShardSegment::GetCategoryFields() { char *errmsg = nullptr; int rc = sqlite3_exec(database_paths_[0], common::SafeCStr(sql), SelectCallback, &field_names, &errmsg); if (rc != SQLITE_OK) { - MS_LOG(ERROR) << "Error in select statement, sql: " << sql << ", error: " << errmsg; + std::ostringstream oss; + oss << "Error in select statement, sql: " << sql + ", error: " << errmsg; sqlite3_free(errmsg); sqlite3_close(database_paths_[0]); database_paths_[0] = nullptr; - return {FAILED, vector{}}; + RETURN_STATUS_UNEXPECTED(oss.str()); } else { MS_LOG(INFO) << "Get " << static_cast(field_names.size()) << " records from index."; } @@ -55,53 +60,46 @@ std::pair> ShardSegment::GetCategoryFields() { sqlite3_free(errmsg); sqlite3_close(database_paths_[0]); database_paths_[0] = nullptr; - return {FAILED, vector{}}; + RETURN_STATUS_UNEXPECTED("idx is out of range."); } candidate_category_fields_.push_back(field_names[idx][1]); idx += 2; } sqlite3_free(errmsg); - return {SUCCESS, candidate_category_fields_}; + *fields_ptr = std::make_shared>(candidate_category_fields_); + return Status::OK(); } -MSRStatus ShardSegment::SetCategoryField(std::string category_field) { - if (GetCategoryFields().first != SUCCESS) { - MS_LOG(ERROR) << "Get candidate category field failed"; - return FAILED; - } +Status ShardSegment::SetCategoryField(std::string category_field) { + std::shared_ptr> fields_ptr; + RETURN_IF_NOT_OK(GetCategoryFields(&fields_ptr)); category_field = category_field + "_0"; if (std::any_of(std::begin(candidate_category_fields_), std::end(candidate_category_fields_), [category_field](std::string x) { return x == category_field; })) { current_category_field_ = category_field; - return SUCCESS; + return Status::OK(); } - MS_LOG(ERROR) << "Field " << category_field << " is not a candidate category field."; - return FAILED; + RETURN_STATUS_UNEXPECTED("Field " + category_field + " is not a candidate category field."); } -std::pair ShardSegment::ReadCategoryInfo() { +Status ShardSegment::ReadCategoryInfo(std::shared_ptr *category_ptr) { + RETURN_UNEXPECTED_IF_NULL(category_ptr); MS_LOG(INFO) << "Read category begin"; - auto ret = WrapCategoryInfo(); - if (ret.first != SUCCESS) { - MS_LOG(ERROR) << "Get category info failed"; - return {FAILED, ""}; - } + auto category_info_ptr = std::make_shared(); + RETURN_IF_NOT_OK(WrapCategoryInfo(&category_info_ptr)); // Convert category info to json string - auto category_json_string = ToJsonForCategory(ret.second); + *category_ptr = std::make_shared(ToJsonForCategory(*category_info_ptr)); MS_LOG(INFO) << "Read category end"; - return {SUCCESS, category_json_string}; + return Status::OK(); } -std::pair>> ShardSegment::WrapCategoryInfo() { +Status ShardSegment::WrapCategoryInfo(std::shared_ptr *category_info_ptr) { + RETURN_UNEXPECTED_IF_NULL(category_info_ptr); std::map counter; - - if (!ValidateFieldName(current_category_field_)) { - MS_LOG(ERROR) << "category field error from index, it is: " << current_category_field_; - return {FAILED, std::vector>()}; - } - + CHECK_FAIL_RETURN_UNEXPECTED(ValidateFieldName(current_category_field_), + "Category field error from index, it is: " + current_category_field_); std::string sql = "SELECT " + current_category_field_ + ", COUNT(" + current_category_field_ + ") AS `value_occurrence` FROM indexes GROUP BY " + current_category_field_ + ";"; @@ -109,13 +107,13 @@ std::pair>> ShardSegmen std::vector> field_count; char *errmsg = nullptr; - int rc = sqlite3_exec(db, common::SafeCStr(sql), SelectCallback, &field_count, &errmsg); - if (rc != SQLITE_OK) { - MS_LOG(ERROR) << "Error in select statement, sql: " << sql << ", error: " << errmsg; + if (sqlite3_exec(db, common::SafeCStr(sql), SelectCallback, &field_count, &errmsg) != SQLITE_OK) { + std::ostringstream oss; + oss << "Error in select statement, sql: " << sql + ", error: " << errmsg; sqlite3_free(errmsg); sqlite3_close(db); db = nullptr; - return {FAILED, std::vector>()}; + RETURN_STATUS_UNEXPECTED(oss.str()); } else { MS_LOG(INFO) << "Get " << static_cast(field_count.size()) << " records from index."; } @@ -127,14 +125,14 @@ std::pair>> ShardSegmen } int idx = 0; - std::vector> category_vec(counter.size()); - (void)std::transform(counter.begin(), counter.end(), category_vec.begin(), [&idx](std::tuple item) { - return std::make_tuple(idx++, std::get<0>(item), std::get<1>(item)); - }); - return {SUCCESS, std::move(category_vec)}; + (*category_info_ptr)->resize(counter.size()); + (void)std::transform( + counter.begin(), counter.end(), (*category_info_ptr)->begin(), + [&idx](std::tuple item) { return std::make_tuple(idx++, std::get<0>(item), std::get<1>(item)); }); + return Status::OK(); } -std::string ShardSegment::ToJsonForCategory(const std::vector> &tri_vec) { +std::string ShardSegment::ToJsonForCategory(const CATEGORY_INFO &tri_vec) { std::vector category_json_vec; for (auto q : tri_vec) { json j; @@ -152,27 +150,20 @@ std::string ShardSegment::ToJsonForCategory(const std::vector>> ShardSegment::ReadAtPageById(int64_t category_id, - int64_t page_no, - int64_t n_rows_of_page) { - auto ret = WrapCategoryInfo(); - if (ret.first != SUCCESS) { - MS_LOG(ERROR) << "Get category info"; - return {FAILED, std::vector>{}}; - } - if (category_id >= static_cast(ret.second.size()) || category_id < 0) { - MS_LOG(ERROR) << "Illegal category id, id: " << category_id; - return {FAILED, std::vector>{}}; - } - int total_rows_in_category = std::get<2>(ret.second[category_id]); +Status ShardSegment::ReadAtPageById(int64_t category_id, int64_t page_no, int64_t n_rows_of_page, + std::shared_ptr>> *page_ptr) { + RETURN_UNEXPECTED_IF_NULL(page_ptr); + auto category_info_ptr = std::make_shared(); + RETURN_IF_NOT_OK(WrapCategoryInfo(&category_info_ptr)); + CHECK_FAIL_RETURN_UNEXPECTED(category_id < static_cast(category_info_ptr->size()) && category_id >= 0, + "Invalid category id, id: " + std::to_string(category_id)); + int total_rows_in_category = std::get<2>((*category_info_ptr)[category_id]); // Quit if category not found or page number is out of range - if (total_rows_in_category <= 0 || page_no < 0 || n_rows_of_page <= 0 || - page_no * n_rows_of_page >= total_rows_in_category) { - MS_LOG(ERROR) << "Illegal page no / page size, page no: " << page_no << ", page size: " << n_rows_of_page; - return {FAILED, std::vector>{}}; - } + CHECK_FAIL_RETURN_UNEXPECTED(total_rows_in_category > 0 && page_no >= 0 && n_rows_of_page > 0 && + page_no * n_rows_of_page < total_rows_in_category, + "Invalid page no / page size, page no: " + std::to_string(page_no) + + ", page size: " + std::to_string(n_rows_of_page)); - std::vector> page; auto row_group_summary = ReadRowGroupSummary(); uint64_t i_start = page_no * n_rows_of_page; @@ -183,12 +174,12 @@ std::pair>> ShardSegment::ReadAtPage auto shard_id = std::get<0>(rg); auto group_id = std::get<1>(rg); - auto details = ReadRowGroupCriteria( - group_id, shard_id, std::make_pair(CleanUp(current_category_field_), std::get<1>(ret.second[category_id]))); - if (SUCCESS != std::get<0>(details)) { - return {FAILED, std::vector>{}}; - } - auto offsets = std::get<4>(details); + std::shared_ptr row_group_brief_ptr; + RETURN_IF_NOT_OK(ReadRowGroupCriteria( + group_id, shard_id, + std::make_pair(CleanUp(current_category_field_), std::get<1>((*category_info_ptr)[category_id])), {""}, + &row_group_brief_ptr)); + auto offsets = std::get<3>(*row_group_brief_ptr); uint64_t number_of_rows = offsets.size(); if (idx + number_of_rows < i_start) { idx += number_of_rows; @@ -197,131 +188,116 @@ std::pair>> ShardSegment::ReadAtPage for (uint64_t i = 0; i < number_of_rows; ++i, ++idx) { if (idx >= i_start && idx < i_end) { - auto ret1 = PackImages(group_id, shard_id, offsets[i]); - if (SUCCESS != ret1.first) { - return {FAILED, std::vector>{}}; - } - page.push_back(std::move(ret1.second)); + auto images_ptr = std::make_shared>(); + RETURN_IF_NOT_OK(PackImages(group_id, shard_id, offsets[i], &images_ptr)); + (*page_ptr)->push_back(std::move(*images_ptr)); } } } - return {SUCCESS, std::move(page)}; + return Status::OK(); } -std::pair> ShardSegment::PackImages(int group_id, int shard_id, - std::vector offset) { - const auto &ret = shard_header_->GetPageByGroupId(group_id, shard_id); - if (SUCCESS != ret.first) { - return {FAILED, std::vector()}; - } - const std::shared_ptr &blob_page = ret.second; - +Status ShardSegment::PackImages(int group_id, int shard_id, std::vector offset, + std::shared_ptr> *images_ptr) { + RETURN_UNEXPECTED_IF_NULL(images_ptr); + std::shared_ptr page_ptr; + RETURN_IF_NOT_OK(shard_header_->GetPageByGroupId(group_id, shard_id, &page_ptr)); // Pack image list - std::vector images(offset[1] - offset[0]); - auto file_offset = header_size_ + page_size_ * (blob_page->GetPageID()) + offset[0]; + (*images_ptr)->resize(offset[1] - offset[0]); + + auto file_offset = header_size_ + page_size_ * page_ptr->GetPageID() + offset[0]; auto &io_seekg = file_streams_random_[0][shard_id]->seekg(file_offset, std::ios::beg); if (!io_seekg.good() || io_seekg.fail() || io_seekg.bad()) { - MS_LOG(ERROR) << "File seekg failed"; file_streams_random_[0][shard_id]->close(); - return {FAILED, {}}; + RETURN_STATUS_UNEXPECTED("Failed to seekg file."); } - auto &io_read = file_streams_random_[0][shard_id]->read(reinterpret_cast(&images[0]), offset[1] - offset[0]); + auto &io_read = + file_streams_random_[0][shard_id]->read(reinterpret_cast(&((*(*images_ptr))[0])), offset[1] - offset[0]); if (!io_read.good() || io_read.fail() || io_read.bad()) { - MS_LOG(ERROR) << "File read failed"; file_streams_random_[0][shard_id]->close(); - return {FAILED, {}}; + RETURN_STATUS_UNEXPECTED("Failed to read file."); } - - return {SUCCESS, std::move(images)}; + return Status::OK(); } -std::pair>> ShardSegment::ReadAtPageByName(std::string category_name, - int64_t page_no, - int64_t n_rows_of_page) { - auto ret = WrapCategoryInfo(); - if (ret.first != SUCCESS) { - MS_LOG(ERROR) << "Get category info"; - return {FAILED, std::vector>{}}; - } - for (const auto &categories : ret.second) { +Status ShardSegment::ReadAtPageByName(std::string category_name, int64_t page_no, int64_t n_rows_of_page, + std::shared_ptr>> *pages_ptr) { + RETURN_UNEXPECTED_IF_NULL(pages_ptr); + auto category_info_ptr = std::make_shared(); + RETURN_IF_NOT_OK(WrapCategoryInfo(&category_info_ptr)); + for (const auto &categories : *category_info_ptr) { if (std::get<1>(categories) == category_name) { - auto result = ReadAtPageById(std::get<0>(categories), page_no, n_rows_of_page); - return result; + RETURN_IF_NOT_OK(ReadAtPageById(std::get<0>(categories), page_no, n_rows_of_page, pages_ptr)); + return Status::OK(); } } - return {FAILED, std::vector>()}; + RETURN_STATUS_UNEXPECTED("Category name can not match."); } -std::pair, json>>> ShardSegment::ReadAllAtPageById( - int64_t category_id, int64_t page_no, int64_t n_rows_of_page) { - auto ret = WrapCategoryInfo(); - if (ret.first != SUCCESS || category_id >= static_cast(ret.second.size())) { - MS_LOG(ERROR) << "Illegal category id, id: " << category_id; - return {FAILED, std::vector, json>>{}}; - } - int total_rows_in_category = std::get<2>(ret.second[category_id]); - // Quit if category not found or page number is out of range - if (total_rows_in_category <= 0 || page_no < 0 || page_no * n_rows_of_page >= total_rows_in_category) { - MS_LOG(ERROR) << "Illegal page no: " << page_no << ", page size: " << n_rows_of_page; - return {FAILED, std::vector, json>>{}}; - } +Status ShardSegment::ReadAllAtPageById(int64_t category_id, int64_t page_no, int64_t n_rows_of_page, + std::shared_ptr *pages_ptr) { + RETURN_UNEXPECTED_IF_NULL(pages_ptr); + auto category_info_ptr = std::make_shared(); + RETURN_IF_NOT_OK(WrapCategoryInfo(&category_info_ptr)); + CHECK_FAIL_RETURN_UNEXPECTED(category_id < static_cast(category_info_ptr->size()), + "Invalid category id: " + std::to_string(category_id)); - std::vector, json>> page; + int total_rows_in_category = std::get<2>((*category_info_ptr)[category_id]); + // Quit if category not found or page number is out of range + CHECK_FAIL_RETURN_UNEXPECTED(total_rows_in_category > 0 && page_no >= 0 && n_rows_of_page > 0 && + page_no * n_rows_of_page < total_rows_in_category, + "Invalid page no / page size / total size, page no: " + std::to_string(page_no) + + ", page size of page: " + std::to_string(n_rows_of_page) + + ", total size: " + std::to_string(total_rows_in_category)); auto row_group_summary = ReadRowGroupSummary(); int i_start = page_no * n_rows_of_page; int i_end = std::min(static_cast(total_rows_in_category), (page_no + 1) * n_rows_of_page); int idx = 0; for (const auto &rg : row_group_summary) { - if (idx >= i_end) break; + if (idx >= i_end) { + break; + } auto shard_id = std::get<0>(rg); auto group_id = std::get<1>(rg); - auto details = ReadRowGroupCriteria( - group_id, shard_id, std::make_pair(CleanUp(current_category_field_), std::get<1>(ret.second[category_id]))); - if (SUCCESS != std::get<0>(details)) { - return {FAILED, std::vector, json>>{}}; - } - auto offsets = std::get<4>(details); - auto labels = std::get<5>(details); + std::shared_ptr row_group_brief_ptr; + RETURN_IF_NOT_OK(ReadRowGroupCriteria( + group_id, shard_id, + std::make_pair(CleanUp(current_category_field_), std::get<1>((*category_info_ptr)[category_id])), {""}, + &row_group_brief_ptr)); + auto offsets = std::get<3>(*row_group_brief_ptr); + auto labels = std::get<4>(*row_group_brief_ptr); int number_of_rows = offsets.size(); if (idx + number_of_rows < i_start) { idx += number_of_rows; continue; } - - if (number_of_rows > static_cast(labels.size())) { - MS_LOG(ERROR) << "Illegal row number of page: " << number_of_rows; - return {FAILED, std::vector, json>>{}}; - } + CHECK_FAIL_RETURN_UNEXPECTED(number_of_rows <= static_cast(labels.size()), + "Invalid row number of page: " + number_of_rows); for (int i = 0; i < number_of_rows; ++i, ++idx) { if (idx >= i_start && idx < i_end) { - auto ret1 = PackImages(group_id, shard_id, offsets[i]); - if (SUCCESS != ret1.first) { - return {FAILED, std::vector, json>>{}}; - } - page.emplace_back(std::move(ret1.second), std::move(labels[i])); + auto images_ptr = std::make_shared>(); + RETURN_IF_NOT_OK(PackImages(group_id, shard_id, offsets[i], &images_ptr)); + (*pages_ptr)->emplace_back(std::move(*images_ptr), std::move(labels[i])); } } } - return {SUCCESS, std::move(page)}; + return Status::OK(); } -std::pair, json>>> ShardSegment::ReadAllAtPageByName( - std::string category_name, int64_t page_no, int64_t n_rows_of_page) { - auto ret = WrapCategoryInfo(); - if (ret.first != SUCCESS) { - MS_LOG(ERROR) << "Get category info"; - return {FAILED, std::vector, json>>{}}; - } - +Status ShardSegment::ReadAllAtPageByName(std::string category_name, int64_t page_no, int64_t n_rows_of_page, + std::shared_ptr *pages_ptr) { + RETURN_UNEXPECTED_IF_NULL(pages_ptr); + auto category_info_ptr = std::make_shared(); + RETURN_IF_NOT_OK(WrapCategoryInfo(&category_info_ptr)); // category_name to category_id int64_t category_id = -1; - for (const auto &categories : ret.second) { + for (const auto &categories : *category_info_ptr) { std::string categories_name = std::get<1>(categories); if (categories_name == category_name) { @@ -329,45 +305,8 @@ std::pair, json>>> ShardS break; } } - - if (category_id == -1) { - return {FAILED, std::vector, json>>{}}; - } - - return ReadAllAtPageById(category_id, page_no, n_rows_of_page); -} - -std::pair, pybind11::object>>> ShardSegment::ReadAtPageByIdPy( - int64_t category_id, int64_t page_no, int64_t n_rows_of_page) { - auto res = ReadAllAtPageById(category_id, page_no, n_rows_of_page); - if (res.first != SUCCESS) { - return {FAILED, std::vector, pybind11::object>>{}}; - } - - vector, pybind11::object>> json_data; - std::transform(res.second.begin(), res.second.end(), std::back_inserter(json_data), - [](const std::tuple, json> &item) { - auto &j = std::get<1>(item); - pybind11::object obj = nlohmann::detail::FromJsonImpl(j); - return std::make_tuple(std::get<0>(item), std::move(obj)); - }); - return {SUCCESS, std::move(json_data)}; -} - -std::pair, pybind11::object>>> ShardSegment::ReadAtPageByNamePy( - std::string category_name, int64_t page_no, int64_t n_rows_of_page) { - auto res = ReadAllAtPageByName(category_name, page_no, n_rows_of_page); - if (res.first != SUCCESS) { - return {FAILED, std::vector, pybind11::object>>{}}; - } - vector, pybind11::object>> json_data; - std::transform(res.second.begin(), res.second.end(), std::back_inserter(json_data), - [](const std::tuple, json> &item) { - auto &j = std::get<1>(item); - pybind11::object obj = nlohmann::detail::FromJsonImpl(j); - return std::make_tuple(std::get<0>(item), std::move(obj)); - }); - return {SUCCESS, std::move(json_data)}; + CHECK_FAIL_RETURN_UNEXPECTED(category_id != -1, "Invalid category name."); + return ReadAllAtPageById(category_id, page_no, n_rows_of_page, pages_ptr); } std::pair> ShardSegment::GetBlobFields() { @@ -382,7 +321,9 @@ std::pair> ShardSegment::GetBlobFields() { } std::string ShardSegment::CleanUp(std::string field_name) { - while (field_name.back() >= '0' && field_name.back() <= '9') field_name.pop_back(); + while (field_name.back() >= '0' && field_name.back() <= '9') { + field_name.pop_back(); + } field_name.pop_back(); return field_name; } diff --git a/mindspore/ccsrc/minddata/mindrecord/io/shard_writer.cc b/mindspore/ccsrc/minddata/mindrecord/io/shard_writer.cc index e80d16c212..363da56f70 100644 --- a/mindspore/ccsrc/minddata/mindrecord/io/shard_writer.cc +++ b/mindspore/ccsrc/minddata/mindrecord/io/shard_writer.cc @@ -40,82 +40,63 @@ ShardWriter::~ShardWriter() { } } -MSRStatus ShardWriter::GetFullPathFromFileName(const std::vector &paths) { +Status ShardWriter::GetFullPathFromFileName(const std::vector &paths) { // Get full path from file name for (const auto &path : paths) { - if (!CheckIsValidUtf8(path)) { - MS_LOG(ERROR) << "The filename contains invalid uft-8 data: " << path << "."; - return FAILED; - } + CHECK_FAIL_RETURN_UNEXPECTED(CheckIsValidUtf8(path), "The filename contains invalid uft-8 data: " + path); char resolved_path[PATH_MAX] = {0}; char buf[PATH_MAX] = {0}; - if (strncpy_s(buf, PATH_MAX, common::SafeCStr(path), path.length()) != EOK) { - MS_LOG(ERROR) << "Secure func failed"; - return FAILED; - } + CHECK_FAIL_RETURN_UNEXPECTED(strncpy_s(buf, PATH_MAX, common::SafeCStr(path), path.length()) == EOK, + "Secure func failed"); #if defined(_WIN32) || defined(_WIN64) - if (_fullpath(resolved_path, dirname(&(buf[0])), PATH_MAX) == nullptr) { - MS_LOG(ERROR) << "Invalid file path"; - return FAILED; - } - if (_fullpath(resolved_path, common::SafeCStr(path), PATH_MAX) == nullptr) { - MS_LOG(DEBUG) << "Path " << resolved_path; - } + RETURN_UNEXPECTED_IF_NULL(_fullpath(resolved_path, dirname(&(buf[0])), PATH_MAX)); + RETURN_UNEXPECTED_IF_NULL(_fullpath(resolved_path, common::SafeCStr(path), PATH_MAX)); #else - if (realpath(dirname(&(buf[0])), resolved_path) == nullptr) { - MS_LOG(ERROR) << "Invalid file path"; - return FAILED; - } + RETURN_UNEXPECTED_IF_NULL(realpath(dirname(&(buf[0])), resolved_path)); if (realpath(common::SafeCStr(path), resolved_path) == nullptr) { MS_LOG(DEBUG) << "Path " << resolved_path; } #endif file_paths_.emplace_back(string(resolved_path)); } - return SUCCESS; + return Status::OK(); } -MSRStatus ShardWriter::OpenDataFiles(bool append) { +Status ShardWriter::OpenDataFiles(bool append) { // Open files for (const auto &file : file_paths_) { auto realpath = Common::GetRealPath(file); - if (!realpath.has_value()) { - MS_LOG(ERROR) << "Get real path failed, path=" << file; - return FAILED; - } + CHECK_FAIL_RETURN_UNEXPECTED(realpath.has_value(), "Get real path failed, path=" + file); std::shared_ptr fs = std::make_shared(); if (!append) { // if not append and mindrecord file exist, return FAILED fs->open(realpath.value(), std::ios::in | std::ios::binary); if (fs->good()) { - MS_LOG(ERROR) << "MindRecord file already existed, please delete file: " << common::SafeCStr(file); fs->close(); - return FAILED; + RETURN_STATUS_UNEXPECTED("MindRecord file already existed, please delete file: " + file); } fs->close(); - // open the mindrecord file to write fs->open(common::SafeCStr(file), std::ios::out | std::ios::in | std::ios::binary | std::ios::trunc); if (!fs->good()) { - MS_LOG(ERROR) << "MindRecord file could not opened: " << file; - return FAILED; + RETURN_STATUS_UNEXPECTED("MindRecord file could not opened: " + file); } } else { // open the mindrecord file to append fs->open(common::SafeCStr(file), std::ios::out | std::ios::in | std::ios::binary); if (!fs->good()) { - MS_LOG(ERROR) << "MindRecord file could not opened for append: " << file; - return FAILED; + fs->close(); + RETURN_STATUS_UNEXPECTED("MindRecord file could not opened for append: " + file); } } MS_LOG(INFO) << "Open shard file successfully."; file_streams_.push_back(fs); } - return SUCCESS; + return Status::OK(); } -MSRStatus ShardWriter::RemoveLockFile() { +Status ShardWriter::RemoveLockFile() { // Remove temporary file int ret = std::remove(pages_file_.c_str()); if (ret == 0) { @@ -126,125 +107,68 @@ MSRStatus ShardWriter::RemoveLockFile() { if (ret == 0) { MS_LOG(DEBUG) << "Remove lock file."; } - return SUCCESS; + return Status::OK(); } -MSRStatus ShardWriter::InitLockFile() { - if (file_paths_.size() == 0) { - MS_LOG(ERROR) << "File path not initialized."; - return FAILED; - } +Status ShardWriter::InitLockFile() { + CHECK_FAIL_RETURN_UNEXPECTED(file_paths_.size() != 0, "File path not initialized."); lock_file_ = file_paths_[0] + kLockFileSuffix; pages_file_ = file_paths_[0] + kPageFileSuffix; - - if (RemoveLockFile() == FAILED) { - MS_LOG(ERROR) << "Remove file failed."; - return FAILED; - } - return SUCCESS; + RETURN_IF_NOT_OK(RemoveLockFile()); + return Status::OK(); } -MSRStatus ShardWriter::Open(const std::vector &paths, bool append) { +Status ShardWriter::Open(const std::vector &paths, bool append) { shard_count_ = paths.size(); - if (shard_count_ > kMaxShardCount || shard_count_ == 0) { - MS_LOG(ERROR) << "The Shard Count greater than max value(1000) or equal to 0, but got " << shard_count_; - return FAILED; - } - if (schema_count_ > kMaxSchemaCount) { - MS_LOG(ERROR) << "The schema Count greater than max value(1), but got " << schema_count_; - return FAILED; - } + CHECK_FAIL_RETURN_UNEXPECTED(shard_count_ <= kMaxShardCount && shard_count_ != 0, + "The Shard Count greater than max value(1000) or equal to 0, but got " + shard_count_); + CHECK_FAIL_RETURN_UNEXPECTED(schema_count_ <= kMaxSchemaCount, + "The schema Count greater than max value(1), but got " + schema_count_); // Get full path from file name - if (GetFullPathFromFileName(paths) == FAILED) { - MS_LOG(ERROR) << "Get full path from file name failed."; - return FAILED; - } - + RETURN_IF_NOT_OK(GetFullPathFromFileName(paths)); // Open files - if (OpenDataFiles(append) == FAILED) { - MS_LOG(ERROR) << "Open data files failed."; - return FAILED; - } - + RETURN_IF_NOT_OK(OpenDataFiles(append)); // Init lock file - if (InitLockFile() == FAILED) { - MS_LOG(ERROR) << "Init lock file failed."; - return FAILED; - } - return SUCCESS; + RETURN_IF_NOT_OK(InitLockFile()); + return Status::OK(); } -MSRStatus ShardWriter::OpenForAppend(const std::string &path) { - if (!IsLegalFile(path)) { - return FAILED; - } - auto ret1 = ShardHeader::BuildSingleHeader(path); - if (ret1.first != SUCCESS) { - return FAILED; - } - auto json_header = ret1.second; - auto ret2 = GetDatasetFiles(path, json_header["shard_addresses"]); - if (SUCCESS != ret2.first) { - return FAILED; - } - auto addresses = ret2.second; +Status ShardWriter::OpenForAppend(const std::string &path) { + CHECK_FAIL_RETURN_UNEXPECTED(IsLegalFile(path), "Invalid file pacth."); + std::shared_ptr header_ptr; + RETURN_IF_NOT_OK(ShardHeader::BuildSingleHeader(path, &header_ptr)); + auto ds = std::make_shared>(); + RETURN_IF_NOT_OK(GetDatasetFiles(path, (*header_ptr)["shard_addresses"], &ds)); ShardHeader header = ShardHeader(); - if (header.BuildDataset(addresses) == FAILED) { - return FAILED; - } + RETURN_IF_NOT_OK(header.BuildDataset(*ds)); shard_header_ = std::make_shared(header); - MSRStatus ret = SetHeaderSize(shard_header_->GetHeaderSize()); - if (ret == FAILED) { - return FAILED; - } - ret = SetPageSize(shard_header_->GetPageSize()); - if (ret == FAILED) { - return FAILED; - } + RETURN_IF_NOT_OK(SetHeaderSize(shard_header_->GetHeaderSize())); + RETURN_IF_NOT_OK(SetPageSize(shard_header_->GetPageSize())); compression_size_ = shard_header_->GetCompressionSize(); - ret = Open(addresses, true); - if (ret == FAILED) { - MS_LOG(ERROR) << "Invalid file, failed to open file: " << addresses; - return FAILED; - } + RETURN_IF_NOT_OK(Open(*ds, true)); shard_column_ = std::make_shared(shard_header_); - return SUCCESS; + return Status::OK(); } -MSRStatus ShardWriter::Commit() { +Status ShardWriter::Commit() { // Read pages file std::ifstream page_file(pages_file_.c_str()); if (page_file.good()) { page_file.close(); - if (shard_header_->FileToPages(pages_file_) == FAILED) { - MS_LOG(ERROR) << "Read pages from file failed"; - return FAILED; - } - } - - if (WriteShardHeader() == FAILED) { - MS_LOG(ERROR) << "Write metadata failed"; - return FAILED; + RETURN_IF_NOT_OK(shard_header_->FileToPages(pages_file_)); } + RETURN_IF_NOT_OK(WriteShardHeader()); MS_LOG(INFO) << "Write metadata successfully."; - // Remove lock file - if (RemoveLockFile() == FAILED) { - MS_LOG(ERROR) << "Remove lock file failed."; - return FAILED; - } + RETURN_IF_NOT_OK(RemoveLockFile()); - return SUCCESS; + return Status::OK(); } -MSRStatus ShardWriter::SetShardHeader(std::shared_ptr header_data) { - MSRStatus ret = header_data->InitByFiles(file_paths_); - if (ret == FAILED) { - return FAILED; - } - +Status ShardWriter::SetShardHeader(std::shared_ptr header_data) { + RETURN_IF_NOT_OK(header_data->InitByFiles(file_paths_)); // set fields in mindrecord when empty std::vector> fields = header_data->GetFields(); if (fields.empty()) { @@ -264,11 +188,7 @@ MSRStatus ShardWriter::SetShardHeader(std::shared_ptr header_data) } // only blob data if (!fields.empty()) { - ret = header_data->AddIndexFields(fields); - if (ret == FAILED) { - MS_LOG(ERROR) << "Add index field failed"; - return FAILED; - } + RETURN_IF_NOT_OK(header_data->AddIndexFields(fields)); } } @@ -276,36 +196,25 @@ MSRStatus ShardWriter::SetShardHeader(std::shared_ptr header_data) shard_header_->SetHeaderSize(header_size_); shard_header_->SetPageSize(page_size_); shard_column_ = std::make_shared(shard_header_); - return SUCCESS; + return Status::OK(); } -MSRStatus ShardWriter::SetHeaderSize(const uint64_t &header_size) { +Status ShardWriter::SetHeaderSize(const uint64_t &header_size) { // header_size [16KB, 128MB] - if (header_size < kMinHeaderSize || header_size > kMaxHeaderSize) { - MS_LOG(ERROR) << "Header size should between 16KB and 128MB."; - return FAILED; - } - if (header_size % 4 != 0) { - MS_LOG(ERROR) << "Header size should be divided by four."; - return FAILED; - } - + CHECK_FAIL_RETURN_UNEXPECTED(header_size >= kMinHeaderSize && header_size <= kMaxHeaderSize, + "Header size should between 16KB and 128MB."); + CHECK_FAIL_RETURN_UNEXPECTED(header_size % 4 == 0, "Header size should be divided by four."); header_size_ = header_size; - return SUCCESS; + return Status::OK(); } -MSRStatus ShardWriter::SetPageSize(const uint64_t &page_size) { +Status ShardWriter::SetPageSize(const uint64_t &page_size) { // PageSize [32KB, 256MB] - if (page_size < kMinPageSize || page_size > kMaxPageSize) { - MS_LOG(ERROR) << "Page size should between 16KB and 256MB."; - return FAILED; - } - if (page_size % 4 != 0) { - MS_LOG(ERROR) << "Page size should be divided by four."; - return FAILED; - } + CHECK_FAIL_RETURN_UNEXPECTED(page_size >= kMinPageSize && page_size <= kMaxPageSize, + "Page size should between 16KB and 256MB."); + CHECK_FAIL_RETURN_UNEXPECTED(page_size % 4 == 0, "Page size should be divided by four."); page_size_ = page_size; - return SUCCESS; + return Status::OK(); } void ShardWriter::DeleteErrorData(std::map> &raw_data, @@ -348,17 +257,16 @@ void ShardWriter::PopulateMutexErrorData(const int &row, const std::string &mess (void)err_raw_data.insert(std::make_pair(row, message)); } -MSRStatus ShardWriter::CheckDataTypeAndValue(const std::string &key, const json &value, const json &data, const int &i, - std::map &err_raw_data) { +Status ShardWriter::CheckDataTypeAndValue(const std::string &key, const json &value, const json &data, const int &i, + std::map &err_raw_data) { auto data_type = std::string(value["type"].get()); - if ((data_type == "int32" && !data[key].is_number_integer()) || (data_type == "int64" && !data[key].is_number_integer()) || (data_type == "float32" && !data[key].is_number_float()) || (data_type == "float64" && !data[key].is_number_float()) || (data_type == "string" && !data[key].is_string())) { std::string message = "field: " + key + " type : " + data_type + " value: " + data[key].dump() + " is not matched"; PopulateMutexErrorData(i, message, err_raw_data); - return FAILED; + RETURN_STATUS_UNEXPECTED(message); } if (data_type == "int32" && data[key].is_number_integer()) { @@ -368,10 +276,10 @@ MSRStatus ShardWriter::CheckDataTypeAndValue(const std::string &key, const json std::string message = "field: " + key + " type : " + data_type + " value: " + data[key].dump() + " is out of range"; PopulateMutexErrorData(i, message, err_raw_data); - return FAILED; + RETURN_STATUS_UNEXPECTED(message); } } - return SUCCESS; + return Status::OK(); } void ShardWriter::CheckSliceData(int start_row, int end_row, json schema, const std::vector &sub_raw_data, @@ -396,14 +304,14 @@ void ShardWriter::CheckSliceData(int start_row, int end_row, json schema, const continue; } - if (CheckDataTypeAndValue(key, value, data, i, err_raw_data) != SUCCESS) { + if (CheckDataTypeAndValue(key, value, data, i, err_raw_data).IsError()) { break; } } } } -MSRStatus ShardWriter::CheckData(const std::map> &raw_data) { +Status ShardWriter::CheckData(const std::map> &raw_data) { auto rawdata_iter = raw_data.begin(); // make sure rawdata match schema @@ -411,12 +319,10 @@ MSRStatus ShardWriter::CheckData(const std::map> &ra // used for storing error std::map sub_err_mg; int schema_id = rawdata_iter->first; - auto result = shard_header_->GetSchemaByID(schema_id); - if (result.second != SUCCESS) { - return FAILED; - } - json schema = result.first->GetSchema()["schema"]; - for (const auto &field : result.first->GetBlobFields()) { + std::shared_ptr schema_ptr; + RETURN_IF_NOT_OK(shard_header_->GetSchemaByID(schema_id, &schema_ptr)); + json schema = schema_ptr->GetSchema()["schema"]; + for (const auto &field : schema_ptr->GetBlobFields()) { (void)schema.erase(field); } std::vector sub_raw_data = rawdata_iter->second; @@ -424,9 +330,7 @@ MSRStatus ShardWriter::CheckData(const std::map> &ra // calculate start position and end position for each thread int batch_size = rawdata_iter->second.size() / shard_count_; int thread_num = shard_count_; - if (thread_num <= 0) { - return FAILED; - } + CHECK_FAIL_RETURN_UNEXPECTED(thread_num > 0, "Invalid thread number."); if (thread_num > kMaxThreadCount) { thread_num = kMaxThreadCount; } @@ -445,9 +349,7 @@ MSRStatus ShardWriter::CheckData(const std::map> &ra thread_set[x] = std::thread(&ShardWriter::CheckSliceData, this, start_row, end_row, schema, std::ref(sub_raw_data), std::ref(sub_err_mg)); } - if (thread_num > kMaxThreadCount) { - return FAILED; - } + CHECK_FAIL_RETURN_UNEXPECTED(thread_num <= kMaxThreadCount, "Invalid thread number."); // Wait for threads done for (int x = 0; x < thread_num; ++x) { thread_set[x].join(); @@ -455,18 +357,16 @@ MSRStatus ShardWriter::CheckData(const std::map> &ra (void)err_mg_.insert(std::make_pair(schema_id, sub_err_mg)); } - return SUCCESS; + return Status::OK(); } -std::tuple ShardWriter::ValidateRawData(std::map> &raw_data, - std::vector> &blob_data, bool sign) { +Status ShardWriter::ValidateRawData(std::map> &raw_data, + std::vector> &blob_data, bool sign, + std::shared_ptr> *count_ptr) { + RETURN_UNEXPECTED_IF_NULL(count_ptr); auto rawdata_iter = raw_data.begin(); schema_count_ = raw_data.size(); - std::tuple failed(FAILED, 0, 0); - if (schema_count_ == 0) { - MS_LOG(ERROR) << "Data size is zero"; - return failed; - } + CHECK_FAIL_RETURN_UNEXPECTED(schema_count_ > 0, "Data size is not positive."); // keep schema_id std::set schema_ids; @@ -474,52 +374,38 @@ std::tuple ShardWriter::ValidateRawData(std::mapGetSchemas().size() != schema_count_) { - MS_LOG(ERROR) << "Data size is not equal with the schema size"; - return failed; - } - + CHECK_FAIL_RETURN_UNEXPECTED(shard_header_->GetSchemas().size() == schema_count_, + "Data size is not equal with the schema size"); // Determine raw_data size == blob_data size - if (raw_data[0].size() != blob_data.size()) { - MS_LOG(ERROR) << "Raw data size is not equal blob data size"; - return failed; - } + CHECK_FAIL_RETURN_UNEXPECTED(raw_data[0].size() == blob_data.size(), "Raw data size is not equal blob data size"); // Determine whether the number of samples corresponding to each schema is the same for (rawdata_iter = raw_data.begin(); rawdata_iter != raw_data.end(); ++rawdata_iter) { - if (row_count_ != rawdata_iter->second.size()) { - MS_LOG(ERROR) << "Data size is not equal"; - return failed; - } + CHECK_FAIL_RETURN_UNEXPECTED(row_count_ == rawdata_iter->second.size(), "Data size is not equal"); (void)schema_ids.insert(rawdata_iter->first); } const std::vector> &schemas = shard_header_->GetSchemas(); - if (std::any_of(schemas.begin(), schemas.end(), [schema_ids](const std::shared_ptr &schema) { - return schema_ids.find(schema->GetSchemaID()) == schema_ids.end(); - })) { - // There is not enough data which is not matching the number of schema - MS_LOG(ERROR) << "Input rawdata schema id do not match real schema id."; - return failed; - } - + // There is not enough data which is not matching the number of schema + CHECK_FAIL_RETURN_UNEXPECTED(!std::any_of(schemas.begin(), schemas.end(), + [schema_ids](const std::shared_ptr &schema) { + return schema_ids.find(schema->GetSchemaID()) == schema_ids.end(); + }), + "Input rawdata schema id do not match real schema id."); if (!sign) { - std::tuple success(SUCCESS, schema_count_, row_count_); - return success; + *count_ptr = std::make_shared>(schema_count_, row_count_); + return Status::OK(); } // check the data according the schema - if (CheckData(raw_data) != SUCCESS) { - MS_LOG(ERROR) << "Data validate check failed"; - return std::tuple(FAILED, schema_count_, row_count_); - } + RETURN_IF_NOT_OK(CheckData(raw_data)); // delete wrong data from raw data DeleteErrorData(raw_data, blob_data); // update raw count row_count_ = row_count_ - err_mg_.begin()->second.size(); - std::tuple success(SUCCESS, schema_count_, row_count_); - return success; + *count_ptr = std::make_shared>(schema_count_, row_count_); + return Status::OK(); } void ShardWriter::FillArray(int start, int end, std::map> &raw_data, @@ -544,22 +430,23 @@ void ShardWriter::FillArray(int start, int end, std::map> } } -int ShardWriter::LockWriter(bool parallel_writer) { +Status ShardWriter::LockWriter(bool parallel_writer, std::unique_ptr *fd_ptr) { if (!parallel_writer) { - return 0; + *fd_ptr = std::make_unique(0); + return Status::OK(); } #if defined(_WIN32) || defined(_WIN64) MS_LOG(DEBUG) << "Lock file done by python."; const int fd = 0; + #else const int fd = open(lock_file_.c_str(), O_WRONLY | O_CREAT, 0666); if (fd >= 0) { flock(fd, LOCK_EX); } else { - MS_LOG(ERROR) << "Shard writer failed when locking file"; close(fd); - return -1; + RETURN_STATUS_UNEXPECTED("Shard writer failed when locking file."); } #endif @@ -568,62 +455,50 @@ int ShardWriter::LockWriter(bool parallel_writer) { for (const auto &file : file_paths_) { auto realpath = Common::GetRealPath(file); if (!realpath.has_value()) { - MS_LOG(ERROR) << "Get real path failed, path=" << file; close(fd); - return -1; + RETURN_STATUS_UNEXPECTED("Get real path failed, path=" + file); } - std::shared_ptr fs = std::make_shared(); fs->open(realpath.value(), std::ios::in | std::ios::out | std::ios::binary); if (fs->fail()) { - MS_LOG(ERROR) << "Invalid file, failed to open file: " << file; close(fd); - return -1; + RETURN_STATUS_UNEXPECTED("Invalid file, failed to open file: " + file); } file_streams_.push_back(fs); } - - if (shard_header_->FileToPages(pages_file_) == FAILED) { - MS_LOG(ERROR) << "Invalid data, failed to read pages from file."; + auto status = shard_header_->FileToPages(pages_file_); + if (status.IsError()) { close(fd); - return -1; + RETURN_STATUS_UNEXPECTED("Error in FileToPages."); } - return fd; + *fd_ptr = std::make_unique(fd); + return Status::OK(); } -MSRStatus ShardWriter::UnlockWriter(int fd, bool parallel_writer) { +Status ShardWriter::UnlockWriter(int fd, bool parallel_writer) { if (!parallel_writer) { - return SUCCESS; - } - - if (shard_header_->PagesToFile(pages_file_) == FAILED) { - MS_LOG(ERROR) << "Write pages to file failed"; - return FAILED; + return Status::OK(); } - + RETURN_IF_NOT_OK(shard_header_->PagesToFile(pages_file_)); for (int i = static_cast(file_streams_.size()) - 1; i >= 0; i--) { file_streams_[i]->close(); } - #if defined(_WIN32) || defined(_WIN64) MS_LOG(DEBUG) << "Unlock file done by python."; #else flock(fd, LOCK_UN); close(fd); #endif - return SUCCESS; + return Status::OK(); } -MSRStatus ShardWriter::WriteRawDataPreCheck(std::map> &raw_data, - std::vector> &blob_data, bool sign, int *schema_count, - int *row_count) { +Status ShardWriter::WriteRawDataPreCheck(std::map> &raw_data, + std::vector> &blob_data, bool sign, int *schema_count, + int *row_count) { // check the free disk size - auto st_space = GetDiskSize(file_paths_[0], kFreeSize); - if (st_space.first != SUCCESS || st_space.second < kMinFreeDiskSize) { - MS_LOG(ERROR) << "IO error / there is no free disk to be used"; - return FAILED; - } - + std::shared_ptr size_ptr; + RETURN_IF_NOT_OK(GetDiskSize(file_paths_[0], kFreeSize, &size_ptr)); + CHECK_FAIL_RETURN_UNEXPECTED(*size_ptr >= kMinFreeDiskSize, "IO error / there is no free disk to be used"); // compress blob if (shard_column_->CheckCompressBlob()) { for (auto &blob : blob_data) { @@ -642,21 +517,17 @@ MSRStatus ShardWriter::WriteRawDataPreCheck(std::map if (blob_data.size() > 0 && raw_data.size() == 0) { raw_data.insert(std::pair>(0, std::vector(blob_data.size(), kDummyId))); } - - auto v = ValidateRawData(raw_data, blob_data, sign); - if (std::get<0>(v) == FAILED) { - MS_LOG(ERROR) << "Validate raw data failed"; - return FAILED; - } - *schema_count = std::get<1>(v); - *row_count = std::get<2>(v); - return SUCCESS; + std::shared_ptr> count_ptr; + RETURN_IF_NOT_OK(ValidateRawData(raw_data, blob_data, sign, &count_ptr)); + *schema_count = (*count_ptr).first; + *row_count = (*count_ptr).second; + return Status::OK(); } -MSRStatus ShardWriter::MergeBlobData(const std::vector &blob_fields, - const std::map>> &row_bin_data, - std::shared_ptr> *output) { +Status ShardWriter::MergeBlobData(const std::vector &blob_fields, + const std::map>> &row_bin_data, + std::shared_ptr> *output) { if (blob_fields.empty()) { - return SUCCESS; + return Status::OK(); } if (blob_fields.size() == 1) { auto &blob = row_bin_data.at(blob_fields[0]); @@ -686,71 +557,44 @@ MSRStatus ShardWriter::MergeBlobData(const std::vector &blob_fields, idx += b->size(); } } - return SUCCESS; + return Status::OK(); } -MSRStatus ShardWriter::WriteRawData(std::map> &raw_data, - std::vector> &blob_data, bool sign, bool parallel_writer) { +Status ShardWriter::WriteRawData(std::map> &raw_data, + std::vector> &blob_data, bool sign, bool parallel_writer) { // Lock Writer if loading data parallel - int fd = LockWriter(parallel_writer); - if (fd < 0) { - MS_LOG(ERROR) << "Lock writer failed"; - return FAILED; - } + std::unique_ptr fd_ptr; + RETURN_IF_NOT_OK(LockWriter(parallel_writer, &fd_ptr)); // Get the count of schemas and rows int schema_count = 0; int row_count = 0; // Serialize raw data - if (WriteRawDataPreCheck(raw_data, blob_data, sign, &schema_count, &row_count) == FAILED) { - MS_LOG(ERROR) << "Check raw data failed"; - return FAILED; - } - + RETURN_IF_NOT_OK(WriteRawDataPreCheck(raw_data, blob_data, sign, &schema_count, &row_count)); + CHECK_FAIL_RETURN_UNEXPECTED(row_count >= kInt0, "Raw data size is not positive."); if (row_count == kInt0) { - MS_LOG(INFO) << "Raw data size is 0."; - return SUCCESS; + return Status::OK(); } - std::vector> bin_raw_data(row_count * schema_count); - // Serialize raw data - if (SerializeRawData(raw_data, bin_raw_data, row_count) == FAILED) { - MS_LOG(ERROR) << "Serialize raw data failed"; - return FAILED; - } - + RETURN_IF_NOT_OK(SerializeRawData(raw_data, bin_raw_data, row_count)); // Set row size of raw data - if (SetRawDataSize(bin_raw_data) == FAILED) { - MS_LOG(ERROR) << "Set raw data size failed"; - return FAILED; - } - + RETURN_IF_NOT_OK(SetRawDataSize(bin_raw_data)); // Set row size of blob data - if (SetBlobDataSize(blob_data) == FAILED) { - MS_LOG(ERROR) << "Set blob data size failed"; - return FAILED; - } - + RETURN_IF_NOT_OK(SetBlobDataSize(blob_data)); // Write data to disk with multi threads - if (ParallelWriteData(blob_data, bin_raw_data) == FAILED) { - MS_LOG(ERROR) << "Parallel write data failed"; - return FAILED; - } + RETURN_IF_NOT_OK(ParallelWriteData(blob_data, bin_raw_data)); MS_LOG(INFO) << "Write " << bin_raw_data.size() << " records successfully."; - if (UnlockWriter(fd, parallel_writer) == FAILED) { - MS_LOG(ERROR) << "Unlock writer failed"; - return FAILED; - } + RETURN_IF_NOT_OK(UnlockWriter(*fd_ptr, parallel_writer)); - return SUCCESS; + return Status::OK(); } -MSRStatus ShardWriter::WriteRawData(std::map> &raw_data, - std::map> &blob_data, bool sign, - bool parallel_writer) { +Status ShardWriter::WriteRawData(std::map> &raw_data, + std::map> &blob_data, bool sign, + bool parallel_writer) { std::map> raw_data_json; std::map> blob_data_json; @@ -780,35 +624,16 @@ MSRStatus ShardWriter::WriteRawData(std::map> std::vector> bin_blob_data(row_count * schema_count); // Serialize blob data - if (SerializeRawData(blob_data_json, bin_blob_data, row_count) == FAILED) { - MS_LOG(ERROR) << "Serialize raw data failed in write raw data"; - return FAILED; - } + RETURN_IF_NOT_OK(SerializeRawData(blob_data_json, bin_blob_data, row_count)); return WriteRawData(raw_data_json, bin_blob_data, sign, parallel_writer); } -MSRStatus ShardWriter::WriteRawData(std::map> &raw_data, - vector> &blob_data, bool sign, bool parallel_writer) { - std::map> raw_data_json; - (void)std::transform(raw_data.begin(), raw_data.end(), std::inserter(raw_data_json, raw_data_json.end()), - [](const std::pair> &pair) { - auto &py_raw_data = pair.second; - std::vector json_raw_data; - (void)std::transform(py_raw_data.begin(), py_raw_data.end(), std::back_inserter(json_raw_data), - [](const py::handle &obj) { return nlohmann::detail::ToJsonImpl(obj); }); - return std::make_pair(pair.first, std::move(json_raw_data)); - }); - return WriteRawData(raw_data_json, blob_data, sign, parallel_writer); -} - -MSRStatus ShardWriter::ParallelWriteData(const std::vector> &blob_data, - const std::vector> &bin_raw_data) { +Status ShardWriter::ParallelWriteData(const std::vector> &blob_data, + const std::vector> &bin_raw_data) { auto shards = BreakIntoShards(); // define the number of thread int thread_num = static_cast(shard_count_); - if (thread_num < 0) { - return FAILED; - } + CHECK_FAIL_RETURN_UNEXPECTED(thread_num > 0, "Invalid thread number."); if (thread_num > kMaxThreadCount) { thread_num = kMaxThreadCount; } @@ -835,16 +660,16 @@ MSRStatus ShardWriter::ParallelWriteData(const std::vector> current_thread += thread_num; } } - return SUCCESS; + return Status::OK(); } -MSRStatus ShardWriter::WriteByShard(int shard_id, int start_row, int end_row, - const std::vector> &blob_data, - const std::vector> &bin_raw_data) { +Status ShardWriter::WriteByShard(int shard_id, int start_row, int end_row, + const std::vector> &blob_data, + const std::vector> &bin_raw_data) { MS_LOG(DEBUG) << "Shard: " << shard_id << ", start: " << start_row << ", end: " << end_row << ", schema size: " << schema_count_; if (start_row == end_row) { - return SUCCESS; + return Status::OK(); } vector> rows_in_group; std::shared_ptr last_raw_page = nullptr; @@ -852,38 +677,19 @@ MSRStatus ShardWriter::WriteByShard(int shard_id, int start_row, int end_row, SetLastRawPage(shard_id, last_raw_page); SetLastBlobPage(shard_id, last_blob_page); - if (CutRowGroup(start_row, end_row, blob_data, rows_in_group, last_raw_page, last_blob_page) == FAILED) { - MS_LOG(ERROR) << "Cut row group failed"; - return FAILED; - } - - if (AppendBlobPage(shard_id, blob_data, rows_in_group, last_blob_page) == FAILED) { - MS_LOG(ERROR) << "Append bolb page failed"; - return FAILED; - } - - if (NewBlobPage(shard_id, blob_data, rows_in_group, last_blob_page) == FAILED) { - MS_LOG(ERROR) << "New blob page failed"; - return FAILED; - } - - if (ShiftRawPage(shard_id, rows_in_group, last_raw_page) == FAILED) { - MS_LOG(ERROR) << "Shit raw page failed"; - return FAILED; - } - - if (WriteRawPage(shard_id, rows_in_group, last_raw_page, bin_raw_data) == FAILED) { - MS_LOG(ERROR) << "Write raw page failed"; - return FAILED; - } + RETURN_IF_NOT_OK(CutRowGroup(start_row, end_row, blob_data, rows_in_group, last_raw_page, last_blob_page)); + RETURN_IF_NOT_OK(AppendBlobPage(shard_id, blob_data, rows_in_group, last_blob_page)); + RETURN_IF_NOT_OK(NewBlobPage(shard_id, blob_data, rows_in_group, last_blob_page)); + RETURN_IF_NOT_OK(ShiftRawPage(shard_id, rows_in_group, last_raw_page)); + RETURN_IF_NOT_OK(WriteRawPage(shard_id, rows_in_group, last_raw_page, bin_raw_data)); - return SUCCESS; + return Status::OK(); } -MSRStatus ShardWriter::CutRowGroup(int start_row, int end_row, const std::vector> &blob_data, - std::vector> &rows_in_group, - const std::shared_ptr &last_raw_page, - const std::shared_ptr &last_blob_page) { +Status ShardWriter::CutRowGroup(int start_row, int end_row, const std::vector> &blob_data, + std::vector> &rows_in_group, + const std::shared_ptr &last_raw_page, + const std::shared_ptr &last_blob_page) { auto n_byte_blob = last_blob_page ? last_blob_page->GetPageSize() : 0; auto last_raw_page_size = last_raw_page ? last_raw_page->GetPageSize() : 0; @@ -891,12 +697,11 @@ MSRStatus ShardWriter::CutRowGroup(int start_row, int end_row, const std::vector auto n_byte_raw = last_raw_page_size - last_raw_offset; int page_start_row = start_row; - if (start_row > end_row) { - return FAILED; - } - if (end_row > static_cast(blob_data_size_.size()) || end_row > static_cast(raw_data_size_.size())) { - return FAILED; - } + CHECK_FAIL_RETURN_UNEXPECTED(start_row <= end_row, "Invalid start row."); + + CHECK_FAIL_RETURN_UNEXPECTED( + end_row <= static_cast(blob_data_size_.size()) && end_row <= static_cast(raw_data_size_.size()), + "Invalid end row."); for (int i = start_row; i < end_row; ++i) { // n_byte_blob(0) indicate appendBlobPage if (n_byte_blob == 0 || n_byte_blob + blob_data_size_[i] > page_size_ || @@ -913,23 +718,23 @@ MSRStatus ShardWriter::CutRowGroup(int start_row, int end_row, const std::vector // Not forget last one rows_in_group.emplace_back(page_start_row, end_row); - return SUCCESS; + return Status::OK(); } -MSRStatus ShardWriter::AppendBlobPage(const int &shard_id, const std::vector> &blob_data, - const std::vector> &rows_in_group, - const std::shared_ptr &last_blob_page) { +Status ShardWriter::AppendBlobPage(const int &shard_id, const std::vector> &blob_data, + const std::vector> &rows_in_group, + const std::shared_ptr &last_blob_page) { auto blob_row = rows_in_group[0]; - if (blob_row.first == blob_row.second) return SUCCESS; - + if (blob_row.first == blob_row.second) { + return Status::OK(); + } // Write disk auto page_id = last_blob_page->GetPageID(); auto bytes_page = last_blob_page->GetPageSize(); auto &io_seekp = file_streams_[shard_id]->seekp(page_size_ * page_id + header_size_ + bytes_page, std::ios::beg); if (!io_seekp.good() || io_seekp.fail() || io_seekp.bad()) { - MS_LOG(ERROR) << "File seekp failed"; file_streams_[shard_id]->close(); - return FAILED; + RETURN_STATUS_UNEXPECTED("Failed to seekg file."); } (void)FlushBlobChunk(file_streams_[shard_id], blob_data, blob_row); @@ -940,12 +745,12 @@ MSRStatus ShardWriter::AppendBlobPage(const int &shard_id, const std::vectorGetEndRowID() + blob_row.second - blob_row.first; last_blob_page->SetEndRowID(end_row); (void)shard_header_->SetPage(last_blob_page); - return SUCCESS; + return Status::OK(); } -MSRStatus ShardWriter::NewBlobPage(const int &shard_id, const std::vector> &blob_data, - const std::vector> &rows_in_group, - const std::shared_ptr &last_blob_page) { +Status ShardWriter::NewBlobPage(const int &shard_id, const std::vector> &blob_data, + const std::vector> &rows_in_group, + const std::shared_ptr &last_blob_page) { auto page_id = shard_header_->GetLastPageId(shard_id); auto page_type_id = last_blob_page ? last_blob_page->GetPageTypeID() : -1; auto current_row = last_blob_page ? last_blob_page->GetEndRowID() : 0; @@ -956,9 +761,8 @@ MSRStatus ShardWriter::NewBlobPage(const int &shard_id, const std::vectorseekp(page_size_ * (page_id + 1) + header_size_, std::ios::beg); if (!io_seekp.good() || io_seekp.fail() || io_seekp.bad()) { - MS_LOG(ERROR) << "File seekp failed"; file_streams_[shard_id]->close(); - return FAILED; + RETURN_STATUS_UNEXPECTED("Failed to seekg file."); } (void)FlushBlobChunk(file_streams_[shard_id], blob_data, blob_row); @@ -972,18 +776,20 @@ MSRStatus ShardWriter::NewBlobPage(const int &shard_id, const std::vectorAddPage(std::make_shared(page)); current_row = end_row; } - return SUCCESS; + return Status::OK(); } -MSRStatus ShardWriter::ShiftRawPage(const int &shard_id, const std::vector> &rows_in_group, - std::shared_ptr &last_raw_page) { +Status ShardWriter::ShiftRawPage(const int &shard_id, const std::vector> &rows_in_group, + std::shared_ptr &last_raw_page) { auto blob_row = rows_in_group[0]; - if (blob_row.first == blob_row.second) return SUCCESS; + if (blob_row.first == blob_row.second) { + return Status::OK(); + } auto last_raw_page_size = last_raw_page ? last_raw_page->GetPageSize() : 0; if (std::accumulate(raw_data_size_.begin() + blob_row.first, raw_data_size_.begin() + blob_row.second, 0) + last_raw_page_size <= page_size_) { - return SUCCESS; + return Status::OK(); } auto page_id = shard_header_->GetLastPageId(shard_id); auto last_row_group_id_offset = last_raw_page->GetLastRowGroupID().second; @@ -993,38 +799,32 @@ MSRStatus ShardWriter::ShiftRawPage(const int &shard_id, const std::vector buf(shift_size); // Read last row group from previous raw data page - if (shard_id < 0 || shard_id >= file_streams_.size()) { - return FAILED; - } + CHECK_FAIL_RETURN_UNEXPECTED(shard_id >= 0 && shard_id < file_streams_.size(), "Invalid shard id"); auto &io_seekg = file_streams_[shard_id]->seekg( page_size_ * last_raw_page_id + header_size_ + last_row_group_id_offset, std::ios::beg); if (!io_seekg.good() || io_seekg.fail() || io_seekg.bad()) { - MS_LOG(ERROR) << "File seekg failed"; file_streams_[shard_id]->close(); - return FAILED; + RETURN_STATUS_UNEXPECTED("Failed to seekg file."); } auto &io_read = file_streams_[shard_id]->read(reinterpret_cast(&buf[0]), buf.size()); if (!io_read.good() || io_read.fail() || io_read.bad()) { - MS_LOG(ERROR) << "File read failed"; file_streams_[shard_id]->close(); - return FAILED; + RETURN_STATUS_UNEXPECTED("Failed to read file."); } // Merge into new row group at new raw data page auto &io_seekp = file_streams_[shard_id]->seekp(page_size_ * (page_id + 1) + header_size_, std::ios::beg); if (!io_seekp.good() || io_seekp.fail() || io_seekp.bad()) { - MS_LOG(ERROR) << "File seekp failed"; file_streams_[shard_id]->close(); - return FAILED; + RETURN_STATUS_UNEXPECTED("Failed to seekg file."); } auto &io_handle = file_streams_[shard_id]->write(reinterpret_cast(&buf[0]), buf.size()); if (!io_handle.good() || io_handle.fail() || io_handle.bad()) { - MS_LOG(ERROR) << "File write failed"; file_streams_[shard_id]->close(); - return FAILED; + RETURN_STATUS_UNEXPECTED("Failed to write file."); } last_raw_page->DeleteLastGroupId(); (void)shard_header_->SetPage(last_raw_page); @@ -1039,44 +839,45 @@ MSRStatus ShardWriter::ShiftRawPage(const int &shard_id, const std::vector> &rows_in_group, - std::shared_ptr &last_raw_page, - const std::vector> &bin_raw_data) { +Status ShardWriter::WriteRawPage(const int &shard_id, const std::vector> &rows_in_group, + std::shared_ptr &last_raw_page, + const std::vector> &bin_raw_data) { int last_row_group_id = last_raw_page ? last_raw_page->GetLastRowGroupID().first : -1; for (uint32_t i = 0; i < rows_in_group.size(); ++i) { const auto &blob_row = rows_in_group[i]; - if (blob_row.first == blob_row.second) continue; + if (blob_row.first == blob_row.second) { + continue; + } auto raw_size = std::accumulate(raw_data_size_.begin() + blob_row.first, raw_data_size_.begin() + blob_row.second, 0); if (!last_raw_page) { - EmptyRawPage(shard_id, last_raw_page); + RETURN_IF_NOT_OK(EmptyRawPage(shard_id, last_raw_page)); } else if (last_raw_page->GetPageSize() + raw_size > page_size_) { - (void)shard_header_->SetPage(last_raw_page); - EmptyRawPage(shard_id, last_raw_page); - } - if (AppendRawPage(shard_id, rows_in_group, i, last_row_group_id, last_raw_page, bin_raw_data) != SUCCESS) { - return FAILED; + RETURN_IF_NOT_OK(shard_header_->SetPage(last_raw_page)); + RETURN_IF_NOT_OK(EmptyRawPage(shard_id, last_raw_page)); } + RETURN_IF_NOT_OK(AppendRawPage(shard_id, rows_in_group, i, last_row_group_id, last_raw_page, bin_raw_data)); } - (void)shard_header_->SetPage(last_raw_page); - return SUCCESS; + RETURN_IF_NOT_OK(shard_header_->SetPage(last_raw_page)); + return Status::OK(); } -void ShardWriter::EmptyRawPage(const int &shard_id, std::shared_ptr &last_raw_page) { +Status ShardWriter::EmptyRawPage(const int &shard_id, std::shared_ptr &last_raw_page) { auto row_group_ids = std::vector>(); auto page_id = shard_header_->GetLastPageId(shard_id); auto page_type_id = last_raw_page ? last_raw_page->GetPageID() : -1; auto page = Page(++page_id, shard_id, kPageTypeRaw, ++page_type_id, 0, 0, row_group_ids, 0); - (void)shard_header_->AddPage(std::make_shared(page)); + RETURN_IF_NOT_OK(shard_header_->AddPage(std::make_shared(page))); SetLastRawPage(shard_id, last_raw_page); + return Status::OK(); } -MSRStatus ShardWriter::AppendRawPage(const int &shard_id, const std::vector> &rows_in_group, - const int &chunk_id, int &last_row_group_id, std::shared_ptr last_raw_page, - const std::vector> &bin_raw_data) { +Status ShardWriter::AppendRawPage(const int &shard_id, const std::vector> &rows_in_group, + const int &chunk_id, int &last_row_group_id, std::shared_ptr last_raw_page, + const std::vector> &bin_raw_data) { std::vector> row_group_ids = last_raw_page->GetRowGroupIds(); auto last_raw_page_id = last_raw_page->GetPageID(); auto n_bytes = last_raw_page->GetPageSize(); @@ -1085,67 +886,62 @@ MSRStatus ShardWriter::AppendRawPage(const int &shard_id, const std::vectorseekp(page_size_ * last_raw_page_id + header_size_ + n_bytes, std::ios::beg); if (!io_seekp.good() || io_seekp.fail() || io_seekp.bad()) { - MS_LOG(ERROR) << "File seekp failed"; file_streams_[shard_id]->close(); - return FAILED; + RETURN_STATUS_UNEXPECTED("Failed to seekg file."); } - if (chunk_id > 0) row_group_ids.emplace_back(++last_row_group_id, n_bytes); + if (chunk_id > 0) { + row_group_ids.emplace_back(++last_row_group_id, n_bytes); + } n_bytes += std::accumulate(raw_data_size_.begin() + rows_in_group[chunk_id].first, raw_data_size_.begin() + rows_in_group[chunk_id].second, 0); - (void)FlushRawChunk(file_streams_[shard_id], rows_in_group, chunk_id, bin_raw_data); + RETURN_IF_NOT_OK(FlushRawChunk(file_streams_[shard_id], rows_in_group, chunk_id, bin_raw_data)); // Update previous raw data page last_raw_page->SetPageSize(n_bytes); last_raw_page->SetRowGroupIds(row_group_ids); - (void)shard_header_->SetPage(last_raw_page); + RETURN_IF_NOT_OK(shard_header_->SetPage(last_raw_page)); - return SUCCESS; + return Status::OK(); } -MSRStatus ShardWriter::FlushBlobChunk(const std::shared_ptr &out, - const std::vector> &blob_data, - const std::pair &blob_row) { - if (blob_row.first > blob_row.second) { - return FAILED; - } - if (blob_row.second > static_cast(blob_data.size()) || blob_row.first < 0) { - return FAILED; - } +Status ShardWriter::FlushBlobChunk(const std::shared_ptr &out, + const std::vector> &blob_data, + const std::pair &blob_row) { + CHECK_FAIL_RETURN_UNEXPECTED( + blob_row.first <= blob_row.second && blob_row.second <= static_cast(blob_data.size()) && blob_row.first >= 0, + "Invalid blob row"); for (int j = blob_row.first; j < blob_row.second; ++j) { // Write the size of blob uint64_t line_len = blob_data[j].size(); auto &io_handle = out->write(reinterpret_cast(&line_len), kInt64Len); if (!io_handle.good() || io_handle.fail() || io_handle.bad()) { - MS_LOG(ERROR) << "File write failed"; out->close(); - return FAILED; + RETURN_STATUS_UNEXPECTED("Failed to write file."); } // Write the data of blob auto line = blob_data[j]; auto &io_handle_data = out->write(reinterpret_cast(&line[0]), line_len); if (!io_handle_data.good() || io_handle_data.fail() || io_handle_data.bad()) { - MS_LOG(ERROR) << "File write failed"; out->close(); - return FAILED; + RETURN_STATUS_UNEXPECTED("Failed to write file."); } } - return SUCCESS; + return Status::OK(); } -MSRStatus ShardWriter::FlushRawChunk(const std::shared_ptr &out, - const std::vector> &rows_in_group, const int &chunk_id, - const std::vector> &bin_raw_data) { +Status ShardWriter::FlushRawChunk(const std::shared_ptr &out, + const std::vector> &rows_in_group, const int &chunk_id, + const std::vector> &bin_raw_data) { for (int i = rows_in_group[chunk_id].first; i < rows_in_group[chunk_id].second; i++) { // Write the size of multi schemas for (uint32_t j = 0; j < schema_count_; ++j) { uint64_t line_len = bin_raw_data[i * schema_count_ + j].size(); auto &io_handle = out->write(reinterpret_cast(&line_len), kInt64Len); if (!io_handle.good() || io_handle.fail() || io_handle.bad()) { - MS_LOG(ERROR) << "File write failed"; out->close(); - return FAILED; + RETURN_STATUS_UNEXPECTED("Failed to write file."); } } // Write the data of multi schemas @@ -1153,13 +949,12 @@ MSRStatus ShardWriter::FlushRawChunk(const std::shared_ptr &out, auto line = bin_raw_data[i * schema_count_ + j]; auto &io_handle = out->write(reinterpret_cast(&line[0]), line.size()); if (!io_handle.good() || io_handle.fail() || io_handle.bad()) { - MS_LOG(ERROR) << "File write failed"; out->close(); - return FAILED; + RETURN_STATUS_UNEXPECTED("Failed to write file."); } } } - return SUCCESS; + return Status::OK(); } // Allocate data to shards evenly @@ -1187,62 +982,55 @@ std::vector> ShardWriter::BreakIntoShards() { return shards; } -MSRStatus ShardWriter::WriteShardHeader() { - if (shard_header_ == nullptr) { - MS_LOG(ERROR) << "Shard header is null"; - return FAILED; - } - +Status ShardWriter::WriteShardHeader() { + RETURN_UNEXPECTED_IF_NULL(shard_header_); int64_t compression_temp = compression_size_; uint64_t compression_size = compression_temp > 0 ? compression_temp : 0; shard_header_->SetCompressionSize(compression_size); auto shard_header = shard_header_->SerializeHeader(); // Write header data to multi files - if (shard_count_ > static_cast(file_streams_.size()) || shard_count_ > static_cast(shard_header.size())) { - return FAILED; - } + CHECK_FAIL_RETURN_UNEXPECTED( + shard_count_ <= static_cast(file_streams_.size()) && shard_count_ <= static_cast(shard_header.size()), + "Invalid shard count"); if (shard_count_ <= kMaxShardCount) { for (int shard_id = 0; shard_id < shard_count_; ++shard_id) { auto &io_seekp = file_streams_[shard_id]->seekp(0, std::ios::beg); if (!io_seekp.good() || io_seekp.fail() || io_seekp.bad()) { - MS_LOG(ERROR) << "File seekp failed"; file_streams_[shard_id]->close(); - return FAILED; + RETURN_STATUS_UNEXPECTED("Failed to seekp file."); } std::vector bin_header(shard_header[shard_id].begin(), shard_header[shard_id].end()); uint64_t line_len = bin_header.size(); if (line_len + kInt64Len > header_size_) { - MS_LOG(ERROR) << "Shard header is too big"; file_streams_[shard_id]->close(); - return FAILED; + RETURN_STATUS_UNEXPECTED("shard header is too big."); } - auto &io_handle = file_streams_[shard_id]->write(reinterpret_cast(&line_len), kInt64Len); if (!io_handle.good() || io_handle.fail() || io_handle.bad()) { - MS_LOG(ERROR) << "File write failed"; file_streams_[shard_id]->close(); - return FAILED; + RETURN_STATUS_UNEXPECTED("Failed to write file."); } auto &io_handle_header = file_streams_[shard_id]->write(reinterpret_cast(&bin_header[0]), line_len); if (!io_handle_header.good() || io_handle_header.fail() || io_handle_header.bad()) { - MS_LOG(ERROR) << "File write failed"; file_streams_[shard_id]->close(); - return FAILED; + RETURN_STATUS_UNEXPECTED("Failed to write file."); } file_streams_[shard_id]->close(); } } - return SUCCESS; + return Status::OK(); } -MSRStatus ShardWriter::SerializeRawData(std::map> &raw_data, - std::vector> &bin_data, uint32_t row_count) { +Status ShardWriter::SerializeRawData(std::map> &raw_data, + std::vector> &bin_data, uint32_t row_count) { // define the number of thread uint32_t thread_num = std::thread::hardware_concurrency(); - if (thread_num == 0) thread_num = kThreadNumber; + if (thread_num == 0) { + thread_num = kThreadNumber; + } // Set the number of samples processed by each thread int group_num = ceil(row_count * 1.0 / thread_num); std::vector thread_set(thread_num); @@ -1262,66 +1050,54 @@ MSRStatus ShardWriter::SerializeRawData(std::map> &r // Set obstacles to prevent the main thread from running thread_set[x].join(); } - return flag_ == true ? FAILED : SUCCESS; + CHECK_FAIL_RETURN_SYNTAX_ERROR(flag_ != true, "Error in FailArray"); + return Status::OK(); } -MSRStatus ShardWriter::SetRawDataSize(const std::vector> &bin_raw_data) { +Status ShardWriter::SetRawDataSize(const std::vector> &bin_raw_data) { raw_data_size_ = std::vector(row_count_, 0); for (uint32_t i = 0; i < row_count_; ++i) { raw_data_size_[i] = std::accumulate( bin_raw_data.begin() + (i * schema_count_), bin_raw_data.begin() + (i * schema_count_) + schema_count_, 0, [](uint64_t accumulator, const std::vector &row) { return accumulator + kInt64Len + row.size(); }); } - if (*std::max_element(raw_data_size_.begin(), raw_data_size_.end()) > page_size_) { - MS_LOG(ERROR) << "Page size is too small to save a row!"; - return FAILED; - } - return SUCCESS; + CHECK_FAIL_RETURN_SYNTAX_ERROR(*std::max_element(raw_data_size_.begin(), raw_data_size_.end()) <= page_size_, + "Page size is too small to save a row!"); + return Status::OK(); } -MSRStatus ShardWriter::SetBlobDataSize(const std::vector> &blob_data) { +Status ShardWriter::SetBlobDataSize(const std::vector> &blob_data) { blob_data_size_ = std::vector(row_count_); (void)std::transform(blob_data.begin(), blob_data.end(), blob_data_size_.begin(), [](const std::vector &row) { return kInt64Len + row.size(); }); - if (*std::max_element(blob_data_size_.begin(), blob_data_size_.end()) > page_size_) { - MS_LOG(ERROR) << "Page size is too small to save a row!"; - return FAILED; - } - return SUCCESS; + CHECK_FAIL_RETURN_SYNTAX_ERROR(*std::max_element(blob_data_size_.begin(), blob_data_size_.end()) <= page_size_, + "Page size is too small to save a row!"); + return Status::OK(); } -void ShardWriter::SetLastRawPage(const int &shard_id, std::shared_ptr &last_raw_page) { +Status ShardWriter::SetLastRawPage(const int &shard_id, std::shared_ptr &last_raw_page) { // Get last raw page auto last_raw_page_id = shard_header_->GetLastPageIdByType(shard_id, kPageTypeRaw); - if (last_raw_page_id >= 0) { - auto page = shard_header_->GetPage(shard_id, last_raw_page_id); - last_raw_page = page.first; - } + CHECK_FAIL_RETURN_SYNTAX_ERROR(last_raw_page_id >= 0, "Invalid last_raw_page_id."); + RETURN_IF_NOT_OK(shard_header_->GetPage(shard_id, last_raw_page_id, &last_raw_page)); + return Status::OK(); } -void ShardWriter::SetLastBlobPage(const int &shard_id, std::shared_ptr &last_blob_page) { +Status ShardWriter::SetLastBlobPage(const int &shard_id, std::shared_ptr &last_blob_page) { // Get last blob page auto last_blob_page_id = shard_header_->GetLastPageIdByType(shard_id, kPageTypeBlob); - if (last_blob_page_id >= 0) { - auto page = shard_header_->GetPage(shard_id, last_blob_page_id); - last_blob_page = page.first; - } + CHECK_FAIL_RETURN_SYNTAX_ERROR(last_blob_page_id >= 0, "Invalid last_blob_page_id."); + RETURN_IF_NOT_OK(shard_header_->GetPage(shard_id, last_blob_page_id, &last_blob_page)); + return Status::OK(); } -MSRStatus ShardWriter::Initialize(const std::unique_ptr *writer_ptr, - const std::vector &file_names) { - if (writer_ptr == nullptr) { - MS_LOG(ERROR) << "ShardWriter pointer is NULL."; - return FAILED; - } - auto res = (*writer_ptr)->Open(file_names, false); - if (SUCCESS != res) { - MS_LOG(ERROR) << "Failed to open mindrecord files to writer."; - return FAILED; - } - (*writer_ptr)->SetHeaderSize(kDefaultHeaderSize); - (*writer_ptr)->SetPageSize(kDefaultPageSize); - return SUCCESS; +Status ShardWriter::Initialize(const std::unique_ptr *writer_ptr, + const std::vector &file_names) { + RETURN_UNEXPECTED_IF_NULL(writer_ptr); + RETURN_IF_NOT_OK((*writer_ptr)->Open(file_names, false)); + RETURN_IF_NOT_OK((*writer_ptr)->SetHeaderSize(kDefaultHeaderSize)); + RETURN_IF_NOT_OK((*writer_ptr)->SetPageSize(kDefaultPageSize)); + return Status::OK(); } } // namespace mindrecord } // namespace mindspore diff --git a/mindspore/ccsrc/minddata/mindrecord/meta/shard_category.cc b/mindspore/ccsrc/minddata/mindrecord/meta/shard_category.cc index 7f2a479655..462023d09b 100644 --- a/mindspore/ccsrc/minddata/mindrecord/meta/shard_category.cc +++ b/mindspore/ccsrc/minddata/mindrecord/meta/shard_category.cc @@ -34,7 +34,7 @@ ShardCategory::ShardCategory(const std::string &category_field, int64_t num_elem num_categories_(num_categories), replacement_(replacement) {} -MSRStatus ShardCategory::Execute(ShardTaskList &tasks) { return SUCCESS; } +Status ShardCategory::Execute(ShardTaskList &tasks) { return Status::OK(); } int64_t ShardCategory::GetNumSamples(int64_t dataset_size, int64_t num_classes) { if (dataset_size == 0) return dataset_size; diff --git a/mindspore/ccsrc/minddata/mindrecord/meta/shard_column.cc b/mindspore/ccsrc/minddata/mindrecord/meta/shard_column.cc index 25d0463dd8..245c607dd6 100644 --- a/mindspore/ccsrc/minddata/mindrecord/meta/shard_column.cc +++ b/mindspore/ccsrc/minddata/mindrecord/meta/shard_column.cc @@ -72,36 +72,36 @@ void ShardColumn::Init(const json &schema_json, bool compress_integer) { num_blob_column_ = blob_column_.size(); } -std::pair ShardColumn::GetColumnTypeByName(const std::string &column_name, - ColumnDataType *column_data_type, - uint64_t *column_data_type_size, - std::vector *column_shape) { +Status ShardColumn::GetColumnTypeByName(const std::string &column_name, ColumnDataType *column_data_type, + uint64_t *column_data_type_size, std::vector *column_shape, + ColumnCategory *column_category) { + RETURN_UNEXPECTED_IF_NULL(column_data_type); + RETURN_UNEXPECTED_IF_NULL(column_data_type_size); + RETURN_UNEXPECTED_IF_NULL(column_shape); + RETURN_UNEXPECTED_IF_NULL(column_category); // Skip if column not found - auto column_category = CheckColumnName(column_name); - if (column_category == ColumnNotFound) { - return {FAILED, ColumnNotFound}; - } + *column_category = CheckColumnName(column_name); + CHECK_FAIL_RETURN_UNEXPECTED(*column_category != ColumnNotFound, "Invalid column category."); // Get data type and size auto column_id = column_name_id_[column_name]; *column_data_type = column_data_type_[column_id]; *column_data_type_size = ColumnDataTypeSize[*column_data_type]; *column_shape = column_shape_[column_id]; - - return {SUCCESS, column_category}; + return Status::OK(); } -MSRStatus ShardColumn::GetColumnValueByName(const std::string &column_name, const std::vector &columns_blob, - const json &columns_json, const unsigned char **data, - std::unique_ptr *data_ptr, uint64_t *const n_bytes, - ColumnDataType *column_data_type, uint64_t *column_data_type_size, - std::vector *column_shape) { +Status ShardColumn::GetColumnValueByName(const std::string &column_name, const std::vector &columns_blob, + const json &columns_json, const unsigned char **data, + std::unique_ptr *data_ptr, uint64_t *const n_bytes, + ColumnDataType *column_data_type, uint64_t *column_data_type_size, + std::vector *column_shape) { + RETURN_UNEXPECTED_IF_NULL(column_data_type); + RETURN_UNEXPECTED_IF_NULL(column_data_type_size); + RETURN_UNEXPECTED_IF_NULL(column_shape); // Skip if column not found auto column_category = CheckColumnName(column_name); - if (column_category == ColumnNotFound) { - return FAILED; - } - + CHECK_FAIL_RETURN_UNEXPECTED(column_category != ColumnNotFound, "Invalid column category."); // Get data type and size auto column_id = column_name_id_[column_name]; *column_data_type = column_data_type_[column_id]; @@ -110,37 +110,31 @@ MSRStatus ShardColumn::GetColumnValueByName(const std::string &column_name, cons // Retrieve value from json if (column_category == ColumnInRaw) { - if (GetColumnFromJson(column_name, columns_json, data_ptr, n_bytes) == FAILED) { - MS_LOG(ERROR) << "Error when get data from json, column name is " << column_name << "."; - return FAILED; - } + RETURN_IF_NOT_OK(GetColumnFromJson(column_name, columns_json, data_ptr, n_bytes)); *data = reinterpret_cast(data_ptr->get()); - return SUCCESS; + return Status::OK(); } // Retrieve value from blob - if (GetColumnFromBlob(column_name, columns_blob, data, data_ptr, n_bytes) == FAILED) { - MS_LOG(ERROR) << "Error when get data from blob, column name is " << column_name << "."; - return FAILED; - } + RETURN_IF_NOT_OK(GetColumnFromBlob(column_name, columns_blob, data, data_ptr, n_bytes)); if (*data == nullptr) { *data = reinterpret_cast(data_ptr->get()); } - return SUCCESS; + return Status::OK(); } -MSRStatus ShardColumn::GetColumnFromJson(const std::string &column_name, const json &columns_json, - std::unique_ptr *data_ptr, uint64_t *n_bytes) { +Status ShardColumn::GetColumnFromJson(const std::string &column_name, const json &columns_json, + std::unique_ptr *data_ptr, uint64_t *n_bytes) { + RETURN_UNEXPECTED_IF_NULL(n_bytes); + RETURN_UNEXPECTED_IF_NULL(data_ptr); auto column_id = column_name_id_[column_name]; auto column_data_type = column_data_type_[column_id]; // Initialize num bytes *n_bytes = ColumnDataTypeSize[column_data_type]; auto json_column_value = columns_json[column_name]; - if (!json_column_value.is_string() && !json_column_value.is_number()) { - MS_LOG(ERROR) << "Conversion failed (" << json_column_value << ")."; - return FAILED; - } + CHECK_FAIL_RETURN_UNEXPECTED(json_column_value.is_string() || json_column_value.is_number(), + "Conversion to string or number failed (" + json_column_value.dump() + ")."); switch (column_data_type) { case ColumnFloat32: { return GetFloat(data_ptr, json_column_value, false); @@ -171,12 +165,13 @@ MSRStatus ShardColumn::GetColumnFromJson(const std::string &column_name, const j break; } } - return SUCCESS; + return Status::OK(); } template -MSRStatus ShardColumn::GetFloat(std::unique_ptr *data_ptr, const json &json_column_value, - bool use_double) { +Status ShardColumn::GetFloat(std::unique_ptr *data_ptr, const json &json_column_value, + bool use_double) { + RETURN_UNEXPECTED_IF_NULL(data_ptr); std::unique_ptr array_data = std::make_unique(1); if (json_column_value.is_number()) { array_data[0] = json_column_value; @@ -189,8 +184,7 @@ MSRStatus ShardColumn::GetFloat(std::unique_ptr *data_ptr, cons array_data[0] = json_column_value.get(); } } catch (json::exception &e) { - MS_LOG(ERROR) << "Conversion to float failed (" << json_column_value << ")."; - return FAILED; + RETURN_STATUS_UNEXPECTED("Conversion to float failed (" + json_column_value.dump() + ")."); } } @@ -199,54 +193,43 @@ MSRStatus ShardColumn::GetFloat(std::unique_ptr *data_ptr, cons for (uint32_t i = 0; i < sizeof(T); i++) { (*data_ptr)[i] = *(data + i); } - - return SUCCESS; + return Status::OK(); } template -MSRStatus ShardColumn::GetInt(std::unique_ptr *data_ptr, const json &json_column_value) { +Status ShardColumn::GetInt(std::unique_ptr *data_ptr, const json &json_column_value) { + RETURN_UNEXPECTED_IF_NULL(data_ptr); std::unique_ptr array_data = std::make_unique(1); int64_t temp_value; bool less_than_zero = false; if (json_column_value.is_number_integer()) { const json json_zero = 0; - if (json_column_value < json_zero) less_than_zero = true; + if (json_column_value < json_zero) { + less_than_zero = true; + } temp_value = json_column_value; } else if (json_column_value.is_string()) { std::string string_value = json_column_value; - - if (!string_value.empty() && string_value[0] == '-') { - try { + try { + if (!string_value.empty() && string_value[0] == '-') { temp_value = std::stoll(string_value); less_than_zero = true; - } catch (std::invalid_argument &e) { - MS_LOG(ERROR) << "Conversion to int failed, invalid argument."; - return FAILED; - } catch (std::out_of_range &e) { - MS_LOG(ERROR) << "Conversion to int failed, out of range."; - return FAILED; - } - } else { - try { + } else { temp_value = static_cast(std::stoull(string_value)); - } catch (std::invalid_argument &e) { - MS_LOG(ERROR) << "Conversion to int failed, invalid argument."; - return FAILED; - } catch (std::out_of_range &e) { - MS_LOG(ERROR) << "Conversion to int failed, out of range."; - return FAILED; } + } catch (std::invalid_argument &e) { + RETURN_STATUS_UNEXPECTED("Conversion to int failed: " + std::string(e.what())); + } catch (std::out_of_range &e) { + RETURN_STATUS_UNEXPECTED("Conversion to int failed: " + std::string(e.what())); } } else { - MS_LOG(ERROR) << "Conversion to int failed."; - return FAILED; + RETURN_STATUS_UNEXPECTED("Conversion to int failed."); } if ((less_than_zero && temp_value < static_cast(std::numeric_limits::min())) || (!less_than_zero && static_cast(temp_value) > static_cast(std::numeric_limits::max()))) { - MS_LOG(ERROR) << "Conversion to int failed. Out of range"; - return FAILED; + RETURN_STATUS_UNEXPECTED("Conversion to int failed, out of range."); } array_data[0] = static_cast(temp_value); @@ -255,33 +238,26 @@ MSRStatus ShardColumn::GetInt(std::unique_ptr *data_ptr, const for (uint32_t i = 0; i < sizeof(T); i++) { (*data_ptr)[i] = *(data + i); } - - return SUCCESS; + return Status::OK(); } -MSRStatus ShardColumn::GetColumnFromBlob(const std::string &column_name, const std::vector &columns_blob, - const unsigned char **data, std::unique_ptr *data_ptr, - uint64_t *const n_bytes) { +Status ShardColumn::GetColumnFromBlob(const std::string &column_name, const std::vector &columns_blob, + const unsigned char **data, std::unique_ptr *data_ptr, + uint64_t *const n_bytes) { + RETURN_UNEXPECTED_IF_NULL(data); uint64_t offset_address = 0; auto column_id = column_name_id_[column_name]; - if (GetColumnAddressInBlock(column_id, columns_blob, n_bytes, &offset_address) == FAILED) { - return FAILED; - } - + RETURN_IF_NOT_OK(GetColumnAddressInBlock(column_id, columns_blob, n_bytes, &offset_address)); auto column_data_type = column_data_type_[column_id]; if (has_compress_blob_ && column_data_type == ColumnInt32) { - if (UncompressInt(column_id, data_ptr, columns_blob, n_bytes, offset_address) == FAILED) { - return FAILED; - } + RETURN_IF_NOT_OK(UncompressInt(column_id, data_ptr, columns_blob, n_bytes, offset_address)); } else if (has_compress_blob_ && column_data_type == ColumnInt64) { - if (UncompressInt(column_id, data_ptr, columns_blob, n_bytes, offset_address) == FAILED) { - return FAILED; - } + RETURN_IF_NOT_OK(UncompressInt(column_id, data_ptr, columns_blob, n_bytes, offset_address)); } else { *data = reinterpret_cast(&(columns_blob[offset_address])); } - return SUCCESS; + return Status::OK(); } ColumnCategory ShardColumn::CheckColumnName(const std::string &column_name) { @@ -296,7 +272,9 @@ ColumnCategory ShardColumn::CheckColumnName(const std::string &column_name) { std::vector ShardColumn::CompressBlob(const std::vector &blob, int64_t *compression_size) { // Skip if no compress columns *compression_size = 0; - if (!CheckCompressBlob()) return blob; + if (!CheckCompressBlob()) { + return blob; + } std::vector dst_blob; uint64_t i_src = 0; @@ -380,12 +358,14 @@ vector ShardColumn::CompressInt(const vector &src_bytes, const return dst_bytes; } -MSRStatus ShardColumn::GetColumnAddressInBlock(const uint64_t &column_id, const std::vector &columns_blob, - uint64_t *num_bytes, uint64_t *shift_idx) { +Status ShardColumn::GetColumnAddressInBlock(const uint64_t &column_id, const std::vector &columns_blob, + uint64_t *num_bytes, uint64_t *shift_idx) { + RETURN_UNEXPECTED_IF_NULL(num_bytes); + RETURN_UNEXPECTED_IF_NULL(shift_idx); if (num_blob_column_ == 1) { *num_bytes = columns_blob.size(); *shift_idx = 0; - return SUCCESS; + return Status::OK(); } auto blob_id = blob_column_id_[column_name_[column_id]]; @@ -396,13 +376,14 @@ MSRStatus ShardColumn::GetColumnAddressInBlock(const uint64_t &column_id, const (*shift_idx) += kInt64Len; - return SUCCESS; + return Status::OK(); } template -MSRStatus ShardColumn::UncompressInt(const uint64_t &column_id, std::unique_ptr *const data_ptr, - const std::vector &columns_blob, uint64_t *num_bytes, - uint64_t shift_idx) { +Status ShardColumn::UncompressInt(const uint64_t &column_id, std::unique_ptr *const data_ptr, + const std::vector &columns_blob, uint64_t *num_bytes, uint64_t shift_idx) { + RETURN_UNEXPECTED_IF_NULL(data_ptr); + RETURN_UNEXPECTED_IF_NULL(num_bytes); auto num_elements = BytesBigToUInt64(columns_blob, shift_idx, kInt32Type); *num_bytes = sizeof(T) * num_elements; @@ -421,19 +402,12 @@ MSRStatus ShardColumn::UncompressInt(const uint64_t &column_id, std::unique_ptr< auto data = reinterpret_cast(array_data.get()); *data_ptr = std::make_unique(*num_bytes); - // field is none. for example: numpy is null if (*num_bytes == 0) { - return SUCCESS; + return Status::OK(); } - - int ret_code = memcpy_s(data_ptr->get(), *num_bytes, data, *num_bytes); - if (ret_code != 0) { - MS_LOG(ERROR) << "Failed to copy data!"; - return FAILED; - } - - return SUCCESS; + CHECK_FAIL_RETURN_UNEXPECTED(memcpy_s(data_ptr->get(), *num_bytes, data, *num_bytes) == 0, "Failed to copy data!"); + return Status::OK(); } uint64_t ShardColumn::BytesBigToUInt64(const std::vector &bytes_array, const uint64_t &pos, diff --git a/mindspore/ccsrc/minddata/mindrecord/meta/shard_distributed_sample.cc b/mindspore/ccsrc/minddata/mindrecord/meta/shard_distributed_sample.cc index 93c3c76ecf..7a86e2f402 100644 --- a/mindspore/ccsrc/minddata/mindrecord/meta/shard_distributed_sample.cc +++ b/mindspore/ccsrc/minddata/mindrecord/meta/shard_distributed_sample.cc @@ -55,15 +55,11 @@ int64_t ShardDistributedSample::GetNumSamples(int64_t dataset_size, int64_t num_ return 0; } -MSRStatus ShardDistributedSample::PreExecute(ShardTaskList &tasks) { +Status ShardDistributedSample::PreExecute(ShardTaskList &tasks) { auto total_no = tasks.Size(); if (no_of_padded_samples_ > 0 && first_epoch_) { - if (total_no % denominator_ != 0) { - MS_LOG(ERROR) << "Dataset size plus number of padded samples is not divisible by number of shards. " - << "task size: " << total_no << ", number padded: " << no_of_padded_samples_ - << ", denominator: " << denominator_; - return FAILED; - } + CHECK_FAIL_RETURN_UNEXPECTED(total_no % denominator_ == 0, + "Dataset size plus number of padded samples is not divisible by number of shards."); } if (first_epoch_) { first_epoch_ = false; @@ -74,11 +70,9 @@ MSRStatus ShardDistributedSample::PreExecute(ShardTaskList &tasks) { if (shuffle_ == true) { shuffle_op_->SetShardSampleCount(GetShardSampleCount()); shuffle_op_->UpdateShuffleMode(GetShuffleMode()); - if (SUCCESS != (*shuffle_op_)(tasks)) { - return FAILED; - } + RETURN_IF_NOT_OK((*shuffle_op_)(tasks)); } - return SUCCESS; + return Status::OK(); } } // namespace mindrecord } // namespace mindspore diff --git a/mindspore/ccsrc/minddata/mindrecord/meta/shard_header.cc b/mindspore/ccsrc/minddata/mindrecord/meta/shard_header.cc index 737b6e93c2..67dc336bfc 100644 --- a/mindspore/ccsrc/minddata/mindrecord/meta/shard_header.cc +++ b/mindspore/ccsrc/minddata/mindrecord/meta/shard_header.cc @@ -38,104 +38,74 @@ ShardHeader::ShardHeader() : shard_count_(0), header_size_(0), page_size_(0), co index_ = std::make_shared(); } -MSRStatus ShardHeader::InitializeHeader(const std::vector &headers, bool load_dataset) { +Status ShardHeader::InitializeHeader(const std::vector &headers, bool load_dataset) { shard_count_ = headers.size(); int shard_index = 0; bool first = true; for (const auto &header : headers) { if (first) { first = false; - if (ParseSchema(header["schema"]) != SUCCESS) { - return FAILED; - } - if (ParseIndexFields(header["index_fields"]) != SUCCESS) { - return FAILED; - } - if (ParseStatistics(header["statistics"]) != SUCCESS) { - return FAILED; - } + RETURN_IF_NOT_OK(ParseSchema(header["schema"])); + RETURN_IF_NOT_OK(ParseIndexFields(header["index_fields"])); + RETURN_IF_NOT_OK(ParseStatistics(header["statistics"])); ParseShardAddress(header["shard_addresses"]); header_size_ = header["header_size"].get(); page_size_ = header["page_size"].get(); compression_size_ = header.contains("compression_size") ? header["compression_size"].get() : 0; } - if (SUCCESS != ParsePage(header["page"], shard_index, load_dataset)) { - return FAILED; - } + RETURN_IF_NOT_OK(ParsePage(header["page"], shard_index, load_dataset)); shard_index++; } - return SUCCESS; + return Status::OK(); } -MSRStatus ShardHeader::CheckFileStatus(const std::string &path) { +Status ShardHeader::CheckFileStatus(const std::string &path) { auto realpath = Common::GetRealPath(path); - if (!realpath.has_value()) { - MS_LOG(ERROR) << "Get real path failed, path=" << path; - return FAILED; - } - + CHECK_FAIL_RETURN_UNEXPECTED(realpath.has_value(), "Get real path failed, path: " + path); std::ifstream fin(realpath.value(), std::ios::in | std::ios::binary); - if (!fin) { - MS_LOG(ERROR) << "File does not exist or permission denied. path: " << path; - return FAILED; - } - if (fin.fail()) { - MS_LOG(ERROR) << "Failed to open file. path: " << path; - return FAILED; - } - + CHECK_FAIL_RETURN_UNEXPECTED(fin, "Failed to open file. path: " + path); // fetch file size auto &io_seekg = fin.seekg(0, std::ios::end); if (!io_seekg.good() || io_seekg.fail() || io_seekg.bad()) { fin.close(); - MS_LOG(ERROR) << "File seekg failed. path: " << path; - return FAILED; + RETURN_STATUS_UNEXPECTED("File seekg failed. path: " + path); } size_t file_size = fin.tellg(); if (file_size < kMinFileSize) { fin.close(); - MS_LOG(ERROR) << "Invalid file. path: " << path; - return FAILED; + RETURN_STATUS_UNEXPECTED("Invalid file. path: " + path); } fin.close(); - return SUCCESS; + return Status::OK(); } -std::pair ShardHeader::ValidateHeader(const std::string &path) { - if (CheckFileStatus(path) != SUCCESS) { - return {FAILED, {}}; - } - +Status ShardHeader::ValidateHeader(const std::string &path, std::shared_ptr *header_ptr) { + RETURN_UNEXPECTED_IF_NULL(header_ptr); + RETURN_IF_NOT_OK(CheckFileStatus(path)); // read header size json json_header; std::ifstream fin(common::SafeCStr(path), std::ios::in | std::ios::binary); - if (!fin.is_open()) { - MS_LOG(ERROR) << "File seekg failed. path: " << path; - return {FAILED, json_header}; - } + CHECK_FAIL_RETURN_UNEXPECTED(fin.is_open(), "Failed to open file. path: " + path); uint64_t header_size = 0; auto &io_read = fin.read(reinterpret_cast(&header_size), kInt64Len); if (!io_read.good() || io_read.fail() || io_read.bad()) { - MS_LOG(ERROR) << "File read failed"; fin.close(); - return {FAILED, json_header}; + RETURN_STATUS_UNEXPECTED("File read failed"); } if (header_size > kMaxHeaderSize) { fin.close(); - MS_LOG(ERROR) << "Invalid file content. path: " << path; - return {FAILED, json_header}; + RETURN_STATUS_UNEXPECTED("Invalid file content. path: " + path); } // read header content std::vector header_content(header_size); auto &io_read_content = fin.read(reinterpret_cast(&header_content[0]), header_size); if (!io_read_content.good() || io_read_content.fail() || io_read_content.bad()) { - MS_LOG(ERROR) << "File read failed. path: " << path; fin.close(); - return {FAILED, json_header}; + RETURN_STATUS_UNEXPECTED("File read failed. path: " + path); } fin.close(); @@ -144,34 +114,35 @@ std::pair ShardHeader::ValidateHeader(const std::string &path) try { json_header = json::parse(raw_header_content); } catch (json::parse_error &e) { - MS_LOG(ERROR) << "Json parse error: " << e.what(); - return {FAILED, json_header}; + RETURN_STATUS_UNEXPECTED("Json parse error: " + std::string(e.what())); } - return {SUCCESS, json_header}; + *header_ptr = std::make_shared(json_header); + return Status::OK(); } -std::pair ShardHeader::BuildSingleHeader(const std::string &file_path) { - auto ret = ValidateHeader(file_path); - if (SUCCESS != ret.first) { - return {FAILED, json()}; - } - json raw_header = ret.second; +Status ShardHeader::BuildSingleHeader(const std::string &file_path, std::shared_ptr *header_ptr) { + RETURN_UNEXPECTED_IF_NULL(header_ptr); + std::shared_ptr raw_header; + RETURN_IF_NOT_OK(ValidateHeader(file_path, &raw_header)); uint64_t compression_size = - raw_header.contains("compression_size") ? raw_header["compression_size"].get() : 0; - json header = {{"shard_addresses", raw_header["shard_addresses"]}, - {"header_size", raw_header["header_size"]}, - {"page_size", raw_header["page_size"]}, + raw_header->contains("compression_size") ? (*raw_header)["compression_size"].get() : 0; + json header = {{"shard_addresses", (*raw_header)["shard_addresses"]}, + {"header_size", (*raw_header)["header_size"]}, + {"page_size", (*raw_header)["page_size"]}, {"compression_size", compression_size}, - {"index_fields", raw_header["index_fields"]}, - {"blob_fields", raw_header["schema"][0]["blob_fields"]}, - {"schema", raw_header["schema"][0]["schema"]}, - {"version", raw_header["version"]}}; - return {SUCCESS, header}; + {"index_fields", (*raw_header)["index_fields"]}, + {"blob_fields", (*raw_header)["schema"][0]["blob_fields"]}, + {"schema", (*raw_header)["schema"][0]["schema"]}, + {"version", (*raw_header)["version"]}}; + *header_ptr = std::make_shared(header); + return Status::OK(); } -MSRStatus ShardHeader::BuildDataset(const std::vector &file_paths, bool load_dataset) { +Status ShardHeader::BuildDataset(const std::vector &file_paths, bool load_dataset) { uint32_t thread_num = std::thread::hardware_concurrency(); - if (thread_num == 0) thread_num = kThreadNumber; + if (thread_num == 0) { + thread_num = kThreadNumber; + } uint32_t work_thread_num = 0; uint32_t shard_count = file_paths.size(); int group_num = ceil(shard_count * 1.0 / thread_num); @@ -194,12 +165,10 @@ MSRStatus ShardHeader::BuildDataset(const std::vector &file_paths, } if (thread_status) { thread_status = false; - return FAILED; + RETURN_STATUS_UNEXPECTED("Error occurred in GetHeadersOneTask thread."); } - if (SUCCESS != InitializeHeader(headers, load_dataset)) { - return FAILED; - } - return SUCCESS; + RETURN_IF_NOT_OK(InitializeHeader(headers, load_dataset)); + return Status::OK(); } void ShardHeader::GetHeadersOneTask(int start, int end, std::vector &headers, @@ -208,48 +177,39 @@ void ShardHeader::GetHeadersOneTask(int start, int end, std::vector &heade return; } for (int x = start; x < end; ++x) { - auto ret = ValidateHeader(realAddresses[x]); - if (SUCCESS != ret.first) { + std::shared_ptr header; + if (ValidateHeader(realAddresses[x], &header).IsError()) { thread_status = true; return; } - json header; - header = ret.second; - header["shard_addresses"] = realAddresses; - if (std::find(kSupportedVersion.begin(), kSupportedVersion.end(), header["version"]) == kSupportedVersion.end()) { - MS_LOG(ERROR) << "Version wrong, file version is: " << header["version"].dump() + (*header)["shard_addresses"] = realAddresses; + if (std::find(kSupportedVersion.begin(), kSupportedVersion.end(), (*header)["version"]) == + kSupportedVersion.end()) { + MS_LOG(ERROR) << "Version wrong, file version is: " << (*header)["version"].dump() << ", lib version is: " << kVersion; thread_status = true; return; } - headers[x] = header; + headers[x] = *header; } } -MSRStatus ShardHeader::InitByFiles(const std::vector &file_paths) { +Status ShardHeader::InitByFiles(const std::vector &file_paths) { std::vector file_names(file_paths.size()); std::transform(file_paths.begin(), file_paths.end(), file_names.begin(), [](std::string fp) -> std::string { - if (GetFileName(fp).first == SUCCESS) { - return GetFileName(fp).second; - } + std::shared_ptr fn; + return GetFileName(fp, &fn).IsOk() ? *fn : ""; }); shard_addresses_ = std::move(file_names); shard_count_ = file_paths.size(); - if (shard_count_ == 0) { - return FAILED; - } - if (shard_count_ <= kMaxShardCount) { - pages_.resize(shard_count_); - } else { - return FAILED; - } - return SUCCESS; + CHECK_FAIL_RETURN_UNEXPECTED(shard_count_ != 0 && (shard_count_ <= kMaxShardCount), + "shard count is invalid. shard count: " + std::to_string(shard_count_)); + pages_.resize(shard_count_); + return Status::OK(); } -void ShardHeader::ParseHeader(const json &header) {} - -MSRStatus ShardHeader::ParseIndexFields(const json &index_fields) { +Status ShardHeader::ParseIndexFields(const json &index_fields) { std::vector> parsed_index_fields; for (auto &index_field : index_fields) { auto schema_id = index_field["schema_id"].get(); @@ -257,18 +217,15 @@ MSRStatus ShardHeader::ParseIndexFields(const json &index_fields) { std::pair parsed_index_field(schema_id, field_name); parsed_index_fields.push_back(parsed_index_field); } - if (!parsed_index_fields.empty() && AddIndexFields(parsed_index_fields) != SUCCESS) { - return FAILED; - } - return SUCCESS; + RETURN_IF_NOT_OK(AddIndexFields(parsed_index_fields)); + return Status::OK(); } -MSRStatus ShardHeader::ParsePage(const json &pages, int shard_index, bool load_dataset) { +Status ShardHeader::ParsePage(const json &pages, int shard_index, bool load_dataset) { // set shard_index when load_dataset is false - if (shard_count_ > kMaxFileCount) { - MS_LOG(ERROR) << "The number of mindrecord files is greater than max value: " << kMaxFileCount; - return FAILED; - } + CHECK_FAIL_RETURN_UNEXPECTED( + shard_count_ <= kMaxFileCount, + "The number of mindrecord files is greater than max value: " + std::to_string(kMaxFileCount)); if (pages_.empty() && shard_count_ <= kMaxFileCount) { pages_.resize(shard_count_); } @@ -295,44 +252,37 @@ MSRStatus ShardHeader::ParsePage(const json &pages, int shard_index, bool load_d pages_[shard_index].push_back(std::move(parsed_page)); } } - return SUCCESS; + return Status::OK(); } -MSRStatus ShardHeader::ParseStatistics(const json &statistics) { +Status ShardHeader::ParseStatistics(const json &statistics) { for (auto &statistic : statistics) { - if (statistic.find("desc") == statistic.end() || statistic.find("statistics") == statistic.end()) { - MS_LOG(ERROR) << "Deserialize statistics failed, statistic: " << statistics.dump(); - return FAILED; - } + CHECK_FAIL_RETURN_UNEXPECTED( + statistic.find("desc") != statistic.end() && statistic.find("statistics") != statistic.end(), + "Deserialize statistics failed, statistic: " + statistics.dump()); std::string statistic_description = statistic["desc"].get(); json statistic_body = statistic["statistics"]; std::shared_ptr parsed_statistic = Statistics::Build(statistic_description, statistic_body); - if (!parsed_statistic) { - return FAILED; - } + RETURN_UNEXPECTED_IF_NULL(parsed_statistic); AddStatistic(parsed_statistic); } - return SUCCESS; + return Status::OK(); } -MSRStatus ShardHeader::ParseSchema(const json &schemas) { +Status ShardHeader::ParseSchema(const json &schemas) { for (auto &schema : schemas) { // change how we get schemaBody once design is finalized - if (schema.find("desc") == schema.end() || schema.find("blob_fields") == schema.end() || - schema.find("schema") == schema.end()) { - MS_LOG(ERROR) << "Deserialize schema failed. schema: " << schema.dump(); - return FAILED; - } + CHECK_FAIL_RETURN_UNEXPECTED(schema.find("desc") != schema.end() && schema.find("blob_fields") != schema.end() && + schema.find("schema") != schema.end(), + "Deserialize schema failed. schema: " + schema.dump()); std::string schema_description = schema["desc"].get(); std::vector blob_fields = schema["blob_fields"].get>(); json schema_body = schema["schema"]; std::shared_ptr parsed_schema = Schema::Build(schema_description, schema_body); - if (!parsed_schema) { - return FAILED; - } + RETURN_UNEXPECTED_IF_NULL(parsed_schema); AddSchema(parsed_schema); } - return SUCCESS; + return Status::OK(); } void ShardHeader::ParseShardAddress(const json &address) { @@ -340,7 +290,7 @@ void ShardHeader::ParseShardAddress(const json &address) { } std::vector ShardHeader::SerializeHeader() { - std::vector header; + std::vector header; auto index = SerializeIndexFields(); auto stats = SerializeStatistics(); auto schema = SerializeSchema(); @@ -406,45 +356,42 @@ std::string ShardHeader::SerializeSchema() { std::string ShardHeader::SerializeShardAddress() { json j; - (void)std::transform(shard_addresses_.begin(), shard_addresses_.end(), std::back_inserter(j), - [](const std::string &addr) { return GetFileName(addr).second; }); + std::shared_ptr fn_ptr; + for (const auto &addr : shard_addresses_) { + (void)GetFileName(addr, &fn_ptr); + j.emplace_back(*fn_ptr); + } return j.dump(); } -std::pair, MSRStatus> ShardHeader::GetPage(const int &shard_id, const int &page_id) { +Status ShardHeader::GetPage(const int &shard_id, const int &page_id, std::shared_ptr *page_ptr) { + RETURN_UNEXPECTED_IF_NULL(page_ptr); if (shard_id < static_cast(pages_.size()) && page_id < static_cast(pages_[shard_id].size())) { - return std::make_pair(pages_[shard_id][page_id], SUCCESS); - } else { - return std::make_pair(nullptr, FAILED); + *page_ptr = pages_[shard_id][page_id]; + return Status::OK(); } + page_ptr = nullptr; + RETURN_STATUS_UNEXPECTED("Failed to Get Page."); } -MSRStatus ShardHeader::SetPage(const std::shared_ptr &new_page) { - if (new_page == nullptr) { - return FAILED; - } +Status ShardHeader::SetPage(const std::shared_ptr &new_page) { int shard_id = new_page->GetShardID(); int page_id = new_page->GetPageID(); if (shard_id < static_cast(pages_.size()) && page_id < static_cast(pages_[shard_id].size())) { pages_[shard_id][page_id] = new_page; - return SUCCESS; - } else { - return FAILED; + return Status::OK(); } + RETURN_STATUS_UNEXPECTED("Failed to Set Page."); } -MSRStatus ShardHeader::AddPage(const std::shared_ptr &new_page) { - if (new_page == nullptr) { - return FAILED; - } +Status ShardHeader::AddPage(const std::shared_ptr &new_page) { int shard_id = new_page->GetShardID(); int page_id = new_page->GetPageID(); if (shard_id < static_cast(pages_.size()) && page_id == static_cast(pages_[shard_id].size())) { pages_[shard_id].push_back(new_page); - return SUCCESS; - } else { - return FAILED; + return Status::OK(); } + RETURN_STATUS_UNEXPECTED("Failed to Add Page."); } int64_t ShardHeader::GetLastPageId(const int &shard_id) { @@ -468,20 +415,18 @@ int ShardHeader::GetLastPageIdByType(const int &shard_id, const std::string &pag return last_page_id; } -const std::pair> ShardHeader::GetPageByGroupId(const int &group_id, - const int &shard_id) { - if (shard_id >= static_cast(pages_.size())) { - MS_LOG(ERROR) << "Shard id is more than sum of shards."; - return {FAILED, nullptr}; - } +Status ShardHeader::GetPageByGroupId(const int &group_id, const int &shard_id, std::shared_ptr *page_ptr) { + RETURN_UNEXPECTED_IF_NULL(page_ptr); + CHECK_FAIL_RETURN_UNEXPECTED(shard_id < static_cast(pages_.size()), "Shard id is more than sum of shards."); for (uint64_t i = pages_[shard_id].size(); i >= 1; i--) { auto page = pages_[shard_id][i - 1]; if (page->GetPageType() == kPageTypeBlob && page->GetPageTypeID() == group_id) { - return {SUCCESS, page}; + *page_ptr = std::make_shared(*page); + return Status::OK(); } } - MS_LOG(ERROR) << "Could not get page by group id " << group_id; - return {FAILED, nullptr}; + page_ptr = nullptr; + RETURN_STATUS_UNEXPECTED("Failed to get page by group id: " + group_id); } int ShardHeader::AddSchema(std::shared_ptr schema) { @@ -524,151 +469,88 @@ std::shared_ptr ShardHeader::InitIndexPtr() { return index; } -MSRStatus ShardHeader::CheckIndexField(const std::string &field, const json &schema) { +Status ShardHeader::CheckIndexField(const std::string &field, const json &schema) { // check field name is or is not valid - if (schema.find(field) == schema.end()) { - MS_LOG(ERROR) << "Schema do not contain the field: " << field << "."; - return FAILED; - } - - if (schema[field]["type"] == "bytes") { - MS_LOG(ERROR) << field << " is bytes type, can not be schema index field."; - return FAILED; - } - - if (schema.find(field) != schema.end() && schema[field].find("shape") != schema[field].end()) { - MS_LOG(ERROR) << field << " array can not be schema index field."; - return FAILED; - } - return SUCCESS; + CHECK_FAIL_RETURN_UNEXPECTED(schema.find(field) != schema.end(), "Filed can not found in schema."); + CHECK_FAIL_RETURN_UNEXPECTED(schema[field]["type"] != "Bytes", "bytes can not be as index field."); + CHECK_FAIL_RETURN_UNEXPECTED(schema.find(field) == schema.end() || schema[field].find("shape") == schema[field].end(), + "array can not be as index field."); + return Status::OK(); } -MSRStatus ShardHeader::AddIndexFields(const std::vector &fields) { +Status ShardHeader::AddIndexFields(const std::vector &fields) { + if (fields.empty()) { + return Status::OK(); + } + CHECK_FAIL_RETURN_UNEXPECTED(!GetSchemas().empty(), "Schema is empty."); // create index Object std::shared_ptr index = InitIndexPtr(); - - if (fields.size() == kInt0) { - MS_LOG(ERROR) << "There are no index fields"; - return FAILED; - } - - if (GetSchemas().empty()) { - MS_LOG(ERROR) << "No schema is set"; - return FAILED; - } - for (const auto &schemaPtr : schema_) { - auto result = GetSchemaByID(schemaPtr->GetSchemaID()); - if (result.second != SUCCESS) { - MS_LOG(ERROR) << "Could not get schema by id."; - return FAILED; - } - - if (result.first == nullptr) { - MS_LOG(ERROR) << "Could not get schema by id."; - return FAILED; - } - - json schema = result.first->GetSchema().at("schema"); - + std::shared_ptr schema_ptr; + RETURN_IF_NOT_OK(GetSchemaByID(schemaPtr->GetSchemaID(), &schema_ptr)); + json schema = schema_ptr->GetSchema().at("schema"); // checkout and add fields for each schema std::set field_set; for (const auto &item : index->GetFields()) { field_set.insert(item.second); } for (const auto &field : fields) { - if (field_set.find(field) != field_set.end()) { - MS_LOG(ERROR) << "Add same index field twice"; - return FAILED; - } - + CHECK_FAIL_RETURN_UNEXPECTED(field_set.find(field) == field_set.end(), "Add same index field twice."); // check field name is or is not valid - if (CheckIndexField(field, schema) == FAILED) { - return FAILED; - } + RETURN_IF_NOT_OK(CheckIndexField(field, schema)); field_set.insert(field); - // add field into index index.get()->AddIndexField(schemaPtr->GetSchemaID(), field); } } - index_ = index; - return SUCCESS; + return Status::OK(); } -MSRStatus ShardHeader::GetAllSchemaID(std::set &bucket_count) { +Status ShardHeader::GetAllSchemaID(std::set &bucket_count) { // get all schema id for (const auto &schema : schema_) { auto bucket_it = bucket_count.find(schema->GetSchemaID()); - if (bucket_it != bucket_count.end()) { - MS_LOG(ERROR) << "Schema duplication"; - return FAILED; - } else { - bucket_count.insert(schema->GetSchemaID()); - } + CHECK_FAIL_RETURN_UNEXPECTED(bucket_it == bucket_count.end(), "Schema duplication."); + bucket_count.insert(schema->GetSchemaID()); } - return SUCCESS; + return Status::OK(); } -MSRStatus ShardHeader::AddIndexFields(std::vector> fields) { +Status ShardHeader::AddIndexFields(std::vector> fields) { + if (fields.empty()) { + return Status::OK(); + } // create index Object std::shared_ptr index = InitIndexPtr(); - - if (fields.size() == kInt0) { - MS_LOG(ERROR) << "There are no index fields"; - return FAILED; - } - // get all schema id std::set bucket_count; - if (GetAllSchemaID(bucket_count) != SUCCESS) { - return FAILED; - } - + RETURN_IF_NOT_OK(GetAllSchemaID(bucket_count)); // check and add fields for each schema std::set> field_set; for (const auto &item : index->GetFields()) { field_set.insert(item); } for (const auto &field : fields) { - if (field_set.find(field) != field_set.end()) { - MS_LOG(ERROR) << "Add same index field twice"; - return FAILED; - } - + CHECK_FAIL_RETURN_UNEXPECTED(field_set.find(field) == field_set.end(), "Add same index field twice."); uint64_t schema_id = field.first; std::string field_name = field.second; // check schemaId is or is not valid - if (bucket_count.find(schema_id) == bucket_count.end()) { - MS_LOG(ERROR) << "Illegal schema id: " << schema_id; - return FAILED; - } - + CHECK_FAIL_RETURN_UNEXPECTED(bucket_count.find(schema_id) != bucket_count.end(), "Invalid schema id: " + schema_id); // check field name is or is not valid - auto result = GetSchemaByID(schema_id); - if (result.second != SUCCESS) { - MS_LOG(ERROR) << "Could not get schema by id."; - return FAILED; - } - json schema = result.first->GetSchema().at("schema"); - if (schema.find(field_name) == schema.end()) { - MS_LOG(ERROR) << "Schema " << schema_id << " do not contain the field: " << field_name; - return FAILED; - } - - if (CheckIndexField(field_name, schema) == FAILED) { - return FAILED; - } - + std::shared_ptr schema_ptr; + RETURN_IF_NOT_OK(GetSchemaByID(schema_id, &schema_ptr)); + json schema = schema_ptr->GetSchema().at("schema"); + CHECK_FAIL_RETURN_UNEXPECTED(schema.find(field_name) != schema.end(), + "Schema " + std::to_string(schema_id) + " do not contain the field: " + field_name); + RETURN_IF_NOT_OK(CheckIndexField(field_name, schema)); field_set.insert(field); - // add field into index - index.get()->AddIndexField(schema_id, field_name); + index->AddIndexField(schema_id, field_name); } index_ = index; - return SUCCESS; + return Status::OK(); } std::string ShardHeader::GetShardAddressByID(int64_t shard_id) { @@ -686,103 +568,71 @@ std::vector> ShardHeader::GetFields() { return std::shared_ptr ShardHeader::GetIndex() { return index_; } -std::pair, MSRStatus> ShardHeader::GetSchemaByID(int64_t schema_id) { +Status ShardHeader::GetSchemaByID(int64_t schema_id, std::shared_ptr *schema_ptr) { + RETURN_UNEXPECTED_IF_NULL(schema_ptr); int64_t schemaSize = schema_.size(); - if (schema_id < 0 || schema_id >= schemaSize) { - MS_LOG(ERROR) << "Illegal schema id"; - return std::make_pair(nullptr, FAILED); - } - return std::make_pair(schema_.at(schema_id), SUCCESS); + CHECK_FAIL_RETURN_UNEXPECTED(schema_id >= 0 && schema_id < schemaSize, "schema id is invalid."); + *schema_ptr = schema_.at(schema_id); + return Status::OK(); } -std::pair, MSRStatus> ShardHeader::GetStatisticByID(int64_t statistic_id) { +Status ShardHeader::GetStatisticByID(int64_t statistic_id, std::shared_ptr *statistics_ptr) { + RETURN_UNEXPECTED_IF_NULL(statistics_ptr); int64_t statistics_size = statistics_.size(); - if (statistic_id < 0 || statistic_id >= statistics_size) { - return std::make_pair(nullptr, FAILED); - } - return std::make_pair(statistics_.at(statistic_id), SUCCESS); + CHECK_FAIL_RETURN_UNEXPECTED(statistic_id >= 0 && statistic_id < statistics_size, "statistic id is invalid."); + *statistics_ptr = statistics_.at(statistic_id); + return Status::OK(); } -MSRStatus ShardHeader::PagesToFile(const std::string dump_file_name) { +Status ShardHeader::PagesToFile(const std::string dump_file_name) { auto realpath = Common::GetRealPath(dump_file_name); - if (!realpath.has_value()) { - MS_LOG(ERROR) << "Get real path failed, path=" << dump_file_name; - return FAILED; - } - + CHECK_FAIL_RETURN_UNEXPECTED(realpath.has_value(), "Get real path failed, path=" + dump_file_name); // write header content to file, dump whatever is in the file before std::ofstream page_out_handle(realpath.value(), std::ios_base::trunc | std::ios_base::out); - if (page_out_handle.fail()) { - MS_LOG(ERROR) << "Failed in opening page file"; - return FAILED; - } - + CHECK_FAIL_RETURN_UNEXPECTED(page_out_handle.good(), "Failed to open page file."); auto pages = SerializePage(); for (const auto &shard_pages : pages) { page_out_handle << shard_pages << "\n"; } - page_out_handle.close(); - return SUCCESS; + return Status::OK(); } -MSRStatus ShardHeader::FileToPages(const std::string dump_file_name) { +Status ShardHeader::FileToPages(const std::string dump_file_name) { for (auto &v : pages_) { // clean pages v.clear(); } - auto realpath = Common::GetRealPath(dump_file_name); - if (!realpath.has_value()) { - MS_LOG(ERROR) << "Get real path failed, path=" << dump_file_name; - return FAILED; - } - + CHECK_FAIL_RETURN_UNEXPECTED(realpath.has_value(), "Get real path failed, path=" + dump_file_name); // attempt to open the file contains the page in json std::ifstream page_in_handle(realpath.value()); - - if (!page_in_handle.good()) { - MS_LOG(INFO) << "No page file exists."; - return SUCCESS; - } - + CHECK_FAIL_RETURN_UNEXPECTED(page_in_handle.good(), "No page file exists."); std::string line; while (std::getline(page_in_handle, line)) { - if (SUCCESS != ParsePage(json::parse(line), -1, true)) { - return FAILED; - } + RETURN_IF_NOT_OK(ParsePage(json::parse(line), -1, true)); } - page_in_handle.close(); - return SUCCESS; + return Status::OK(); } -MSRStatus ShardHeader::Initialize(const std::shared_ptr *header_ptr, const json &schema, - const std::vector &index_fields, std::vector &blob_fields, - uint64_t &schema_id) { - if (header_ptr == nullptr) { - MS_LOG(ERROR) << "ShardHeader pointer is NULL."; - return FAILED; - } +Status ShardHeader::Initialize(const std::shared_ptr *header_ptr, const json &schema, + const std::vector &index_fields, std::vector &blob_fields, + uint64_t &schema_id) { + RETURN_UNEXPECTED_IF_NULL(header_ptr); auto schema_ptr = Schema::Build("mindrecord", schema); - if (schema_ptr == nullptr) { - MS_LOG(ERROR) << "Got unexpected error when building mindrecord schema."; - return FAILED; - } + RETURN_UNEXPECTED_IF_NULL(schema_ptr); schema_id = (*header_ptr)->AddSchema(schema_ptr); // create index std::vector> id_index_fields; if (!index_fields.empty()) { - (void)std::transform(index_fields.begin(), index_fields.end(), std::back_inserter(id_index_fields), - [schema_id](const std::string &el) { return std::make_pair(schema_id, el); }); - if (SUCCESS != (*header_ptr)->AddIndexFields(id_index_fields)) { - MS_LOG(ERROR) << "Got unexpected error when adding mindrecord index."; - return FAILED; - } + (void)transform(index_fields.begin(), index_fields.end(), std::back_inserter(id_index_fields), + [schema_id](const std::string &el) { return std::make_pair(schema_id, el); }); + RETURN_IF_NOT_OK((*header_ptr)->AddIndexFields(id_index_fields)); } auto build_schema_ptr = (*header_ptr)->GetSchemas()[0]; blob_fields = build_schema_ptr->GetBlobFields(); - return SUCCESS; + return Status::OK(); } } // namespace mindrecord } // namespace mindspore diff --git a/mindspore/ccsrc/minddata/mindrecord/meta/shard_pk_sample.cc b/mindspore/ccsrc/minddata/mindrecord/meta/shard_pk_sample.cc index aa01204d29..3974128fd5 100644 --- a/mindspore/ccsrc/minddata/mindrecord/meta/shard_pk_sample.cc +++ b/mindspore/ccsrc/minddata/mindrecord/meta/shard_pk_sample.cc @@ -37,13 +37,11 @@ ShardPkSample::ShardPkSample(const std::string &category_field, int64_t num_elem shuffle_op_ = std::make_shared(seed, kShuffleSample); // do shuffle and replacement } -MSRStatus ShardPkSample::SufExecute(ShardTaskList &tasks) { +Status ShardPkSample::SufExecute(ShardTaskList &tasks) { if (shuffle_ == true) { - if (SUCCESS != (*shuffle_op_)(tasks)) { - return FAILED; - } + RETURN_IF_NOT_OK((*shuffle_op_)(tasks)); } - return SUCCESS; + return Status::OK(); } } // namespace mindrecord } // namespace mindspore diff --git a/mindspore/ccsrc/minddata/mindrecord/meta/shard_sample.cc b/mindspore/ccsrc/minddata/mindrecord/meta/shard_sample.cc index cfad1a1106..94ee4c2092 100644 --- a/mindspore/ccsrc/minddata/mindrecord/meta/shard_sample.cc +++ b/mindspore/ccsrc/minddata/mindrecord/meta/shard_sample.cc @@ -80,7 +80,7 @@ int64_t ShardSample::GetNumSamples(int64_t dataset_size, int64_t num_classes) { return 0; } -MSRStatus ShardSample::UpdateTasks(ShardTaskList &tasks, int taking) { +Status ShardSample::UpdateTasks(ShardTaskList &tasks, int taking) { if (tasks.permutation_.empty()) { ShardTaskList new_tasks; int total_no = static_cast(tasks.sample_ids_.size()); @@ -110,9 +110,7 @@ MSRStatus ShardSample::UpdateTasks(ShardTaskList &tasks, int taking) { ShardTaskList::TaskListSwap(tasks, new_tasks); } else { ShardTaskList new_tasks; - if (taking > static_cast(tasks.sample_ids_.size())) { - return FAILED; - } + CHECK_FAIL_RETURN_UNEXPECTED(taking <= static_cast(tasks.sample_ids_.size()), "taking is out of range."); int total_no = static_cast(tasks.permutation_.size()); int cnt = 0; for (size_t i = partition_id_ * taking; i < (partition_id_ + 1) * taking; i++) { @@ -122,10 +120,10 @@ MSRStatus ShardSample::UpdateTasks(ShardTaskList &tasks, int taking) { } ShardTaskList::TaskListSwap(tasks, new_tasks); } - return SUCCESS; + return Status::OK(); } -MSRStatus ShardSample::Execute(ShardTaskList &tasks) { +Status ShardSample::Execute(ShardTaskList &tasks) { if (offset_ != -1) { int64_t old_v = 0; int num_rows_ = static_cast(tasks.sample_ids_.size()); @@ -146,10 +144,8 @@ MSRStatus ShardSample::Execute(ShardTaskList &tasks) { no_of_samples_ = std::min(no_of_samples_, total_no); taking = no_of_samples_ - no_of_samples_ % no_of_categories; } else if (sampler_type_ == kSubsetRandomSampler || sampler_type_ == kSubsetSampler) { - if (indices_.size() > static_cast(total_no)) { - MS_LOG(ERROR) << "parameter indices's size is greater than dataset size."; - return FAILED; - } + CHECK_FAIL_RETURN_UNEXPECTED(indices_.size() <= static_cast(total_no), + "Parameter indices's size is greater than dataset size."); } else { // constructor TopPercent if (numerator_ > 0 && denominator_ > 0 && numerator_ <= denominator_) { if (numerator_ == 1 && denominator_ > 1) { // sharding @@ -159,20 +155,17 @@ MSRStatus ShardSample::Execute(ShardTaskList &tasks) { taking -= (taking % no_of_categories); } } else { - MS_LOG(ERROR) << "parameter numerator or denominator is illegal"; - return FAILED; + RETURN_STATUS_UNEXPECTED("Parameter numerator or denominator is invalid."); } } return UpdateTasks(tasks, taking); } -MSRStatus ShardSample::SufExecute(ShardTaskList &tasks) { +Status ShardSample::SufExecute(ShardTaskList &tasks) { if (sampler_type_ == kSubsetRandomSampler) { - if (SUCCESS != (*shuffle_op_)(tasks)) { - return FAILED; - } + RETURN_IF_NOT_OK((*shuffle_op_)(tasks)); } - return SUCCESS; + return Status::OK(); } } // namespace mindrecord } // namespace mindspore diff --git a/mindspore/ccsrc/minddata/mindrecord/meta/shard_schema.cc b/mindspore/ccsrc/minddata/mindrecord/meta/shard_schema.cc index b6d37bf836..6cbfe8c2ee 100644 --- a/mindspore/ccsrc/minddata/mindrecord/meta/shard_schema.cc +++ b/mindspore/ccsrc/minddata/mindrecord/meta/shard_schema.cc @@ -38,12 +38,6 @@ std::shared_ptr Schema::Build(std::string desc, const json &schema) { return std::make_shared(object_schema); } -std::shared_ptr Schema::Build(std::string desc, pybind11::handle schema) { - // validate check - json schema_json = nlohmann::detail::ToJsonImpl(schema); - return Build(std::move(desc), schema_json); -} - std::string Schema::GetDesc() const { return desc_; } json Schema::GetSchema() const { @@ -54,12 +48,6 @@ json Schema::GetSchema() const { return str_schema; } -pybind11::object Schema::GetSchemaForPython() const { - json schema_json = GetSchema(); - pybind11::object schema_py = nlohmann::detail::FromJsonImpl(schema_json); - return schema_py; -} - void Schema::SetSchemaID(int64_t id) { schema_id_ = id; } int64_t Schema::GetSchemaID() const { return schema_id_; } diff --git a/mindspore/ccsrc/minddata/mindrecord/meta/shard_sequential_sample.cc b/mindspore/ccsrc/minddata/mindrecord/meta/shard_sequential_sample.cc index c7f57f43cd..67e1ce8368 100644 --- a/mindspore/ccsrc/minddata/mindrecord/meta/shard_sequential_sample.cc +++ b/mindspore/ccsrc/minddata/mindrecord/meta/shard_sequential_sample.cc @@ -38,7 +38,7 @@ int64_t ShardSequentialSample::GetNumSamples(int64_t dataset_size, int64_t num_c return std::min(static_cast(no_of_samples_), dataset_size); } -MSRStatus ShardSequentialSample::Execute(ShardTaskList &tasks) { +Status ShardSequentialSample::Execute(ShardTaskList &tasks) { int64_t taking; int64_t total_no = static_cast(tasks.sample_ids_.size()); if (no_of_samples_ == 0 && (per_ >= -kEpsilon && per_ <= kEpsilon)) { @@ -58,16 +58,15 @@ MSRStatus ShardSequentialSample::Execute(ShardTaskList &tasks) { ShardTaskList::TaskListSwap(tasks, new_tasks); } else { // shuffled ShardTaskList new_tasks; - if (taking > static_cast(tasks.permutation_.size())) { - return FAILED; - } + CHECK_FAIL_RETURN_UNEXPECTED(taking <= static_cast(tasks.permutation_.size()), + "Taking is out of task range."); total_no = static_cast(tasks.permutation_.size()); for (size_t i = offset_; i < taking + offset_; ++i) { new_tasks.AssignTask(tasks, tasks.permutation_[i % total_no]); } ShardTaskList::TaskListSwap(tasks, new_tasks); } - return SUCCESS; + return Status::OK(); } } // namespace mindrecord diff --git a/mindspore/ccsrc/minddata/mindrecord/meta/shard_shuffle.cc b/mindspore/ccsrc/minddata/mindrecord/meta/shard_shuffle.cc index fe472a9cd8..cfd02f793c 100644 --- a/mindspore/ccsrc/minddata/mindrecord/meta/shard_shuffle.cc +++ b/mindspore/ccsrc/minddata/mindrecord/meta/shard_shuffle.cc @@ -42,7 +42,7 @@ int64_t ShardShuffle::GetNumSamples(int64_t dataset_size, int64_t num_classes) { return no_of_samples_ == 0 ? dataset_size : std::min(dataset_size, no_of_samples_); } -MSRStatus ShardShuffle::CategoryShuffle(ShardTaskList &tasks) { +Status ShardShuffle::CategoryShuffle(ShardTaskList &tasks) { uint32_t individual_size = tasks.sample_ids_.size() / tasks.categories; std::vector> new_permutations(tasks.categories, std::vector(individual_size)); for (uint32_t i = 0; i < tasks.categories; i++) { @@ -62,17 +62,14 @@ MSRStatus ShardShuffle::CategoryShuffle(ShardTaskList &tasks) { } ShardTaskList::TaskListSwap(tasks, new_tasks); - return SUCCESS; + return Status::OK(); } -MSRStatus ShardShuffle::ShuffleFiles(ShardTaskList &tasks) { +Status ShardShuffle::ShuffleFiles(ShardTaskList &tasks) { if (no_of_samples_ == 0) { no_of_samples_ = static_cast(tasks.Size()); } - if (no_of_samples_ <= 0) { - MS_LOG(ERROR) << "no_of_samples need to be positive."; - return FAILED; - } + CHECK_FAIL_RETURN_UNEXPECTED(no_of_samples_ > 0, "Parameter no_of_samples need to be positive."); auto shard_sample_cout = GetShardSampleCount(); // shuffle the files index @@ -118,16 +115,14 @@ MSRStatus ShardShuffle::ShuffleFiles(ShardTaskList &tasks) { new_tasks.AssignTask(tasks, tasks.permutation_[i]); } ShardTaskList::TaskListSwap(tasks, new_tasks); + return Status::OK(); } -MSRStatus ShardShuffle::ShuffleInfile(ShardTaskList &tasks) { +Status ShardShuffle::ShuffleInfile(ShardTaskList &tasks) { if (no_of_samples_ == 0) { no_of_samples_ = static_cast(tasks.Size()); } - if (no_of_samples_ <= 0) { - MS_LOG(ERROR) << "no_of_samples need to be positive."; - return FAILED; - } + CHECK_FAIL_RETURN_UNEXPECTED(no_of_samples_ > 0, "Parameter no_of_samples need to be positive."); // reconstruct the permutation in file // -- before -- // file1: [0, 1, 2] @@ -154,13 +149,12 @@ MSRStatus ShardShuffle::ShuffleInfile(ShardTaskList &tasks) { new_tasks.AssignTask(tasks, tasks.permutation_[i]); } ShardTaskList::TaskListSwap(tasks, new_tasks); + return Status::OK(); } -MSRStatus ShardShuffle::Execute(ShardTaskList &tasks) { +Status ShardShuffle::Execute(ShardTaskList &tasks) { if (reshuffle_each_epoch_) shuffle_seed_++; - if (tasks.categories < 1) { - return FAILED; - } + CHECK_FAIL_RETURN_UNEXPECTED(tasks.categories >= 1, "Task category is invalid."); if (shuffle_type_ == kShuffleSample) { // shuffle each sample if (tasks.permutation_.empty() == true) { tasks.MakePerm(); @@ -169,10 +163,7 @@ MSRStatus ShardShuffle::Execute(ShardTaskList &tasks) { if (replacement_ == true) { ShardTaskList new_tasks; if (no_of_samples_ == 0) no_of_samples_ = static_cast(tasks.sample_ids_.size()); - if (no_of_samples_ <= 0) { - MS_LOG(ERROR) << "no_of_samples need to be positive."; - return FAILED; - } + CHECK_FAIL_RETURN_UNEXPECTED(no_of_samples_ > 0, "Parameter no_of_samples need to be positive."); for (uint32_t i = 0; i < no_of_samples_; ++i) { new_tasks.AssignTask(tasks, tasks.GetRandomTaskID()); } @@ -190,20 +181,14 @@ MSRStatus ShardShuffle::Execute(ShardTaskList &tasks) { ShardTaskList::TaskListSwap(tasks, new_tasks); } } else if (GetShuffleMode() == dataset::ShuffleMode::kInfile) { - auto ret = ShuffleInfile(tasks); - if (ret != SUCCESS) { - return ret; - } + RETURN_IF_NOT_OK(ShuffleInfile(tasks)); } else if (GetShuffleMode() == dataset::ShuffleMode::kFiles) { - auto ret = ShuffleFiles(tasks); - if (ret != SUCCESS) { - return ret; - } + RETURN_IF_NOT_OK(ShuffleFiles(tasks)); } } else { // shuffle unit like: (a1, b1, c1),(a2, b2, c2),..., (an, bn, cn) return this->CategoryShuffle(tasks); } - return SUCCESS; + return Status::OK(); } } // namespace mindrecord } // namespace mindspore diff --git a/mindspore/ccsrc/minddata/mindrecord/meta/shard_statistics.cc b/mindspore/ccsrc/minddata/mindrecord/meta/shard_statistics.cc index 7024a2ab06..869466ba54 100644 --- a/mindspore/ccsrc/minddata/mindrecord/meta/shard_statistics.cc +++ b/mindspore/ccsrc/minddata/mindrecord/meta/shard_statistics.cc @@ -35,19 +35,6 @@ std::shared_ptr Statistics::Build(std::string desc, const json &stat return std::make_shared(object_statistics); } -std::shared_ptr Statistics::Build(std::string desc, pybind11::handle statistics) { - // validate check - json statistics_json = nlohmann::detail::ToJsonImpl(statistics); - if (!Validate(statistics_json)) { - return nullptr; - } - Statistics object_statistics; - object_statistics.desc_ = std::move(desc); - object_statistics.statistics_ = statistics_json; - object_statistics.statistics_id_ = -1; - return std::make_shared(object_statistics); -} - std::string Statistics::GetDesc() const { return desc_; } json Statistics::GetStatistics() const { @@ -57,11 +44,6 @@ json Statistics::GetStatistics() const { return str_statistics; } -pybind11::object Statistics::GetStatisticsForPython() const { - json str_statistics = Statistics::GetStatistics(); - return nlohmann::detail::FromJsonImpl(str_statistics); -} - void Statistics::SetStatisticsID(int64_t id) { statistics_id_ = id; } int64_t Statistics::GetStatisticsID() const { return statistics_id_; } diff --git a/mindspore/ccsrc/minddata/mindrecord/meta/shard_task_list.cc b/mindspore/ccsrc/minddata/mindrecord/meta/shard_task_list.cc index 5f52630ccf..9016aa8898 100644 --- a/mindspore/ccsrc/minddata/mindrecord/meta/shard_task_list.cc +++ b/mindspore/ccsrc/minddata/mindrecord/meta/shard_task_list.cc @@ -85,15 +85,9 @@ uint32_t ShardTaskList::SizeOfRows() const { return nRows; } -ShardTask &ShardTaskList::GetTaskByID(size_t id) { - MS_ASSERT(id < task_list_.size()); - return task_list_[id]; -} +ShardTask &ShardTaskList::GetTaskByID(size_t id) { return task_list_[id]; } -int ShardTaskList::GetTaskSampleByID(size_t id) { - MS_ASSERT(id < sample_ids_.size()); - return sample_ids_[id]; -} +int ShardTaskList::GetTaskSampleByID(size_t id) { return sample_ids_[id]; } int ShardTaskList::GetRandomTaskID() { std::mt19937 gen = mindspore::dataset::GetRandomDevice(); diff --git a/mindspore/mindrecord/shardreader.py b/mindspore/mindrecord/shardreader.py index 0714af8374..3cc7692ebc 100644 --- a/mindspore/mindrecord/shardreader.py +++ b/mindspore/mindrecord/shardreader.py @@ -70,7 +70,7 @@ class ShardReader: Raises: MRMLaunchError: If failed to launch worker threads. """ - ret = self._reader.launch(False) + ret = self._reader.launch() if ret != ms.MSRStatus.SUCCESS: logger.error("Failed to launch worker threads.") raise MRMLaunchError diff --git a/mindspore/mindrecord/shardsegment.py b/mindspore/mindrecord/shardsegment.py index bda1b02959..eee5c6b6d4 100644 --- a/mindspore/mindrecord/shardsegment.py +++ b/mindspore/mindrecord/shardsegment.py @@ -19,7 +19,6 @@ import mindspore._c_mindrecord as ms from mindspore import log as logger from .shardutils import populate_data, SUCCESS from .shardheader import ShardHeader -from .common.exceptions import MRMOpenError, MRMFetchCandidateFieldsError, MRMReadCategoryInfoError, MRMFetchDataError __all__ = ['ShardSegment'] @@ -73,15 +72,8 @@ class ShardSegment: Returns: list[str], by which data could be grouped. - Raises: - MRMFetchCandidateFieldsError: If failed to get candidate category fields. """ - ret, fields = self._segment.get_category_fields() - if ret != SUCCESS: - logger.error("Failed to get candidate category fields.") - raise MRMFetchCandidateFieldsError - return fields - + return self._segment.get_category_fields() def set_category_field(self, category_field): """Select one category field to use.""" @@ -94,14 +86,8 @@ class ShardSegment: Returns: str, description fo group information. - Raises: - MRMReadCategoryInfoError: If failed to read category information. """ - ret, category_info = self._segment.read_category_info() - if ret != SUCCESS: - logger.error("Failed to read category information.") - raise MRMReadCategoryInfoError - return category_info + return self._segment.read_category_info() def read_at_page_by_id(self, category_id, page, num_row): """ @@ -116,13 +102,9 @@ class ShardSegment: list[dict] Raises: - MRMFetchDataError: If failed to read by category id. MRMUnsupportedSchemaError: If schema is invalid. """ - ret, data = self._segment.read_at_page_by_id(category_id, page, num_row) - if ret != SUCCESS: - logger.error("Failed to read by category id.") - raise MRMFetchDataError + data = self._segment.read_at_page_by_id(category_id, page, num_row) return [populate_data(raw, blob, self._columns, self._header.blob_fields, self._header.schema) for blob, raw in data] @@ -139,12 +121,8 @@ class ShardSegment: list[dict] Raises: - MRMFetchDataError: If failed to read by category name. MRMUnsupportedSchemaError: If schema is invalid. """ - ret, data = self._segment.read_at_page_by_name(category_name, page, num_row) - if ret != SUCCESS: - logger.error("Failed to read by category name.") - raise MRMFetchDataError + data = self._segment.read_at_page_by_name(category_name, page, num_row) return [populate_data(raw, blob, self._columns, self._header.blob_fields, self._header.schema) for blob, raw in data] diff --git a/tests/ut/cpp/mindrecord/ut_common.cc b/tests/ut/cpp/mindrecord/ut_common.cc index 2d2d69bd54..78bb946d78 100644 --- a/tests/ut/cpp/mindrecord/ut_common.cc +++ b/tests/ut/cpp/mindrecord/ut_common.cc @@ -384,8 +384,8 @@ void ShardWriterImageNetOpenForAppend(string filename) { { MS_LOG(INFO) << "=============== images " << bin_data.size() << " ============================"; mindrecord::ShardWriter fw; - auto ret = fw.OpenForAppend(filename); - if (ret == FAILED) { + auto status = fw.OpenForAppend(filename); + if (status.IsError()) { return; } diff --git a/tests/ut/cpp/mindrecord/ut_shard.cc b/tests/ut/cpp/mindrecord/ut_shard.cc index 11492e9f28..fbf6c8a638 100644 --- a/tests/ut/cpp/mindrecord/ut_shard.cc +++ b/tests/ut/cpp/mindrecord/ut_shard.cc @@ -121,14 +121,19 @@ TEST_F(TestShard, TestShardHeaderPart) { re_statistics.push_back(*statistic); } ASSERT_EQ(re_statistics, validate_statistics); - ASSERT_EQ(header_data.GetStatisticByID(-1).second, FAILED); - ASSERT_EQ(header_data.GetStatisticByID(10).second, FAILED); + std::shared_ptr statistics_ptr; + + auto status = header_data.GetStatisticByID(-1, &statistics_ptr); + EXPECT_FALSE(status.IsOk()); + status = header_data.GetStatisticByID(10, &statistics_ptr); + EXPECT_FALSE(status.IsOk()); // test add index fields std::vector> fields; std::pair pair1(0, "name"); fields.push_back(pair1); - ASSERT_TRUE(header_data.AddIndexFields(fields) == SUCCESS); + status = header_data.AddIndexFields(fields); + EXPECT_TRUE(status.IsOk()); std::vector> resFields = header_data.GetFields(); ASSERT_EQ(resFields, fields); } diff --git a/tests/ut/cpp/mindrecord/ut_shard_header_test.cc b/tests/ut/cpp/mindrecord/ut_shard_header_test.cc index 2ff3d1655d..20c41fa879 100644 --- a/tests/ut/cpp/mindrecord/ut_shard_header_test.cc +++ b/tests/ut/cpp/mindrecord/ut_shard_header_test.cc @@ -79,36 +79,37 @@ TEST_F(TestShardHeader, AddIndexFields) { std::pair index_field2(schema_id1, "box"); fields.push_back(index_field1); fields.push_back(index_field2); - MSRStatus res = header_data.AddIndexFields(fields); - ASSERT_EQ(res, SUCCESS); + Status status = header_data.AddIndexFields(fields); + EXPECT_TRUE(status.IsOk()); + ASSERT_EQ(header_data.GetFields().size(), 2); fields.clear(); std::pair index_field3(schema_id1, "name"); fields.push_back(index_field3); - res = header_data.AddIndexFields(fields); - ASSERT_EQ(res, FAILED); + status = header_data.AddIndexFields(fields); + EXPECT_FALSE(status.IsOk()); ASSERT_EQ(header_data.GetFields().size(), 2); fields.clear(); std::pair index_field4(schema_id1, "names"); fields.push_back(index_field4); - res = header_data.AddIndexFields(fields); - ASSERT_EQ(res, FAILED); + status = header_data.AddIndexFields(fields); + EXPECT_FALSE(status.IsOk()); ASSERT_EQ(header_data.GetFields().size(), 2); fields.clear(); std::pair index_field5(schema_id1 + 1, "name"); fields.push_back(index_field5); - res = header_data.AddIndexFields(fields); - ASSERT_EQ(res, FAILED); + status = header_data.AddIndexFields(fields); + EXPECT_FALSE(status.IsOk()); ASSERT_EQ(header_data.GetFields().size(), 2); fields.clear(); std::pair index_field6(schema_id1, "label"); fields.push_back(index_field6); - res = header_data.AddIndexFields(fields); - ASSERT_EQ(res, FAILED); + status = header_data.AddIndexFields(fields); + EXPECT_FALSE(status.IsOk()); ASSERT_EQ(header_data.GetFields().size(), 2); std::string desc_new = "this is a test1"; @@ -129,26 +130,26 @@ TEST_F(TestShardHeader, AddIndexFields) { single_fields.push_back("name"); single_fields.push_back("name"); single_fields.push_back("box"); - res = header_data_new.AddIndexFields(single_fields); - ASSERT_EQ(res, FAILED); + status = header_data_new.AddIndexFields(single_fields); + EXPECT_FALSE(status.IsOk()); ASSERT_EQ(header_data_new.GetFields().size(), 1); single_fields.push_back("name"); single_fields.push_back("box"); - res = header_data_new.AddIndexFields(single_fields); - ASSERT_EQ(res, FAILED); + status = header_data_new.AddIndexFields(single_fields); + EXPECT_FALSE(status.IsOk()); ASSERT_EQ(header_data_new.GetFields().size(), 1); single_fields.clear(); single_fields.push_back("names"); - res = header_data_new.AddIndexFields(single_fields); - ASSERT_EQ(res, FAILED); + status = header_data_new.AddIndexFields(single_fields); + EXPECT_FALSE(status.IsOk()); ASSERT_EQ(header_data_new.GetFields().size(), 1); single_fields.clear(); single_fields.push_back("box"); - res = header_data_new.AddIndexFields(single_fields); - ASSERT_EQ(res, SUCCESS); + status = header_data_new.AddIndexFields(single_fields); + EXPECT_TRUE(status.IsOk()); ASSERT_EQ(header_data_new.GetFields().size(), 2); } } // namespace mindrecord diff --git a/tests/ut/cpp/mindrecord/ut_shard_reader_test.cc b/tests/ut/cpp/mindrecord/ut_shard_reader_test.cc index a7102ee918..9dff522e04 100644 --- a/tests/ut/cpp/mindrecord/ut_shard_reader_test.cc +++ b/tests/ut/cpp/mindrecord/ut_shard_reader_test.cc @@ -167,8 +167,8 @@ TEST_F(TestShardReader, TestShardReaderColumnNotInIndex) { std::string file_name = "./imagenet.shard01"; auto column_list = std::vector{"label"}; ShardReader dataset; - MSRStatus ret = dataset.Open({file_name}, true, 4, column_list); - ASSERT_EQ(ret, SUCCESS); + auto status = dataset.Open({file_name}, true, 4, column_list); + EXPECT_TRUE(status.IsOk()); dataset.Launch(); while (true) { @@ -188,16 +188,16 @@ TEST_F(TestShardReader, TestShardReaderColumnNotInSchema) { std::string file_name = "./imagenet.shard01"; auto column_list = std::vector{"file_namex"}; ShardReader dataset; - MSRStatus ret = dataset.Open({file_name}, true, 4, column_list); - ASSERT_EQ(ret, ILLEGAL_COLUMN_LIST); + auto status= dataset.Open({file_name}, true, 4, column_list); + EXPECT_FALSE(status.IsOk()); } TEST_F(TestShardReader, TestShardVersion) { MS_LOG(INFO) << FormatInfo("Test shard version"); std::string file_name = "./imagenet.shard01"; ShardReader dataset; - MSRStatus ret = dataset.Open({file_name}, true, 4); - ASSERT_EQ(ret, SUCCESS); + auto status = dataset.Open({file_name}, true, 4); + EXPECT_TRUE(status.IsOk()); dataset.Launch(); while (true) { @@ -219,8 +219,8 @@ TEST_F(TestShardReader, TestShardReaderDir) { auto column_list = std::vector{"file_name"}; ShardReader dataset; - MSRStatus ret = dataset.Open({file_name}, true, 4, column_list); - ASSERT_EQ(ret, FAILED); + auto status = dataset.Open({file_name}, true, 4, column_list); + EXPECT_FALSE(status.IsOk()); } TEST_F(TestShardReader, TestShardReaderConsumer) { diff --git a/tests/ut/cpp/mindrecord/ut_shard_segment_test.cc b/tests/ut/cpp/mindrecord/ut_shard_segment_test.cc index a4900a51f2..a9c112fccf 100644 --- a/tests/ut/cpp/mindrecord/ut_shard_segment_test.cc +++ b/tests/ut/cpp/mindrecord/ut_shard_segment_test.cc @@ -61,35 +61,44 @@ TEST_F(TestShardSegment, TestShardSegment) { ShardSegment dataset; dataset.Open({file_name}, true, 4); - auto x = dataset.GetCategoryFields(); - for (const auto &fields : x.second) { + auto fields_ptr = std::make_shared>(); + auto status = dataset.GetCategoryFields(&fields_ptr); + for (const auto &fields : *fields_ptr) { MS_LOG(INFO) << "Get category field: " << fields; } - ASSERT_TRUE(dataset.SetCategoryField("label") == SUCCESS); - ASSERT_TRUE(dataset.SetCategoryField("laabel_0") == FAILED); + status = dataset.SetCategoryField("label"); + EXPECT_TRUE(status.IsOk()); + status = dataset.SetCategoryField("laabel_0"); + EXPECT_FALSE(status.IsOk()); - MS_LOG(INFO) << "Read category info: " << dataset.ReadCategoryInfo().second; - auto ret = dataset.ReadAtPageByName("822", 0, 10); - auto images = ret.second; - MS_LOG(INFO) << "category field: 822, images count: " << images.size() << ", image[0] size: " << images[0].size(); + std::shared_ptr category_ptr; + status = dataset.ReadCategoryInfo(&category_ptr); + EXPECT_TRUE(status.IsOk()); + MS_LOG(INFO) << "Read category info: " << *category_ptr; - auto ret1 = dataset.ReadAtPageByName("823", 0, 10); - auto images2 = ret1.second; - MS_LOG(INFO) << "category field: 823, images count: " << images2.size(); + auto pages_ptr = std::make_shared>>(); + status = dataset.ReadAtPageByName("822", 0, 10, &pages_ptr); + EXPECT_TRUE(status.IsOk()); + MS_LOG(INFO) << "category field: 822, images count: " << pages_ptr->size() << ", image[0] size: " << ((*pages_ptr)[0]).size(); - auto ret2 = dataset.ReadAtPageById(1, 0, 10); - auto images3 = ret2.second; - MS_LOG(INFO) << "category id: 1, images count: " << images3.size() << ", image[0] size: " << images3[0].size(); + auto pages_ptr_1 = std::make_shared>>(); + status = dataset.ReadAtPageByName("823", 0, 10, &pages_ptr_1); + MS_LOG(INFO) << "category field: 823, images count: " << pages_ptr_1->size(); - auto ret3 = dataset.ReadAllAtPageByName("822", 0, 10); - auto images4 = ret3.second; - MS_LOG(INFO) << "category field: 822, images count: " << images4.size(); + auto pages_ptr_2 = std::make_shared>>(); + status = dataset.ReadAtPageById(1, 0, 10, &pages_ptr_2); + EXPECT_TRUE(status.IsOk()); + MS_LOG(INFO) << "category id: 1, images count: " << pages_ptr_2->size() << ", image[0] size: " << ((*pages_ptr_2)[0]).size(); - auto ret4 = dataset.ReadAllAtPageById(1, 0, 10); - auto images5 = ret4.second; - MS_LOG(INFO) << "category id: 1, images count: " << images5.size(); + auto pages_ptr_3 = std::make_shared(); + status = dataset.ReadAllAtPageByName("822", 0, 10, &pages_ptr_3); + MS_LOG(INFO) << "category field: 822, images count: " << pages_ptr_3->size(); + + auto pages_ptr_4 = std::make_shared(); + status = dataset.ReadAllAtPageById(1, 0, 10, &pages_ptr_4); + MS_LOG(INFO) << "category id: 1, images count: " << pages_ptr_4->size(); } TEST_F(TestShardSegment, TestReadAtPageByNameOfCategoryName) { @@ -99,21 +108,28 @@ TEST_F(TestShardSegment, TestReadAtPageByNameOfCategoryName) { ShardSegment dataset; dataset.Open({file_name}, true, 4); - auto x = dataset.GetCategoryFields(); - for (const auto &fields : x.second) { + auto fields_ptr = std::make_shared>(); + auto status = dataset.GetCategoryFields(&fields_ptr); + for (const auto &fields : *fields_ptr) { MS_LOG(INFO) << "Get category field: " << fields; } string category_name = "82Cus"; string category_field = "laabel_0"; - ASSERT_TRUE(dataset.SetCategoryField("label") == SUCCESS); - ASSERT_TRUE(dataset.SetCategoryField(category_field) == FAILED); + status = dataset.SetCategoryField("label"); + EXPECT_TRUE(status.IsOk()); + status = dataset.SetCategoryField(category_field); + EXPECT_FALSE(status.IsOk()); - MS_LOG(INFO) << "Read category info: " << dataset.ReadCategoryInfo().second; + std::shared_ptr category_ptr; + status = dataset.ReadCategoryInfo(&category_ptr); + EXPECT_TRUE(status.IsOk()); + MS_LOG(INFO) << "Read category info: " << *category_ptr; - auto ret = dataset.ReadAtPageByName(category_name, 0, 10); - EXPECT_TRUE(ret.first == FAILED); + auto pages_ptr = std::make_shared>>(); + status = dataset.ReadAtPageByName(category_name, 0, 10, &pages_ptr); + EXPECT_FALSE(status.IsOk()); } TEST_F(TestShardSegment, TestReadAtPageByIdOfCategoryId) { @@ -123,19 +139,25 @@ TEST_F(TestShardSegment, TestReadAtPageByIdOfCategoryId) { ShardSegment dataset; dataset.Open({file_name}, true, 4); - auto x = dataset.GetCategoryFields(); - for (const auto &fields : x.second) { + auto fields_ptr = std::make_shared>(); + auto status = dataset.GetCategoryFields(&fields_ptr); + for (const auto &fields : *fields_ptr) { MS_LOG(INFO) << "Get category field: " << fields; } int64_t categoryId = 2251799813685247; MS_LOG(INFO) << "Input category id: " << categoryId; - ASSERT_TRUE(dataset.SetCategoryField("label") == SUCCESS); - MS_LOG(INFO) << "Read category info: " << dataset.ReadCategoryInfo().second; + status = dataset.SetCategoryField("label"); + EXPECT_TRUE(status.IsOk()); + std::shared_ptr category_ptr; + status = dataset.ReadCategoryInfo(&category_ptr); + EXPECT_TRUE(status.IsOk()); + MS_LOG(INFO) << "Read category info: " << *category_ptr; - auto ret2 = dataset.ReadAtPageById(categoryId, 0, 10); - EXPECT_TRUE(ret2.first == FAILED); + auto pages_ptr = std::make_shared>>(); + status = dataset.ReadAtPageById(categoryId, 0, 10, &pages_ptr); + EXPECT_FALSE(status.IsOk()); } TEST_F(TestShardSegment, TestReadAtPageByIdOfPageNo) { @@ -145,19 +167,27 @@ TEST_F(TestShardSegment, TestReadAtPageByIdOfPageNo) { ShardSegment dataset; dataset.Open({file_name}, true, 4); - auto x = dataset.GetCategoryFields(); - for (const auto &fields : x.second) { + auto fields_ptr = std::make_shared>(); + auto status = dataset.GetCategoryFields(&fields_ptr); + for (const auto &fields : *fields_ptr) { MS_LOG(INFO) << "Get category field: " << fields; } int64_t page_no = 2251799813685247; MS_LOG(INFO) << "Input page no: " << page_no; - ASSERT_TRUE(dataset.SetCategoryField("label") == SUCCESS); - MS_LOG(INFO) << "Read category info: " << dataset.ReadCategoryInfo().second; + status = dataset.SetCategoryField("label"); + EXPECT_TRUE(status.IsOk()); + - auto ret2 = dataset.ReadAtPageById(1, page_no, 10); - EXPECT_TRUE(ret2.first == FAILED); + std::shared_ptr category_ptr; + status = dataset.ReadCategoryInfo(&category_ptr); + EXPECT_TRUE(status.IsOk()); + MS_LOG(INFO) << "Read category info: " << *category_ptr; + + auto pages_ptr = std::make_shared>>(); + status = dataset.ReadAtPageById(1, page_no, 10, &pages_ptr); + EXPECT_FALSE(status.IsOk()); } TEST_F(TestShardSegment, TestReadAtPageByIdOfPageRows) { @@ -167,19 +197,26 @@ TEST_F(TestShardSegment, TestReadAtPageByIdOfPageRows) { ShardSegment dataset; dataset.Open({file_name}, true, 4); - auto x = dataset.GetCategoryFields(); - for (const auto &fields : x.second) { + auto fields_ptr = std::make_shared>(); + auto status = dataset.GetCategoryFields(&fields_ptr); + for (const auto &fields : *fields_ptr) { MS_LOG(INFO) << "Get category field: " << fields; } int64_t pageRows = 0; MS_LOG(INFO) << "Input page rows: " << pageRows; - ASSERT_TRUE(dataset.SetCategoryField("label") == SUCCESS); - MS_LOG(INFO) << "Read category info: " << dataset.ReadCategoryInfo().second; + status = dataset.SetCategoryField("label"); + EXPECT_TRUE(status.IsOk()); + + std::shared_ptr category_ptr; + status = dataset.ReadCategoryInfo(&category_ptr); + EXPECT_TRUE(status.IsOk()); + MS_LOG(INFO) << "Read category info: " << *category_ptr; - auto ret2 = dataset.ReadAtPageById(1, 0, pageRows); - EXPECT_TRUE(ret2.first == FAILED); + auto pages_ptr = std::make_shared>>(); + status = dataset.ReadAtPageById(1, 0, pageRows, &pages_ptr); + EXPECT_FALSE(status.IsOk()); } } // namespace mindrecord diff --git a/tests/ut/cpp/mindrecord/ut_shard_writer_test.cc b/tests/ut/cpp/mindrecord/ut_shard_writer_test.cc index 031c62f917..c91c5356d4 100644 --- a/tests/ut/cpp/mindrecord/ut_shard_writer_test.cc +++ b/tests/ut/cpp/mindrecord/ut_shard_writer_test.cc @@ -60,8 +60,8 @@ TEST_F(TestShardWriter, TestShardWriterOneSample) { std::string filename = "./OneSample.shard01"; ShardReader dataset; - MSRStatus ret = dataset.Open({filename}, true, 4); - ASSERT_EQ(ret, SUCCESS); + auto status = dataset.Open({filename}, true, 4); + EXPECT_TRUE(status.IsOk()); dataset.Launch(); while (true) { @@ -675,8 +675,8 @@ TEST_F(TestShardWriter, AllRawDataWrong) { fw.SetShardHeader(std::make_shared(header_data)); // write rawdata - MSRStatus res = fw.WriteRawData(rawdatas, bin_data); - ASSERT_EQ(res, SUCCESS); + auto status = fw.WriteRawData(rawdatas, bin_data); + EXPECT_TRUE(status.IsOk()); for (const auto &filename : file_names) { auto filename_db = filename + ".db"; remove(common::SafeCStr(filename_db)); @@ -716,7 +716,8 @@ TEST_F(TestShardWriter, TestShardReaderStringAndNumberColumnInIndex) { fields.push_back(index_field2); // add index to shardHeader - ASSERT_EQ(header_data.AddIndexFields(fields), SUCCESS); + auto status = header_data.AddIndexFields(fields); + EXPECT_TRUE(status.IsOk()); MS_LOG(INFO) << "Init Index Fields Already."; // load meta data @@ -736,28 +737,34 @@ TEST_F(TestShardWriter, TestShardReaderStringAndNumberColumnInIndex) { } mindrecord::ShardWriter fw_init; - ASSERT_TRUE(fw_init.Open(file_names) == SUCCESS); + status = fw_init.Open(file_names); + EXPECT_TRUE(status.IsOk()); // set shardHeader - ASSERT_TRUE(fw_init.SetShardHeader(std::make_shared(header_data)) == SUCCESS); + status = fw_init.SetShardHeader(std::make_shared(header_data)); + EXPECT_TRUE(status.IsOk()); + // write raw data - ASSERT_TRUE(fw_init.WriteRawData(rawdatas, bin_data) == SUCCESS); - ASSERT_TRUE(fw_init.Commit() == SUCCESS); + status = fw_init.WriteRawData(rawdatas, bin_data); + EXPECT_TRUE(status.IsOk()); + status = fw_init.Commit(); + EXPECT_TRUE(status.IsOk()); // create the index file std::string filename = "./imagenet.shard01"; mindrecord::ShardIndexGenerator sg{filename}; sg.Build(); - ASSERT_TRUE(sg.WriteToDatabase() == SUCCESS); + status = sg.WriteToDatabase(); + EXPECT_TRUE(status.IsOk()); MS_LOG(INFO) << "Done create index"; // read the mindrecord file filename = "./imagenet.shard01"; auto column_list = std::vector{"label", "file_name", "data"}; ShardReader dataset; - MSRStatus ret = dataset.Open({filename}, true, 4, column_list); - ASSERT_EQ(ret, SUCCESS); + status = dataset.Open({filename}, true, 4, column_list); + EXPECT_TRUE(status.IsOk()); dataset.Launch(); int count = 0; @@ -822,28 +829,34 @@ TEST_F(TestShardWriter, TestShardNoBlob) { } mindrecord::ShardWriter fw_init; - ASSERT_TRUE(fw_init.Open(file_names) == SUCCESS); + auto status = fw_init.Open(file_names); + EXPECT_TRUE(status.IsOk()); + // set shardHeader - ASSERT_TRUE(fw_init.SetShardHeader(std::make_shared(header_data)) == SUCCESS); + status = fw_init.SetShardHeader(std::make_shared(header_data)); + EXPECT_TRUE(status.IsOk()); // write raw data - ASSERT_TRUE(fw_init.WriteRawData(rawdatas, bin_data) == SUCCESS); - ASSERT_TRUE(fw_init.Commit() == SUCCESS); + status = fw_init.WriteRawData(rawdatas, bin_data); + EXPECT_TRUE(status.IsOk()); + status = fw_init.Commit(); + EXPECT_TRUE(status.IsOk()); // create the index file std::string filename = "./imagenet.shard01"; mindrecord::ShardIndexGenerator sg{filename}; sg.Build(); - ASSERT_TRUE(sg.WriteToDatabase() == SUCCESS); + status = sg.WriteToDatabase(); + EXPECT_TRUE(status.IsOk()); MS_LOG(INFO) << "Done create index"; // read the mindrecord file filename = "./imagenet.shard01"; auto column_list = std::vector{"label", "file_name"}; ShardReader dataset; - MSRStatus ret = dataset.Open({filename}, true, 4, column_list); - ASSERT_EQ(ret, SUCCESS); + status = dataset.Open({filename}, true, 4, column_list); + EXPECT_TRUE(status.IsOk()); dataset.Launch(); int count = 0; @@ -896,7 +909,8 @@ TEST_F(TestShardWriter, TestShardReaderStringAndNumberNotColumnInIndex) { fields.push_back(index_field1); // add index to shardHeader - ASSERT_EQ(header_data.AddIndexFields(fields), SUCCESS); + auto status = header_data.AddIndexFields(fields); + EXPECT_TRUE(status.IsOk()); MS_LOG(INFO) << "Init Index Fields Already."; // load meta data @@ -916,28 +930,34 @@ TEST_F(TestShardWriter, TestShardReaderStringAndNumberNotColumnInIndex) { } mindrecord::ShardWriter fw_init; - ASSERT_TRUE(fw_init.Open(file_names) == SUCCESS); + status = fw_init.Open(file_names); + EXPECT_TRUE(status.IsOk()); + // set shardHeader - ASSERT_TRUE(fw_init.SetShardHeader(std::make_shared(header_data)) == SUCCESS); + status = fw_init.SetShardHeader(std::make_shared(header_data)); + EXPECT_TRUE(status.IsOk()); // write raw data - ASSERT_TRUE(fw_init.WriteRawData(rawdatas, bin_data) == SUCCESS); - ASSERT_TRUE(fw_init.Commit() == SUCCESS); + status = fw_init.WriteRawData(rawdatas, bin_data); + EXPECT_TRUE(status.IsOk()); + status = fw_init.Commit(); + EXPECT_TRUE(status.IsOk()); // create the index file std::string filename = "./imagenet.shard01"; mindrecord::ShardIndexGenerator sg{filename}; sg.Build(); - ASSERT_TRUE(sg.WriteToDatabase() == SUCCESS); + status = sg.WriteToDatabase(); + EXPECT_TRUE(status.IsOk()); MS_LOG(INFO) << "Done create index"; // read the mindrecord file filename = "./imagenet.shard01"; auto column_list = std::vector{"label", "data"}; ShardReader dataset; - MSRStatus ret = dataset.Open({filename}, true, 4, column_list); - ASSERT_EQ(ret, SUCCESS); + status = dataset.Open({filename}, true, 4, column_list); + EXPECT_TRUE(status.IsOk()); dataset.Launch(); int count = 0; @@ -1043,8 +1063,8 @@ TEST_F(TestShardWriter, TestShardWriter10Sample40Shard) { filename = "./TenSampleFortyShard.shard01"; ShardReader dataset; - MSRStatus ret = dataset.Open({filename}, true, 4); - ASSERT_EQ(ret, SUCCESS); + auto status = dataset.Open({filename}, true, 4); + EXPECT_TRUE(status.IsOk()); dataset.Launch(); int count = 0; diff --git a/tests/ut/python/dataset/test_minddataset_exception.py b/tests/ut/python/dataset/test_minddataset_exception.py index ff8ab28197..79c2b09851 100644 --- a/tests/ut/python/dataset/test_minddataset_exception.py +++ b/tests/ut/python/dataset/test_minddataset_exception.py @@ -95,7 +95,7 @@ def test_invalid_mindrecord(): f.write('just for test') columns_list = ["data", "file_name", "label"] num_readers = 4 - with pytest.raises(Exception, match="MindRecordOp init failed"): + with pytest.raises(RuntimeError, match="Unexpected error. Invalid file content. path:"): data_set = ds.MindDataset('dummy.mindrecord', columns_list, num_readers) num_iter = 0 for _ in data_set.create_dict_iterator(num_epochs=1, output_numpy=True): @@ -114,7 +114,7 @@ def test_minddataset_lack_db(): os.remove("{}.db".format(CV_FILE_NAME)) columns_list = ["data", "file_name", "label"] num_readers = 4 - with pytest.raises(Exception, match="MindRecordOp init failed"): + with pytest.raises(RuntimeError, match="Unexpected error. Invalid database file:"): data_set = ds.MindDataset(CV_FILE_NAME, columns_list, num_readers) num_iter = 0 for _ in data_set.create_dict_iterator(num_epochs=1, output_numpy=True): @@ -133,7 +133,7 @@ def test_cv_minddataset_pk_sample_error_class_column(): columns_list = ["data", "file_name", "label"] num_readers = 4 sampler = ds.PKSampler(5, None, True, 'no_exist_column') - with pytest.raises(Exception, match="MindRecordOp launch failed"): + with pytest.raises(RuntimeError, match="Unexpected error. Failed to launch read threads."): data_set = ds.MindDataset(CV_FILE_NAME, columns_list, num_readers, sampler=sampler) num_iter = 0 for _ in data_set.create_dict_iterator(num_epochs=1, output_numpy=True): @@ -162,7 +162,7 @@ def test_cv_minddataset_reader_different_schema(): create_diff_schema_cv_mindrecord(1) columns_list = ["data", "label"] num_readers = 4 - with pytest.raises(Exception, match="MindRecordOp init failed"): + with pytest.raises(RuntimeError, match="Mindrecord files meta information is different"): data_set = ds.MindDataset([CV_FILE_NAME, CV1_FILE_NAME], columns_list, num_readers) num_iter = 0 @@ -179,7 +179,7 @@ def test_cv_minddataset_reader_different_page_size(): create_diff_page_size_cv_mindrecord(1) columns_list = ["data", "label"] num_readers = 4 - with pytest.raises(Exception, match="MindRecordOp init failed"): + with pytest.raises(RuntimeError, match="Mindrecord files meta information is different"): data_set = ds.MindDataset([CV_FILE_NAME, CV1_FILE_NAME], columns_list, num_readers) num_iter = 0 diff --git a/tests/ut/python/mindrecord/test_cifar100_to_mindrecord.py b/tests/ut/python/mindrecord/test_cifar100_to_mindrecord.py index 37d13f0c2b..33632af253 100644 --- a/tests/ut/python/mindrecord/test_cifar100_to_mindrecord.py +++ b/tests/ut/python/mindrecord/test_cifar100_to_mindrecord.py @@ -19,7 +19,6 @@ import pytest from mindspore import log as logger from mindspore.mindrecord import Cifar100ToMR from mindspore.mindrecord import FileReader -from mindspore.mindrecord import MRMOpenError from mindspore.mindrecord import SUCCESS CIFAR100_DIR = "../data/mindrecord/testCifar100Data" @@ -119,8 +118,8 @@ def test_cifar100_to_mindrecord_directory(fixture_file): test transform cifar10 dataset to mindrecord when destination path is directory. """ - with pytest.raises(MRMOpenError, - match="MindRecord File could not open successfully"): + with pytest.raises(RuntimeError, + match="MindRecord file already existed, please delete file:"): cifar100_transformer = Cifar100ToMR(CIFAR100_DIR, CIFAR100_DIR) cifar100_transformer.transform() @@ -130,8 +129,8 @@ def test_cifar100_to_mindrecord_filename_equals_cifar100(fixture_file): test transform cifar10 dataset to mindrecord when destination path equals source path. """ - with pytest.raises(MRMOpenError, - match="MindRecord File could not open successfully"): + with pytest.raises(RuntimeError, + match="indRecord file already existed, please delete file:"): cifar100_transformer = Cifar100ToMR(CIFAR100_DIR, CIFAR100_DIR + "/train") cifar100_transformer.transform() diff --git a/tests/ut/python/mindrecord/test_cifar10_to_mindrecord.py b/tests/ut/python/mindrecord/test_cifar10_to_mindrecord.py index 5464cc0e50..500be0d5b6 100644 --- a/tests/ut/python/mindrecord/test_cifar10_to_mindrecord.py +++ b/tests/ut/python/mindrecord/test_cifar10_to_mindrecord.py @@ -19,7 +19,7 @@ import pytest from mindspore import log as logger from mindspore.mindrecord import Cifar10ToMR from mindspore.mindrecord import FileReader -from mindspore.mindrecord import MRMOpenError, SUCCESS +from mindspore.mindrecord import SUCCESS CIFAR10_DIR = "../data/mindrecord/testCifar10Data" MINDRECORD_FILE = "./cifar10.mindrecord" @@ -146,8 +146,8 @@ def test_cifar10_to_mindrecord_directory(fixture_file): test transform cifar10 dataset to mindrecord when destination path is directory. """ - with pytest.raises(MRMOpenError, - match="MindRecord File could not open successfully"): + with pytest.raises(RuntimeError, + match="MindRecord file already existed, please delete file:"): cifar10_transformer = Cifar10ToMR(CIFAR10_DIR, CIFAR10_DIR) cifar10_transformer.transform() @@ -157,8 +157,8 @@ def test_cifar10_to_mindrecord_filename_equals_cifar10(): test transform cifar10 dataset to mindrecord when destination path equals source path. """ - with pytest.raises(MRMOpenError, - match="MindRecord File could not open successfully"): + with pytest.raises(RuntimeError, + match="MindRecord file already existed, please delete file:"): cifar10_transformer = Cifar10ToMR(CIFAR10_DIR, CIFAR10_DIR + "/data_batch_0") cifar10_transformer.transform() diff --git a/tests/ut/python/mindrecord/test_mindrecord_exception.py b/tests/ut/python/mindrecord/test_mindrecord_exception.py index f521154fb7..1710b5834c 100644 --- a/tests/ut/python/mindrecord/test_mindrecord_exception.py +++ b/tests/ut/python/mindrecord/test_mindrecord_exception.py @@ -21,8 +21,7 @@ from utils import get_data from mindspore import log as logger from mindspore.mindrecord import FileWriter, FileReader, MindPage, SUCCESS -from mindspore.mindrecord import MRMOpenError, MRMGenerateIndexError, ParamValueError, MRMGetMetaError, \ - MRMFetchDataError +from mindspore.mindrecord import ParamValueError, MRMGetMetaError CV_FILE_NAME = "./imagenet.mindrecord" NLP_FILE_NAME = "./aclImdb.mindrecord" @@ -106,21 +105,19 @@ def create_cv_mindrecord(files_num): def test_lack_partition_and_db(): """test file reader when mindrecord file does not exist.""" - with pytest.raises(MRMOpenError) as err: + with pytest.raises(RuntimeError) as err: reader = FileReader('dummy.mindrecord') reader.close() - assert '[MRMOpenError]: MindRecord File could not open successfully.' \ - in str(err.value) + assert 'Unexpected error. Invalid file path:' in str(err.value) def test_lack_db(fixture_cv_file): """test file reader when db file does not exist.""" create_cv_mindrecord(1) os.remove("{}.db".format(CV_FILE_NAME)) - with pytest.raises(MRMOpenError) as err: + with pytest.raises(RuntimeError) as err: reader = FileReader(CV_FILE_NAME) reader.close() - assert '[MRMOpenError]: MindRecord File could not open successfully.' \ - in str(err.value) + assert 'Unexpected error. Invalid database file:' in str(err.value) def test_lack_some_partition_and_db(fixture_cv_file): """test file reader when some partition and db do not exist.""" @@ -129,11 +126,10 @@ def test_lack_some_partition_and_db(fixture_cv_file): for x in range(FILES_NUM)] os.remove("{}".format(paths[3])) os.remove("{}.db".format(paths[3])) - with pytest.raises(MRMOpenError) as err: + with pytest.raises(RuntimeError) as err: reader = FileReader(CV_FILE_NAME + "0") reader.close() - assert '[MRMOpenError]: MindRecord File could not open successfully.' \ - in str(err.value) + assert 'Unexpected error. Invalid file path:' in str(err.value) def test_lack_some_partition_first(fixture_cv_file): """test file reader when first partition does not exist.""" @@ -141,11 +137,10 @@ def test_lack_some_partition_first(fixture_cv_file): paths = ["{}{}".format(CV_FILE_NAME, str(x).rjust(1, '0')) for x in range(FILES_NUM)] os.remove("{}".format(paths[0])) - with pytest.raises(MRMOpenError) as err: + with pytest.raises(RuntimeError) as err: reader = FileReader(CV_FILE_NAME + "0") reader.close() - assert '[MRMOpenError]: MindRecord File could not open successfully.' \ - in str(err.value) + assert 'Unexpected error. Invalid file path:' in str(err.value) def test_lack_some_partition_middle(fixture_cv_file): """test file reader when some partition does not exist.""" @@ -153,11 +148,10 @@ def test_lack_some_partition_middle(fixture_cv_file): paths = ["{}{}".format(CV_FILE_NAME, str(x).rjust(1, '0')) for x in range(FILES_NUM)] os.remove("{}".format(paths[1])) - with pytest.raises(MRMOpenError) as err: + with pytest.raises(RuntimeError) as err: reader = FileReader(CV_FILE_NAME + "0") reader.close() - assert '[MRMOpenError]: MindRecord File could not open successfully.' \ - in str(err.value) + assert 'Unexpected error. Invalid file path:' in str(err.value) def test_lack_some_partition_last(fixture_cv_file): """test file reader when last partition does not exist.""" @@ -165,11 +159,10 @@ def test_lack_some_partition_last(fixture_cv_file): paths = ["{}{}".format(CV_FILE_NAME, str(x).rjust(1, '0')) for x in range(FILES_NUM)] os.remove("{}".format(paths[3])) - with pytest.raises(MRMOpenError) as err: + with pytest.raises(RuntimeError) as err: reader = FileReader(CV_FILE_NAME + "0") reader.close() - assert '[MRMOpenError]: MindRecord File could not open successfully.' \ - in str(err.value) + assert 'Unexpected error. Invalid file path:' in str(err.value) def test_mindpage_lack_some_partition(fixture_cv_file): """test page reader when some partition does not exist.""" @@ -177,10 +170,9 @@ def test_mindpage_lack_some_partition(fixture_cv_file): paths = ["{}{}".format(CV_FILE_NAME, str(x).rjust(1, '0')) for x in range(FILES_NUM)] os.remove("{}".format(paths[0])) - with pytest.raises(MRMOpenError) as err: + with pytest.raises(RuntimeError) as err: MindPage(CV_FILE_NAME + "0") - assert '[MRMOpenError]: MindRecord File could not open successfully.' \ - in str(err.value) + assert 'Unexpected error. Invalid file path:' in str(err.value) def test_lack_some_db(fixture_cv_file): """test file reader when some db does not exist.""" @@ -188,11 +180,10 @@ def test_lack_some_db(fixture_cv_file): paths = ["{}{}".format(CV_FILE_NAME, str(x).rjust(1, '0')) for x in range(FILES_NUM)] os.remove("{}.db".format(paths[3])) - with pytest.raises(MRMOpenError) as err: + with pytest.raises(RuntimeError) as err: reader = FileReader(CV_FILE_NAME + "0") reader.close() - assert '[MRMOpenError]: MindRecord File could not open successfully.' \ - in str(err.value) + assert 'Unexpected error. Invalid database file:' in str(err.value) def test_invalid_mindrecord(): @@ -200,10 +191,9 @@ def test_invalid_mindrecord(): with open(CV_FILE_NAME, 'w') as f: dummy = 's' * 100 f.write(dummy) - with pytest.raises(MRMOpenError) as err: + with pytest.raises(RuntimeError) as err: FileReader(CV_FILE_NAME) - assert '[MRMOpenError]: MindRecord File could not open successfully.' \ - in str(err.value) + assert 'Unexpected error. Invalid file content. path:' in str(err.value) os.remove(CV_FILE_NAME) def test_invalid_db(fixture_cv_file): @@ -212,27 +202,26 @@ def test_invalid_db(fixture_cv_file): os.remove("imagenet.mindrecord.db") with open('imagenet.mindrecord.db', 'w') as f: f.write('just for test') - with pytest.raises(MRMOpenError) as err: + with pytest.raises(RuntimeError) as err: FileReader('imagenet.mindrecord') - assert '[MRMOpenError]: MindRecord File could not open successfully.' \ - in str(err.value) + assert 'Unexpected error. Error in execute sql:' in str(err.value) def test_overwrite_invalid_mindrecord(fixture_cv_file): """test file writer when overwrite invalid mindreocrd file.""" with open(CV_FILE_NAME, 'w') as f: f.write('just for test') - with pytest.raises(MRMOpenError) as err: + with pytest.raises(RuntimeError) as err: create_cv_mindrecord(1) - assert '[MRMOpenError]: MindRecord File could not open successfully.' \ + assert 'Unexpected error. MindRecord file already existed, please delete file:' \ in str(err.value) def test_overwrite_invalid_db(fixture_cv_file): """test file writer when overwrite invalid db file.""" with open('imagenet.mindrecord.db', 'w') as f: f.write('just for test') - with pytest.raises(MRMGenerateIndexError) as err: + with pytest.raises(RuntimeError) as err: create_cv_mindrecord(1) - assert '[MRMGenerateIndexError]: Failed to generate index.' in str(err.value) + assert 'Unexpected error. Failed to write data to db.' in str(err.value) def test_read_after_close(fixture_cv_file): """test file reader when close read.""" @@ -302,7 +291,7 @@ def test_mindpage_pageno_pagesize_not_int(fixture_cv_file): with pytest.raises(ParamValueError): reader.read_at_page_by_name("822", 0, "qwer") - with pytest.raises(MRMFetchDataError, match="Failed to fetch data by category."): + with pytest.raises(RuntimeError, match="Unexpected error. Invalid category id:"): reader.read_at_page_by_id(99999, 0, 1) @@ -320,10 +309,10 @@ def test_mindpage_filename_not_exist(fixture_cv_file): info = reader.read_category_info() logger.info("category info: {}".format(info)) - with pytest.raises(MRMFetchDataError): + with pytest.raises(RuntimeError, match="Unexpected error. Invalid category id:"): reader.read_at_page_by_id(9999, 0, 1) - with pytest.raises(MRMFetchDataError): + with pytest.raises(RuntimeError, match="Unexpected error. Invalid category name."): reader.read_at_page_by_name("abc.jpg", 0, 1) with pytest.raises(ParamValueError): @@ -475,7 +464,7 @@ def test_write_with_invalid_data(): mindrecord_file_name = "test.mindrecord" # field: file_name => filename - with pytest.raises(Exception, match="Failed to write dataset"): + with pytest.raises(RuntimeError, match="Unexpected error. Data size is not positive."): remove_one_file(mindrecord_file_name) remove_one_file(mindrecord_file_name + ".db") @@ -510,7 +499,7 @@ def test_write_with_invalid_data(): writer.commit() # field: mask => masks - with pytest.raises(Exception, match="Failed to write dataset"): + with pytest.raises(RuntimeError, match="Unexpected error. Data size is not positive."): remove_one_file(mindrecord_file_name) remove_one_file(mindrecord_file_name + ".db") @@ -545,7 +534,7 @@ def test_write_with_invalid_data(): writer.commit() # field: data => image - with pytest.raises(Exception, match="Failed to write dataset"): + with pytest.raises(RuntimeError, match="Unexpected error. Data size is not positive."): remove_one_file(mindrecord_file_name) remove_one_file(mindrecord_file_name + ".db") @@ -580,7 +569,7 @@ def test_write_with_invalid_data(): writer.commit() # field: label => labels - with pytest.raises(Exception, match="Failed to write dataset"): + with pytest.raises(RuntimeError, match="Unexpected error. Data size is not positive."): remove_one_file(mindrecord_file_name) remove_one_file(mindrecord_file_name + ".db") @@ -615,7 +604,7 @@ def test_write_with_invalid_data(): writer.commit() # field: score => scores - with pytest.raises(Exception, match="Failed to write dataset"): + with pytest.raises(RuntimeError, match="Unexpected error. Data size is not positive."): remove_one_file(mindrecord_file_name) remove_one_file(mindrecord_file_name + ".db") @@ -650,7 +639,7 @@ def test_write_with_invalid_data(): writer.commit() # string type with int value - with pytest.raises(Exception, match="Failed to write dataset"): + with pytest.raises(RuntimeError, match="Unexpected error. Data size is not positive."): remove_one_file(mindrecord_file_name) remove_one_file(mindrecord_file_name + ".db") @@ -685,7 +674,7 @@ def test_write_with_invalid_data(): writer.commit() # field with int64 type, but the real data is string - with pytest.raises(Exception, match="Failed to write dataset"): + with pytest.raises(RuntimeError, match="Unexpected error. Data size is not positive."): remove_one_file(mindrecord_file_name) remove_one_file(mindrecord_file_name + ".db") @@ -720,7 +709,7 @@ def test_write_with_invalid_data(): writer.commit() # bytes field is string - with pytest.raises(Exception, match="Failed to write dataset"): + with pytest.raises(RuntimeError, match="Unexpected error. Data size is not positive."): remove_one_file(mindrecord_file_name) remove_one_file(mindrecord_file_name + ".db") @@ -755,7 +744,7 @@ def test_write_with_invalid_data(): writer.commit() # field is not numpy type - with pytest.raises(Exception, match="Failed to write dataset"): + with pytest.raises(RuntimeError, match="Unexpected error. Data size is not positive."): remove_one_file(mindrecord_file_name) remove_one_file(mindrecord_file_name + ".db") @@ -790,7 +779,7 @@ def test_write_with_invalid_data(): writer.commit() # not enough field - with pytest.raises(Exception, match="Failed to write dataset"): + with pytest.raises(RuntimeError, match="Unexpected error. Data size is not positive."): remove_one_file(mindrecord_file_name) remove_one_file(mindrecord_file_name + ".db")