| @@ -256,9 +256,7 @@ Status SaveToDisk::Save() { | |||
| auto mr_header = std::make_shared<mindrecord::ShardHeader>(); | |||
| auto mr_writer = std::make_unique<mindrecord::ShardWriter>(); | |||
| std::vector<std::string> blob_fields; | |||
| if (mindrecord::SUCCESS != mindrecord::ShardWriter::Initialize(&mr_writer, file_names)) { | |||
| RETURN_STATUS_UNEXPECTED("Error: failed to initialize ShardWriter, please check above `ERROR` level message."); | |||
| } | |||
| RETURN_IF_NOT_OK(mindrecord::ShardWriter::Initialize(&mr_writer, file_names)); | |||
| std::unordered_map<std::string, int32_t> column_name_id_map; | |||
| for (auto el : tree_adapter_->GetColumnNameMap()) { | |||
| @@ -286,22 +284,16 @@ Status SaveToDisk::Save() { | |||
| std::vector<std::string> index_fields; | |||
| RETURN_IF_NOT_OK(FetchMetaFromTensorRow(column_name_id_map, row, &mr_json, &index_fields)); | |||
| MS_LOG(INFO) << "Schema of saved mindrecord: " << mr_json.dump(); | |||
| if (mindrecord::SUCCESS != | |||
| mindrecord::ShardHeader::Initialize(&mr_header, mr_json, index_fields, blob_fields, mr_schema_id)) { | |||
| RETURN_STATUS_UNEXPECTED("Error: failed to initialize ShardHeader."); | |||
| } | |||
| if (mindrecord::SUCCESS != mr_writer->SetShardHeader(mr_header)) { | |||
| RETURN_STATUS_UNEXPECTED("Error: failed to set header of ShardWriter."); | |||
| } | |||
| RETURN_IF_NOT_OK( | |||
| mindrecord::ShardHeader::Initialize(&mr_header, mr_json, index_fields, blob_fields, mr_schema_id)); | |||
| RETURN_IF_NOT_OK(mr_writer->SetShardHeader(mr_header)); | |||
| first_loop = false; | |||
| } | |||
| // construct data | |||
| if (!row.empty()) { // write data | |||
| RETURN_IF_NOT_OK(FetchDataFromTensorRow(row, column_name_id_map, &row_raw_data, &row_bin_data)); | |||
| std::shared_ptr<std::vector<uint8_t>> output_bin_data; | |||
| if (mindrecord::SUCCESS != mr_writer->MergeBlobData(blob_fields, row_bin_data, &output_bin_data)) { | |||
| RETURN_STATUS_UNEXPECTED("Error: failed to merge blob data of ShardWriter."); | |||
| } | |||
| RETURN_IF_NOT_OK(mr_writer->MergeBlobData(blob_fields, row_bin_data, &output_bin_data)); | |||
| std::map<std::uint64_t, std::vector<nlohmann::json>> raw_data; | |||
| raw_data.insert( | |||
| std::pair<uint64_t, std::vector<nlohmann::json>>(mr_schema_id, std::vector<nlohmann::json>{row_raw_data})); | |||
| @@ -309,18 +301,12 @@ Status SaveToDisk::Save() { | |||
| if (output_bin_data != nullptr) { | |||
| bin_data.emplace_back(*output_bin_data); | |||
| } | |||
| if (mindrecord::SUCCESS != mr_writer->WriteRawData(raw_data, bin_data)) { | |||
| RETURN_STATUS_UNEXPECTED("Error: failed to write raw data to ShardWriter."); | |||
| } | |||
| RETURN_IF_NOT_OK(mr_writer->WriteRawData(raw_data, bin_data)); | |||
| } | |||
| } while (!row.empty()); | |||
| if (mindrecord::SUCCESS != mr_writer->Commit()) { | |||
| RETURN_STATUS_UNEXPECTED("Error: failed to commit ShardWriter."); | |||
| } | |||
| if (mindrecord::SUCCESS != mindrecord::ShardIndexGenerator::Finalize(file_names)) { | |||
| RETURN_STATUS_UNEXPECTED("Error: failed to finalize ShardIndexGenerator."); | |||
| } | |||
| RETURN_IF_NOT_OK(mr_writer->Commit()); | |||
| RETURN_IF_NOT_OK(mindrecord::ShardIndexGenerator::Finalize(file_names)); | |||
| return Status::OK(); | |||
| } | |||
| @@ -23,6 +23,7 @@ | |||
| #include "minddata/dataset/core/config_manager.h" | |||
| #include "minddata/dataset/core/global_context.h" | |||
| #include "minddata/dataset/engine/datasetops/source/sampler/mind_record_sampler.h" | |||
| #include "minddata/mindrecord/include/shard_column.h" | |||
| #include "minddata/dataset/engine/db_connector.h" | |||
| #include "minddata/dataset/engine/execution_tree.h" | |||
| #include "minddata/dataset/include/dataset/constants.h" | |||
| @@ -63,10 +64,8 @@ MindRecordOp::MindRecordOp(int32_t num_mind_record_workers, std::vector<std::str | |||
| // Private helper method to encapsulate some common construction/reset tasks | |||
| Status MindRecordOp::Init() { | |||
| auto rc = shard_reader_->Open(dataset_file_, load_dataset_, num_mind_record_workers_, columns_to_load_, operators_, | |||
| num_padded_); | |||
| CHECK_FAIL_RETURN_UNEXPECTED(rc == MSRStatus::SUCCESS, "MindRecordOp init failed, " + ErrnoToMessage(rc)); | |||
| RETURN_IF_NOT_OK(shard_reader_->Open(dataset_file_, load_dataset_, num_mind_record_workers_, columns_to_load_, | |||
| operators_, num_padded_)); | |||
| data_schema_ = std::make_unique<DataSchema>(); | |||
| @@ -206,7 +205,9 @@ Status MindRecordOp::GetRowFromReader(TensorRow *fetched_row, uint64_t row_id, i | |||
| fetched_row->setPath(file_path); | |||
| fetched_row->setId(row_id); | |||
| } | |||
| if (tupled_buffer.empty()) return Status::OK(); | |||
| if (tupled_buffer.empty()) { | |||
| return Status::OK(); | |||
| } | |||
| if (task_type == mindrecord::TaskType::kCommonTask) { | |||
| for (const auto &tupled_row : tupled_buffer) { | |||
| std::vector<uint8_t> columns_blob = std::get<0>(tupled_row); | |||
| @@ -237,20 +238,15 @@ Status MindRecordOp::LoadTensorRow(TensorRow *tensor_row, const std::vector<uint | |||
| // Get column data | |||
| auto shard_column = shard_reader_->GetShardColumn(); | |||
| if (num_padded_ > 0 && task_type == mindrecord::TaskType::kPaddedTask) { | |||
| auto rc = | |||
| shard_column->GetColumnTypeByName(column_name, &column_data_type, &column_data_type_size, &column_shape); | |||
| if (rc.first != MSRStatus::SUCCESS) { | |||
| RETURN_STATUS_UNEXPECTED("Invalid parameter, column_name: " + column_name + "does not exist in dataset."); | |||
| } | |||
| if (rc.second == mindrecord::ColumnInRaw) { | |||
| auto column_in_raw = shard_column->GetColumnFromJson(column_name, sample_json_, &data_ptr, &n_bytes); | |||
| if (column_in_raw == MSRStatus::FAILED) { | |||
| RETURN_STATUS_UNEXPECTED("Invalid data, failed to retrieve raw data from padding sample."); | |||
| } | |||
| } else if (rc.second == mindrecord::ColumnInBlob) { | |||
| if (sample_bytes_.find(column_name) == sample_bytes_.end()) { | |||
| RETURN_STATUS_UNEXPECTED("Invalid data, failed to retrieve blob data from padding sample."); | |||
| } | |||
| mindrecord::ColumnCategory category; | |||
| RETURN_IF_NOT_OK(shard_column->GetColumnTypeByName(column_name, &column_data_type, &column_data_type_size, | |||
| &column_shape, &category)); | |||
| if (category == mindrecord::ColumnInRaw) { | |||
| RETURN_IF_NOT_OK(shard_column->GetColumnFromJson(column_name, sample_json_, &data_ptr, &n_bytes)); | |||
| } else if (category == mindrecord::ColumnInBlob) { | |||
| CHECK_FAIL_RETURN_UNEXPECTED(sample_bytes_.find(column_name) != sample_bytes_.end(), | |||
| "Invalid data, failed to retrieve blob data from padding sample."); | |||
| std::string ss(sample_bytes_[column_name]); | |||
| n_bytes = ss.size(); | |||
| data_ptr = std::make_unique<unsigned char[]>(n_bytes); | |||
| @@ -262,12 +258,9 @@ Status MindRecordOp::LoadTensorRow(TensorRow *tensor_row, const std::vector<uint | |||
| data = reinterpret_cast<const unsigned char *>(data_ptr.get()); | |||
| } | |||
| } else { | |||
| auto has_column = | |||
| shard_column->GetColumnValueByName(column_name, columns_blob, columns_json, &data, &data_ptr, &n_bytes, | |||
| &column_data_type, &column_data_type_size, &column_shape); | |||
| if (has_column == MSRStatus::FAILED) { | |||
| RETURN_STATUS_UNEXPECTED("Invalid data, failed to retrieve data from mindrecord reader."); | |||
| } | |||
| RETURN_IF_NOT_OK(shard_column->GetColumnValueByName(column_name, columns_blob, columns_json, &data, &data_ptr, | |||
| &n_bytes, &column_data_type, &column_data_type_size, | |||
| &column_shape)); | |||
| } | |||
| std::shared_ptr<Tensor> tensor; | |||
| @@ -309,15 +302,10 @@ Status MindRecordOp::Reset() { | |||
| } | |||
| Status MindRecordOp::LaunchThreadsAndInitOp() { | |||
| if (tree_ == nullptr) { | |||
| RETURN_STATUS_UNEXPECTED("Pipeline init failed, Execution tree not set."); | |||
| } | |||
| RETURN_UNEXPECTED_IF_NULL(tree_); | |||
| RETURN_IF_NOT_OK(io_block_queues_.Register(tree_->AllTasks())); | |||
| RETURN_IF_NOT_OK(wait_for_workers_post_.Register(tree_->AllTasks())); | |||
| if (shard_reader_->Launch(true) == MSRStatus::FAILED) { | |||
| RETURN_STATUS_UNEXPECTED("MindRecordOp launch failed."); | |||
| } | |||
| RETURN_IF_NOT_OK(shard_reader_->Launch(true)); | |||
| // Launch main workers that load TensorRows by reading all images | |||
| RETURN_IF_NOT_OK( | |||
| tree_->LaunchWorkers(num_workers_, std::bind(&MindRecordOp::WorkerEntry, this, std::placeholders::_1), "", id())); | |||
| @@ -330,12 +318,7 @@ Status MindRecordOp::LaunchThreadsAndInitOp() { | |||
| Status MindRecordOp::CountTotalRows(const std::vector<std::string> dataset_path, bool load_dataset, | |||
| const std::shared_ptr<ShardOperator> &op, int64_t *count, int64_t num_padded) { | |||
| std::unique_ptr<ShardReader> shard_reader = std::make_unique<ShardReader>(); | |||
| MSRStatus rc = shard_reader->CountTotalRows(dataset_path, load_dataset, op, count, num_padded); | |||
| if (rc == MSRStatus::FAILED) { | |||
| RETURN_STATUS_UNEXPECTED( | |||
| "Invalid data, MindRecordOp failed to count total rows. Check whether there are corresponding .db files " | |||
| "and the value of dataset_file parameter is given correctly."); | |||
| } | |||
| RETURN_IF_NOT_OK(shard_reader->CountTotalRows(dataset_path, load_dataset, op, count, num_padded)); | |||
| return Status::OK(); | |||
| } | |||
| @@ -38,9 +38,8 @@ Status GraphFeatureParser::LoadFeatureTensor(const std::string &key, const std:: | |||
| uint64_t n_bytes = 0, col_type_size = 1; | |||
| mindrecord::ColumnDataType col_type = mindrecord::ColumnNoDataType; | |||
| std::vector<int64_t> column_shape; | |||
| MSRStatus rs = shard_column_->GetColumnValueByName(key, col_blob, {}, &data, &data_ptr, &n_bytes, &col_type, | |||
| &col_type_size, &column_shape); | |||
| CHECK_FAIL_RETURN_UNEXPECTED(rs == mindrecord::SUCCESS, "fail to load column" + key); | |||
| RETURN_IF_NOT_OK(shard_column_->GetColumnValueByName(key, col_blob, {}, &data, &data_ptr, &n_bytes, &col_type, | |||
| &col_type_size, &column_shape)); | |||
| if (data == nullptr) data = reinterpret_cast<const unsigned char *>(&data_ptr[0]); | |||
| RETURN_IF_NOT_OK(Tensor::CreateFromMemory(std::move(TensorShape({static_cast<dsize_t>(n_bytes / col_type_size)})), | |||
| std::move(DataType(mindrecord::ColumnDataTypeNameNormalized[col_type])), | |||
| @@ -57,9 +56,8 @@ Status GraphFeatureParser::LoadFeatureToSharedMemory(const std::string &key, con | |||
| uint64_t n_bytes = 0, col_type_size = 1; | |||
| mindrecord::ColumnDataType col_type = mindrecord::ColumnNoDataType; | |||
| std::vector<int64_t> column_shape; | |||
| MSRStatus rs = shard_column_->GetColumnValueByName(key, col_blob, {}, &data, &data_ptr, &n_bytes, &col_type, | |||
| &col_type_size, &column_shape); | |||
| CHECK_FAIL_RETURN_UNEXPECTED(rs == mindrecord::SUCCESS, "fail to load column" + key); | |||
| RETURN_IF_NOT_OK(shard_column_->GetColumnValueByName(key, col_blob, {}, &data, &data_ptr, &n_bytes, &col_type, | |||
| &col_type_size, &column_shape)); | |||
| if (data == nullptr) data = reinterpret_cast<const unsigned char *>(&data_ptr[0]); | |||
| std::shared_ptr<Tensor> tensor; | |||
| RETURN_IF_NOT_OK(Tensor::CreateEmpty(std::move(TensorShape({2})), std::move(DataType(DataType::DE_INT64)), &tensor)); | |||
| @@ -81,9 +79,8 @@ Status GraphFeatureParser::LoadFeatureIndex(const std::string &key, const std::v | |||
| uint64_t n_bytes = 0, col_type_size = 1; | |||
| mindrecord::ColumnDataType col_type = mindrecord::ColumnNoDataType; | |||
| std::vector<int64_t> column_shape; | |||
| MSRStatus rs = shard_column_->GetColumnValueByName(key, col_blob, {}, &data, &data_ptr, &n_bytes, &col_type, | |||
| &col_type_size, &column_shape); | |||
| CHECK_FAIL_RETURN_UNEXPECTED(rs == mindrecord::SUCCESS, "fail to load column:" + key); | |||
| RETURN_IF_NOT_OK(shard_column_->GetColumnValueByName(key, col_blob, {}, &data, &data_ptr, &n_bytes, &col_type, | |||
| &col_type_size, &column_shape)); | |||
| if (data == nullptr) data = reinterpret_cast<const unsigned char *>(&data_ptr[0]); | |||
| @@ -94,10 +94,9 @@ Status GraphLoader::InitAndLoad() { | |||
| TaskGroup vg; | |||
| shard_reader_ = std::make_unique<ShardReader>(); | |||
| CHECK_FAIL_RETURN_UNEXPECTED(shard_reader_->Open({mr_path_}, true, num_workers_) == MSRStatus::SUCCESS, | |||
| "Fail to open" + mr_path_); | |||
| RETURN_IF_NOT_OK(shard_reader_->Open({mr_path_}, true, num_workers_)); | |||
| CHECK_FAIL_RETURN_UNEXPECTED(shard_reader_->GetShardHeader()->GetSchemaCount() > 0, "No schema found!"); | |||
| CHECK_FAIL_RETURN_UNEXPECTED(shard_reader_->Launch(true) == MSRStatus::SUCCESS, "fail to launch mr"); | |||
| RETURN_IF_NOT_OK(shard_reader_->Launch(true)); | |||
| graph_impl_->data_schema_ = (shard_reader_->GetShardHeader()->GetSchemas()[0]->GetSchema()); | |||
| mindrecord::json schema = graph_impl_->data_schema_["schema"]; | |||
| @@ -116,8 +115,7 @@ Status GraphLoader::InitAndLoad() { | |||
| if (graph_impl_->server_mode_) { | |||
| #if !defined(_WIN32) && !defined(_WIN64) | |||
| int64_t total_blob_size = 0; | |||
| CHECK_FAIL_RETURN_UNEXPECTED(shard_reader_->GetTotalBlobSize(&total_blob_size) == MSRStatus::SUCCESS, | |||
| "failed to get total blob size"); | |||
| RETURN_IF_NOT_OK(shard_reader_->GetTotalBlobSize(&total_blob_size)); | |||
| graph_impl_->graph_shared_memory_ = std::make_unique<GraphSharedMemory>(total_blob_size, mr_path_); | |||
| RETURN_IF_NOT_OK(graph_impl_->graph_shared_memory_->CreateSharedMemory()); | |||
| #endif | |||
| @@ -1,83 +0,0 @@ | |||
| /** | |||
| * Copyright 2019 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "minddata/mindrecord/include/shard_error.h" | |||
| namespace mindspore { | |||
| namespace mindrecord { | |||
| static const std::map<MSRStatus, std::string> kErrnoToMessage = { | |||
| {FAILED, "operator failed"}, | |||
| {SUCCESS, "operator success"}, | |||
| {OPEN_FILE_FAILED, "open file failed"}, | |||
| {CLOSE_FILE_FAILED, "close file failed"}, | |||
| {WRITE_METADATA_FAILED, "write metadata failed"}, | |||
| {WRITE_RAWDATA_FAILED, "write rawdata failed"}, | |||
| {GET_SCHEMA_FAILED, "get schema failed"}, | |||
| {ILLEGAL_RAWDATA, "illegal raw data"}, | |||
| {PYTHON_TO_JSON_FAILED, "pybind: python object to json failed"}, | |||
| {DIR_CREATE_FAILED, "directory create failed"}, | |||
| {OPEN_DIR_FAILED, "open directory failed"}, | |||
| {INVALID_STATISTICS, "invalid statistics object"}, | |||
| {OPEN_DATABASE_FAILED, "open database failed"}, | |||
| {CLOSE_DATABASE_FAILED, "close database failed"}, | |||
| {DATABASE_OPERATE_FAILED, "database operate failed"}, | |||
| {BUILD_SCHEMA_FAILED, "build schema failed"}, | |||
| {DIVISOR_IS_ILLEGAL, "divisor is illegal"}, | |||
| {INVALID_FILE_PATH, "file path is invalid"}, | |||
| {SECURE_FUNC_FAILED, "secure function failed"}, | |||
| {ALLOCATE_MEM_FAILED, "allocate memory failed"}, | |||
| {ILLEGAL_FIELD_NAME, "illegal field name"}, | |||
| {ILLEGAL_FIELD_TYPE, "illegal field type"}, | |||
| {SET_METADATA_FAILED, "set metadata failed"}, | |||
| {ILLEGAL_SCHEMA_DEFINITION, "illegal schema definition"}, | |||
| {ILLEGAL_COLUMN_LIST, "illegal column list"}, | |||
| {SQL_ERROR, "sql error"}, | |||
| {ILLEGAL_SHARD_COUNT, "illegal shard count"}, | |||
| {ILLEGAL_SCHEMA_COUNT, "illegal schema count"}, | |||
| {VERSION_ERROR, "data version is not matched"}, | |||
| {ADD_SCHEMA_FAILED, "add schema failed"}, | |||
| {ILLEGAL_Header_SIZE, "illegal header size"}, | |||
| {ILLEGAL_Page_SIZE, "illegal page size"}, | |||
| {ILLEGAL_SIZE_VALUE, "illegal size value"}, | |||
| {INDEX_FIELD_ERROR, "add index fields failed"}, | |||
| {GET_CANDIDATE_CATEGORYFIELDS_FAILED, "get candidate category fields failed"}, | |||
| {GET_CATEGORY_INFO_FAILED, "get category information failed"}, | |||
| {ILLEGAL_CATEGORY_ID, "illegal category id"}, | |||
| {ILLEGAL_ROWNUMBER_OF_PAGE, "illegal row number of page"}, | |||
| {ILLEGAL_SCHEMA_ID, "illegal schema id"}, | |||
| {DESERIALIZE_SCHEMA_FAILED, "deserialize schema failed"}, | |||
| {DESERIALIZE_STATISTICS_FAILED, "deserialize statistics failed"}, | |||
| {ILLEGAL_DB_FILE, "illegal db file"}, | |||
| {OVERWRITE_DB_FILE, "overwrite db file"}, | |||
| {OVERWRITE_MINDRECORD_FILE, "overwrite mindrecord file"}, | |||
| {ILLEGAL_MINDRECORD_FILE, "illegal mindrecord file"}, | |||
| {PARSE_JSON_FAILED, "parse json failed"}, | |||
| {ILLEGAL_PARAMETERS, "illegal parameters"}, | |||
| {GET_PAGE_BY_GROUP_ID_FAILED, "get page by group id failed"}, | |||
| {GET_SYSTEM_STATE_FAILED, "get system state failed"}, | |||
| {IO_FAILED, "io operate failed"}, | |||
| {MATCH_HEADER_FAILED, "match header failed"}}; | |||
| std::string ErrnoToMessage(MSRStatus status) { | |||
| auto iter = kErrnoToMessage.find(status); | |||
| if (iter != kErrnoToMessage.end()) { | |||
| return kErrnoToMessage.at(status); | |||
| } else { | |||
| return "invalid error no"; | |||
| } | |||
| } | |||
| } // namespace mindrecord | |||
| } // namespace mindspore | |||
| @@ -36,20 +36,42 @@ using mindspore::MsLogLevel::ERROR; | |||
| namespace mindspore { | |||
| namespace mindrecord { | |||
| #define THROW_IF_ERROR(s) \ | |||
| do { \ | |||
| Status rc = std::move(s); \ | |||
| if (rc.IsError()) throw std::runtime_error(rc.ToString()); \ | |||
| } while (false) | |||
| void BindSchema(py::module *m) { | |||
| (void)py::class_<Schema, std::shared_ptr<Schema>>(*m, "Schema", py::module_local()) | |||
| .def_static("build", (std::shared_ptr<Schema>(*)(std::string, py::handle)) & Schema::Build) | |||
| .def_static("build", | |||
| [](const std::string &desc, const pybind11::handle &schema) { | |||
| json schema_json = nlohmann::detail::ToJsonImpl(schema); | |||
| return Schema::Build(std::move(desc), schema_json); | |||
| }) | |||
| .def("get_desc", &Schema::GetDesc) | |||
| .def("get_schema_content", (py::object(Schema::*)()) & Schema::GetSchemaForPython) | |||
| .def("get_schema_content", | |||
| [](Schema &s) { | |||
| json schema_json = s.GetSchema(); | |||
| return nlohmann::detail::FromJsonImpl(schema_json); | |||
| }) | |||
| .def("get_blob_fields", &Schema::GetBlobFields) | |||
| .def("get_schema_id", &Schema::GetSchemaID); | |||
| } | |||
| void BindStatistics(const py::module *m) { | |||
| (void)py::class_<Statistics, std::shared_ptr<Statistics>>(*m, "Statistics", py::module_local()) | |||
| .def_static("build", (std::shared_ptr<Statistics>(*)(std::string, py::handle)) & Statistics::Build) | |||
| .def_static("build", | |||
| [](const std::string desc, const pybind11::handle &statistics) { | |||
| json statistics_json = nlohmann::detail::ToJsonImpl(statistics); | |||
| return Statistics::Build(std::move(desc), statistics_json); | |||
| }) | |||
| .def("get_desc", &Statistics::GetDesc) | |||
| .def("get_statistics", (py::object(Statistics::*)()) & Statistics::GetStatisticsForPython) | |||
| .def("get_statistics", | |||
| [](Statistics &s) { | |||
| json statistics_json = s.GetStatistics(); | |||
| return nlohmann::detail::FromJsonImpl(statistics_json); | |||
| }) | |||
| .def("get_statistics_id", &Statistics::GetStatisticsID); | |||
| } | |||
| @@ -59,70 +81,179 @@ void BindShardHeader(const py::module *m) { | |||
| .def("add_schema", &ShardHeader::AddSchema) | |||
| .def("add_statistics", &ShardHeader::AddStatistic) | |||
| .def("add_index_fields", | |||
| (MSRStatus(ShardHeader::*)(const std::vector<std::string> &)) & ShardHeader::AddIndexFields) | |||
| [](ShardHeader &s, const std::vector<std::string> &fields) { | |||
| THROW_IF_ERROR(s.AddIndexFields(fields)); | |||
| return SUCCESS; | |||
| }) | |||
| .def("get_meta", &ShardHeader::GetSchemas) | |||
| .def("get_statistics", &ShardHeader::GetStatistics) | |||
| .def("get_fields", &ShardHeader::GetFields) | |||
| .def("get_schema_by_id", &ShardHeader::GetSchemaByID) | |||
| .def("get_statistic_by_id", &ShardHeader::GetStatisticByID); | |||
| .def("get_schema_by_id", | |||
| [](ShardHeader &s, int64_t schema_id) { | |||
| std::shared_ptr<Schema> schema_ptr; | |||
| THROW_IF_ERROR(s.GetSchemaByID(schema_id, &schema_ptr)); | |||
| return schema_ptr; | |||
| }) | |||
| .def("get_statistic_by_id", [](ShardHeader &s, int64_t statistic_id) { | |||
| std::shared_ptr<Statistics> statistics_ptr; | |||
| THROW_IF_ERROR(s.GetStatisticByID(statistic_id, &statistics_ptr)); | |||
| return statistics_ptr; | |||
| }); | |||
| } | |||
| void BindShardWriter(py::module *m) { | |||
| (void)py::class_<ShardWriter>(*m, "ShardWriter", py::module_local()) | |||
| .def(py::init<>()) | |||
| .def("open", &ShardWriter::Open) | |||
| .def("open_for_append", &ShardWriter::OpenForAppend) | |||
| .def("set_header_size", &ShardWriter::SetHeaderSize) | |||
| .def("set_page_size", &ShardWriter::SetPageSize) | |||
| .def("set_shard_header", &ShardWriter::SetShardHeader) | |||
| .def("write_raw_data", (MSRStatus(ShardWriter::*)(std::map<uint64_t, std::vector<py::handle>> &, | |||
| vector<vector<uint8_t>> &, bool, bool)) & | |||
| ShardWriter::WriteRawData) | |||
| .def("commit", &ShardWriter::Commit); | |||
| .def("open", | |||
| [](ShardWriter &s, const std::vector<std::string> &paths, bool append) { | |||
| THROW_IF_ERROR(s.Open(paths, append)); | |||
| return SUCCESS; | |||
| }) | |||
| .def("open_for_append", | |||
| [](ShardWriter &s, const std::string &path) { | |||
| THROW_IF_ERROR(s.OpenForAppend(path)); | |||
| return SUCCESS; | |||
| }) | |||
| .def("set_header_size", | |||
| [](ShardWriter &s, const uint64_t &header_size) { | |||
| THROW_IF_ERROR(s.SetHeaderSize(header_size)); | |||
| return SUCCESS; | |||
| }) | |||
| .def("set_page_size", | |||
| [](ShardWriter &s, const uint64_t &page_size) { | |||
| THROW_IF_ERROR(s.SetPageSize(page_size)); | |||
| return SUCCESS; | |||
| }) | |||
| .def("set_shard_header", | |||
| [](ShardWriter &s, std::shared_ptr<ShardHeader> header_data) { | |||
| THROW_IF_ERROR(s.SetShardHeader(header_data)); | |||
| return SUCCESS; | |||
| }) | |||
| .def("write_raw_data", | |||
| [](ShardWriter &s, std::map<uint64_t, std::vector<py::handle>> &raw_data, vector<vector<uint8_t>> &blob_data, | |||
| bool sign, bool parallel_writer) { | |||
| std::map<uint64_t, std::vector<json>> raw_data_json; | |||
| (void)std::transform(raw_data.begin(), raw_data.end(), std::inserter(raw_data_json, raw_data_json.end()), | |||
| [](const std::pair<uint64_t, std::vector<py::handle>> &p) { | |||
| auto &py_raw_data = p.second; | |||
| std::vector<json> json_raw_data; | |||
| (void)std::transform( | |||
| py_raw_data.begin(), py_raw_data.end(), std::back_inserter(json_raw_data), | |||
| [](const py::handle &obj) { return nlohmann::detail::ToJsonImpl(obj); }); | |||
| return std::make_pair(p.first, std::move(json_raw_data)); | |||
| }); | |||
| THROW_IF_ERROR(s.WriteRawData(raw_data_json, blob_data, sign, parallel_writer)); | |||
| return SUCCESS; | |||
| }) | |||
| .def("commit", [](ShardWriter &s) { | |||
| THROW_IF_ERROR(s.Commit()); | |||
| return SUCCESS; | |||
| }); | |||
| } | |||
| void BindShardReader(const py::module *m) { | |||
| (void)py::class_<ShardReader, std::shared_ptr<ShardReader>>(*m, "ShardReader", py::module_local()) | |||
| .def(py::init<>()) | |||
| .def("open", (MSRStatus(ShardReader::*)(const std::vector<std::string> &, bool, const int &, | |||
| const std::vector<std::string> &, | |||
| const std::vector<std::shared_ptr<ShardOperator>> &)) & | |||
| ShardReader::OpenPy) | |||
| .def("launch", &ShardReader::Launch) | |||
| .def("open", | |||
| [](ShardReader &s, const std::vector<std::string> &file_paths, bool load_dataset, const int &n_consumer, | |||
| const std::vector<std::string> &selected_columns, | |||
| const std::vector<std::shared_ptr<ShardOperator>> &operators) { | |||
| THROW_IF_ERROR(s.Open(file_paths, load_dataset, n_consumer, selected_columns, operators)); | |||
| return SUCCESS; | |||
| }) | |||
| .def("launch", | |||
| [](ShardReader &s) { | |||
| THROW_IF_ERROR(s.Launch(false)); | |||
| return SUCCESS; | |||
| }) | |||
| .def("get_header", &ShardReader::GetShardHeader) | |||
| .def("get_blob_fields", &ShardReader::GetBlobFields) | |||
| .def("get_next", (std::vector<std::tuple<std::vector<std::vector<uint8_t>>, pybind11::object>>(ShardReader::*)()) & | |||
| ShardReader::GetNextPy) | |||
| .def("get_next", | |||
| [](ShardReader &s) { | |||
| auto data = s.GetNext(); | |||
| vector<std::tuple<std::vector<std::vector<uint8_t>>, pybind11::object>> res; | |||
| std::transform(data.begin(), data.end(), std::back_inserter(res), | |||
| [&s](const std::tuple<std::vector<uint8_t>, json> &item) { | |||
| auto &j = std::get<1>(item); | |||
| pybind11::object obj = nlohmann::detail::FromJsonImpl(j); | |||
| auto blob_data_ptr = std::make_shared<std::vector<std::vector<uint8_t>>>(); | |||
| (void)s.UnCompressBlob(std::get<0>(item), &blob_data_ptr); | |||
| return std::make_tuple(*blob_data_ptr, std::move(obj)); | |||
| }); | |||
| return res; | |||
| }) | |||
| .def("close", &ShardReader::Close); | |||
| } | |||
| void BindShardIndexGenerator(const py::module *m) { | |||
| (void)py::class_<ShardIndexGenerator>(*m, "ShardIndexGenerator", py::module_local()) | |||
| .def(py::init<const std::string &, bool>()) | |||
| .def("build", &ShardIndexGenerator::Build) | |||
| .def("write_to_db", &ShardIndexGenerator::WriteToDatabase); | |||
| .def("build", | |||
| [](ShardIndexGenerator &s) { | |||
| THROW_IF_ERROR(s.Build()); | |||
| return SUCCESS; | |||
| }) | |||
| .def("write_to_db", [](ShardIndexGenerator &s) { | |||
| THROW_IF_ERROR(s.WriteToDatabase()); | |||
| return SUCCESS; | |||
| }); | |||
| } | |||
| void BindShardSegment(py::module *m) { | |||
| (void)py::class_<ShardSegment>(*m, "ShardSegment", py::module_local()) | |||
| .def(py::init<>()) | |||
| .def("open", (MSRStatus(ShardSegment::*)(const std::vector<std::string> &, bool, const int &, | |||
| const std::vector<std::string> &, | |||
| const std::vector<std::shared_ptr<ShardOperator>> &)) & | |||
| ShardSegment::OpenPy) | |||
| .def("open", | |||
| [](ShardSegment &s, const std::vector<std::string> &file_paths, bool load_dataset, const int &n_consumer, | |||
| const std::vector<std::string> &selected_columns, | |||
| const std::vector<std::shared_ptr<ShardOperator>> &operators) { | |||
| THROW_IF_ERROR(s.Open(file_paths, load_dataset, n_consumer, selected_columns, operators)); | |||
| return SUCCESS; | |||
| }) | |||
| .def("get_category_fields", | |||
| (std::pair<MSRStatus, vector<std::string>>(ShardSegment::*)()) & ShardSegment::GetCategoryFields) | |||
| .def("set_category_field", (MSRStatus(ShardSegment::*)(std::string)) & ShardSegment::SetCategoryField) | |||
| .def("read_category_info", (std::pair<MSRStatus, std::string>(ShardSegment::*)()) & ShardSegment::ReadCategoryInfo) | |||
| .def("read_at_page_by_id", (std::pair<MSRStatus, std::vector<std::tuple<std::vector<uint8_t>, pybind11::object>>>( | |||
| ShardSegment::*)(int64_t, int64_t, int64_t)) & | |||
| ShardSegment::ReadAtPageByIdPy) | |||
| .def("read_at_page_by_name", (std::pair<MSRStatus, std::vector<std::tuple<std::vector<uint8_t>, pybind11::object>>>( | |||
| ShardSegment::*)(std::string, int64_t, int64_t)) & | |||
| ShardSegment::ReadAtPageByNamePy) | |||
| [](ShardSegment &s) { | |||
| auto fields_ptr = std::make_shared<vector<std::string>>(); | |||
| THROW_IF_ERROR(s.GetCategoryFields(&fields_ptr)); | |||
| return *fields_ptr; | |||
| }) | |||
| .def("set_category_field", | |||
| [](ShardSegment &s, const std::string &category_field) { | |||
| THROW_IF_ERROR(s.SetCategoryField(category_field)); | |||
| return SUCCESS; | |||
| }) | |||
| .def("read_category_info", | |||
| [](ShardSegment &s) { | |||
| std::shared_ptr<std::string> category_ptr; | |||
| THROW_IF_ERROR(s.ReadCategoryInfo(&category_ptr)); | |||
| return *category_ptr; | |||
| }) | |||
| .def("read_at_page_by_id", | |||
| [](ShardSegment &s, int64_t category_id, int64_t page_no, int64_t n_rows_of_page) { | |||
| auto pages_load_ptr = std::make_shared<PAGES_LOAD>(); | |||
| auto pages_ptr = std::make_shared<PAGES>(); | |||
| THROW_IF_ERROR(s.ReadAllAtPageById(category_id, page_no, n_rows_of_page, &pages_ptr)); | |||
| (void)std::transform(pages_ptr->begin(), pages_ptr->end(), std::back_inserter(*pages_load_ptr), | |||
| [](const std::tuple<std::vector<uint8_t>, json> &item) { | |||
| auto &j = std::get<1>(item); | |||
| pybind11::object obj = nlohmann::detail::FromJsonImpl(j); | |||
| return std::make_tuple(std::get<0>(item), std::move(obj)); | |||
| }); | |||
| return *pages_load_ptr; | |||
| }) | |||
| .def("read_at_page_by_name", | |||
| [](ShardSegment &s, std::string category_name, int64_t page_no, int64_t n_rows_of_page) { | |||
| auto pages_load_ptr = std::make_shared<PAGES_LOAD>(); | |||
| auto pages_ptr = std::make_shared<PAGES>(); | |||
| THROW_IF_ERROR(s.ReadAllAtPageByName(category_name, page_no, n_rows_of_page, &pages_ptr)); | |||
| (void)std::transform(pages_ptr->begin(), pages_ptr->end(), std::back_inserter(*pages_load_ptr), | |||
| [](const std::tuple<std::vector<uint8_t>, json> &item) { | |||
| auto &j = std::get<1>(item); | |||
| pybind11::object obj = nlohmann::detail::FromJsonImpl(j); | |||
| return std::make_tuple(std::get<0>(item), std::move(obj)); | |||
| }); | |||
| return *pages_load_ptr; | |||
| }) | |||
| .def("get_header", &ShardSegment::GetShardHeader) | |||
| .def("get_blob_fields", | |||
| (std::pair<ShardType, std::vector<std::string>>(ShardSegment::*)()) & ShardSegment::GetBlobFields); | |||
| .def("get_blob_fields", [](ShardSegment &s) { return s.GetBlobFields(); }); | |||
| } | |||
| void BindGlobalParams(py::module *m) { | |||
| @@ -57,26 +57,24 @@ bool ValidateFieldName(const std::string &str) { | |||
| return true; | |||
| } | |||
| std::pair<MSRStatus, std::string> GetFileName(const std::string &path) { | |||
| Status GetFileName(const std::string &path, std::shared_ptr<std::string> *fn_ptr) { | |||
| RETURN_UNEXPECTED_IF_NULL(fn_ptr); | |||
| char real_path[PATH_MAX] = {0}; | |||
| char buf[PATH_MAX] = {0}; | |||
| if (strncpy_s(buf, PATH_MAX, common::SafeCStr(path), path.length()) != EOK) { | |||
| MS_LOG(ERROR) << "Securec func [strncpy_s] failed, path: " << path; | |||
| return {FAILED, ""}; | |||
| RETURN_STATUS_UNEXPECTED("Securec func [strncpy_s] failed, path: " + path); | |||
| } | |||
| char tmp[PATH_MAX] = {0}; | |||
| #if defined(_WIN32) || defined(_WIN64) | |||
| if (_fullpath(tmp, dirname(&(buf[0])), PATH_MAX) == nullptr) { | |||
| MS_LOG(ERROR) << "Invalid file path, path: " << buf; | |||
| return {FAILED, ""}; | |||
| RETURN_STATUS_UNEXPECTED("Invalid file path, path: " + std::string(buf)); | |||
| } | |||
| if (_fullpath(real_path, common::SafeCStr(path), PATH_MAX) == nullptr) { | |||
| MS_LOG(DEBUG) << "Path: " << common::SafeCStr(path) << "check successfully"; | |||
| } | |||
| #else | |||
| if (realpath(dirname(&(buf[0])), tmp) == nullptr) { | |||
| MS_LOG(ERROR) << "Invalid file path, path: " << buf; | |||
| return {FAILED, ""}; | |||
| RETURN_STATUS_UNEXPECTED(std::string("Invalid file path, path: ") + buf); | |||
| } | |||
| if (realpath(common::SafeCStr(path), real_path) == nullptr) { | |||
| MS_LOG(DEBUG) << "Path: " << path << "check successfully"; | |||
| @@ -87,32 +85,32 @@ std::pair<MSRStatus, std::string> GetFileName(const std::string &path) { | |||
| size_t i = s.rfind(sep, s.length()); | |||
| if (i != std::string::npos) { | |||
| if (i + 1 < s.size()) { | |||
| return {SUCCESS, s.substr(i + 1)}; | |||
| *fn_ptr = std::make_shared<std::string>(s.substr(i + 1)); | |||
| return Status::OK(); | |||
| } | |||
| } | |||
| return {SUCCESS, s}; | |||
| *fn_ptr = std::make_shared<std::string>(s); | |||
| return Status::OK(); | |||
| } | |||
| std::pair<MSRStatus, std::string> GetParentDir(const std::string &path) { | |||
| Status GetParentDir(const std::string &path, std::shared_ptr<std::string> *pd_ptr) { | |||
| RETURN_UNEXPECTED_IF_NULL(pd_ptr); | |||
| char real_path[PATH_MAX] = {0}; | |||
| char buf[PATH_MAX] = {0}; | |||
| if (strncpy_s(buf, PATH_MAX, common::SafeCStr(path), path.length()) != EOK) { | |||
| MS_LOG(ERROR) << "Securec func [strncpy_s] failed, path: " << path; | |||
| return {FAILED, ""}; | |||
| RETURN_STATUS_UNEXPECTED("Securec func [strncpy_s] failed, path: " + path); | |||
| } | |||
| char tmp[PATH_MAX] = {0}; | |||
| #if defined(_WIN32) || defined(_WIN64) | |||
| if (_fullpath(tmp, dirname(&(buf[0])), PATH_MAX) == nullptr) { | |||
| MS_LOG(ERROR) << "Invalid file path, path: " << buf; | |||
| return {FAILED, ""}; | |||
| RETURN_STATUS_UNEXPECTED("Invalid file path, path: " + std::string(buf)); | |||
| } | |||
| if (_fullpath(real_path, common::SafeCStr(path), PATH_MAX) == nullptr) { | |||
| MS_LOG(DEBUG) << "Path: " << common::SafeCStr(path) << "check successfully"; | |||
| } | |||
| #else | |||
| if (realpath(dirname(&(buf[0])), tmp) == nullptr) { | |||
| MS_LOG(ERROR) << "Invalid file path, path: " << buf; | |||
| return {FAILED, ""}; | |||
| RETURN_STATUS_UNEXPECTED(std::string("Invalid file path, path: ") + buf); | |||
| } | |||
| if (realpath(common::SafeCStr(path), real_path) == nullptr) { | |||
| MS_LOG(DEBUG) << "Path: " << path << "check successfully"; | |||
| @@ -120,9 +118,11 @@ std::pair<MSRStatus, std::string> GetParentDir(const std::string &path) { | |||
| #endif | |||
| std::string s = real_path; | |||
| if (s.rfind('/') + 1 <= s.size()) { | |||
| return {SUCCESS, s.substr(0, s.rfind('/') + 1)}; | |||
| *pd_ptr = std::make_shared<std::string>(s.substr(0, s.rfind('/') + 1)); | |||
| return Status::OK(); | |||
| } | |||
| return {SUCCESS, "/"}; | |||
| *pd_ptr = std::make_shared<std::string>("/"); | |||
| return Status::OK(); | |||
| } | |||
| bool CheckIsValidUtf8(const std::string &str) { | |||
| @@ -163,15 +163,16 @@ bool IsLegalFile(const std::string &path) { | |||
| return false; | |||
| } | |||
| std::pair<MSRStatus, uint64_t> GetDiskSize(const std::string &str_dir, const DiskSizeType &disk_type) { | |||
| Status GetDiskSize(const std::string &str_dir, const DiskSizeType &disk_type, std::shared_ptr<uint64_t> *size_ptr) { | |||
| RETURN_UNEXPECTED_IF_NULL(size_ptr); | |||
| #if defined(_WIN32) || defined(_WIN64) || defined(__APPLE__) | |||
| return {SUCCESS, 100}; | |||
| *size_ptr = std::make_shared<uint64_t>(100); | |||
| return Status::OK(); | |||
| #else | |||
| uint64_t ll_count = 0; | |||
| struct statfs disk_info; | |||
| if (statfs(common::SafeCStr(str_dir), &disk_info) == -1) { | |||
| MS_LOG(ERROR) << "Get disk size error"; | |||
| return {FAILED, 0}; | |||
| RETURN_STATUS_UNEXPECTED("Get disk size error."); | |||
| } | |||
| switch (disk_type) { | |||
| @@ -187,8 +188,8 @@ std::pair<MSRStatus, uint64_t> GetDiskSize(const std::string &str_dir, const Dis | |||
| ll_count = 0; | |||
| break; | |||
| } | |||
| return {SUCCESS, ll_count}; | |||
| *size_ptr = std::make_shared<uint64_t>(ll_count); | |||
| return Status::OK(); | |||
| #endif | |||
| } | |||
| @@ -201,17 +202,15 @@ uint32_t GetMaxThreadNum() { | |||
| return thread_num; | |||
| } | |||
| std::pair<MSRStatus, std::vector<std::string>> GetDatasetFiles(const std::string &path, const json &addresses) { | |||
| auto ret = GetParentDir(path); | |||
| if (SUCCESS != ret.first) { | |||
| return {FAILED, {}}; | |||
| } | |||
| std::vector<std::string> abs_addresses; | |||
| Status GetDatasetFiles(const std::string &path, const json &addresses, std::shared_ptr<std::vector<std::string>> *ds) { | |||
| RETURN_UNEXPECTED_IF_NULL(ds); | |||
| std::shared_ptr<std::string> parent_dir; | |||
| RETURN_IF_NOT_OK(GetParentDir(path, &parent_dir)); | |||
| for (const auto &p : addresses) { | |||
| std::string abs_path = ret.second + std::string(p); | |||
| abs_addresses.emplace_back(abs_path); | |||
| std::string abs_path = *parent_dir + std::string(p); | |||
| (*ds)->emplace_back(abs_path); | |||
| } | |||
| return {SUCCESS, abs_addresses}; | |||
| return Status::OK(); | |||
| } | |||
| } // namespace mindrecord | |||
| } // namespace mindspore | |||
| @@ -33,6 +33,7 @@ | |||
| #include <future> | |||
| #include <iostream> | |||
| #include <map> | |||
| #include <memory> | |||
| #include <random> | |||
| #include <set> | |||
| #include <sstream> | |||
| @@ -159,13 +160,15 @@ bool ValidateFieldName(const std::string &str); | |||
| /// \brief get the filename by the path | |||
| /// \param s file path | |||
| /// \return | |||
| std::pair<MSRStatus, std::string> GetFileName(const std::string &s); | |||
| /// \param fn_ptr shared ptr of file name | |||
| /// \return Status | |||
| Status GetFileName(const std::string &path, std::shared_ptr<std::string> *fn_ptr); | |||
| /// \brief get parent dir | |||
| /// \param path file path | |||
| /// \return parent path | |||
| std::pair<MSRStatus, std::string> GetParentDir(const std::string &path); | |||
| /// \param pd_ptr shared ptr of parent path | |||
| /// \return Status | |||
| Status GetParentDir(const std::string &path, std::shared_ptr<std::string> *pd_ptr); | |||
| bool CheckIsValidUtf8(const std::string &str); | |||
| @@ -179,8 +182,9 @@ enum DiskSizeType { kTotalSize = 0, kFreeSize }; | |||
| /// \brief get the free space about the disk | |||
| /// \param str_dir file path | |||
| /// \param disk_type: kTotalSize / kFreeSize | |||
| /// \return size in Megabytes | |||
| std::pair<MSRStatus, uint64_t> GetDiskSize(const std::string &str_dir, const DiskSizeType &disk_type); | |||
| /// \param size: shared ptr of size in Megabytes | |||
| /// \return Status | |||
| Status GetDiskSize(const std::string &str_dir, const DiskSizeType &disk_type, std::shared_ptr<uint64_t> *size); | |||
| /// \brief get the max hardware concurrency | |||
| /// \return max concurrency | |||
| @@ -189,8 +193,9 @@ uint32_t GetMaxThreadNum(); | |||
| /// \brief get absolute path of all mindrecord files | |||
| /// \param path path to one fo mindrecord files | |||
| /// \param addresses relative path of all mindrecord files | |||
| /// \return vector of absolute path | |||
| std::pair<MSRStatus, std::vector<std::string>> GetDatasetFiles(const std::string &path, const json &addresses); | |||
| /// \param ds shared ptr of vector of absolute path | |||
| /// \return Status | |||
| Status GetDatasetFiles(const std::string &path, const json &addresses, std::shared_ptr<std::vector<std::string>> *ds); | |||
| } // namespace mindrecord | |||
| } // namespace mindspore | |||
| @@ -46,7 +46,7 @@ class __attribute__((visibility("default"))) ShardCategory : public ShardOperato | |||
| bool GetReplacement() const { return replacement_; } | |||
| MSRStatus Execute(ShardTaskList &tasks) override; | |||
| Status Execute(ShardTaskList &tasks) override; | |||
| int64_t GetNumSamples(int64_t dataset_size, int64_t num_classes) override; | |||
| @@ -65,11 +65,11 @@ class __attribute__((visibility("default"))) ShardColumn { | |||
| ~ShardColumn() = default; | |||
| /// \brief get column value by column name | |||
| MSRStatus GetColumnValueByName(const std::string &column_name, const std::vector<uint8_t> &columns_blob, | |||
| const json &columns_json, const unsigned char **data, | |||
| std::unique_ptr<unsigned char[]> *data_ptr, uint64_t *const n_bytes, | |||
| ColumnDataType *column_data_type, uint64_t *column_data_type_size, | |||
| std::vector<int64_t> *column_shape); | |||
| Status GetColumnValueByName(const std::string &column_name, const std::vector<uint8_t> &columns_blob, | |||
| const json &columns_json, const unsigned char **data, | |||
| std::unique_ptr<unsigned char[]> *data_ptr, uint64_t *const n_bytes, | |||
| ColumnDataType *column_data_type, uint64_t *column_data_type_size, | |||
| std::vector<int64_t> *column_shape); | |||
| /// \brief compress blob | |||
| std::vector<uint8_t> CompressBlob(const std::vector<uint8_t> &blob, int64_t *compression_size); | |||
| @@ -90,19 +90,18 @@ class __attribute__((visibility("default"))) ShardColumn { | |||
| std::vector<std::vector<int64_t>> GetColumnShape() { return column_shape_; } | |||
| /// \brief get column value from blob | |||
| MSRStatus GetColumnFromBlob(const std::string &column_name, const std::vector<uint8_t> &columns_blob, | |||
| const unsigned char **data, std::unique_ptr<unsigned char[]> *data_ptr, | |||
| uint64_t *const n_bytes); | |||
| Status GetColumnFromBlob(const std::string &column_name, const std::vector<uint8_t> &columns_blob, | |||
| const unsigned char **data, std::unique_ptr<unsigned char[]> *data_ptr, | |||
| uint64_t *const n_bytes); | |||
| /// \brief get column type | |||
| std::pair<MSRStatus, ColumnCategory> GetColumnTypeByName(const std::string &column_name, | |||
| ColumnDataType *column_data_type, | |||
| uint64_t *column_data_type_size, | |||
| std::vector<int64_t> *column_shape); | |||
| Status GetColumnTypeByName(const std::string &column_name, ColumnDataType *column_data_type, | |||
| uint64_t *column_data_type_size, std::vector<int64_t> *column_shape, | |||
| ColumnCategory *column_category); | |||
| /// \brief get column value from json | |||
| MSRStatus GetColumnFromJson(const std::string &column_name, const json &columns_json, | |||
| std::unique_ptr<unsigned char[]> *data_ptr, uint64_t *n_bytes); | |||
| Status GetColumnFromJson(const std::string &column_name, const json &columns_json, | |||
| std::unique_ptr<unsigned char[]> *data_ptr, uint64_t *n_bytes); | |||
| private: | |||
| /// \brief initialization | |||
| @@ -110,15 +109,15 @@ class __attribute__((visibility("default"))) ShardColumn { | |||
| /// \brief get float value from json | |||
| template <typename T> | |||
| MSRStatus GetFloat(std::unique_ptr<unsigned char[]> *data_ptr, const json &json_column_value, bool use_double); | |||
| Status GetFloat(std::unique_ptr<unsigned char[]> *data_ptr, const json &json_column_value, bool use_double); | |||
| /// \brief get integer value from json | |||
| template <typename T> | |||
| MSRStatus GetInt(std::unique_ptr<unsigned char[]> *data_ptr, const json &json_column_value); | |||
| Status GetInt(std::unique_ptr<unsigned char[]> *data_ptr, const json &json_column_value); | |||
| /// \brief get column offset address and size from blob | |||
| MSRStatus GetColumnAddressInBlock(const uint64_t &column_id, const std::vector<uint8_t> &columns_blob, | |||
| uint64_t *num_bytes, uint64_t *shift_idx); | |||
| Status GetColumnAddressInBlock(const uint64_t &column_id, const std::vector<uint8_t> &columns_blob, | |||
| uint64_t *num_bytes, uint64_t *shift_idx); | |||
| /// \brief check if column name is available | |||
| ColumnCategory CheckColumnName(const std::string &column_name); | |||
| @@ -128,8 +127,8 @@ class __attribute__((visibility("default"))) ShardColumn { | |||
| /// \brief uncompress integer array column | |||
| template <typename T> | |||
| static MSRStatus UncompressInt(const uint64_t &column_id, std::unique_ptr<unsigned char[]> *const data_ptr, | |||
| const std::vector<uint8_t> &columns_blob, uint64_t *num_bytes, uint64_t shift_idx); | |||
| static Status UncompressInt(const uint64_t &column_id, std::unique_ptr<unsigned char[]> *const data_ptr, | |||
| const std::vector<uint8_t> &columns_blob, uint64_t *num_bytes, uint64_t shift_idx); | |||
| /// \brief convert big-endian bytes to unsigned int | |||
| /// \param bytes_array bytes array | |||
| @@ -39,7 +39,7 @@ class __attribute__((visibility("default"))) ShardDistributedSample : public Sha | |||
| ~ShardDistributedSample() override{}; | |||
| MSRStatus PreExecute(ShardTaskList &tasks) override; | |||
| Status PreExecute(ShardTaskList &tasks) override; | |||
| int64_t GetNumSamples(int64_t dataset_size, int64_t num_classes) override; | |||
| @@ -19,65 +19,55 @@ | |||
| #include <map> | |||
| #include <string> | |||
| #include "include/api/status.h" | |||
| namespace mindspore { | |||
| namespace mindrecord { | |||
| #define RETURN_IF_NOT_OK(_s) \ | |||
| do { \ | |||
| Status __rc = (_s); \ | |||
| if (__rc.IsError()) { \ | |||
| return __rc; \ | |||
| } \ | |||
| } while (false) | |||
| #define RELEASE_AND_RETURN_IF_NOT_OK(_s, _db, _in) \ | |||
| do { \ | |||
| Status __rc = (_s); \ | |||
| if (__rc.IsError()) { \ | |||
| if ((_db) != nullptr) { \ | |||
| sqlite3_close(_db); \ | |||
| } \ | |||
| (_in).close(); \ | |||
| return __rc; \ | |||
| } \ | |||
| } while (false) | |||
| #define CHECK_FAIL_RETURN_UNEXPECTED(_condition, _e) \ | |||
| do { \ | |||
| if (!(_condition)) { \ | |||
| return Status(StatusCode::kMDUnexpectedError, __LINE__, __FILE__, _e); \ | |||
| } \ | |||
| } while (false) | |||
| #define RETURN_UNEXPECTED_IF_NULL(_ptr) \ | |||
| do { \ | |||
| if ((_ptr) == nullptr) { \ | |||
| std::string err_msg = "The pointer[" + std::string(#_ptr) + "] is null."; \ | |||
| RETURN_STATUS_UNEXPECTED(err_msg); \ | |||
| } \ | |||
| } while (false) | |||
| #define RETURN_STATUS_UNEXPECTED(_e) \ | |||
| do { \ | |||
| return Status(StatusCode::kMDUnexpectedError, __LINE__, __FILE__, _e); \ | |||
| } while (false) | |||
| enum MSRStatus { | |||
| SUCCESS = 0, | |||
| FAILED = 1, | |||
| OPEN_FILE_FAILED, | |||
| CLOSE_FILE_FAILED, | |||
| WRITE_METADATA_FAILED, | |||
| WRITE_RAWDATA_FAILED, | |||
| GET_SCHEMA_FAILED, | |||
| ILLEGAL_RAWDATA, | |||
| PYTHON_TO_JSON_FAILED, | |||
| DIR_CREATE_FAILED, | |||
| OPEN_DIR_FAILED, | |||
| INVALID_STATISTICS, | |||
| OPEN_DATABASE_FAILED, | |||
| CLOSE_DATABASE_FAILED, | |||
| DATABASE_OPERATE_FAILED, | |||
| BUILD_SCHEMA_FAILED, | |||
| DIVISOR_IS_ILLEGAL, | |||
| INVALID_FILE_PATH, | |||
| SECURE_FUNC_FAILED, | |||
| ALLOCATE_MEM_FAILED, | |||
| ILLEGAL_FIELD_NAME, | |||
| ILLEGAL_FIELD_TYPE, | |||
| SET_METADATA_FAILED, | |||
| ILLEGAL_SCHEMA_DEFINITION, | |||
| ILLEGAL_COLUMN_LIST, | |||
| SQL_ERROR, | |||
| ILLEGAL_SHARD_COUNT, | |||
| ILLEGAL_SCHEMA_COUNT, | |||
| VERSION_ERROR, | |||
| ADD_SCHEMA_FAILED, | |||
| ILLEGAL_Header_SIZE, | |||
| ILLEGAL_Page_SIZE, | |||
| ILLEGAL_SIZE_VALUE, | |||
| INDEX_FIELD_ERROR, | |||
| GET_CANDIDATE_CATEGORYFIELDS_FAILED, | |||
| GET_CATEGORY_INFO_FAILED, | |||
| ILLEGAL_CATEGORY_ID, | |||
| ILLEGAL_ROWNUMBER_OF_PAGE, | |||
| ILLEGAL_SCHEMA_ID, | |||
| DESERIALIZE_SCHEMA_FAILED, | |||
| DESERIALIZE_STATISTICS_FAILED, | |||
| ILLEGAL_DB_FILE, | |||
| OVERWRITE_DB_FILE, | |||
| OVERWRITE_MINDRECORD_FILE, | |||
| ILLEGAL_MINDRECORD_FILE, | |||
| PARSE_JSON_FAILED, | |||
| ILLEGAL_PARAMETERS, | |||
| GET_PAGE_BY_GROUP_ID_FAILED, | |||
| GET_SYSTEM_STATE_FAILED, | |||
| IO_FAILED, | |||
| MATCH_HEADER_FAILED | |||
| }; | |||
| // convert error no to string message | |||
| std::string __attribute__((visibility("default"))) ErrnoToMessage(MSRStatus status); | |||
| } // namespace mindrecord | |||
| } // namespace mindspore | |||
| @@ -37,9 +37,9 @@ class __attribute__((visibility("default"))) ShardHeader { | |||
| ~ShardHeader() = default; | |||
| MSRStatus BuildDataset(const std::vector<std::string> &file_paths, bool load_dataset = true); | |||
| Status BuildDataset(const std::vector<std::string> &file_paths, bool load_dataset = true); | |||
| static std::pair<MSRStatus, json> BuildSingleHeader(const std::string &file_path); | |||
| static Status BuildSingleHeader(const std::string &file_path, std::shared_ptr<json> *header_ptr); | |||
| /// \brief add the schema and save it | |||
| /// \param[in] schema the schema needs to be added | |||
| /// \return the last schema's id | |||
| @@ -53,9 +53,9 @@ class __attribute__((visibility("default"))) ShardHeader { | |||
| /// \brief create index and add fields which from schema for each schema | |||
| /// \param[in] fields the index fields needs to be added | |||
| /// \return SUCCESS if add successfully, FAILED if not | |||
| MSRStatus AddIndexFields(std::vector<std::pair<uint64_t, std::string>> fields); | |||
| Status AddIndexFields(std::vector<std::pair<uint64_t, std::string>> fields); | |||
| MSRStatus AddIndexFields(const std::vector<std::string> &fields); | |||
| Status AddIndexFields(const std::vector<std::string> &fields); | |||
| /// \brief get the schema | |||
| /// \return the schema | |||
| @@ -79,9 +79,10 @@ class __attribute__((visibility("default"))) ShardHeader { | |||
| std::shared_ptr<Index> GetIndex(); | |||
| /// \brief get the schema by schemaid | |||
| /// \param[in] schemaId the id of schema needs to be got | |||
| /// \return the schema obtained by schemaId | |||
| std::pair<std::shared_ptr<Schema>, MSRStatus> GetSchemaByID(int64_t schema_id); | |||
| /// \param[in] schema_id the id of schema needs to be got | |||
| /// \param[in] schema_ptr the schema obtained by schemaId | |||
| /// \return Status | |||
| Status GetSchemaByID(int64_t schema_id, std::shared_ptr<Schema> *schema_ptr); | |||
| /// \brief get the filepath to shard by shardID | |||
| /// \param[in] shardID the id of shard which filepath needs to be obtained | |||
| @@ -89,25 +90,26 @@ class __attribute__((visibility("default"))) ShardHeader { | |||
| std::string GetShardAddressByID(int64_t shard_id); | |||
| /// \brief get the statistic by statistic id | |||
| /// \param[in] statisticId the id of statistic needs to be get | |||
| /// \return the statistics obtained by statistic id | |||
| std::pair<std::shared_ptr<Statistics>, MSRStatus> GetStatisticByID(int64_t statistic_id); | |||
| /// \param[in] statistic_id the id of statistic needs to be get | |||
| /// \param[in] statistics_ptr the statistics obtained by statistic id | |||
| /// \return Status | |||
| Status GetStatisticByID(int64_t statistic_id, std::shared_ptr<Statistics> *statistics_ptr); | |||
| MSRStatus InitByFiles(const std::vector<std::string> &file_paths); | |||
| Status InitByFiles(const std::vector<std::string> &file_paths); | |||
| void SetIndex(Index index) { index_ = std::make_shared<Index>(index); } | |||
| std::pair<std::shared_ptr<Page>, MSRStatus> GetPage(const int &shard_id, const int &page_id); | |||
| Status GetPage(const int &shard_id, const int &page_id, std::shared_ptr<Page> *page_ptr); | |||
| MSRStatus SetPage(const std::shared_ptr<Page> &new_page); | |||
| Status SetPage(const std::shared_ptr<Page> &new_page); | |||
| MSRStatus AddPage(const std::shared_ptr<Page> &new_page); | |||
| Status AddPage(const std::shared_ptr<Page> &new_page); | |||
| int64_t GetLastPageId(const int &shard_id); | |||
| int GetLastPageIdByType(const int &shard_id, const std::string &page_type); | |||
| const std::pair<MSRStatus, std::shared_ptr<Page>> GetPageByGroupId(const int &group_id, const int &shard_id); | |||
| Status GetPageByGroupId(const int &group_id, const int &shard_id, std::shared_ptr<Page> *page_ptr); | |||
| std::vector<std::string> GetShardAddresses() const { return shard_addresses_; } | |||
| @@ -129,43 +131,41 @@ class __attribute__((visibility("default"))) ShardHeader { | |||
| std::vector<std::string> SerializeHeader(); | |||
| MSRStatus PagesToFile(const std::string dump_file_name); | |||
| Status PagesToFile(const std::string dump_file_name); | |||
| MSRStatus FileToPages(const std::string dump_file_name); | |||
| Status FileToPages(const std::string dump_file_name); | |||
| static MSRStatus Initialize(const std::shared_ptr<ShardHeader> *header_ptr, const json &schema, | |||
| const std::vector<std::string> &index_fields, std::vector<std::string> &blob_fields, | |||
| uint64_t &schema_id); | |||
| static Status Initialize(const std::shared_ptr<ShardHeader> *header_ptr, const json &schema, | |||
| const std::vector<std::string> &index_fields, std::vector<std::string> &blob_fields, | |||
| uint64_t &schema_id); | |||
| private: | |||
| MSRStatus InitializeHeader(const std::vector<json> &headers, bool load_dataset); | |||
| Status InitializeHeader(const std::vector<json> &headers, bool load_dataset); | |||
| /// \brief get the headers from all the shard data | |||
| /// \param[in] the shard data real path | |||
| /// \param[in] the headers which read from the shard data | |||
| /// \return SUCCESS/FAILED | |||
| MSRStatus GetHeaders(const vector<string> &real_addresses, std::vector<json> &headers); | |||
| Status GetHeaders(const vector<string> &real_addresses, std::vector<json> &headers); | |||
| MSRStatus ValidateField(const std::vector<std::string> &field_name, json schema, const uint64_t &schema_id); | |||
| Status ValidateField(const std::vector<std::string> &field_name, json schema, const uint64_t &schema_id); | |||
| /// \brief check the binary file status | |||
| static MSRStatus CheckFileStatus(const std::string &path); | |||
| static Status CheckFileStatus(const std::string &path); | |||
| static std::pair<MSRStatus, json> ValidateHeader(const std::string &path); | |||
| void ParseHeader(const json &header); | |||
| static Status ValidateHeader(const std::string &path, std::shared_ptr<json> *header_ptr); | |||
| void GetHeadersOneTask(int start, int end, std::vector<json> &headers, const vector<string> &realAddresses); | |||
| MSRStatus ParseIndexFields(const json &index_fields); | |||
| Status ParseIndexFields(const json &index_fields); | |||
| MSRStatus CheckIndexField(const std::string &field, const json &schema); | |||
| Status CheckIndexField(const std::string &field, const json &schema); | |||
| MSRStatus ParsePage(const json &page, int shard_index, bool load_dataset); | |||
| Status ParsePage(const json &page, int shard_index, bool load_dataset); | |||
| MSRStatus ParseStatistics(const json &statistics); | |||
| Status ParseStatistics(const json &statistics); | |||
| MSRStatus ParseSchema(const json &schema); | |||
| Status ParseSchema(const json &schema); | |||
| void ParseShardAddress(const json &address); | |||
| @@ -181,7 +181,7 @@ class __attribute__((visibility("default"))) ShardHeader { | |||
| std::shared_ptr<Index> InitIndexPtr(); | |||
| MSRStatus GetAllSchemaID(std::set<uint64_t> &bucket_count); | |||
| Status GetAllSchemaID(std::set<uint64_t> &bucket_count); | |||
| uint32_t shard_count_; | |||
| uint64_t header_size_; | |||
| @@ -30,23 +30,24 @@ | |||
| namespace mindspore { | |||
| namespace mindrecord { | |||
| using INDEX_FIELDS = std::pair<MSRStatus, std::vector<std::tuple<std::string, std::string, std::string>>>; | |||
| using ROW_DATA = std::pair<MSRStatus, std::vector<std::vector<std::tuple<std::string, std::string, std::string>>>>; | |||
| using INDEX_FIELDS = std::vector<std::tuple<std::string, std::string, std::string>>; | |||
| using ROW_DATA = std::vector<std::vector<std::tuple<std::string, std::string, std::string>>>; | |||
| class __attribute__((visibility("default"))) ShardIndexGenerator { | |||
| public: | |||
| explicit ShardIndexGenerator(const std::string &file_path, bool append = false); | |||
| MSRStatus Build(); | |||
| Status Build(); | |||
| static std::pair<MSRStatus, std::string> GenerateFieldName(const std::pair<uint64_t, std::string> &field); | |||
| static Status GenerateFieldName(const std::pair<uint64_t, std::string> &field, std::shared_ptr<std::string> *fn_ptr); | |||
| ~ShardIndexGenerator() {} | |||
| /// \brief fetch value in json by field name | |||
| /// \param[in] field | |||
| /// \param[in] input | |||
| /// \return pair<MSRStatus, value> | |||
| std::pair<MSRStatus, std::string> GetValueByField(const string &field, json input); | |||
| /// \param[in] value | |||
| /// \return Status | |||
| Status GetValueByField(const string &field, json input, std::shared_ptr<std::string> *value); | |||
| /// \brief fetch field type in schema n by field path | |||
| /// \param[in] field_path | |||
| @@ -55,55 +56,54 @@ class __attribute__((visibility("default"))) ShardIndexGenerator { | |||
| static std::string TakeFieldType(const std::string &field_path, json schema); | |||
| /// \brief create databases for indexes | |||
| MSRStatus WriteToDatabase(); | |||
| Status WriteToDatabase(); | |||
| static MSRStatus Finalize(const std::vector<std::string> file_names); | |||
| static Status Finalize(const std::vector<std::string> file_names); | |||
| private: | |||
| static int Callback(void *not_used, int argc, char **argv, char **az_col_name); | |||
| static MSRStatus ExecuteSQL(const std::string &statement, sqlite3 *db, const string &success_msg = ""); | |||
| static Status ExecuteSQL(const std::string &statement, sqlite3 *db, const string &success_msg = ""); | |||
| static std::string ConvertJsonToSQL(const std::string &json); | |||
| std::pair<MSRStatus, sqlite3 *> CreateDatabase(int shard_no); | |||
| Status CreateDatabase(int shard_no, sqlite3 **db); | |||
| std::pair<MSRStatus, std::vector<json>> GetSchemaDetails(const std::vector<uint64_t> &schema_lens, std::fstream &in); | |||
| Status GetSchemaDetails(const std::vector<uint64_t> &schema_lens, std::fstream &in, | |||
| std::shared_ptr<std::vector<json>> *detail_ptr); | |||
| static std::pair<MSRStatus, std::string> GenerateRawSQL(const std::vector<std::pair<uint64_t, std::string>> &fields); | |||
| static Status GenerateRawSQL(const std::vector<std::pair<uint64_t, std::string>> &fields, | |||
| std::shared_ptr<std::string> *sql_ptr); | |||
| std::pair<MSRStatus, sqlite3 *> CheckDatabase(const std::string &shard_address); | |||
| Status CheckDatabase(const std::string &shard_address, sqlite3 **db); | |||
| /// | |||
| /// \param shard_no | |||
| /// \param blob_id_to_page_id | |||
| /// \param raw_page_id | |||
| /// \param in | |||
| /// \return field name, db type, field value | |||
| ROW_DATA GenerateRowData(int shard_no, const std::map<int, int> &blob_id_to_page_id, int raw_page_id, | |||
| std::fstream &in); | |||
| /// \return Status | |||
| Status GenerateRowData(int shard_no, const std::map<int, int> &blob_id_to_page_id, int raw_page_id, std::fstream &in, | |||
| std::shared_ptr<ROW_DATA> *row_data_ptr); | |||
| /// | |||
| /// \param db | |||
| /// \param sql | |||
| /// \param data | |||
| /// \return | |||
| MSRStatus BindParameterExecuteSQL( | |||
| sqlite3 *db, const std::string &sql, | |||
| const std::vector<std::vector<std::tuple<std::string, std::string, std::string>>> &data); | |||
| Status BindParameterExecuteSQL(sqlite3 *db, const std::string &sql, const ROW_DATA &data); | |||
| INDEX_FIELDS GenerateIndexFields(const std::vector<json> &schema_detail); | |||
| Status GenerateIndexFields(const std::vector<json> &schema_detail, std::shared_ptr<INDEX_FIELDS> *index_fields_ptr); | |||
| MSRStatus ExecuteTransaction(const int &shard_no, std::pair<MSRStatus, sqlite3 *> &db, | |||
| const std::vector<int> &raw_page_ids, const std::map<int, int> &blob_id_to_page_id); | |||
| Status ExecuteTransaction(const int &shard_no, sqlite3 *db, const std::vector<int> &raw_page_ids, | |||
| const std::map<int, int> &blob_id_to_page_id); | |||
| MSRStatus CreateShardNameTable(sqlite3 *db, const std::string &shard_name); | |||
| Status CreateShardNameTable(sqlite3 *db, const std::string &shard_name); | |||
| MSRStatus AddBlobPageInfo(std::vector<std::tuple<std::string, std::string, std::string>> &row_data, | |||
| const std::shared_ptr<Page> cur_blob_page, uint64_t &cur_blob_page_offset, | |||
| std::fstream &in); | |||
| Status AddBlobPageInfo(std::vector<std::tuple<std::string, std::string, std::string>> &row_data, | |||
| const std::shared_ptr<Page> cur_blob_page, uint64_t &cur_blob_page_offset, std::fstream &in); | |||
| void AddIndexFieldByRawData(const std::vector<json> &schema_detail, | |||
| std::vector<std::tuple<std::string, std::string, std::string>> &row_data); | |||
| Status AddIndexFieldByRawData(const std::vector<json> &schema_detail, | |||
| std::vector<std::tuple<std::string, std::string, std::string>> &row_data); | |||
| void DatabaseWriter(); // worker thread | |||
| @@ -13,7 +13,6 @@ | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_MINDDATA_MINDRECORD_INCLUDE_SHARD_OPERATOR_H_ | |||
| #define MINDSPORE_CCSRC_MINDDATA_MINDRECORD_INCLUDE_SHARD_OPERATOR_H_ | |||
| @@ -28,33 +27,29 @@ class __attribute__((visibility("default"))) ShardOperator { | |||
| public: | |||
| virtual ~ShardOperator() = default; | |||
| MSRStatus operator()(ShardTaskList &tasks) { | |||
| if (SUCCESS != this->PreExecute(tasks)) { | |||
| return FAILED; | |||
| } | |||
| if (SUCCESS != this->Execute(tasks)) { | |||
| return FAILED; | |||
| } | |||
| if (SUCCESS != this->SufExecute(tasks)) { | |||
| return FAILED; | |||
| } | |||
| return SUCCESS; | |||
| Status operator()(ShardTaskList &tasks) { | |||
| RETURN_IF_NOT_OK(this->PreExecute(tasks)); | |||
| RETURN_IF_NOT_OK(this->Execute(tasks)); | |||
| RETURN_IF_NOT_OK(this->SufExecute(tasks)); | |||
| return Status::OK(); | |||
| } | |||
| virtual bool HasChildOp() { return child_op_ != nullptr; } | |||
| virtual MSRStatus SetChildOp(std::shared_ptr<ShardOperator> child_op) { | |||
| if (child_op != nullptr) child_op_ = child_op; | |||
| return SUCCESS; | |||
| virtual Status SetChildOp(std::shared_ptr<ShardOperator> child_op) { | |||
| if (child_op != nullptr) { | |||
| child_op_ = child_op; | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| virtual std::shared_ptr<ShardOperator> GetChildOp() { return child_op_; } | |||
| virtual MSRStatus PreExecute(ShardTaskList &tasks) { return SUCCESS; } | |||
| virtual Status PreExecute(ShardTaskList &tasks) { return Status::OK(); } | |||
| virtual MSRStatus Execute(ShardTaskList &tasks) = 0; | |||
| virtual Status Execute(ShardTaskList &tasks) = 0; | |||
| virtual MSRStatus SufExecute(ShardTaskList &tasks) { return SUCCESS; } | |||
| virtual Status SufExecute(ShardTaskList &tasks) { return Status::OK(); } | |||
| virtual int64_t GetNumSamples(int64_t dataset_size, int64_t num_classes) { return 0; } | |||
| @@ -72,9 +67,9 @@ class __attribute__((visibility("default"))) ShardOperator { | |||
| std::shared_ptr<ShardOperator> child_op_ = nullptr; | |||
| // indicate shard_id : inc_count | |||
| // 0 : 15 - shard0 has 15 samples | |||
| // 1 : 41 - shard1 has 26 samples | |||
| // 2 : 58 - shard2 has 17 samples | |||
| // // 0 : 15 - shard0 has 15 samples | |||
| // // 1 : 41 - shard1 has 26 samples | |||
| // // 2 : 58 - shard2 has 17 samples | |||
| std::vector<uint32_t> shard_sample_count_; | |||
| dataset::ShuffleMode shuffle_mode_ = dataset::ShuffleMode::kGlobal; | |||
| @@ -38,7 +38,7 @@ class __attribute__((visibility("default"))) ShardPkSample : public ShardCategor | |||
| ~ShardPkSample() override{}; | |||
| MSRStatus SufExecute(ShardTaskList &tasks) override; | |||
| Status SufExecute(ShardTaskList &tasks) override; | |||
| int64_t GetNumSamples() const { return num_samples_; } | |||
| @@ -59,12 +59,9 @@ | |||
| namespace mindspore { | |||
| namespace mindrecord { | |||
| using ROW_GROUPS = | |||
| std::tuple<MSRStatus, std::vector<std::vector<std::vector<uint64_t>>>, std::vector<std::vector<json>>>; | |||
| using ROW_GROUP_BRIEF = | |||
| std::tuple<MSRStatus, std::string, int, uint64_t, std::vector<std::vector<uint64_t>>, std::vector<json>>; | |||
| using TASK_RETURN_CONTENT = | |||
| std::pair<MSRStatus, std::pair<TaskType, std::vector<std::tuple<std::vector<uint8_t>, json>>>>; | |||
| using ROW_GROUPS = std::pair<std::vector<std::vector<std::vector<uint64_t>>>, std::vector<std::vector<json>>>; | |||
| using ROW_GROUP_BRIEF = std::tuple<std::string, int, uint64_t, std::vector<std::vector<uint64_t>>, std::vector<json>>; | |||
| using TASK_CONTENT = std::pair<TaskType, std::vector<std::tuple<std::vector<uint8_t>, json>>>; | |||
| const int kNumBatchInMap = 1000; // iterator buffer size in row-reader mode | |||
| class API_PUBLIC ShardReader { | |||
| @@ -82,21 +79,10 @@ class API_PUBLIC ShardReader { | |||
| /// \param[in] num_padded the number of padded samples | |||
| /// \param[in] lazy_load if the mindrecord dataset is too large, enable lazy load mode to speed up initialization | |||
| /// \return MSRStatus the status of MSRStatus | |||
| MSRStatus Open(const std::vector<std::string> &file_paths, bool load_dataset, int n_consumer = 4, | |||
| const std::vector<std::string> &selected_columns = {}, | |||
| const std::vector<std::shared_ptr<ShardOperator>> &operators = {}, const int num_padded = 0, | |||
| bool lazy_load = false); | |||
| /// \brief open files and initialize reader, python API | |||
| /// \param[in] file_paths the path of ONE file, any file in dataset is fine or file list | |||
| /// \param[in] load_dataset load dataset from single file or not | |||
| /// \param[in] n_consumer number of threads when reading | |||
| /// \param[in] selected_columns column list to be populated | |||
| /// \param[in] operators operators applied to data, operator type is shuffle, sample or category | |||
| /// \return MSRStatus the status of MSRStatus | |||
| MSRStatus OpenPy(const std::vector<std::string> &file_paths, bool load_dataset, const int &n_consumer = 4, | |||
| const std::vector<std::string> &selected_columns = {}, | |||
| const std::vector<std::shared_ptr<ShardOperator>> &operators = {}); | |||
| Status Open(const std::vector<std::string> &file_paths, bool load_dataset, int n_consumer = 4, | |||
| const std::vector<std::string> &selected_columns = {}, | |||
| const std::vector<std::shared_ptr<ShardOperator>> &operators = {}, const int num_padded = 0, | |||
| bool lazy_load = false); | |||
| /// \brief close reader | |||
| /// \return null | |||
| @@ -104,16 +90,16 @@ class API_PUBLIC ShardReader { | |||
| /// \brief read the file, get schema meta,statistics and index, single-thread mode | |||
| /// \return MSRStatus the status of MSRStatus | |||
| MSRStatus Open(); | |||
| Status Open(); | |||
| /// \brief read the file, get schema meta,statistics and index, multiple-thread mode | |||
| /// \return MSRStatus the status of MSRStatus | |||
| MSRStatus Open(int n_consumer); | |||
| Status Open(int n_consumer); | |||
| /// \brief launch threads to get batches | |||
| /// \param[in] is_simple_reader trigger threads if false; do nothing if true | |||
| /// \return MSRStatus the status of MSRStatus | |||
| MSRStatus Launch(bool is_simple_reader = false); | |||
| Status Launch(bool is_simple_reader = false); | |||
| /// \brief aim to get the meta data | |||
| /// \return the metadata | |||
| @@ -133,8 +119,8 @@ class API_PUBLIC ShardReader { | |||
| /// \param[in] op smart pointer refer to ShardCategory or ShardSample object | |||
| /// \param[out] count # of rows | |||
| /// \return MSRStatus the status of MSRStatus | |||
| MSRStatus CountTotalRows(const std::vector<std::string> &file_paths, bool load_dataset, | |||
| const std::shared_ptr<ShardOperator> &op, int64_t *count, const int num_padded); | |||
| Status CountTotalRows(const std::vector<std::string> &file_paths, bool load_dataset, | |||
| const std::shared_ptr<ShardOperator> &op, int64_t *count, const int num_padded); | |||
| /// \brief shuffle task with incremental seed | |||
| /// \return void | |||
| @@ -162,8 +148,8 @@ class API_PUBLIC ShardReader { | |||
| /// 3. Offset address of row group in file | |||
| /// 4. The list of image offset in page [startOffset, endOffset) | |||
| /// 5. The list of columns data | |||
| ROW_GROUP_BRIEF ReadRowGroupBrief(int group_id, int shard_id, | |||
| const std::vector<std::string> &columns = std::vector<std::string>()); | |||
| Status ReadRowGroupBrief(int group_id, int shard_id, const std::vector<std::string> &columns, | |||
| std::shared_ptr<ROW_GROUP_BRIEF> *row_group_brief_ptr); | |||
| /// \brief Read 1 row group data, excluding images, following an index field criteria | |||
| /// \param[in] groupID row group ID | |||
| @@ -176,8 +162,9 @@ class API_PUBLIC ShardReader { | |||
| /// 3. Offset address of row group in file | |||
| /// 4. The list of image offset in page [startOffset, endOffset) | |||
| /// 5. The list of columns data | |||
| ROW_GROUP_BRIEF ReadRowGroupCriteria(int group_id, int shard_id, const std::pair<std::string, std::string> &criteria, | |||
| const std::vector<std::string> &columns = std::vector<std::string>()); | |||
| Status ReadRowGroupCriteria(int group_id, int shard_id, const std::pair<std::string, std::string> &criteria, | |||
| const std::vector<std::string> &columns, | |||
| std::shared_ptr<ROW_GROUP_BRIEF> *row_group_brief_ptr); | |||
| /// \brief return a batch, given that one is ready | |||
| /// \return a batch of images and image data | |||
| @@ -185,13 +172,7 @@ class API_PUBLIC ShardReader { | |||
| /// \brief return a row by id | |||
| /// \return a batch of images and image data | |||
| std::pair<TaskType, std::vector<std::tuple<std::vector<uint8_t>, json>>> GetNextById(const int64_t &task_id, | |||
| const int32_t &consumer_id); | |||
| /// \brief return a batch, given that one is ready, python API | |||
| /// \return a batch of images and image data | |||
| std::vector<std::tuple<std::vector<std::vector<uint8_t>>, pybind11::object>> GetNextPy(); | |||
| TASK_CONTENT GetNextById(const int64_t &task_id, const int32_t &consumer_id); | |||
| /// \brief get blob filed list | |||
| /// \return blob field list | |||
| std::pair<ShardType, std::vector<std::string>> GetBlobFields(); | |||
| @@ -205,83 +186,86 @@ class API_PUBLIC ShardReader { | |||
| void SetAllInIndex(bool all_in_index) { all_in_index_ = all_in_index; } | |||
| /// \brief get all classes | |||
| MSRStatus GetAllClasses(const std::string &category_field, std::shared_ptr<std::set<std::string>> category_ptr); | |||
| /// \brief get the size of blob data | |||
| MSRStatus GetTotalBlobSize(int64_t *total_blob_size); | |||
| Status GetAllClasses(const std::string &category_field, std::shared_ptr<std::set<std::string>> category_ptr); | |||
| /// \brief get a read-only ptr to the sampled ids for this epoch | |||
| const std::vector<int> *GetSampleIds(); | |||
| /// \brief get the size of blob data | |||
| Status GetTotalBlobSize(int64_t *total_blob_size); | |||
| /// \brief extract uncompressed data based on column list | |||
| Status UnCompressBlob(const std::vector<uint8_t> &raw_blob_data, | |||
| std::shared_ptr<std::vector<std::vector<uint8_t>>> *blob_data_ptr); | |||
| protected: | |||
| /// \brief sqlite call back function | |||
| static int SelectCallback(void *p_data, int num_fields, char **p_fields, char **p_col_names); | |||
| private: | |||
| /// \brief wrap up labels to json format | |||
| MSRStatus ConvertLabelToJson(const std::vector<std::vector<std::string>> &labels, std::shared_ptr<std::fstream> fs, | |||
| std::shared_ptr<std::vector<std::vector<std::vector<uint64_t>>>> offset_ptr, | |||
| int shard_id, const std::vector<std::string> &columns, | |||
| std::shared_ptr<std::vector<std::vector<json>>> col_val_ptr); | |||
| Status ConvertLabelToJson(const std::vector<std::vector<std::string>> &labels, std::shared_ptr<std::fstream> fs, | |||
| std::shared_ptr<std::vector<std::vector<std::vector<uint64_t>>>> offset_ptr, int shard_id, | |||
| const std::vector<std::string> &columns, | |||
| std::shared_ptr<std::vector<std::vector<json>>> col_val_ptr); | |||
| /// \brief read all rows for specified columns | |||
| ROW_GROUPS ReadAllRowGroup(const std::vector<std::string> &columns); | |||
| Status ReadAllRowGroup(const std::vector<std::string> &columns, std::shared_ptr<ROW_GROUPS> *row_group_ptr); | |||
| /// \brief read row meta by shard_id and sample_id | |||
| ROW_GROUPS ReadRowGroupByShardIDAndSampleID(const std::vector<std::string> &columns, const uint32_t &shard_id, | |||
| const uint32_t &sample_id); | |||
| Status ReadRowGroupByShardIDAndSampleID(const std::vector<std::string> &columns, const uint32_t &shard_id, | |||
| const uint32_t &sample_id, std::shared_ptr<ROW_GROUPS> *row_group_ptr); | |||
| /// \brief read all rows in one shard | |||
| MSRStatus ReadAllRowsInShard(int shard_id, const std::string &sql, const std::vector<std::string> &columns, | |||
| std::shared_ptr<std::vector<std::vector<std::vector<uint64_t>>>> offset_ptr, | |||
| std::shared_ptr<std::vector<std::vector<json>>> col_val_ptr); | |||
| Status ReadAllRowsInShard(int shard_id, const std::string &sql, const std::vector<std::string> &columns, | |||
| std::shared_ptr<std::vector<std::vector<std::vector<uint64_t>>>> offset_ptr, | |||
| std::shared_ptr<std::vector<std::vector<json>>> col_val_ptr); | |||
| /// \brief initialize reader | |||
| MSRStatus Init(const std::vector<std::string> &file_paths, bool load_dataset); | |||
| Status Init(const std::vector<std::string> &file_paths, bool load_dataset); | |||
| /// \brief validate column list | |||
| MSRStatus CheckColumnList(const std::vector<std::string> &selected_columns); | |||
| Status CheckColumnList(const std::vector<std::string> &selected_columns); | |||
| /// \brief populate one row by task list in row-reader mode | |||
| MSRStatus ConsumerByRow(int consumer_id); | |||
| void ConsumerByRow(int consumer_id); | |||
| /// \brief get offset address of images within page | |||
| std::vector<std::vector<uint64_t>> GetImageOffset(int group_id, int shard_id, | |||
| const std::pair<std::string, std::string> &criteria = {"", ""}); | |||
| /// \brief get page id by category | |||
| std::pair<MSRStatus, std::vector<uint64_t>> GetPagesByCategory(int shard_id, | |||
| const std::pair<std::string, std::string> &criteria); | |||
| Status GetPagesByCategory(int shard_id, const std::pair<std::string, std::string> &criteria, | |||
| std::shared_ptr<std::vector<uint64_t>> *pages_ptr); | |||
| /// \brief execute sqlite query with prepare statement | |||
| MSRStatus QueryWithCriteria(sqlite3 *db, const string &sql, const string &criteria, | |||
| std::shared_ptr<std::vector<std::vector<std::string>>> labels_ptr); | |||
| Status QueryWithCriteria(sqlite3 *db, const string &sql, const string &criteria, | |||
| std::shared_ptr<std::vector<std::vector<std::string>>> labels_ptr); | |||
| /// \brief verify the validity of dataset | |||
| MSRStatus VerifyDataset(sqlite3 **db, const string &file); | |||
| Status VerifyDataset(sqlite3 **db, const string &file); | |||
| /// \brief get column values | |||
| std::pair<MSRStatus, std::vector<json>> GetLabels(int group_id, int shard_id, const std::vector<std::string> &columns, | |||
| const std::pair<std::string, std::string> &criteria = {"", ""}); | |||
| Status GetLabels(int page_id, int shard_id, const std::vector<std::string> &columns, | |||
| const std::pair<std::string, std::string> &criteria, std::shared_ptr<std::vector<json>> *labels_ptr); | |||
| /// \brief get column values from raw data page | |||
| std::pair<MSRStatus, std::vector<json>> GetLabelsFromPage(int group_id, int shard_id, | |||
| const std::vector<std::string> &columns, | |||
| const std::pair<std::string, std::string> &criteria = {"", | |||
| ""}); | |||
| Status GetLabelsFromPage(int page_id, int shard_id, const std::vector<std::string> &columns, | |||
| const std::pair<std::string, std::string> &criteria, | |||
| std::shared_ptr<std::vector<json>> *labels_ptr); | |||
| /// \brief create category-applied task list | |||
| MSRStatus CreateTasksByCategory(const std::shared_ptr<ShardOperator> &op); | |||
| Status CreateTasksByCategory(const std::shared_ptr<ShardOperator> &op); | |||
| /// \brief create task list in row-reader mode | |||
| MSRStatus CreateTasksByRow(const std::vector<std::tuple<int, int, int, uint64_t>> &row_group_summary, | |||
| const std::vector<std::shared_ptr<ShardOperator>> &operators); | |||
| Status CreateTasksByRow(const std::vector<std::tuple<int, int, int, uint64_t>> &row_group_summary, | |||
| const std::vector<std::shared_ptr<ShardOperator>> &operators); | |||
| /// \brief create task list in row-reader mode and lazy mode | |||
| MSRStatus CreateLazyTasksByRow(const std::vector<std::tuple<int, int, int, uint64_t>> &row_group_summary, | |||
| const std::vector<std::shared_ptr<ShardOperator>> &operators); | |||
| Status CreateLazyTasksByRow(const std::vector<std::tuple<int, int, int, uint64_t>> &row_group_summary, | |||
| const std::vector<std::shared_ptr<ShardOperator>> &operators); | |||
| /// \brief crate task list | |||
| MSRStatus CreateTasks(const std::vector<std::tuple<int, int, int, uint64_t>> &row_group_summary, | |||
| const std::vector<std::shared_ptr<ShardOperator>> &operators); | |||
| Status CreateTasks(const std::vector<std::tuple<int, int, int, uint64_t>> &row_group_summary, | |||
| const std::vector<std::shared_ptr<ShardOperator>> &operators); | |||
| /// \brief check if all specified columns are in index table | |||
| void CheckIfColumnInIndex(const std::vector<std::string> &columns); | |||
| @@ -290,11 +274,12 @@ class API_PUBLIC ShardReader { | |||
| void FileStreamsOperator(); | |||
| /// \brief read one row by one task | |||
| TASK_RETURN_CONTENT ConsumerOneTask(int task_id, uint32_t consumer_id); | |||
| Status ConsumerOneTask(int task_id, uint32_t consumer_id, std::shared_ptr<TASK_CONTENT> *task_content_pt); | |||
| /// \brief get labels from binary file | |||
| std::pair<MSRStatus, std::vector<json>> GetLabelsFromBinaryFile( | |||
| int shard_id, const std::vector<std::string> &columns, const std::vector<std::vector<std::string>> &label_offsets); | |||
| Status GetLabelsFromBinaryFile(int shard_id, const std::vector<std::string> &columns, | |||
| const std::vector<std::vector<std::string>> &label_offsets, | |||
| std::shared_ptr<std::vector<json>> *labels_ptr); | |||
| /// \brief get classes in one shard | |||
| void GetClassesInShard(sqlite3 *db, int shard_id, const std::string &sql, | |||
| @@ -304,11 +289,8 @@ class API_PUBLIC ShardReader { | |||
| int64_t GetNumClasses(const std::string &category_field); | |||
| /// \brief get meta of header | |||
| std::pair<MSRStatus, std::vector<std::string>> GetMeta(const std::string &file_path, | |||
| std::shared_ptr<json> meta_data_ptr); | |||
| /// \brief extract uncompressed data based on column list | |||
| std::pair<MSRStatus, std::vector<std::vector<uint8_t>>> UnCompressBlob(const std::vector<uint8_t> &raw_blob_data); | |||
| Status GetMeta(const std::string &file_path, std::shared_ptr<json> meta_data_ptr, | |||
| std::shared_ptr<std::vector<std::string>> *addresses_ptr); | |||
| protected: | |||
| uint64_t header_size_; // header size | |||
| @@ -40,11 +40,11 @@ class __attribute__((visibility("default"))) ShardSample : public ShardOperator | |||
| ~ShardSample() override{}; | |||
| MSRStatus Execute(ShardTaskList &tasks) override; | |||
| Status Execute(ShardTaskList &tasks) override; | |||
| MSRStatus UpdateTasks(ShardTaskList &tasks, int taking); | |||
| Status UpdateTasks(ShardTaskList &tasks, int taking); | |||
| MSRStatus SufExecute(ShardTaskList &tasks) override; | |||
| Status SufExecute(ShardTaskList &tasks) override; | |||
| int64_t GetNumSamples(int64_t dataset_size, int64_t num_classes) override; | |||
| @@ -39,11 +39,6 @@ class __attribute__((visibility("default"))) Schema { | |||
| /// \param[in] schema the schema's json | |||
| static std::shared_ptr<Schema> Build(std::string desc, const json &schema); | |||
| /// \brief obtain the json schema and its description for python | |||
| /// \param[in] desc the description of the schema | |||
| /// \param[in] schema the schema's json | |||
| static std::shared_ptr<Schema> Build(std::string desc, pybind11::handle schema); | |||
| /// \brief compare two schema to judge if they are equal | |||
| /// \param b another schema to be judged | |||
| /// \return true if they are equal,false if not | |||
| @@ -57,10 +52,6 @@ class __attribute__((visibility("default"))) Schema { | |||
| /// \return the json format of the schema and its description | |||
| json GetSchema() const; | |||
| /// \brief get the schema and its description for python method | |||
| /// \return the python object of the schema and its description | |||
| pybind11::object GetSchemaForPython() const; | |||
| /// set the schema id | |||
| /// \param[in] id the id need to be set | |||
| void SetSchemaID(int64_t id); | |||
| @@ -17,6 +17,7 @@ | |||
| #ifndef MINDSPORE_CCSRC_MINDDATA_MINDRECORD_INCLUDE_SHARD_SEGMENT_H_ | |||
| #define MINDSPORE_CCSRC_MINDDATA_MINDRECORD_INCLUDE_SHARD_SEGMENT_H_ | |||
| #include <memory> | |||
| #include <string> | |||
| #include <tuple> | |||
| #include <utility> | |||
| @@ -25,6 +26,10 @@ | |||
| namespace mindspore { | |||
| namespace mindrecord { | |||
| using CATEGORY_INFO = std::vector<std::tuple<int, std::string, int>>; | |||
| using PAGES = std::vector<std::tuple<std::vector<uint8_t>, json>>; | |||
| using PAGES_LOAD = std::vector<std::tuple<std::vector<uint8_t>, pybind11::object>>; | |||
| class __attribute__((visibility("default"))) ShardSegment : public ShardReader { | |||
| public: | |||
| ShardSegment(); | |||
| @@ -33,12 +38,12 @@ class __attribute__((visibility("default"))) ShardSegment : public ShardReader { | |||
| /// \brief Get candidate category fields | |||
| /// \return a list of fields names which are the candidates of category | |||
| std::pair<MSRStatus, vector<std::string>> GetCategoryFields(); | |||
| Status GetCategoryFields(std::shared_ptr<vector<std::string>> *fields_ptr); | |||
| /// \brief Set category field | |||
| /// \param[in] category_field category name | |||
| /// \return true if category name is existed | |||
| MSRStatus SetCategoryField(std::string category_field); | |||
| Status SetCategoryField(std::string category_field); | |||
| /// \brief Thread-safe implementation of ReadCategoryInfo | |||
| /// \return statistics data in json format with 2 field: "key" and "categories". | |||
| @@ -50,47 +55,41 @@ class __attribute__((visibility("default"))) ShardSegment : public ShardReader { | |||
| /// { "key": "label", | |||
| /// "categories": [ { "count": 3, "id": 0, "name": "sport", }, | |||
| /// { "count": 3, "id": 1, "name": "finance", } ] } | |||
| std::pair<MSRStatus, std::string> ReadCategoryInfo(); | |||
| Status ReadCategoryInfo(std::shared_ptr<std::string> *category_ptr); | |||
| /// \brief Thread-safe implementation of ReadAtPageById | |||
| /// \param[in] category_id category ID | |||
| /// \param[in] page_no page number | |||
| /// \param[in] n_rows_of_page rows number in one page | |||
| /// \return images array, image is a vector of uint8_t | |||
| std::pair<MSRStatus, std::vector<std::vector<uint8_t>>> ReadAtPageById(int64_t category_id, int64_t page_no, | |||
| int64_t n_rows_of_page); | |||
| Status ReadAtPageById(int64_t category_id, int64_t page_no, int64_t n_rows_of_page, | |||
| std::shared_ptr<std::vector<std::vector<uint8_t>>> *page_ptr); | |||
| /// \brief Thread-safe implementation of ReadAtPageByName | |||
| /// \param[in] category_name category Name | |||
| /// \param[in] page_no page number | |||
| /// \param[in] n_rows_of_page rows number in one page | |||
| /// \return images array, image is a vector of uint8_t | |||
| std::pair<MSRStatus, std::vector<std::vector<uint8_t>>> ReadAtPageByName(std::string category_name, int64_t page_no, | |||
| int64_t n_rows_of_page); | |||
| std::pair<MSRStatus, std::vector<std::tuple<std::vector<uint8_t>, json>>> ReadAllAtPageById(int64_t category_id, | |||
| int64_t page_no, | |||
| int64_t n_rows_of_page); | |||
| std::pair<MSRStatus, std::vector<std::tuple<std::vector<uint8_t>, json>>> ReadAllAtPageByName( | |||
| std::string category_name, int64_t page_no, int64_t n_rows_of_page); | |||
| Status ReadAtPageByName(std::string category_name, int64_t page_no, int64_t n_rows_of_page, | |||
| std::shared_ptr<std::vector<std::vector<uint8_t>>> *pages_ptr); | |||
| std::pair<MSRStatus, std::vector<std::tuple<std::vector<uint8_t>, pybind11::object>>> ReadAtPageByIdPy( | |||
| int64_t category_id, int64_t page_no, int64_t n_rows_of_page); | |||
| Status ReadAllAtPageById(int64_t category_id, int64_t page_no, int64_t n_rows_of_page, | |||
| std::shared_ptr<PAGES> *pages_ptr); | |||
| std::pair<MSRStatus, std::vector<std::tuple<std::vector<uint8_t>, pybind11::object>>> ReadAtPageByNamePy( | |||
| std::string category_name, int64_t page_no, int64_t n_rows_of_page); | |||
| Status ReadAllAtPageByName(std::string category_name, int64_t page_no, int64_t n_rows_of_page, | |||
| std::shared_ptr<PAGES> *pages_ptr); | |||
| std::pair<ShardType, std::vector<std::string>> GetBlobFields(); | |||
| private: | |||
| std::pair<MSRStatus, std::vector<std::tuple<int, std::string, int>>> WrapCategoryInfo(); | |||
| Status WrapCategoryInfo(std::shared_ptr<CATEGORY_INFO> *category_info_ptr); | |||
| std::string ToJsonForCategory(const std::vector<std::tuple<int, std::string, int>> &tri_vec); | |||
| std::string CleanUp(std::string fieldName); | |||
| std::pair<MSRStatus, std::vector<uint8_t>> PackImages(int group_id, int shard_id, std::vector<uint64_t> offset); | |||
| Status PackImages(int group_id, int shard_id, std::vector<uint64_t> offset, | |||
| std::shared_ptr<std::vector<uint8_t>> *images_ptr); | |||
| std::vector<std::string> candidate_category_fields_; | |||
| std::string current_category_field_; | |||
| @@ -33,7 +33,7 @@ class __attribute__((visibility("default"))) ShardSequentialSample : public Shar | |||
| ~ShardSequentialSample() override{}; | |||
| MSRStatus Execute(ShardTaskList &tasks) override; | |||
| Status Execute(ShardTaskList &tasks) override; | |||
| int64_t GetNumSamples(int64_t dataset_size, int64_t num_classes) override; | |||
| @@ -31,19 +31,19 @@ class __attribute__((visibility("default"))) ShardShuffle : public ShardOperator | |||
| ~ShardShuffle() override{}; | |||
| MSRStatus Execute(ShardTaskList &tasks) override; | |||
| Status Execute(ShardTaskList &tasks) override; | |||
| int64_t GetNumSamples(int64_t dataset_size, int64_t num_classes) override; | |||
| private: | |||
| // Private helper function | |||
| MSRStatus CategoryShuffle(ShardTaskList &tasks); | |||
| Status CategoryShuffle(ShardTaskList &tasks); | |||
| // Keep the file sequence the same but shuffle the data within each file | |||
| MSRStatus ShuffleInfile(ShardTaskList &tasks); | |||
| Status ShuffleInfile(ShardTaskList &tasks); | |||
| // Shuffle the file sequence but keep the order of data within each file | |||
| MSRStatus ShuffleFiles(ShardTaskList &tasks); | |||
| Status ShuffleFiles(ShardTaskList &tasks); | |||
| uint32_t shuffle_seed_; | |||
| int64_t no_of_samples_; | |||
| @@ -39,11 +39,6 @@ class __attribute__((visibility("default"))) Statistics { | |||
| /// \param[in] statistics the statistic needs to be saved | |||
| static std::shared_ptr<Statistics> Build(std::string desc, const json &statistics); | |||
| /// \brief save the statistic from python and its description | |||
| /// \param[in] desc the statistic's description | |||
| /// \param[in] statistics the statistic needs to be saved | |||
| static std::shared_ptr<Statistics> Build(std::string desc, pybind11::handle statistics); | |||
| ~Statistics() = default; | |||
| /// \brief compare two statistics to judge if they are equal | |||
| @@ -59,10 +54,6 @@ class __attribute__((visibility("default"))) Statistics { | |||
| /// \return json format of the statistic | |||
| json GetStatistics() const; | |||
| /// \brief get the statistic for python | |||
| /// \return the python object of statistics | |||
| pybind11::object GetStatisticsForPython() const; | |||
| /// \brief decode the bson statistics to json | |||
| /// \param[in] encodedStatistics the bson type of statistics | |||
| /// \return json type of statistic | |||
| @@ -55,69 +55,60 @@ class __attribute__((visibility("default"))) ShardWriter { | |||
| /// \brief Open file at the beginning | |||
| /// \param[in] paths the file names list | |||
| /// \param[in] append new data at the end of file if true, otherwise overwrite file | |||
| /// \return MSRStatus the status of MSRStatus | |||
| MSRStatus Open(const std::vector<std::string> &paths, bool append = false); | |||
| /// \return Status | |||
| Status Open(const std::vector<std::string> &paths, bool append = false); | |||
| /// \brief Open file at the ending | |||
| /// \param[in] paths the file names list | |||
| /// \return MSRStatus the status of MSRStatus | |||
| MSRStatus OpenForAppend(const std::string &path); | |||
| Status OpenForAppend(const std::string &path); | |||
| /// \brief Write header to disk | |||
| /// \return MSRStatus the status of MSRStatus | |||
| MSRStatus Commit(); | |||
| Status Commit(); | |||
| /// \brief Set file size | |||
| /// \param[in] header_size the size of header, only (1<<N) is accepted | |||
| /// \return MSRStatus the status of MSRStatus | |||
| MSRStatus SetHeaderSize(const uint64_t &header_size); | |||
| Status SetHeaderSize(const uint64_t &header_size); | |||
| /// \brief Set page size | |||
| /// \param[in] page_size the size of page, only (1<<N) is accepted | |||
| /// \return MSRStatus the status of MSRStatus | |||
| MSRStatus SetPageSize(const uint64_t &page_size); | |||
| Status SetPageSize(const uint64_t &page_size); | |||
| /// \brief Set shard header | |||
| /// \param[in] header_data the info of header | |||
| /// WARNING, only called when file is empty | |||
| /// \return MSRStatus the status of MSRStatus | |||
| MSRStatus SetShardHeader(std::shared_ptr<ShardHeader> header_data); | |||
| Status SetShardHeader(std::shared_ptr<ShardHeader> header_data); | |||
| /// \brief write raw data by group size | |||
| /// \param[in] raw_data the vector of raw json data, vector format | |||
| /// \param[in] blob_data the vector of image data | |||
| /// \param[in] sign validate data or not | |||
| /// \return MSRStatus the status of MSRStatus to judge if write successfully | |||
| MSRStatus WriteRawData(std::map<uint64_t, std::vector<json>> &raw_data, vector<vector<uint8_t>> &blob_data, | |||
| bool sign = true, bool parallel_writer = false); | |||
| /// \brief write raw data by group size for call from python | |||
| /// \param[in] raw_data the vector of raw json data, python-handle format | |||
| /// \param[in] blob_data the vector of image data | |||
| /// \param[in] sign validate data or not | |||
| /// \return MSRStatus the status of MSRStatus to judge if write successfully | |||
| MSRStatus WriteRawData(std::map<uint64_t, std::vector<py::handle>> &raw_data, vector<vector<uint8_t>> &blob_data, | |||
| bool sign = true, bool parallel_writer = false); | |||
| Status WriteRawData(std::map<uint64_t, std::vector<json>> &raw_data, vector<vector<uint8_t>> &blob_data, | |||
| bool sign = true, bool parallel_writer = false); | |||
| /// \brief write raw data by group size for call from python | |||
| /// \param[in] raw_data the vector of raw json data, python-handle format | |||
| /// \param[in] blob_data the vector of blob json data, python-handle format | |||
| /// \param[in] sign validate data or not | |||
| /// \return MSRStatus the status of MSRStatus to judge if write successfully | |||
| MSRStatus WriteRawData(std::map<uint64_t, std::vector<py::handle>> &raw_data, | |||
| std::map<uint64_t, std::vector<py::handle>> &blob_data, bool sign = true, | |||
| bool parallel_writer = false); | |||
| Status WriteRawData(std::map<uint64_t, std::vector<py::handle>> &raw_data, | |||
| std::map<uint64_t, std::vector<py::handle>> &blob_data, bool sign = true, | |||
| bool parallel_writer = false); | |||
| MSRStatus MergeBlobData(const std::vector<string> &blob_fields, | |||
| const std::map<std::string, std::unique_ptr<std::vector<uint8_t>>> &row_bin_data, | |||
| std::shared_ptr<std::vector<uint8_t>> *output); | |||
| Status MergeBlobData(const std::vector<string> &blob_fields, | |||
| const std::map<std::string, std::unique_ptr<std::vector<uint8_t>>> &row_bin_data, | |||
| std::shared_ptr<std::vector<uint8_t>> *output); | |||
| static MSRStatus Initialize(const std::unique_ptr<ShardWriter> *writer_ptr, | |||
| const std::vector<std::string> &file_names); | |||
| static Status Initialize(const std::unique_ptr<ShardWriter> *writer_ptr, const std::vector<std::string> &file_names); | |||
| private: | |||
| /// \brief write shard header data to disk | |||
| MSRStatus WriteShardHeader(); | |||
| Status WriteShardHeader(); | |||
| /// \brief erase error data | |||
| void DeleteErrorData(std::map<uint64_t, std::vector<json>> &raw_data, std::vector<std::vector<uint8_t>> &blob_data); | |||
| @@ -130,108 +121,107 @@ class __attribute__((visibility("default"))) ShardWriter { | |||
| std::map<int, std::string> &err_raw_data); | |||
| /// \brief write shard header data to disk | |||
| std::tuple<MSRStatus, int, int> ValidateRawData(std::map<uint64_t, std::vector<json>> &raw_data, | |||
| std::vector<std::vector<uint8_t>> &blob_data, bool sign); | |||
| Status ValidateRawData(std::map<uint64_t, std::vector<json>> &raw_data, std::vector<std::vector<uint8_t>> &blob_data, | |||
| bool sign, std::shared_ptr<std::pair<int, int>> *count_ptr); | |||
| /// \brief fill data array in multiple thread run | |||
| void FillArray(int start, int end, std::map<uint64_t, vector<json>> &raw_data, | |||
| std::vector<std::vector<uint8_t>> &bin_data); | |||
| /// \brief serialized raw data | |||
| MSRStatus SerializeRawData(std::map<uint64_t, std::vector<json>> &raw_data, | |||
| std::vector<std::vector<uint8_t>> &bin_data, uint32_t row_count); | |||
| Status SerializeRawData(std::map<uint64_t, std::vector<json>> &raw_data, std::vector<std::vector<uint8_t>> &bin_data, | |||
| uint32_t row_count); | |||
| /// \brief write all data parallel | |||
| MSRStatus ParallelWriteData(const std::vector<std::vector<uint8_t>> &blob_data, | |||
| const std::vector<std::vector<uint8_t>> &bin_raw_data); | |||
| Status ParallelWriteData(const std::vector<std::vector<uint8_t>> &blob_data, | |||
| const std::vector<std::vector<uint8_t>> &bin_raw_data); | |||
| /// \brief write data shard by shard | |||
| MSRStatus WriteByShard(int shard_id, int start_row, int end_row, const std::vector<std::vector<uint8_t>> &blob_data, | |||
| const std::vector<std::vector<uint8_t>> &bin_raw_data); | |||
| Status WriteByShard(int shard_id, int start_row, int end_row, const std::vector<std::vector<uint8_t>> &blob_data, | |||
| const std::vector<std::vector<uint8_t>> &bin_raw_data); | |||
| /// \brief break image data up into multiple row groups | |||
| MSRStatus CutRowGroup(int start_row, int end_row, const std::vector<std::vector<uint8_t>> &blob_data, | |||
| std::vector<std::pair<int, int>> &rows_in_group, const std::shared_ptr<Page> &last_raw_page, | |||
| const std::shared_ptr<Page> &last_blob_page); | |||
| Status CutRowGroup(int start_row, int end_row, const std::vector<std::vector<uint8_t>> &blob_data, | |||
| std::vector<std::pair<int, int>> &rows_in_group, const std::shared_ptr<Page> &last_raw_page, | |||
| const std::shared_ptr<Page> &last_blob_page); | |||
| /// \brief append partial blob data to previous page | |||
| MSRStatus AppendBlobPage(const int &shard_id, const std::vector<std::vector<uint8_t>> &blob_data, | |||
| const std::vector<std::pair<int, int>> &rows_in_group, | |||
| const std::shared_ptr<Page> &last_blob_page); | |||
| /// \brief write new blob data page to disk | |||
| MSRStatus NewBlobPage(const int &shard_id, const std::vector<std::vector<uint8_t>> &blob_data, | |||
| Status AppendBlobPage(const int &shard_id, const std::vector<std::vector<uint8_t>> &blob_data, | |||
| const std::vector<std::pair<int, int>> &rows_in_group, | |||
| const std::shared_ptr<Page> &last_blob_page); | |||
| /// \brief write new blob data page to disk | |||
| Status NewBlobPage(const int &shard_id, const std::vector<std::vector<uint8_t>> &blob_data, | |||
| const std::vector<std::pair<int, int>> &rows_in_group, | |||
| const std::shared_ptr<Page> &last_blob_page); | |||
| /// \brief shift last row group to next raw page for new appending | |||
| MSRStatus ShiftRawPage(const int &shard_id, const std::vector<std::pair<int, int>> &rows_in_group, | |||
| std::shared_ptr<Page> &last_raw_page); | |||
| Status ShiftRawPage(const int &shard_id, const std::vector<std::pair<int, int>> &rows_in_group, | |||
| std::shared_ptr<Page> &last_raw_page); | |||
| /// \brief write raw data page to disk | |||
| MSRStatus WriteRawPage(const int &shard_id, const std::vector<std::pair<int, int>> &rows_in_group, | |||
| std::shared_ptr<Page> &last_raw_page, const std::vector<std::vector<uint8_t>> &bin_raw_data); | |||
| Status WriteRawPage(const int &shard_id, const std::vector<std::pair<int, int>> &rows_in_group, | |||
| std::shared_ptr<Page> &last_raw_page, const std::vector<std::vector<uint8_t>> &bin_raw_data); | |||
| /// \brief generate empty raw data page | |||
| void EmptyRawPage(const int &shard_id, std::shared_ptr<Page> &last_raw_page); | |||
| Status EmptyRawPage(const int &shard_id, std::shared_ptr<Page> &last_raw_page); | |||
| /// \brief append a row group at the end of raw page | |||
| MSRStatus AppendRawPage(const int &shard_id, const std::vector<std::pair<int, int>> &rows_in_group, | |||
| const int &chunk_id, int &last_row_groupId, std::shared_ptr<Page> last_raw_page, | |||
| const std::vector<std::vector<uint8_t>> &bin_raw_data); | |||
| Status AppendRawPage(const int &shard_id, const std::vector<std::pair<int, int>> &rows_in_group, const int &chunk_id, | |||
| int &last_row_groupId, std::shared_ptr<Page> last_raw_page, | |||
| const std::vector<std::vector<uint8_t>> &bin_raw_data); | |||
| /// \brief write blob chunk to disk | |||
| MSRStatus FlushBlobChunk(const std::shared_ptr<std::fstream> &out, const std::vector<std::vector<uint8_t>> &blob_data, | |||
| const std::pair<int, int> &blob_row); | |||
| Status FlushBlobChunk(const std::shared_ptr<std::fstream> &out, const std::vector<std::vector<uint8_t>> &blob_data, | |||
| const std::pair<int, int> &blob_row); | |||
| /// \brief write raw chunk to disk | |||
| MSRStatus FlushRawChunk(const std::shared_ptr<std::fstream> &out, | |||
| const std::vector<std::pair<int, int>> &rows_in_group, const int &chunk_id, | |||
| const std::vector<std::vector<uint8_t>> &bin_raw_data); | |||
| Status FlushRawChunk(const std::shared_ptr<std::fstream> &out, const std::vector<std::pair<int, int>> &rows_in_group, | |||
| const int &chunk_id, const std::vector<std::vector<uint8_t>> &bin_raw_data); | |||
| /// \brief break up into tasks by shard | |||
| std::vector<std::pair<int, int>> BreakIntoShards(); | |||
| /// \brief calculate raw data size row by row | |||
| MSRStatus SetRawDataSize(const std::vector<std::vector<uint8_t>> &bin_raw_data); | |||
| Status SetRawDataSize(const std::vector<std::vector<uint8_t>> &bin_raw_data); | |||
| /// \brief calculate blob data size row by row | |||
| MSRStatus SetBlobDataSize(const std::vector<std::vector<uint8_t>> &blob_data); | |||
| Status SetBlobDataSize(const std::vector<std::vector<uint8_t>> &blob_data); | |||
| /// \brief populate last raw page pointer | |||
| void SetLastRawPage(const int &shard_id, std::shared_ptr<Page> &last_raw_page); | |||
| Status SetLastRawPage(const int &shard_id, std::shared_ptr<Page> &last_raw_page); | |||
| /// \brief populate last blob page pointer | |||
| void SetLastBlobPage(const int &shard_id, std::shared_ptr<Page> &last_blob_page); | |||
| Status SetLastBlobPage(const int &shard_id, std::shared_ptr<Page> &last_blob_page); | |||
| /// \brief check the data by schema | |||
| MSRStatus CheckData(const std::map<uint64_t, std::vector<json>> &raw_data); | |||
| Status CheckData(const std::map<uint64_t, std::vector<json>> &raw_data); | |||
| /// \brief check the data and type | |||
| MSRStatus CheckDataTypeAndValue(const std::string &key, const json &value, const json &data, const int &i, | |||
| std::map<int, std::string> &err_raw_data); | |||
| Status CheckDataTypeAndValue(const std::string &key, const json &value, const json &data, const int &i, | |||
| std::map<int, std::string> &err_raw_data); | |||
| /// \brief Lock writer and save pages info | |||
| int LockWriter(bool parallel_writer = false); | |||
| Status LockWriter(bool parallel_writer, std::unique_ptr<int> *fd_ptr); | |||
| /// \brief Unlock writer and save pages info | |||
| MSRStatus UnlockWriter(int fd, bool parallel_writer = false); | |||
| Status UnlockWriter(int fd, bool parallel_writer = false); | |||
| /// \brief Check raw data before writing | |||
| MSRStatus WriteRawDataPreCheck(std::map<uint64_t, std::vector<json>> &raw_data, vector<vector<uint8_t>> &blob_data, | |||
| bool sign, int *schema_count, int *row_count); | |||
| Status WriteRawDataPreCheck(std::map<uint64_t, std::vector<json>> &raw_data, vector<vector<uint8_t>> &blob_data, | |||
| bool sign, int *schema_count, int *row_count); | |||
| /// \brief Get full path from file name | |||
| MSRStatus GetFullPathFromFileName(const std::vector<std::string> &paths); | |||
| Status GetFullPathFromFileName(const std::vector<std::string> &paths); | |||
| /// \brief Open files | |||
| MSRStatus OpenDataFiles(bool append); | |||
| Status OpenDataFiles(bool append); | |||
| /// \brief Remove lock file | |||
| MSRStatus RemoveLockFile(); | |||
| Status RemoveLockFile(); | |||
| /// \brief Remove lock file | |||
| MSRStatus InitLockFile(); | |||
| Status InitLockFile(); | |||
| private: | |||
| const std::string kLockFileSuffix = "_Locker"; | |||
| @@ -37,70 +37,48 @@ ShardIndexGenerator::ShardIndexGenerator(const std::string &file_path, bool appe | |||
| task_(0), | |||
| write_success_(true) {} | |||
| MSRStatus ShardIndexGenerator::Build() { | |||
| auto ret = ShardHeader::BuildSingleHeader(file_path_); | |||
| if (ret.first != SUCCESS) { | |||
| return FAILED; | |||
| } | |||
| auto json_header = ret.second; | |||
| auto ret2 = GetDatasetFiles(file_path_, json_header["shard_addresses"]); | |||
| if (SUCCESS != ret2.first) { | |||
| return FAILED; | |||
| } | |||
| Status ShardIndexGenerator::Build() { | |||
| std::shared_ptr<json> header_ptr; | |||
| RETURN_IF_NOT_OK(ShardHeader::BuildSingleHeader(file_path_, &header_ptr)); | |||
| auto ds = std::make_shared<std::vector<std::string>>(); | |||
| RETURN_IF_NOT_OK(GetDatasetFiles(file_path_, (*header_ptr)["shard_addresses"], &ds)); | |||
| ShardHeader header = ShardHeader(); | |||
| auto addresses = ret2.second; | |||
| if (header.BuildDataset(addresses) == FAILED) { | |||
| return FAILED; | |||
| } | |||
| RETURN_IF_NOT_OK(header.BuildDataset(*ds)); | |||
| shard_header_ = header; | |||
| MS_LOG(INFO) << "Init header from mindrecord file for index successfully."; | |||
| return SUCCESS; | |||
| return Status::OK(); | |||
| } | |||
| std::pair<MSRStatus, std::string> ShardIndexGenerator::GetValueByField(const string &field, json input) { | |||
| if (field.empty()) { | |||
| MS_LOG(ERROR) << "The input field is None."; | |||
| return {FAILED, ""}; | |||
| } | |||
| if (input.empty()) { | |||
| MS_LOG(ERROR) << "The input json is None."; | |||
| return {FAILED, ""}; | |||
| } | |||
| Status ShardIndexGenerator::GetValueByField(const string &field, json input, std::shared_ptr<std::string> *value) { | |||
| RETURN_UNEXPECTED_IF_NULL(value); | |||
| CHECK_FAIL_RETURN_UNEXPECTED(!field.empty(), "The input field is empty."); | |||
| CHECK_FAIL_RETURN_UNEXPECTED(!input.empty(), "The input json is empty."); | |||
| // parameter input does not contain the field | |||
| if (input.find(field) == input.end()) { | |||
| MS_LOG(ERROR) << "The field " << field << " is not found in parameter " << input; | |||
| return {FAILED, ""}; | |||
| } | |||
| CHECK_FAIL_RETURN_UNEXPECTED(input.find(field) != input.end(), | |||
| "The field " + field + " is not found in json " + input.dump()); | |||
| // schema does not contain the field | |||
| auto schema = shard_header_.GetSchemas()[0]->GetSchema()["schema"]; | |||
| if (schema.find(field) == schema.end()) { | |||
| MS_LOG(ERROR) << "The field " << field << " is not found in schema " << schema; | |||
| return {FAILED, ""}; | |||
| } | |||
| CHECK_FAIL_RETURN_UNEXPECTED(schema.find(field) != schema.end(), | |||
| "The field " + field + " is not found in schema " + schema.dump()); | |||
| // field should be scalar type | |||
| if (kScalarFieldTypeSet.find(schema[field]["type"]) == kScalarFieldTypeSet.end()) { | |||
| MS_LOG(ERROR) << "The field " << field << " type is " << schema[field]["type"] << ", it is not retrievable"; | |||
| return {FAILED, ""}; | |||
| } | |||
| CHECK_FAIL_RETURN_UNEXPECTED( | |||
| kScalarFieldTypeSet.find(schema[field]["type"]) != kScalarFieldTypeSet.end(), | |||
| "The field " + field + " type is " + schema[field]["type"].dump() + " which is not retrievable."); | |||
| if (kNumberFieldTypeSet.find(schema[field]["type"]) != kNumberFieldTypeSet.end()) { | |||
| auto schema_field_options = schema[field]; | |||
| if (schema_field_options.find("shape") == schema_field_options.end()) { | |||
| return {SUCCESS, input[field].dump()}; | |||
| } else { | |||
| // field with shape option | |||
| MS_LOG(ERROR) << "The field " << field << " shape is " << schema[field]["shape"] << " which is not retrievable"; | |||
| return {FAILED, ""}; | |||
| } | |||
| CHECK_FAIL_RETURN_UNEXPECTED( | |||
| schema_field_options.find("shape") == schema_field_options.end(), | |||
| "The field " + field + " shape is " + schema[field]["shape"].dump() + " which is not retrievable."); | |||
| *value = std::make_shared<std::string>(input[field].dump()); | |||
| } else { | |||
| // the field type is string in here | |||
| *value = std::make_shared<std::string>(input[field].get<std::string>()); | |||
| } | |||
| // the field type is string in here | |||
| return {SUCCESS, input[field].get<std::string>()}; | |||
| return Status::OK(); | |||
| } | |||
| std::string ShardIndexGenerator::TakeFieldType(const string &field_path, json schema) { | |||
| @@ -150,24 +128,28 @@ int ShardIndexGenerator::Callback(void *not_used, int argc, char **argv, char ** | |||
| return 0; | |||
| } | |||
| MSRStatus ShardIndexGenerator::ExecuteSQL(const std::string &sql, sqlite3 *db, const std::string &success_msg) { | |||
| Status ShardIndexGenerator::ExecuteSQL(const std::string &sql, sqlite3 *db, const std::string &success_msg) { | |||
| char *z_err_msg = nullptr; | |||
| int rc = sqlite3_exec(db, common::SafeCStr(sql), Callback, nullptr, &z_err_msg); | |||
| if (rc != SQLITE_OK) { | |||
| MS_LOG(ERROR) << "Sql error: " << z_err_msg; | |||
| std::ostringstream oss; | |||
| oss << "Failed to exec sqlite3_exec, msg is: " << z_err_msg; | |||
| MS_LOG(DEBUG) << oss.str(); | |||
| sqlite3_free(z_err_msg); | |||
| return FAILED; | |||
| sqlite3_close(db); | |||
| RETURN_STATUS_UNEXPECTED(oss.str()); | |||
| } else { | |||
| if (!success_msg.empty()) { | |||
| MS_LOG(DEBUG) << "Sqlite3_exec exec success, msg is: " << success_msg; | |||
| MS_LOG(DEBUG) << "Suceess to exec sqlite3_exec, msg is: " << success_msg; | |||
| } | |||
| sqlite3_free(z_err_msg); | |||
| return SUCCESS; | |||
| return Status::OK(); | |||
| } | |||
| } | |||
| std::pair<MSRStatus, std::string> ShardIndexGenerator::GenerateFieldName( | |||
| const std::pair<uint64_t, std::string> &field) { | |||
| Status ShardIndexGenerator::GenerateFieldName(const std::pair<uint64_t, std::string> &field, | |||
| std::shared_ptr<std::string> *fn_ptr) { | |||
| RETURN_UNEXPECTED_IF_NULL(fn_ptr); | |||
| // Replaces dots and dashes with underscores for SQL use | |||
| std::string field_name = field.second; | |||
| // white list to avoid sql injection | |||
| @@ -176,95 +158,71 @@ std::pair<MSRStatus, std::string> ShardIndexGenerator::GenerateFieldName( | |||
| auto pos = std::find_if_not(field_name.begin(), field_name.end(), [](char x) { | |||
| return (x >= 'A' && x <= 'Z') || (x >= 'a' && x <= 'z') || x == '_' || (x >= '0' && x <= '9'); | |||
| }); | |||
| if (pos != field_name.end()) { | |||
| MS_LOG(ERROR) << "Field name must be composed of '0-9' or 'a-z' or 'A-Z' or '_', field_name: " << field_name; | |||
| return {FAILED, ""}; | |||
| } | |||
| return {SUCCESS, field_name + "_" + std::to_string(field.first)}; | |||
| CHECK_FAIL_RETURN_UNEXPECTED( | |||
| pos == field_name.end(), | |||
| "Field name must be composed of '0-9' or 'a-z' or 'A-Z' or '_', field_name: " + field_name); | |||
| *fn_ptr = std::make_shared<std::string>(field_name + "_" + std::to_string(field.first)); | |||
| return Status::OK(); | |||
| } | |||
| std::pair<MSRStatus, sqlite3 *> ShardIndexGenerator::CheckDatabase(const std::string &shard_address) { | |||
| Status ShardIndexGenerator::CheckDatabase(const std::string &shard_address, sqlite3 **db) { | |||
| auto realpath = Common::GetRealPath(shard_address); | |||
| if (!realpath.has_value()) { | |||
| MS_LOG(ERROR) << "Get real path failed, path=" << shard_address; | |||
| return {FAILED, nullptr}; | |||
| } | |||
| sqlite3 *db = nullptr; | |||
| CHECK_FAIL_RETURN_UNEXPECTED(realpath.has_value(), "Get real path failed, path=" + shard_address); | |||
| std::ifstream fin(realpath.value()); | |||
| if (!append_ && fin.good()) { | |||
| MS_LOG(ERROR) << "Invalid file, DB file already exist: " << shard_address; | |||
| fin.close(); | |||
| return {FAILED, nullptr}; | |||
| RETURN_STATUS_UNEXPECTED("Invalid file, DB file already exist: " + shard_address); | |||
| } | |||
| fin.close(); | |||
| int rc = sqlite3_open_v2(common::SafeCStr(shard_address), &db, SQLITE_OPEN_READWRITE | SQLITE_OPEN_CREATE, nullptr); | |||
| if (rc) { | |||
| MS_LOG(ERROR) << "Invalid file, failed to open database: " << shard_address << ", error" << sqlite3_errmsg(db); | |||
| return {FAILED, nullptr}; | |||
| } else { | |||
| MS_LOG(DEBUG) << "Opened database successfully"; | |||
| return {SUCCESS, db}; | |||
| if (sqlite3_open_v2(common::SafeCStr(shard_address), db, SQLITE_OPEN_READWRITE | SQLITE_OPEN_CREATE, nullptr)) { | |||
| RETURN_STATUS_UNEXPECTED("Invalid file, failed to open database: " + shard_address + ", error" + | |||
| std::string(sqlite3_errmsg(*db))); | |||
| } | |||
| MS_LOG(DEBUG) << "Opened database successfully"; | |||
| return Status::OK(); | |||
| } | |||
| MSRStatus ShardIndexGenerator::CreateShardNameTable(sqlite3 *db, const std::string &shard_name) { | |||
| Status ShardIndexGenerator::CreateShardNameTable(sqlite3 *db, const std::string &shard_name) { | |||
| // create shard_name table | |||
| std::string sql = "DROP TABLE IF EXISTS SHARD_NAME;"; | |||
| if (ExecuteSQL(sql, db, "drop table successfully.") != SUCCESS) { | |||
| return FAILED; | |||
| } | |||
| RETURN_IF_NOT_OK(ExecuteSQL(sql, db, "drop table successfully.")); | |||
| sql = "CREATE TABLE SHARD_NAME(NAME TEXT NOT NULL);"; | |||
| if (ExecuteSQL(sql, db, "create table successfully.") != SUCCESS) { | |||
| return FAILED; | |||
| } | |||
| RETURN_IF_NOT_OK(ExecuteSQL(sql, db, "create table successfully.")); | |||
| sql = "INSERT INTO SHARD_NAME (NAME) VALUES (:SHARD_NAME);"; | |||
| sqlite3_stmt *stmt = nullptr; | |||
| if (sqlite3_prepare_v2(db, common::SafeCStr(sql), -1, &stmt, 0) != SQLITE_OK) { | |||
| if (stmt != nullptr) { | |||
| (void)sqlite3_finalize(stmt); | |||
| } | |||
| MS_LOG(ERROR) << "SQL error: could not prepare statement, sql: " << sql; | |||
| return FAILED; | |||
| sqlite3_close(db); | |||
| RETURN_STATUS_UNEXPECTED("SQL error: could not prepare statement, sql: " + sql); | |||
| } | |||
| int index = sqlite3_bind_parameter_index(stmt, ":SHARD_NAME"); | |||
| if (sqlite3_bind_text(stmt, index, shard_name.data(), -1, SQLITE_STATIC) != SQLITE_OK) { | |||
| (void)sqlite3_finalize(stmt); | |||
| MS_LOG(ERROR) << "SQL error: could not bind parameter, index: " << index << ", field value: " << shard_name; | |||
| return FAILED; | |||
| sqlite3_close(db); | |||
| RETURN_STATUS_UNEXPECTED("SQL error: could not bind parameter, index: " + std::to_string(index) + | |||
| ", field value: " + std::string(shard_name)); | |||
| } | |||
| if (sqlite3_step(stmt) != SQLITE_DONE) { | |||
| (void)sqlite3_finalize(stmt); | |||
| MS_LOG(ERROR) << "SQL error: Could not step (execute) stmt."; | |||
| return FAILED; | |||
| RETURN_STATUS_UNEXPECTED("SQL error: Could not step (execute) stmt."); | |||
| } | |||
| (void)sqlite3_finalize(stmt); | |||
| return SUCCESS; | |||
| return Status::OK(); | |||
| } | |||
| std::pair<MSRStatus, sqlite3 *> ShardIndexGenerator::CreateDatabase(int shard_no) { | |||
| Status ShardIndexGenerator::CreateDatabase(int shard_no, sqlite3 **db) { | |||
| std::string shard_address = shard_header_.GetShardAddressByID(shard_no); | |||
| if (shard_address.empty()) { | |||
| MS_LOG(ERROR) << "Shard address is null, shard no: " << shard_no; | |||
| return {FAILED, nullptr}; | |||
| } | |||
| string shard_name = GetFileName(shard_address).second; | |||
| CHECK_FAIL_RETURN_UNEXPECTED(!shard_address.empty(), "Shard address is empty, shard No: " + shard_no); | |||
| std::shared_ptr<std::string> fn_ptr; | |||
| RETURN_IF_NOT_OK(GetFileName(shard_address, &fn_ptr)); | |||
| shard_address += ".db"; | |||
| auto ret1 = CheckDatabase(shard_address); | |||
| if (ret1.first != SUCCESS) { | |||
| return {FAILED, nullptr}; | |||
| } | |||
| sqlite3 *db = ret1.second; | |||
| RETURN_IF_NOT_OK(CheckDatabase(shard_address, db)); | |||
| std::string sql = "DROP TABLE IF EXISTS INDEXES;"; | |||
| if (ExecuteSQL(sql, db, "drop table successfully.") != SUCCESS) { | |||
| return {FAILED, nullptr}; | |||
| } | |||
| RETURN_IF_NOT_OK(ExecuteSQL(sql, *db, "drop table successfully.")); | |||
| sql = | |||
| "CREATE TABLE INDEXES(" | |||
| " ROW_ID INT NOT NULL, PAGE_ID_RAW INT NOT NULL" | |||
| @@ -273,95 +231,79 @@ std::pair<MSRStatus, sqlite3 *> ShardIndexGenerator::CreateDatabase(int shard_no | |||
| ", PAGE_OFFSET_BLOB INT NOT NULL, PAGE_OFFSET_BLOB_END INT NOT NULL"; | |||
| int field_no = 0; | |||
| std::shared_ptr<std::string> field_ptr; | |||
| for (const auto &field : fields_) { | |||
| uint64_t schema_id = field.first; | |||
| auto result = shard_header_.GetSchemaByID(schema_id); | |||
| if (result.second != SUCCESS) { | |||
| return {FAILED, nullptr}; | |||
| } | |||
| json json_schema = (result.first->GetSchema())["schema"]; | |||
| std::shared_ptr<Schema> schema_ptr; | |||
| RETURN_IF_NOT_OK(shard_header_.GetSchemaByID(schema_id, &schema_ptr)); | |||
| json json_schema = (schema_ptr->GetSchema())["schema"]; | |||
| std::string type = ConvertJsonToSQL(TakeFieldType(field.second, json_schema)); | |||
| auto ret = GenerateFieldName(field); | |||
| if (ret.first != SUCCESS) { | |||
| return {FAILED, nullptr}; | |||
| } | |||
| sql += ",INC_" + std::to_string(field_no++) + " INT, " + ret.second + " " + type; | |||
| RETURN_IF_NOT_OK(GenerateFieldName(field, &field_ptr)); | |||
| sql += ",INC_" + std::to_string(field_no++) + " INT, " + *field_ptr + " " + type; | |||
| } | |||
| sql += ", PRIMARY KEY(ROW_ID"; | |||
| for (uint64_t i = 0; i < fields_.size(); ++i) { | |||
| sql += ",INC_" + std::to_string(i); | |||
| } | |||
| sql += "));"; | |||
| if (ExecuteSQL(sql, db, "create table successfully.") != SUCCESS) { | |||
| return {FAILED, nullptr}; | |||
| } | |||
| if (CreateShardNameTable(db, shard_name) != SUCCESS) { | |||
| return {FAILED, nullptr}; | |||
| } | |||
| return {SUCCESS, db}; | |||
| RETURN_IF_NOT_OK(ExecuteSQL(sql, *db, "create table successfully.")); | |||
| RETURN_IF_NOT_OK(CreateShardNameTable(*db, *fn_ptr)); | |||
| return Status::OK(); | |||
| } | |||
| std::pair<MSRStatus, std::vector<json>> ShardIndexGenerator::GetSchemaDetails(const std::vector<uint64_t> &schema_lens, | |||
| std::fstream &in) { | |||
| std::vector<json> schema_details; | |||
| Status ShardIndexGenerator::GetSchemaDetails(const std::vector<uint64_t> &schema_lens, std::fstream &in, | |||
| std::shared_ptr<std::vector<json>> *detail_ptr) { | |||
| RETURN_UNEXPECTED_IF_NULL(detail_ptr); | |||
| if (schema_count_ <= kMaxSchemaCount) { | |||
| for (int sc = 0; sc < schema_count_; ++sc) { | |||
| std::vector<char> schema_detail(schema_lens[sc]); | |||
| auto &io_read = in.read(&schema_detail[0], schema_lens[sc]); | |||
| if (!io_read.good() || io_read.fail() || io_read.bad()) { | |||
| MS_LOG(ERROR) << "File read failed"; | |||
| in.close(); | |||
| return {FAILED, {}}; | |||
| RETURN_STATUS_UNEXPECTED("Failed to read file."); | |||
| } | |||
| schema_details.emplace_back(json::from_msgpack(std::string(schema_detail.begin(), schema_detail.end()))); | |||
| auto j = json::from_msgpack(std::string(schema_detail.begin(), schema_detail.end())); | |||
| (*detail_ptr)->emplace_back(j); | |||
| } | |||
| } | |||
| return {SUCCESS, schema_details}; | |||
| return Status::OK(); | |||
| } | |||
| std::pair<MSRStatus, std::string> ShardIndexGenerator::GenerateRawSQL( | |||
| const std::vector<std::pair<uint64_t, std::string>> &fields) { | |||
| Status ShardIndexGenerator::GenerateRawSQL(const std::vector<std::pair<uint64_t, std::string>> &fields, | |||
| std::shared_ptr<std::string> *sql_ptr) { | |||
| std::string sql = | |||
| "INSERT INTO INDEXES (ROW_ID,ROW_GROUP_ID,PAGE_ID_RAW,PAGE_OFFSET_RAW,PAGE_OFFSET_RAW_END," | |||
| "PAGE_ID_BLOB,PAGE_OFFSET_BLOB,PAGE_OFFSET_BLOB_END"; | |||
| int field_no = 0; | |||
| for (const auto &field : fields) { | |||
| auto ret = GenerateFieldName(field); | |||
| if (ret.first != SUCCESS) { | |||
| return {FAILED, ""}; | |||
| } | |||
| sql += ",INC_" + std::to_string(field_no++) + "," + ret.second; | |||
| std::shared_ptr<std::string> fn_ptr; | |||
| RETURN_IF_NOT_OK(GenerateFieldName(field, &fn_ptr)); | |||
| sql += ",INC_" + std::to_string(field_no++) + "," + *fn_ptr; | |||
| } | |||
| sql += | |||
| ") VALUES( :ROW_ID,:ROW_GROUP_ID,:PAGE_ID_RAW,:PAGE_OFFSET_RAW,:PAGE_OFFSET_RAW_END,:PAGE_ID_BLOB," | |||
| ":PAGE_OFFSET_BLOB,:PAGE_OFFSET_BLOB_END"; | |||
| field_no = 0; | |||
| for (const auto &field : fields) { | |||
| auto ret = GenerateFieldName(field); | |||
| if (ret.first != SUCCESS) { | |||
| return {FAILED, ""}; | |||
| } | |||
| sql += ",:INC_" + std::to_string(field_no++) + ",:" + ret.second; | |||
| std::shared_ptr<std::string> fn_ptr; | |||
| RETURN_IF_NOT_OK(GenerateFieldName(field, &fn_ptr)); | |||
| sql += ",:INC_" + std::to_string(field_no++) + ",:" + *fn_ptr; | |||
| } | |||
| sql += " )"; | |||
| return {SUCCESS, sql}; | |||
| *sql_ptr = std::make_shared<std::string>(sql); | |||
| return Status::OK(); | |||
| } | |||
| MSRStatus ShardIndexGenerator::BindParameterExecuteSQL( | |||
| sqlite3 *db, const std::string &sql, | |||
| const std::vector<std::vector<std::tuple<std::string, std::string, std::string>>> &data) { | |||
| Status ShardIndexGenerator::BindParameterExecuteSQL(sqlite3 *db, const std::string &sql, const ROW_DATA &data) { | |||
| sqlite3_stmt *stmt = nullptr; | |||
| if (sqlite3_prepare_v2(db, common::SafeCStr(sql), -1, &stmt, 0) != SQLITE_OK) { | |||
| if (stmt != nullptr) { | |||
| (void)sqlite3_finalize(stmt); | |||
| } | |||
| MS_LOG(ERROR) << "SQL error: could not prepare statement, sql: " << sql; | |||
| return FAILED; | |||
| sqlite3_close(db); | |||
| RETURN_STATUS_UNEXPECTED("SQL error: could not prepare statement, sql: " + sql); | |||
| } | |||
| for (auto &row : data) { | |||
| for (auto &field : row) { | |||
| @@ -373,45 +315,47 @@ MSRStatus ShardIndexGenerator::BindParameterExecuteSQL( | |||
| if (field_type == "INTEGER") { | |||
| if (sqlite3_bind_int64(stmt, index, std::stoll(field_value)) != SQLITE_OK) { | |||
| (void)sqlite3_finalize(stmt); | |||
| MS_LOG(ERROR) << "SQL error: could not bind parameter, index: " << index | |||
| << ", field value: " << std::stoll(field_value); | |||
| return FAILED; | |||
| sqlite3_close(db); | |||
| RETURN_STATUS_UNEXPECTED("SQL error: could not bind parameter, index: " + std::to_string(index) + | |||
| ", field value: " + std::string(field_value)); | |||
| } | |||
| } else if (field_type == "NUMERIC") { | |||
| if (sqlite3_bind_double(stmt, index, std::stold(field_value)) != SQLITE_OK) { | |||
| (void)sqlite3_finalize(stmt); | |||
| MS_LOG(ERROR) << "SQL error: could not bind parameter, index: " << index | |||
| << ", field value: " << std::stold(field_value); | |||
| return FAILED; | |||
| sqlite3_close(db); | |||
| RETURN_STATUS_UNEXPECTED("SQL error: could not bind parameter, index: " + std::to_string(index) + | |||
| ", field value: " + std::string(field_value)); | |||
| } | |||
| } else if (field_type == "NULL") { | |||
| if (sqlite3_bind_null(stmt, index) != SQLITE_OK) { | |||
| (void)sqlite3_finalize(stmt); | |||
| MS_LOG(ERROR) << "SQL error: could not bind parameter, index: " << index << ", field value: NULL"; | |||
| return FAILED; | |||
| sqlite3_close(db); | |||
| RETURN_STATUS_UNEXPECTED("SQL error: could not bind parameter, index: " + std::to_string(index) + | |||
| ", field value: NULL"); | |||
| } | |||
| } else { | |||
| if (sqlite3_bind_text(stmt, index, common::SafeCStr(field_value), -1, SQLITE_STATIC) != SQLITE_OK) { | |||
| (void)sqlite3_finalize(stmt); | |||
| MS_LOG(ERROR) << "SQL error: could not bind parameter, index: " << index << ", field value: " << field_value; | |||
| return FAILED; | |||
| sqlite3_close(db); | |||
| RETURN_STATUS_UNEXPECTED("SQL error: could not bind parameter, index: " + std::to_string(index) + | |||
| ", field value: " + std::string(field_value)); | |||
| } | |||
| } | |||
| } | |||
| if (sqlite3_step(stmt) != SQLITE_DONE) { | |||
| (void)sqlite3_finalize(stmt); | |||
| MS_LOG(ERROR) << "SQL error: Could not step (execute) stmt."; | |||
| return FAILED; | |||
| RETURN_STATUS_UNEXPECTED("SQL error: Could not step (execute) stmt."); | |||
| } | |||
| (void)sqlite3_reset(stmt); | |||
| } | |||
| (void)sqlite3_finalize(stmt); | |||
| return SUCCESS; | |||
| return Status::OK(); | |||
| } | |||
| MSRStatus ShardIndexGenerator::AddBlobPageInfo(std::vector<std::tuple<std::string, std::string, std::string>> &row_data, | |||
| const std::shared_ptr<Page> cur_blob_page, | |||
| uint64_t &cur_blob_page_offset, std::fstream &in) { | |||
| Status ShardIndexGenerator::AddBlobPageInfo(std::vector<std::tuple<std::string, std::string, std::string>> &row_data, | |||
| const std::shared_ptr<Page> cur_blob_page, uint64_t &cur_blob_page_offset, | |||
| std::fstream &in) { | |||
| row_data.emplace_back(":PAGE_ID_BLOB", "INTEGER", std::to_string(cur_blob_page->GetPageID())); | |||
| // blob data start | |||
| @@ -419,89 +363,71 @@ MSRStatus ShardIndexGenerator::AddBlobPageInfo(std::vector<std::tuple<std::strin | |||
| auto &io_seekg_blob = | |||
| in.seekg(page_size_ * cur_blob_page->GetPageID() + header_size_ + cur_blob_page_offset, std::ios::beg); | |||
| if (!io_seekg_blob.good() || io_seekg_blob.fail() || io_seekg_blob.bad()) { | |||
| MS_LOG(ERROR) << "File seekg failed"; | |||
| in.close(); | |||
| return FAILED; | |||
| RETURN_STATUS_UNEXPECTED("Failed to seekg file."); | |||
| } | |||
| uint64_t image_size = 0; | |||
| auto &io_read = in.read(reinterpret_cast<char *>(&image_size), kInt64Len); | |||
| if (!io_read.good() || io_read.fail() || io_read.bad()) { | |||
| MS_LOG(ERROR) << "File read failed"; | |||
| in.close(); | |||
| return FAILED; | |||
| RETURN_STATUS_UNEXPECTED("Failed to read file."); | |||
| } | |||
| cur_blob_page_offset += (kInt64Len + image_size); | |||
| row_data.emplace_back(":PAGE_OFFSET_BLOB_END", "INTEGER", std::to_string(cur_blob_page_offset)); | |||
| return SUCCESS; | |||
| return Status::OK(); | |||
| } | |||
| void ShardIndexGenerator::AddIndexFieldByRawData( | |||
| Status ShardIndexGenerator::AddIndexFieldByRawData( | |||
| const std::vector<json> &schema_detail, std::vector<std::tuple<std::string, std::string, std::string>> &row_data) { | |||
| auto result = GenerateIndexFields(schema_detail); | |||
| if (result.first == SUCCESS) { | |||
| int index = 0; | |||
| for (const auto &field : result.second) { | |||
| // assume simple field: string , number etc. | |||
| row_data.emplace_back(":INC_" + std::to_string(index++), "INTEGER", "0"); | |||
| row_data.emplace_back(":" + std::get<0>(field), std::get<1>(field), std::get<2>(field)); | |||
| } | |||
| } | |||
| auto index_fields_ptr = std::make_shared<INDEX_FIELDS>(); | |||
| RETURN_IF_NOT_OK(GenerateIndexFields(schema_detail, &index_fields_ptr)); | |||
| int index = 0; | |||
| for (const auto &field : *index_fields_ptr) { | |||
| // assume simple field: string , number etc. | |||
| row_data.emplace_back(":INC_" + std::to_string(index++), "INTEGER", "0"); | |||
| row_data.emplace_back(":" + std::get<0>(field), std::get<1>(field), std::get<2>(field)); | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| ROW_DATA ShardIndexGenerator::GenerateRowData(int shard_no, const std::map<int, int> &blob_id_to_page_id, | |||
| int raw_page_id, std::fstream &in) { | |||
| std::vector<std::vector<std::tuple<std::string, std::string, std::string>>> full_data; | |||
| Status ShardIndexGenerator::GenerateRowData(int shard_no, const std::map<int, int> &blob_id_to_page_id, int raw_page_id, | |||
| std::fstream &in, std::shared_ptr<ROW_DATA> *row_data_ptr) { | |||
| RETURN_UNEXPECTED_IF_NULL(row_data_ptr); | |||
| // current raw data page | |||
| auto ret1 = shard_header_.GetPage(shard_no, raw_page_id); | |||
| if (ret1.second != SUCCESS) { | |||
| MS_LOG(ERROR) << "Get page failed"; | |||
| return {FAILED, {}}; | |||
| } | |||
| std::shared_ptr<Page> cur_raw_page = ret1.first; | |||
| std::shared_ptr<Page> page_ptr; | |||
| RETURN_IF_NOT_OK(shard_header_.GetPage(shard_no, raw_page_id, &page_ptr)); | |||
| // related blob page | |||
| vector<pair<int, uint64_t>> row_group_list = cur_raw_page->GetRowGroupIds(); | |||
| vector<pair<int, uint64_t>> row_group_list = page_ptr->GetRowGroupIds(); | |||
| // pair: row_group id, offset in raw data page | |||
| for (pair<int, int> blob_ids : row_group_list) { | |||
| // get blob data page according to row_group id | |||
| auto iter = blob_id_to_page_id.find(blob_ids.first); | |||
| if (iter == blob_id_to_page_id.end()) { | |||
| MS_LOG(ERROR) << "Convert blob id failed"; | |||
| return {FAILED, {}}; | |||
| } | |||
| auto ret2 = shard_header_.GetPage(shard_no, iter->second); | |||
| if (ret2.second != SUCCESS) { | |||
| MS_LOG(ERROR) << "Get page failed"; | |||
| return {FAILED, {}}; | |||
| } | |||
| std::shared_ptr<Page> cur_blob_page = ret2.first; | |||
| CHECK_FAIL_RETURN_UNEXPECTED(iter != blob_id_to_page_id.end(), "Failed to get page id from blob id."); | |||
| std::shared_ptr<Page> blob_page_ptr; | |||
| RETURN_IF_NOT_OK(shard_header_.GetPage(shard_no, iter->second, &blob_page_ptr)); | |||
| // offset in current raw data page | |||
| auto cur_raw_page_offset = static_cast<uint64_t>(blob_ids.second); | |||
| uint64_t cur_blob_page_offset = 0; | |||
| for (unsigned int i = cur_blob_page->GetStartRowID(); i < cur_blob_page->GetEndRowID(); ++i) { | |||
| for (unsigned int i = blob_page_ptr->GetStartRowID(); i < blob_page_ptr->GetEndRowID(); ++i) { | |||
| std::vector<std::tuple<std::string, std::string, std::string>> row_data; | |||
| row_data.emplace_back(":ROW_ID", "INTEGER", std::to_string(i)); | |||
| row_data.emplace_back(":ROW_GROUP_ID", "INTEGER", std::to_string(cur_blob_page->GetPageTypeID())); | |||
| row_data.emplace_back(":PAGE_ID_RAW", "INTEGER", std::to_string(cur_raw_page->GetPageID())); | |||
| row_data.emplace_back(":ROW_GROUP_ID", "INTEGER", std::to_string(blob_page_ptr->GetPageTypeID())); | |||
| row_data.emplace_back(":PAGE_ID_RAW", "INTEGER", std::to_string(page_ptr->GetPageID())); | |||
| // raw data start | |||
| row_data.emplace_back(":PAGE_OFFSET_RAW", "INTEGER", std::to_string(cur_raw_page_offset)); | |||
| // calculate raw data end | |||
| auto &io_seekg = | |||
| in.seekg(page_size_ * (cur_raw_page->GetPageID()) + header_size_ + cur_raw_page_offset, std::ios::beg); | |||
| in.seekg(page_size_ * (page_ptr->GetPageID()) + header_size_ + cur_raw_page_offset, std::ios::beg); | |||
| if (!io_seekg.good() || io_seekg.fail() || io_seekg.bad()) { | |||
| MS_LOG(ERROR) << "File seekg failed"; | |||
| return {FAILED, {}}; | |||
| in.close(); | |||
| RETURN_STATUS_UNEXPECTED("Failed to seekg file."); | |||
| } | |||
| std::vector<uint64_t> schema_lens; | |||
| if (schema_count_ <= kMaxSchemaCount) { | |||
| for (int sc = 0; sc < schema_count_; sc++) { | |||
| @@ -509,8 +435,8 @@ ROW_DATA ShardIndexGenerator::GenerateRowData(int shard_no, const std::map<int, | |||
| auto &io_read = in.read(reinterpret_cast<char *>(&schema_size), kInt64Len); | |||
| if (!io_read.good() || io_read.fail() || io_read.bad()) { | |||
| MS_LOG(ERROR) << "File read failed"; | |||
| return {FAILED, {}}; | |||
| in.close(); | |||
| RETURN_STATUS_UNEXPECTED("Failed to read file."); | |||
| } | |||
| cur_raw_page_offset += (kInt64Len + schema_size); | |||
| @@ -520,122 +446,79 @@ ROW_DATA ShardIndexGenerator::GenerateRowData(int shard_no, const std::map<int, | |||
| row_data.emplace_back(":PAGE_OFFSET_RAW_END", "INTEGER", std::to_string(cur_raw_page_offset)); | |||
| // Getting schema for getting data for fields | |||
| auto st_schema_detail = GetSchemaDetails(schema_lens, in); | |||
| if (st_schema_detail.first != SUCCESS) { | |||
| return {FAILED, {}}; | |||
| } | |||
| auto detail_ptr = std::make_shared<std::vector<json>>(); | |||
| RETURN_IF_NOT_OK(GetSchemaDetails(schema_lens, in, &detail_ptr)); | |||
| // start blob page info | |||
| if (AddBlobPageInfo(row_data, cur_blob_page, cur_blob_page_offset, in) != SUCCESS) { | |||
| return {FAILED, {}}; | |||
| } | |||
| RETURN_IF_NOT_OK(AddBlobPageInfo(row_data, blob_page_ptr, cur_blob_page_offset, in)); | |||
| // start index field | |||
| AddIndexFieldByRawData(st_schema_detail.second, row_data); | |||
| full_data.push_back(std::move(row_data)); | |||
| AddIndexFieldByRawData(*detail_ptr, row_data); | |||
| (*row_data_ptr)->push_back(std::move(row_data)); | |||
| } | |||
| } | |||
| return {SUCCESS, full_data}; | |||
| return Status::OK(); | |||
| } | |||
| INDEX_FIELDS ShardIndexGenerator::GenerateIndexFields(const std::vector<json> &schema_detail) { | |||
| std::vector<std::tuple<std::string, std::string, std::string>> fields; | |||
| Status ShardIndexGenerator::GenerateIndexFields(const std::vector<json> &schema_detail, | |||
| std::shared_ptr<INDEX_FIELDS> *index_fields_ptr) { | |||
| RETURN_UNEXPECTED_IF_NULL(index_fields_ptr); | |||
| // index fields | |||
| std::vector<std::pair<uint64_t, std::string>> index_fields = shard_header_.GetFields(); | |||
| for (const auto &field : index_fields) { | |||
| if (field.first >= schema_detail.size()) { | |||
| return {FAILED, {}}; | |||
| } | |||
| auto field_value = GetValueByField(field.second, schema_detail[field.first]); | |||
| if (field_value.first != SUCCESS) { | |||
| MS_LOG(ERROR) << "Get value from json by field name failed"; | |||
| return {FAILED, {}}; | |||
| } | |||
| auto result = shard_header_.GetSchemaByID(field.first); | |||
| if (result.second != SUCCESS) { | |||
| return {FAILED, {}}; | |||
| } | |||
| std::string field_type = ConvertJsonToSQL(TakeFieldType(field.second, result.first->GetSchema()["schema"])); | |||
| auto ret = GenerateFieldName(field); | |||
| if (ret.first != SUCCESS) { | |||
| return {FAILED, {}}; | |||
| } | |||
| fields.emplace_back(ret.second, field_type, field_value.second); | |||
| } | |||
| return {SUCCESS, std::move(fields)}; | |||
| CHECK_FAIL_RETURN_UNEXPECTED(field.first < schema_detail.size(), "Index field id is out of range."); | |||
| std::shared_ptr<std::string> field_val_ptr; | |||
| RETURN_IF_NOT_OK(GetValueByField(field.second, schema_detail[field.first], &field_val_ptr)); | |||
| std::shared_ptr<Schema> schema_ptr; | |||
| RETURN_IF_NOT_OK(shard_header_.GetSchemaByID(field.first, &schema_ptr)); | |||
| std::string field_type = ConvertJsonToSQL(TakeFieldType(field.second, schema_ptr->GetSchema()["schema"])); | |||
| std::shared_ptr<std::string> fn_ptr; | |||
| RETURN_IF_NOT_OK(GenerateFieldName(field, &fn_ptr)); | |||
| (*index_fields_ptr)->emplace_back(*fn_ptr, field_type, *field_val_ptr); | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| MSRStatus ShardIndexGenerator::ExecuteTransaction(const int &shard_no, std::pair<MSRStatus, sqlite3 *> &db, | |||
| const std::vector<int> &raw_page_ids, | |||
| const std::map<int, int> &blob_id_to_page_id) { | |||
| Status ShardIndexGenerator::ExecuteTransaction(const int &shard_no, sqlite3 *db, const std::vector<int> &raw_page_ids, | |||
| const std::map<int, int> &blob_id_to_page_id) { | |||
| // Add index data to database | |||
| std::string shard_address = shard_header_.GetShardAddressByID(shard_no); | |||
| if (shard_address.empty()) { | |||
| MS_LOG(ERROR) << "Invalid data, shard address is null"; | |||
| return FAILED; | |||
| } | |||
| CHECK_FAIL_RETURN_UNEXPECTED(!shard_address.empty(), "shard address is empty."); | |||
| auto realpath = Common::GetRealPath(shard_address); | |||
| if (!realpath.has_value()) { | |||
| MS_LOG(ERROR) << "Get real path failed, path=" << shard_address; | |||
| return FAILED; | |||
| } | |||
| CHECK_FAIL_RETURN_UNEXPECTED(realpath.has_value(), "Get real path failed, path=" + shard_address); | |||
| std::fstream in; | |||
| in.open(realpath.value(), std::ios::in | std::ios::binary); | |||
| if (!in.good()) { | |||
| MS_LOG(ERROR) << "Invalid file, failed to open file: " << shard_address; | |||
| in.close(); | |||
| return FAILED; | |||
| RETURN_STATUS_UNEXPECTED("Failed to open file: " + shard_address); | |||
| } | |||
| (void)sqlite3_exec(db.second, "BEGIN TRANSACTION;", nullptr, nullptr, nullptr); | |||
| (void)sqlite3_exec(db, "BEGIN TRANSACTION;", nullptr, nullptr, nullptr); | |||
| for (int raw_page_id : raw_page_ids) { | |||
| auto sql = GenerateRawSQL(fields_); | |||
| if (sql.first != SUCCESS) { | |||
| MS_LOG(ERROR) << "Generate raw SQL failed"; | |||
| in.close(); | |||
| sqlite3_close(db.second); | |||
| return FAILED; | |||
| } | |||
| auto data = GenerateRowData(shard_no, blob_id_to_page_id, raw_page_id, in); | |||
| if (data.first != SUCCESS) { | |||
| MS_LOG(ERROR) << "Generate raw data failed"; | |||
| in.close(); | |||
| sqlite3_close(db.second); | |||
| return FAILED; | |||
| } | |||
| if (BindParameterExecuteSQL(db.second, sql.second, data.second) == FAILED) { | |||
| MS_LOG(ERROR) << "Execute SQL failed"; | |||
| in.close(); | |||
| sqlite3_close(db.second); | |||
| return FAILED; | |||
| } | |||
| MS_LOG(INFO) << "Insert " << data.second.size() << " rows to index db."; | |||
| } | |||
| (void)sqlite3_exec(db.second, "END TRANSACTION;", nullptr, nullptr, nullptr); | |||
| std::shared_ptr<std::string> sql_ptr; | |||
| RELEASE_AND_RETURN_IF_NOT_OK(GenerateRawSQL(fields_, &sql_ptr), db, in); | |||
| auto row_data_ptr = std::make_shared<ROW_DATA>(); | |||
| RELEASE_AND_RETURN_IF_NOT_OK(GenerateRowData(shard_no, blob_id_to_page_id, raw_page_id, in, &row_data_ptr), db, in); | |||
| RELEASE_AND_RETURN_IF_NOT_OK(BindParameterExecuteSQL(db, *sql_ptr, *row_data_ptr), db, in); | |||
| MS_LOG(INFO) << "Insert " << row_data_ptr->size() << " rows to index db."; | |||
| } | |||
| (void)sqlite3_exec(db, "END TRANSACTION;", nullptr, nullptr, nullptr); | |||
| in.close(); | |||
| // Close database | |||
| if (sqlite3_close(db.second) != SQLITE_OK) { | |||
| MS_LOG(ERROR) << "Close database failed"; | |||
| return FAILED; | |||
| } | |||
| db.second = nullptr; | |||
| return SUCCESS; | |||
| sqlite3_close(db); | |||
| db = nullptr; | |||
| return Status::OK(); | |||
| } | |||
| MSRStatus ShardIndexGenerator::WriteToDatabase() { | |||
| Status ShardIndexGenerator::WriteToDatabase() { | |||
| fields_ = shard_header_.GetFields(); | |||
| page_size_ = shard_header_.GetPageSize(); | |||
| header_size_ = shard_header_.GetHeaderSize(); | |||
| schema_count_ = shard_header_.GetSchemaCount(); | |||
| if (shard_header_.GetShardCount() > kMaxShardCount) { | |||
| MS_LOG(ERROR) << "num shards: " << shard_header_.GetShardCount() << " exceeds max count:" << kMaxSchemaCount; | |||
| return FAILED; | |||
| } | |||
| CHECK_FAIL_RETURN_UNEXPECTED(shard_header_.GetShardCount() <= kMaxShardCount, | |||
| "num shards: " + std::to_string(shard_header_.GetShardCount()) + | |||
| " exceeds max count:" + std::to_string(kMaxSchemaCount)); | |||
| task_ = 0; // set two atomic vars to initial value | |||
| write_success_ = true; | |||
| @@ -653,40 +536,41 @@ MSRStatus ShardIndexGenerator::WriteToDatabase() { | |||
| for (size_t t = 0; t < threads.capacity(); t++) { | |||
| threads[t].join(); | |||
| } | |||
| return write_success_ ? SUCCESS : FAILED; | |||
| CHECK_FAIL_RETURN_UNEXPECTED(write_success_, "Failed to write data to db."); | |||
| return Status::OK(); | |||
| } | |||
| void ShardIndexGenerator::DatabaseWriter() { | |||
| int shard_no = task_++; | |||
| while (shard_no < shard_header_.GetShardCount()) { | |||
| auto db = CreateDatabase(shard_no); | |||
| if (db.first != SUCCESS || db.second == nullptr || write_success_ == false) { | |||
| sqlite3 *db = nullptr; | |||
| if (CreateDatabase(shard_no, &db).IsError()) { | |||
| MS_LOG(ERROR) << "Failed to create Generate database."; | |||
| write_success_ = false; | |||
| return; | |||
| } | |||
| MS_LOG(INFO) << "Init index db for shard: " << shard_no << " successfully."; | |||
| // Pre-processing page information | |||
| auto total_pages = shard_header_.GetLastPageId(shard_no) + 1; | |||
| std::map<int, int> blob_id_to_page_id; | |||
| std::vector<int> raw_page_ids; | |||
| for (uint64_t i = 0; i < total_pages; ++i) { | |||
| auto ret = shard_header_.GetPage(shard_no, i); | |||
| if (ret.second != SUCCESS) { | |||
| std::shared_ptr<Page> page_ptr; | |||
| if (shard_header_.GetPage(shard_no, i, &page_ptr).IsError()) { | |||
| MS_LOG(ERROR) << "Failed to get page."; | |||
| write_success_ = false; | |||
| return; | |||
| } | |||
| std::shared_ptr<Page> cur_page = ret.first; | |||
| if (cur_page->GetPageType() == "RAW_DATA") { | |||
| if (page_ptr->GetPageType() == "RAW_DATA") { | |||
| raw_page_ids.push_back(i); | |||
| } else if (cur_page->GetPageType() == "BLOB_DATA") { | |||
| blob_id_to_page_id[cur_page->GetPageTypeID()] = i; | |||
| } else if (page_ptr->GetPageType() == "BLOB_DATA") { | |||
| blob_id_to_page_id[page_ptr->GetPageTypeID()] = i; | |||
| } | |||
| } | |||
| if (ExecuteTransaction(shard_no, db, raw_page_ids, blob_id_to_page_id) != SUCCESS) { | |||
| if (ExecuteTransaction(shard_no, db, raw_page_ids, blob_id_to_page_id).IsError()) { | |||
| MS_LOG(ERROR) << "Failed to execute transaction."; | |||
| write_success_ = false; | |||
| return; | |||
| } | |||
| @@ -694,21 +578,12 @@ void ShardIndexGenerator::DatabaseWriter() { | |||
| shard_no = task_++; | |||
| } | |||
| } | |||
| MSRStatus ShardIndexGenerator::Finalize(const std::vector<std::string> file_names) { | |||
| if (file_names.empty()) { | |||
| MS_LOG(ERROR) << "Mindrecord files is empty."; | |||
| return FAILED; | |||
| } | |||
| Status ShardIndexGenerator::Finalize(const std::vector<std::string> file_names) { | |||
| CHECK_FAIL_RETURN_UNEXPECTED(!file_names.empty(), "Mindrecord files is empty."); | |||
| ShardIndexGenerator sg{file_names[0]}; | |||
| if (SUCCESS != sg.Build()) { | |||
| MS_LOG(ERROR) << "Failed to build index generator."; | |||
| return FAILED; | |||
| } | |||
| if (SUCCESS != sg.WriteToDatabase()) { | |||
| MS_LOG(ERROR) << "Failed to write to database."; | |||
| return FAILED; | |||
| } | |||
| return SUCCESS; | |||
| RETURN_IF_NOT_OK(sg.Build()); | |||
| RETURN_IF_NOT_OK(sg.WriteToDatabase()); | |||
| return Status::OK(); | |||
| } | |||
| } // namespace mindrecord | |||
| } // namespace mindspore | |||
| @@ -30,9 +30,13 @@ namespace mindspore { | |||
| namespace mindrecord { | |||
| ShardSegment::ShardSegment() { SetAllInIndex(false); } | |||
| std::pair<MSRStatus, vector<std::string>> ShardSegment::GetCategoryFields() { | |||
| Status ShardSegment::GetCategoryFields(std::shared_ptr<vector<std::string>> *fields_ptr) { | |||
| RETURN_UNEXPECTED_IF_NULL(fields_ptr); | |||
| // Skip if already populated | |||
| if (!candidate_category_fields_.empty()) return {SUCCESS, candidate_category_fields_}; | |||
| if (!candidate_category_fields_.empty()) { | |||
| *fields_ptr = std::make_shared<vector<std::string>>(candidate_category_fields_); | |||
| return Status::OK(); | |||
| } | |||
| std::string sql = "PRAGMA table_info(INDEXES);"; | |||
| std::vector<std::vector<std::string>> field_names; | |||
| @@ -40,11 +44,12 @@ std::pair<MSRStatus, vector<std::string>> ShardSegment::GetCategoryFields() { | |||
| char *errmsg = nullptr; | |||
| int rc = sqlite3_exec(database_paths_[0], common::SafeCStr(sql), SelectCallback, &field_names, &errmsg); | |||
| if (rc != SQLITE_OK) { | |||
| MS_LOG(ERROR) << "Error in select statement, sql: " << sql << ", error: " << errmsg; | |||
| std::ostringstream oss; | |||
| oss << "Error in select statement, sql: " << sql + ", error: " << errmsg; | |||
| sqlite3_free(errmsg); | |||
| sqlite3_close(database_paths_[0]); | |||
| database_paths_[0] = nullptr; | |||
| return {FAILED, vector<std::string>{}}; | |||
| RETURN_STATUS_UNEXPECTED(oss.str()); | |||
| } else { | |||
| MS_LOG(INFO) << "Get " << static_cast<int>(field_names.size()) << " records from index."; | |||
| } | |||
| @@ -55,53 +60,46 @@ std::pair<MSRStatus, vector<std::string>> ShardSegment::GetCategoryFields() { | |||
| sqlite3_free(errmsg); | |||
| sqlite3_close(database_paths_[0]); | |||
| database_paths_[0] = nullptr; | |||
| return {FAILED, vector<std::string>{}}; | |||
| RETURN_STATUS_UNEXPECTED("idx is out of range."); | |||
| } | |||
| candidate_category_fields_.push_back(field_names[idx][1]); | |||
| idx += 2; | |||
| } | |||
| sqlite3_free(errmsg); | |||
| return {SUCCESS, candidate_category_fields_}; | |||
| *fields_ptr = std::make_shared<vector<std::string>>(candidate_category_fields_); | |||
| return Status::OK(); | |||
| } | |||
| MSRStatus ShardSegment::SetCategoryField(std::string category_field) { | |||
| if (GetCategoryFields().first != SUCCESS) { | |||
| MS_LOG(ERROR) << "Get candidate category field failed"; | |||
| return FAILED; | |||
| } | |||
| Status ShardSegment::SetCategoryField(std::string category_field) { | |||
| std::shared_ptr<vector<std::string>> fields_ptr; | |||
| RETURN_IF_NOT_OK(GetCategoryFields(&fields_ptr)); | |||
| category_field = category_field + "_0"; | |||
| if (std::any_of(std::begin(candidate_category_fields_), std::end(candidate_category_fields_), | |||
| [category_field](std::string x) { return x == category_field; })) { | |||
| current_category_field_ = category_field; | |||
| return SUCCESS; | |||
| return Status::OK(); | |||
| } | |||
| MS_LOG(ERROR) << "Field " << category_field << " is not a candidate category field."; | |||
| return FAILED; | |||
| RETURN_STATUS_UNEXPECTED("Field " + category_field + " is not a candidate category field."); | |||
| } | |||
| std::pair<MSRStatus, std::string> ShardSegment::ReadCategoryInfo() { | |||
| Status ShardSegment::ReadCategoryInfo(std::shared_ptr<std::string> *category_ptr) { | |||
| RETURN_UNEXPECTED_IF_NULL(category_ptr); | |||
| MS_LOG(INFO) << "Read category begin"; | |||
| auto ret = WrapCategoryInfo(); | |||
| if (ret.first != SUCCESS) { | |||
| MS_LOG(ERROR) << "Get category info failed"; | |||
| return {FAILED, ""}; | |||
| } | |||
| auto category_info_ptr = std::make_shared<CATEGORY_INFO>(); | |||
| RETURN_IF_NOT_OK(WrapCategoryInfo(&category_info_ptr)); | |||
| // Convert category info to json string | |||
| auto category_json_string = ToJsonForCategory(ret.second); | |||
| *category_ptr = std::make_shared<std::string>(ToJsonForCategory(*category_info_ptr)); | |||
| MS_LOG(INFO) << "Read category end"; | |||
| return {SUCCESS, category_json_string}; | |||
| return Status::OK(); | |||
| } | |||
| std::pair<MSRStatus, std::vector<std::tuple<int, std::string, int>>> ShardSegment::WrapCategoryInfo() { | |||
| Status ShardSegment::WrapCategoryInfo(std::shared_ptr<CATEGORY_INFO> *category_info_ptr) { | |||
| RETURN_UNEXPECTED_IF_NULL(category_info_ptr); | |||
| std::map<std::string, int> counter; | |||
| if (!ValidateFieldName(current_category_field_)) { | |||
| MS_LOG(ERROR) << "category field error from index, it is: " << current_category_field_; | |||
| return {FAILED, std::vector<std::tuple<int, std::string, int>>()}; | |||
| } | |||
| CHECK_FAIL_RETURN_UNEXPECTED(ValidateFieldName(current_category_field_), | |||
| "Category field error from index, it is: " + current_category_field_); | |||
| std::string sql = "SELECT " + current_category_field_ + ", COUNT(" + current_category_field_ + | |||
| ") AS `value_occurrence` FROM indexes GROUP BY " + current_category_field_ + ";"; | |||
| @@ -109,13 +107,13 @@ std::pair<MSRStatus, std::vector<std::tuple<int, std::string, int>>> ShardSegmen | |||
| std::vector<std::vector<std::string>> field_count; | |||
| char *errmsg = nullptr; | |||
| int rc = sqlite3_exec(db, common::SafeCStr(sql), SelectCallback, &field_count, &errmsg); | |||
| if (rc != SQLITE_OK) { | |||
| MS_LOG(ERROR) << "Error in select statement, sql: " << sql << ", error: " << errmsg; | |||
| if (sqlite3_exec(db, common::SafeCStr(sql), SelectCallback, &field_count, &errmsg) != SQLITE_OK) { | |||
| std::ostringstream oss; | |||
| oss << "Error in select statement, sql: " << sql + ", error: " << errmsg; | |||
| sqlite3_free(errmsg); | |||
| sqlite3_close(db); | |||
| db = nullptr; | |||
| return {FAILED, std::vector<std::tuple<int, std::string, int>>()}; | |||
| RETURN_STATUS_UNEXPECTED(oss.str()); | |||
| } else { | |||
| MS_LOG(INFO) << "Get " << static_cast<int>(field_count.size()) << " records from index."; | |||
| } | |||
| @@ -127,14 +125,14 @@ std::pair<MSRStatus, std::vector<std::tuple<int, std::string, int>>> ShardSegmen | |||
| } | |||
| int idx = 0; | |||
| std::vector<std::tuple<int, std::string, int>> category_vec(counter.size()); | |||
| (void)std::transform(counter.begin(), counter.end(), category_vec.begin(), [&idx](std::tuple<std::string, int> item) { | |||
| return std::make_tuple(idx++, std::get<0>(item), std::get<1>(item)); | |||
| }); | |||
| return {SUCCESS, std::move(category_vec)}; | |||
| (*category_info_ptr)->resize(counter.size()); | |||
| (void)std::transform( | |||
| counter.begin(), counter.end(), (*category_info_ptr)->begin(), | |||
| [&idx](std::tuple<std::string, int> item) { return std::make_tuple(idx++, std::get<0>(item), std::get<1>(item)); }); | |||
| return Status::OK(); | |||
| } | |||
| std::string ShardSegment::ToJsonForCategory(const std::vector<std::tuple<int, std::string, int>> &tri_vec) { | |||
| std::string ShardSegment::ToJsonForCategory(const CATEGORY_INFO &tri_vec) { | |||
| std::vector<json> category_json_vec; | |||
| for (auto q : tri_vec) { | |||
| json j; | |||
| @@ -152,27 +150,20 @@ std::string ShardSegment::ToJsonForCategory(const std::vector<std::tuple<int, st | |||
| return category_info.dump(); | |||
| } | |||
| std::pair<MSRStatus, std::vector<std::vector<uint8_t>>> ShardSegment::ReadAtPageById(int64_t category_id, | |||
| int64_t page_no, | |||
| int64_t n_rows_of_page) { | |||
| auto ret = WrapCategoryInfo(); | |||
| if (ret.first != SUCCESS) { | |||
| MS_LOG(ERROR) << "Get category info"; | |||
| return {FAILED, std::vector<std::vector<uint8_t>>{}}; | |||
| } | |||
| if (category_id >= static_cast<int>(ret.second.size()) || category_id < 0) { | |||
| MS_LOG(ERROR) << "Illegal category id, id: " << category_id; | |||
| return {FAILED, std::vector<std::vector<uint8_t>>{}}; | |||
| } | |||
| int total_rows_in_category = std::get<2>(ret.second[category_id]); | |||
| Status ShardSegment::ReadAtPageById(int64_t category_id, int64_t page_no, int64_t n_rows_of_page, | |||
| std::shared_ptr<std::vector<std::vector<uint8_t>>> *page_ptr) { | |||
| RETURN_UNEXPECTED_IF_NULL(page_ptr); | |||
| auto category_info_ptr = std::make_shared<CATEGORY_INFO>(); | |||
| RETURN_IF_NOT_OK(WrapCategoryInfo(&category_info_ptr)); | |||
| CHECK_FAIL_RETURN_UNEXPECTED(category_id < static_cast<int>(category_info_ptr->size()) && category_id >= 0, | |||
| "Invalid category id, id: " + std::to_string(category_id)); | |||
| int total_rows_in_category = std::get<2>((*category_info_ptr)[category_id]); | |||
| // Quit if category not found or page number is out of range | |||
| if (total_rows_in_category <= 0 || page_no < 0 || n_rows_of_page <= 0 || | |||
| page_no * n_rows_of_page >= total_rows_in_category) { | |||
| MS_LOG(ERROR) << "Illegal page no / page size, page no: " << page_no << ", page size: " << n_rows_of_page; | |||
| return {FAILED, std::vector<std::vector<uint8_t>>{}}; | |||
| } | |||
| CHECK_FAIL_RETURN_UNEXPECTED(total_rows_in_category > 0 && page_no >= 0 && n_rows_of_page > 0 && | |||
| page_no * n_rows_of_page < total_rows_in_category, | |||
| "Invalid page no / page size, page no: " + std::to_string(page_no) + | |||
| ", page size: " + std::to_string(n_rows_of_page)); | |||
| std::vector<std::vector<uint8_t>> page; | |||
| auto row_group_summary = ReadRowGroupSummary(); | |||
| uint64_t i_start = page_no * n_rows_of_page; | |||
| @@ -183,12 +174,12 @@ std::pair<MSRStatus, std::vector<std::vector<uint8_t>>> ShardSegment::ReadAtPage | |||
| auto shard_id = std::get<0>(rg); | |||
| auto group_id = std::get<1>(rg); | |||
| auto details = ReadRowGroupCriteria( | |||
| group_id, shard_id, std::make_pair(CleanUp(current_category_field_), std::get<1>(ret.second[category_id]))); | |||
| if (SUCCESS != std::get<0>(details)) { | |||
| return {FAILED, std::vector<std::vector<uint8_t>>{}}; | |||
| } | |||
| auto offsets = std::get<4>(details); | |||
| std::shared_ptr<ROW_GROUP_BRIEF> row_group_brief_ptr; | |||
| RETURN_IF_NOT_OK(ReadRowGroupCriteria( | |||
| group_id, shard_id, | |||
| std::make_pair(CleanUp(current_category_field_), std::get<1>((*category_info_ptr)[category_id])), {""}, | |||
| &row_group_brief_ptr)); | |||
| auto offsets = std::get<3>(*row_group_brief_ptr); | |||
| uint64_t number_of_rows = offsets.size(); | |||
| if (idx + number_of_rows < i_start) { | |||
| idx += number_of_rows; | |||
| @@ -197,131 +188,116 @@ std::pair<MSRStatus, std::vector<std::vector<uint8_t>>> ShardSegment::ReadAtPage | |||
| for (uint64_t i = 0; i < number_of_rows; ++i, ++idx) { | |||
| if (idx >= i_start && idx < i_end) { | |||
| auto ret1 = PackImages(group_id, shard_id, offsets[i]); | |||
| if (SUCCESS != ret1.first) { | |||
| return {FAILED, std::vector<std::vector<uint8_t>>{}}; | |||
| } | |||
| page.push_back(std::move(ret1.second)); | |||
| auto images_ptr = std::make_shared<std::vector<uint8_t>>(); | |||
| RETURN_IF_NOT_OK(PackImages(group_id, shard_id, offsets[i], &images_ptr)); | |||
| (*page_ptr)->push_back(std::move(*images_ptr)); | |||
| } | |||
| } | |||
| } | |||
| return {SUCCESS, std::move(page)}; | |||
| return Status::OK(); | |||
| } | |||
| std::pair<MSRStatus, std::vector<uint8_t>> ShardSegment::PackImages(int group_id, int shard_id, | |||
| std::vector<uint64_t> offset) { | |||
| const auto &ret = shard_header_->GetPageByGroupId(group_id, shard_id); | |||
| if (SUCCESS != ret.first) { | |||
| return {FAILED, std::vector<uint8_t>()}; | |||
| } | |||
| const std::shared_ptr<Page> &blob_page = ret.second; | |||
| Status ShardSegment::PackImages(int group_id, int shard_id, std::vector<uint64_t> offset, | |||
| std::shared_ptr<std::vector<uint8_t>> *images_ptr) { | |||
| RETURN_UNEXPECTED_IF_NULL(images_ptr); | |||
| std::shared_ptr<Page> page_ptr; | |||
| RETURN_IF_NOT_OK(shard_header_->GetPageByGroupId(group_id, shard_id, &page_ptr)); | |||
| // Pack image list | |||
| std::vector<uint8_t> images(offset[1] - offset[0]); | |||
| auto file_offset = header_size_ + page_size_ * (blob_page->GetPageID()) + offset[0]; | |||
| (*images_ptr)->resize(offset[1] - offset[0]); | |||
| auto file_offset = header_size_ + page_size_ * page_ptr->GetPageID() + offset[0]; | |||
| auto &io_seekg = file_streams_random_[0][shard_id]->seekg(file_offset, std::ios::beg); | |||
| if (!io_seekg.good() || io_seekg.fail() || io_seekg.bad()) { | |||
| MS_LOG(ERROR) << "File seekg failed"; | |||
| file_streams_random_[0][shard_id]->close(); | |||
| return {FAILED, {}}; | |||
| RETURN_STATUS_UNEXPECTED("Failed to seekg file."); | |||
| } | |||
| auto &io_read = file_streams_random_[0][shard_id]->read(reinterpret_cast<char *>(&images[0]), offset[1] - offset[0]); | |||
| auto &io_read = | |||
| file_streams_random_[0][shard_id]->read(reinterpret_cast<char *>(&((*(*images_ptr))[0])), offset[1] - offset[0]); | |||
| if (!io_read.good() || io_read.fail() || io_read.bad()) { | |||
| MS_LOG(ERROR) << "File read failed"; | |||
| file_streams_random_[0][shard_id]->close(); | |||
| return {FAILED, {}}; | |||
| RETURN_STATUS_UNEXPECTED("Failed to read file."); | |||
| } | |||
| return {SUCCESS, std::move(images)}; | |||
| return Status::OK(); | |||
| } | |||
| std::pair<MSRStatus, std::vector<std::vector<uint8_t>>> ShardSegment::ReadAtPageByName(std::string category_name, | |||
| int64_t page_no, | |||
| int64_t n_rows_of_page) { | |||
| auto ret = WrapCategoryInfo(); | |||
| if (ret.first != SUCCESS) { | |||
| MS_LOG(ERROR) << "Get category info"; | |||
| return {FAILED, std::vector<std::vector<uint8_t>>{}}; | |||
| } | |||
| for (const auto &categories : ret.second) { | |||
| Status ShardSegment::ReadAtPageByName(std::string category_name, int64_t page_no, int64_t n_rows_of_page, | |||
| std::shared_ptr<std::vector<std::vector<uint8_t>>> *pages_ptr) { | |||
| RETURN_UNEXPECTED_IF_NULL(pages_ptr); | |||
| auto category_info_ptr = std::make_shared<CATEGORY_INFO>(); | |||
| RETURN_IF_NOT_OK(WrapCategoryInfo(&category_info_ptr)); | |||
| for (const auto &categories : *category_info_ptr) { | |||
| if (std::get<1>(categories) == category_name) { | |||
| auto result = ReadAtPageById(std::get<0>(categories), page_no, n_rows_of_page); | |||
| return result; | |||
| RETURN_IF_NOT_OK(ReadAtPageById(std::get<0>(categories), page_no, n_rows_of_page, pages_ptr)); | |||
| return Status::OK(); | |||
| } | |||
| } | |||
| return {FAILED, std::vector<std::vector<uint8_t>>()}; | |||
| RETURN_STATUS_UNEXPECTED("Category name can not match."); | |||
| } | |||
| std::pair<MSRStatus, std::vector<std::tuple<std::vector<uint8_t>, json>>> ShardSegment::ReadAllAtPageById( | |||
| int64_t category_id, int64_t page_no, int64_t n_rows_of_page) { | |||
| auto ret = WrapCategoryInfo(); | |||
| if (ret.first != SUCCESS || category_id >= static_cast<int>(ret.second.size())) { | |||
| MS_LOG(ERROR) << "Illegal category id, id: " << category_id; | |||
| return {FAILED, std::vector<std::tuple<std::vector<uint8_t>, json>>{}}; | |||
| } | |||
| int total_rows_in_category = std::get<2>(ret.second[category_id]); | |||
| // Quit if category not found or page number is out of range | |||
| if (total_rows_in_category <= 0 || page_no < 0 || page_no * n_rows_of_page >= total_rows_in_category) { | |||
| MS_LOG(ERROR) << "Illegal page no: " << page_no << ", page size: " << n_rows_of_page; | |||
| return {FAILED, std::vector<std::tuple<std::vector<uint8_t>, json>>{}}; | |||
| } | |||
| Status ShardSegment::ReadAllAtPageById(int64_t category_id, int64_t page_no, int64_t n_rows_of_page, | |||
| std::shared_ptr<PAGES> *pages_ptr) { | |||
| RETURN_UNEXPECTED_IF_NULL(pages_ptr); | |||
| auto category_info_ptr = std::make_shared<CATEGORY_INFO>(); | |||
| RETURN_IF_NOT_OK(WrapCategoryInfo(&category_info_ptr)); | |||
| CHECK_FAIL_RETURN_UNEXPECTED(category_id < static_cast<int64_t>(category_info_ptr->size()), | |||
| "Invalid category id: " + std::to_string(category_id)); | |||
| std::vector<std::tuple<std::vector<uint8_t>, json>> page; | |||
| int total_rows_in_category = std::get<2>((*category_info_ptr)[category_id]); | |||
| // Quit if category not found or page number is out of range | |||
| CHECK_FAIL_RETURN_UNEXPECTED(total_rows_in_category > 0 && page_no >= 0 && n_rows_of_page > 0 && | |||
| page_no * n_rows_of_page < total_rows_in_category, | |||
| "Invalid page no / page size / total size, page no: " + std::to_string(page_no) + | |||
| ", page size of page: " + std::to_string(n_rows_of_page) + | |||
| ", total size: " + std::to_string(total_rows_in_category)); | |||
| auto row_group_summary = ReadRowGroupSummary(); | |||
| int i_start = page_no * n_rows_of_page; | |||
| int i_end = std::min(static_cast<int64_t>(total_rows_in_category), (page_no + 1) * n_rows_of_page); | |||
| int idx = 0; | |||
| for (const auto &rg : row_group_summary) { | |||
| if (idx >= i_end) break; | |||
| if (idx >= i_end) { | |||
| break; | |||
| } | |||
| auto shard_id = std::get<0>(rg); | |||
| auto group_id = std::get<1>(rg); | |||
| auto details = ReadRowGroupCriteria( | |||
| group_id, shard_id, std::make_pair(CleanUp(current_category_field_), std::get<1>(ret.second[category_id]))); | |||
| if (SUCCESS != std::get<0>(details)) { | |||
| return {FAILED, std::vector<std::tuple<std::vector<uint8_t>, json>>{}}; | |||
| } | |||
| auto offsets = std::get<4>(details); | |||
| auto labels = std::get<5>(details); | |||
| std::shared_ptr<ROW_GROUP_BRIEF> row_group_brief_ptr; | |||
| RETURN_IF_NOT_OK(ReadRowGroupCriteria( | |||
| group_id, shard_id, | |||
| std::make_pair(CleanUp(current_category_field_), std::get<1>((*category_info_ptr)[category_id])), {""}, | |||
| &row_group_brief_ptr)); | |||
| auto offsets = std::get<3>(*row_group_brief_ptr); | |||
| auto labels = std::get<4>(*row_group_brief_ptr); | |||
| int number_of_rows = offsets.size(); | |||
| if (idx + number_of_rows < i_start) { | |||
| idx += number_of_rows; | |||
| continue; | |||
| } | |||
| if (number_of_rows > static_cast<int>(labels.size())) { | |||
| MS_LOG(ERROR) << "Illegal row number of page: " << number_of_rows; | |||
| return {FAILED, std::vector<std::tuple<std::vector<uint8_t>, json>>{}}; | |||
| } | |||
| CHECK_FAIL_RETURN_UNEXPECTED(number_of_rows <= static_cast<int>(labels.size()), | |||
| "Invalid row number of page: " + number_of_rows); | |||
| for (int i = 0; i < number_of_rows; ++i, ++idx) { | |||
| if (idx >= i_start && idx < i_end) { | |||
| auto ret1 = PackImages(group_id, shard_id, offsets[i]); | |||
| if (SUCCESS != ret1.first) { | |||
| return {FAILED, std::vector<std::tuple<std::vector<uint8_t>, json>>{}}; | |||
| } | |||
| page.emplace_back(std::move(ret1.second), std::move(labels[i])); | |||
| auto images_ptr = std::make_shared<std::vector<uint8_t>>(); | |||
| RETURN_IF_NOT_OK(PackImages(group_id, shard_id, offsets[i], &images_ptr)); | |||
| (*pages_ptr)->emplace_back(std::move(*images_ptr), std::move(labels[i])); | |||
| } | |||
| } | |||
| } | |||
| return {SUCCESS, std::move(page)}; | |||
| return Status::OK(); | |||
| } | |||
| std::pair<MSRStatus, std::vector<std::tuple<std::vector<uint8_t>, json>>> ShardSegment::ReadAllAtPageByName( | |||
| std::string category_name, int64_t page_no, int64_t n_rows_of_page) { | |||
| auto ret = WrapCategoryInfo(); | |||
| if (ret.first != SUCCESS) { | |||
| MS_LOG(ERROR) << "Get category info"; | |||
| return {FAILED, std::vector<std::tuple<std::vector<uint8_t>, json>>{}}; | |||
| } | |||
| Status ShardSegment::ReadAllAtPageByName(std::string category_name, int64_t page_no, int64_t n_rows_of_page, | |||
| std::shared_ptr<PAGES> *pages_ptr) { | |||
| RETURN_UNEXPECTED_IF_NULL(pages_ptr); | |||
| auto category_info_ptr = std::make_shared<CATEGORY_INFO>(); | |||
| RETURN_IF_NOT_OK(WrapCategoryInfo(&category_info_ptr)); | |||
| // category_name to category_id | |||
| int64_t category_id = -1; | |||
| for (const auto &categories : ret.second) { | |||
| for (const auto &categories : *category_info_ptr) { | |||
| std::string categories_name = std::get<1>(categories); | |||
| if (categories_name == category_name) { | |||
| @@ -329,45 +305,8 @@ std::pair<MSRStatus, std::vector<std::tuple<std::vector<uint8_t>, json>>> ShardS | |||
| break; | |||
| } | |||
| } | |||
| if (category_id == -1) { | |||
| return {FAILED, std::vector<std::tuple<std::vector<uint8_t>, json>>{}}; | |||
| } | |||
| return ReadAllAtPageById(category_id, page_no, n_rows_of_page); | |||
| } | |||
| std::pair<MSRStatus, std::vector<std::tuple<std::vector<uint8_t>, pybind11::object>>> ShardSegment::ReadAtPageByIdPy( | |||
| int64_t category_id, int64_t page_no, int64_t n_rows_of_page) { | |||
| auto res = ReadAllAtPageById(category_id, page_no, n_rows_of_page); | |||
| if (res.first != SUCCESS) { | |||
| return {FAILED, std::vector<std::tuple<std::vector<uint8_t>, pybind11::object>>{}}; | |||
| } | |||
| vector<std::tuple<std::vector<uint8_t>, pybind11::object>> json_data; | |||
| std::transform(res.second.begin(), res.second.end(), std::back_inserter(json_data), | |||
| [](const std::tuple<std::vector<uint8_t>, json> &item) { | |||
| auto &j = std::get<1>(item); | |||
| pybind11::object obj = nlohmann::detail::FromJsonImpl(j); | |||
| return std::make_tuple(std::get<0>(item), std::move(obj)); | |||
| }); | |||
| return {SUCCESS, std::move(json_data)}; | |||
| } | |||
| std::pair<MSRStatus, std::vector<std::tuple<std::vector<uint8_t>, pybind11::object>>> ShardSegment::ReadAtPageByNamePy( | |||
| std::string category_name, int64_t page_no, int64_t n_rows_of_page) { | |||
| auto res = ReadAllAtPageByName(category_name, page_no, n_rows_of_page); | |||
| if (res.first != SUCCESS) { | |||
| return {FAILED, std::vector<std::tuple<std::vector<uint8_t>, pybind11::object>>{}}; | |||
| } | |||
| vector<std::tuple<std::vector<uint8_t>, pybind11::object>> json_data; | |||
| std::transform(res.second.begin(), res.second.end(), std::back_inserter(json_data), | |||
| [](const std::tuple<std::vector<uint8_t>, json> &item) { | |||
| auto &j = std::get<1>(item); | |||
| pybind11::object obj = nlohmann::detail::FromJsonImpl(j); | |||
| return std::make_tuple(std::get<0>(item), std::move(obj)); | |||
| }); | |||
| return {SUCCESS, std::move(json_data)}; | |||
| CHECK_FAIL_RETURN_UNEXPECTED(category_id != -1, "Invalid category name."); | |||
| return ReadAllAtPageById(category_id, page_no, n_rows_of_page, pages_ptr); | |||
| } | |||
| std::pair<ShardType, std::vector<std::string>> ShardSegment::GetBlobFields() { | |||
| @@ -382,7 +321,9 @@ std::pair<ShardType, std::vector<std::string>> ShardSegment::GetBlobFields() { | |||
| } | |||
| std::string ShardSegment::CleanUp(std::string field_name) { | |||
| while (field_name.back() >= '0' && field_name.back() <= '9') field_name.pop_back(); | |||
| while (field_name.back() >= '0' && field_name.back() <= '9') { | |||
| field_name.pop_back(); | |||
| } | |||
| field_name.pop_back(); | |||
| return field_name; | |||
| } | |||
| @@ -34,7 +34,7 @@ ShardCategory::ShardCategory(const std::string &category_field, int64_t num_elem | |||
| num_categories_(num_categories), | |||
| replacement_(replacement) {} | |||
| MSRStatus ShardCategory::Execute(ShardTaskList &tasks) { return SUCCESS; } | |||
| Status ShardCategory::Execute(ShardTaskList &tasks) { return Status::OK(); } | |||
| int64_t ShardCategory::GetNumSamples(int64_t dataset_size, int64_t num_classes) { | |||
| if (dataset_size == 0) return dataset_size; | |||
| @@ -72,36 +72,36 @@ void ShardColumn::Init(const json &schema_json, bool compress_integer) { | |||
| num_blob_column_ = blob_column_.size(); | |||
| } | |||
| std::pair<MSRStatus, ColumnCategory> ShardColumn::GetColumnTypeByName(const std::string &column_name, | |||
| ColumnDataType *column_data_type, | |||
| uint64_t *column_data_type_size, | |||
| std::vector<int64_t> *column_shape) { | |||
| Status ShardColumn::GetColumnTypeByName(const std::string &column_name, ColumnDataType *column_data_type, | |||
| uint64_t *column_data_type_size, std::vector<int64_t> *column_shape, | |||
| ColumnCategory *column_category) { | |||
| RETURN_UNEXPECTED_IF_NULL(column_data_type); | |||
| RETURN_UNEXPECTED_IF_NULL(column_data_type_size); | |||
| RETURN_UNEXPECTED_IF_NULL(column_shape); | |||
| RETURN_UNEXPECTED_IF_NULL(column_category); | |||
| // Skip if column not found | |||
| auto column_category = CheckColumnName(column_name); | |||
| if (column_category == ColumnNotFound) { | |||
| return {FAILED, ColumnNotFound}; | |||
| } | |||
| *column_category = CheckColumnName(column_name); | |||
| CHECK_FAIL_RETURN_UNEXPECTED(*column_category != ColumnNotFound, "Invalid column category."); | |||
| // Get data type and size | |||
| auto column_id = column_name_id_[column_name]; | |||
| *column_data_type = column_data_type_[column_id]; | |||
| *column_data_type_size = ColumnDataTypeSize[*column_data_type]; | |||
| *column_shape = column_shape_[column_id]; | |||
| return {SUCCESS, column_category}; | |||
| return Status::OK(); | |||
| } | |||
| MSRStatus ShardColumn::GetColumnValueByName(const std::string &column_name, const std::vector<uint8_t> &columns_blob, | |||
| const json &columns_json, const unsigned char **data, | |||
| std::unique_ptr<unsigned char[]> *data_ptr, uint64_t *const n_bytes, | |||
| ColumnDataType *column_data_type, uint64_t *column_data_type_size, | |||
| std::vector<int64_t> *column_shape) { | |||
| Status ShardColumn::GetColumnValueByName(const std::string &column_name, const std::vector<uint8_t> &columns_blob, | |||
| const json &columns_json, const unsigned char **data, | |||
| std::unique_ptr<unsigned char[]> *data_ptr, uint64_t *const n_bytes, | |||
| ColumnDataType *column_data_type, uint64_t *column_data_type_size, | |||
| std::vector<int64_t> *column_shape) { | |||
| RETURN_UNEXPECTED_IF_NULL(column_data_type); | |||
| RETURN_UNEXPECTED_IF_NULL(column_data_type_size); | |||
| RETURN_UNEXPECTED_IF_NULL(column_shape); | |||
| // Skip if column not found | |||
| auto column_category = CheckColumnName(column_name); | |||
| if (column_category == ColumnNotFound) { | |||
| return FAILED; | |||
| } | |||
| CHECK_FAIL_RETURN_UNEXPECTED(column_category != ColumnNotFound, "Invalid column category."); | |||
| // Get data type and size | |||
| auto column_id = column_name_id_[column_name]; | |||
| *column_data_type = column_data_type_[column_id]; | |||
| @@ -110,37 +110,31 @@ MSRStatus ShardColumn::GetColumnValueByName(const std::string &column_name, cons | |||
| // Retrieve value from json | |||
| if (column_category == ColumnInRaw) { | |||
| if (GetColumnFromJson(column_name, columns_json, data_ptr, n_bytes) == FAILED) { | |||
| MS_LOG(ERROR) << "Error when get data from json, column name is " << column_name << "."; | |||
| return FAILED; | |||
| } | |||
| RETURN_IF_NOT_OK(GetColumnFromJson(column_name, columns_json, data_ptr, n_bytes)); | |||
| *data = reinterpret_cast<const unsigned char *>(data_ptr->get()); | |||
| return SUCCESS; | |||
| return Status::OK(); | |||
| } | |||
| // Retrieve value from blob | |||
| if (GetColumnFromBlob(column_name, columns_blob, data, data_ptr, n_bytes) == FAILED) { | |||
| MS_LOG(ERROR) << "Error when get data from blob, column name is " << column_name << "."; | |||
| return FAILED; | |||
| } | |||
| RETURN_IF_NOT_OK(GetColumnFromBlob(column_name, columns_blob, data, data_ptr, n_bytes)); | |||
| if (*data == nullptr) { | |||
| *data = reinterpret_cast<const unsigned char *>(data_ptr->get()); | |||
| } | |||
| return SUCCESS; | |||
| return Status::OK(); | |||
| } | |||
| MSRStatus ShardColumn::GetColumnFromJson(const std::string &column_name, const json &columns_json, | |||
| std::unique_ptr<unsigned char[]> *data_ptr, uint64_t *n_bytes) { | |||
| Status ShardColumn::GetColumnFromJson(const std::string &column_name, const json &columns_json, | |||
| std::unique_ptr<unsigned char[]> *data_ptr, uint64_t *n_bytes) { | |||
| RETURN_UNEXPECTED_IF_NULL(n_bytes); | |||
| RETURN_UNEXPECTED_IF_NULL(data_ptr); | |||
| auto column_id = column_name_id_[column_name]; | |||
| auto column_data_type = column_data_type_[column_id]; | |||
| // Initialize num bytes | |||
| *n_bytes = ColumnDataTypeSize[column_data_type]; | |||
| auto json_column_value = columns_json[column_name]; | |||
| if (!json_column_value.is_string() && !json_column_value.is_number()) { | |||
| MS_LOG(ERROR) << "Conversion failed (" << json_column_value << ")."; | |||
| return FAILED; | |||
| } | |||
| CHECK_FAIL_RETURN_UNEXPECTED(json_column_value.is_string() || json_column_value.is_number(), | |||
| "Conversion to string or number failed (" + json_column_value.dump() + ")."); | |||
| switch (column_data_type) { | |||
| case ColumnFloat32: { | |||
| return GetFloat<float>(data_ptr, json_column_value, false); | |||
| @@ -171,12 +165,13 @@ MSRStatus ShardColumn::GetColumnFromJson(const std::string &column_name, const j | |||
| break; | |||
| } | |||
| } | |||
| return SUCCESS; | |||
| return Status::OK(); | |||
| } | |||
| template <typename T> | |||
| MSRStatus ShardColumn::GetFloat(std::unique_ptr<unsigned char[]> *data_ptr, const json &json_column_value, | |||
| bool use_double) { | |||
| Status ShardColumn::GetFloat(std::unique_ptr<unsigned char[]> *data_ptr, const json &json_column_value, | |||
| bool use_double) { | |||
| RETURN_UNEXPECTED_IF_NULL(data_ptr); | |||
| std::unique_ptr<T[]> array_data = std::make_unique<T[]>(1); | |||
| if (json_column_value.is_number()) { | |||
| array_data[0] = json_column_value; | |||
| @@ -189,8 +184,7 @@ MSRStatus ShardColumn::GetFloat(std::unique_ptr<unsigned char[]> *data_ptr, cons | |||
| array_data[0] = json_column_value.get<float>(); | |||
| } | |||
| } catch (json::exception &e) { | |||
| MS_LOG(ERROR) << "Conversion to float failed (" << json_column_value << ")."; | |||
| return FAILED; | |||
| RETURN_STATUS_UNEXPECTED("Conversion to float failed (" + json_column_value.dump() + ")."); | |||
| } | |||
| } | |||
| @@ -199,54 +193,43 @@ MSRStatus ShardColumn::GetFloat(std::unique_ptr<unsigned char[]> *data_ptr, cons | |||
| for (uint32_t i = 0; i < sizeof(T); i++) { | |||
| (*data_ptr)[i] = *(data + i); | |||
| } | |||
| return SUCCESS; | |||
| return Status::OK(); | |||
| } | |||
| template <typename T> | |||
| MSRStatus ShardColumn::GetInt(std::unique_ptr<unsigned char[]> *data_ptr, const json &json_column_value) { | |||
| Status ShardColumn::GetInt(std::unique_ptr<unsigned char[]> *data_ptr, const json &json_column_value) { | |||
| RETURN_UNEXPECTED_IF_NULL(data_ptr); | |||
| std::unique_ptr<T[]> array_data = std::make_unique<T[]>(1); | |||
| int64_t temp_value; | |||
| bool less_than_zero = false; | |||
| if (json_column_value.is_number_integer()) { | |||
| const json json_zero = 0; | |||
| if (json_column_value < json_zero) less_than_zero = true; | |||
| if (json_column_value < json_zero) { | |||
| less_than_zero = true; | |||
| } | |||
| temp_value = json_column_value; | |||
| } else if (json_column_value.is_string()) { | |||
| std::string string_value = json_column_value; | |||
| if (!string_value.empty() && string_value[0] == '-') { | |||
| try { | |||
| try { | |||
| if (!string_value.empty() && string_value[0] == '-') { | |||
| temp_value = std::stoll(string_value); | |||
| less_than_zero = true; | |||
| } catch (std::invalid_argument &e) { | |||
| MS_LOG(ERROR) << "Conversion to int failed, invalid argument."; | |||
| return FAILED; | |||
| } catch (std::out_of_range &e) { | |||
| MS_LOG(ERROR) << "Conversion to int failed, out of range."; | |||
| return FAILED; | |||
| } | |||
| } else { | |||
| try { | |||
| } else { | |||
| temp_value = static_cast<int64_t>(std::stoull(string_value)); | |||
| } catch (std::invalid_argument &e) { | |||
| MS_LOG(ERROR) << "Conversion to int failed, invalid argument."; | |||
| return FAILED; | |||
| } catch (std::out_of_range &e) { | |||
| MS_LOG(ERROR) << "Conversion to int failed, out of range."; | |||
| return FAILED; | |||
| } | |||
| } catch (std::invalid_argument &e) { | |||
| RETURN_STATUS_UNEXPECTED("Conversion to int failed: " + std::string(e.what())); | |||
| } catch (std::out_of_range &e) { | |||
| RETURN_STATUS_UNEXPECTED("Conversion to int failed: " + std::string(e.what())); | |||
| } | |||
| } else { | |||
| MS_LOG(ERROR) << "Conversion to int failed."; | |||
| return FAILED; | |||
| RETURN_STATUS_UNEXPECTED("Conversion to int failed."); | |||
| } | |||
| if ((less_than_zero && temp_value < static_cast<int64_t>(std::numeric_limits<T>::min())) || | |||
| (!less_than_zero && static_cast<uint64_t>(temp_value) > static_cast<uint64_t>(std::numeric_limits<T>::max()))) { | |||
| MS_LOG(ERROR) << "Conversion to int failed. Out of range"; | |||
| return FAILED; | |||
| RETURN_STATUS_UNEXPECTED("Conversion to int failed, out of range."); | |||
| } | |||
| array_data[0] = static_cast<T>(temp_value); | |||
| @@ -255,33 +238,26 @@ MSRStatus ShardColumn::GetInt(std::unique_ptr<unsigned char[]> *data_ptr, const | |||
| for (uint32_t i = 0; i < sizeof(T); i++) { | |||
| (*data_ptr)[i] = *(data + i); | |||
| } | |||
| return SUCCESS; | |||
| return Status::OK(); | |||
| } | |||
| MSRStatus ShardColumn::GetColumnFromBlob(const std::string &column_name, const std::vector<uint8_t> &columns_blob, | |||
| const unsigned char **data, std::unique_ptr<unsigned char[]> *data_ptr, | |||
| uint64_t *const n_bytes) { | |||
| Status ShardColumn::GetColumnFromBlob(const std::string &column_name, const std::vector<uint8_t> &columns_blob, | |||
| const unsigned char **data, std::unique_ptr<unsigned char[]> *data_ptr, | |||
| uint64_t *const n_bytes) { | |||
| RETURN_UNEXPECTED_IF_NULL(data); | |||
| uint64_t offset_address = 0; | |||
| auto column_id = column_name_id_[column_name]; | |||
| if (GetColumnAddressInBlock(column_id, columns_blob, n_bytes, &offset_address) == FAILED) { | |||
| return FAILED; | |||
| } | |||
| RETURN_IF_NOT_OK(GetColumnAddressInBlock(column_id, columns_blob, n_bytes, &offset_address)); | |||
| auto column_data_type = column_data_type_[column_id]; | |||
| if (has_compress_blob_ && column_data_type == ColumnInt32) { | |||
| if (UncompressInt<int32_t>(column_id, data_ptr, columns_blob, n_bytes, offset_address) == FAILED) { | |||
| return FAILED; | |||
| } | |||
| RETURN_IF_NOT_OK(UncompressInt<int32_t>(column_id, data_ptr, columns_blob, n_bytes, offset_address)); | |||
| } else if (has_compress_blob_ && column_data_type == ColumnInt64) { | |||
| if (UncompressInt<int64_t>(column_id, data_ptr, columns_blob, n_bytes, offset_address) == FAILED) { | |||
| return FAILED; | |||
| } | |||
| RETURN_IF_NOT_OK(UncompressInt<int64_t>(column_id, data_ptr, columns_blob, n_bytes, offset_address)); | |||
| } else { | |||
| *data = reinterpret_cast<const unsigned char *>(&(columns_blob[offset_address])); | |||
| } | |||
| return SUCCESS; | |||
| return Status::OK(); | |||
| } | |||
| ColumnCategory ShardColumn::CheckColumnName(const std::string &column_name) { | |||
| @@ -296,7 +272,9 @@ ColumnCategory ShardColumn::CheckColumnName(const std::string &column_name) { | |||
| std::vector<uint8_t> ShardColumn::CompressBlob(const std::vector<uint8_t> &blob, int64_t *compression_size) { | |||
| // Skip if no compress columns | |||
| *compression_size = 0; | |||
| if (!CheckCompressBlob()) return blob; | |||
| if (!CheckCompressBlob()) { | |||
| return blob; | |||
| } | |||
| std::vector<uint8_t> dst_blob; | |||
| uint64_t i_src = 0; | |||
| @@ -380,12 +358,14 @@ vector<uint8_t> ShardColumn::CompressInt(const vector<uint8_t> &src_bytes, const | |||
| return dst_bytes; | |||
| } | |||
| MSRStatus ShardColumn::GetColumnAddressInBlock(const uint64_t &column_id, const std::vector<uint8_t> &columns_blob, | |||
| uint64_t *num_bytes, uint64_t *shift_idx) { | |||
| Status ShardColumn::GetColumnAddressInBlock(const uint64_t &column_id, const std::vector<uint8_t> &columns_blob, | |||
| uint64_t *num_bytes, uint64_t *shift_idx) { | |||
| RETURN_UNEXPECTED_IF_NULL(num_bytes); | |||
| RETURN_UNEXPECTED_IF_NULL(shift_idx); | |||
| if (num_blob_column_ == 1) { | |||
| *num_bytes = columns_blob.size(); | |||
| *shift_idx = 0; | |||
| return SUCCESS; | |||
| return Status::OK(); | |||
| } | |||
| auto blob_id = blob_column_id_[column_name_[column_id]]; | |||
| @@ -396,13 +376,14 @@ MSRStatus ShardColumn::GetColumnAddressInBlock(const uint64_t &column_id, const | |||
| (*shift_idx) += kInt64Len; | |||
| return SUCCESS; | |||
| return Status::OK(); | |||
| } | |||
| template <typename T> | |||
| MSRStatus ShardColumn::UncompressInt(const uint64_t &column_id, std::unique_ptr<unsigned char[]> *const data_ptr, | |||
| const std::vector<uint8_t> &columns_blob, uint64_t *num_bytes, | |||
| uint64_t shift_idx) { | |||
| Status ShardColumn::UncompressInt(const uint64_t &column_id, std::unique_ptr<unsigned char[]> *const data_ptr, | |||
| const std::vector<uint8_t> &columns_blob, uint64_t *num_bytes, uint64_t shift_idx) { | |||
| RETURN_UNEXPECTED_IF_NULL(data_ptr); | |||
| RETURN_UNEXPECTED_IF_NULL(num_bytes); | |||
| auto num_elements = BytesBigToUInt64(columns_blob, shift_idx, kInt32Type); | |||
| *num_bytes = sizeof(T) * num_elements; | |||
| @@ -421,19 +402,12 @@ MSRStatus ShardColumn::UncompressInt(const uint64_t &column_id, std::unique_ptr< | |||
| auto data = reinterpret_cast<const unsigned char *>(array_data.get()); | |||
| *data_ptr = std::make_unique<unsigned char[]>(*num_bytes); | |||
| // field is none. for example: numpy is null | |||
| if (*num_bytes == 0) { | |||
| return SUCCESS; | |||
| return Status::OK(); | |||
| } | |||
| int ret_code = memcpy_s(data_ptr->get(), *num_bytes, data, *num_bytes); | |||
| if (ret_code != 0) { | |||
| MS_LOG(ERROR) << "Failed to copy data!"; | |||
| return FAILED; | |||
| } | |||
| return SUCCESS; | |||
| CHECK_FAIL_RETURN_UNEXPECTED(memcpy_s(data_ptr->get(), *num_bytes, data, *num_bytes) == 0, "Failed to copy data!"); | |||
| return Status::OK(); | |||
| } | |||
| uint64_t ShardColumn::BytesBigToUInt64(const std::vector<uint8_t> &bytes_array, const uint64_t &pos, | |||
| @@ -55,15 +55,11 @@ int64_t ShardDistributedSample::GetNumSamples(int64_t dataset_size, int64_t num_ | |||
| return 0; | |||
| } | |||
| MSRStatus ShardDistributedSample::PreExecute(ShardTaskList &tasks) { | |||
| Status ShardDistributedSample::PreExecute(ShardTaskList &tasks) { | |||
| auto total_no = tasks.Size(); | |||
| if (no_of_padded_samples_ > 0 && first_epoch_) { | |||
| if (total_no % denominator_ != 0) { | |||
| MS_LOG(ERROR) << "Dataset size plus number of padded samples is not divisible by number of shards. " | |||
| << "task size: " << total_no << ", number padded: " << no_of_padded_samples_ | |||
| << ", denominator: " << denominator_; | |||
| return FAILED; | |||
| } | |||
| CHECK_FAIL_RETURN_UNEXPECTED(total_no % denominator_ == 0, | |||
| "Dataset size plus number of padded samples is not divisible by number of shards."); | |||
| } | |||
| if (first_epoch_) { | |||
| first_epoch_ = false; | |||
| @@ -74,11 +70,9 @@ MSRStatus ShardDistributedSample::PreExecute(ShardTaskList &tasks) { | |||
| if (shuffle_ == true) { | |||
| shuffle_op_->SetShardSampleCount(GetShardSampleCount()); | |||
| shuffle_op_->UpdateShuffleMode(GetShuffleMode()); | |||
| if (SUCCESS != (*shuffle_op_)(tasks)) { | |||
| return FAILED; | |||
| } | |||
| RETURN_IF_NOT_OK((*shuffle_op_)(tasks)); | |||
| } | |||
| return SUCCESS; | |||
| return Status::OK(); | |||
| } | |||
| } // namespace mindrecord | |||
| } // namespace mindspore | |||
| @@ -38,104 +38,74 @@ ShardHeader::ShardHeader() : shard_count_(0), header_size_(0), page_size_(0), co | |||
| index_ = std::make_shared<Index>(); | |||
| } | |||
| MSRStatus ShardHeader::InitializeHeader(const std::vector<json> &headers, bool load_dataset) { | |||
| Status ShardHeader::InitializeHeader(const std::vector<json> &headers, bool load_dataset) { | |||
| shard_count_ = headers.size(); | |||
| int shard_index = 0; | |||
| bool first = true; | |||
| for (const auto &header : headers) { | |||
| if (first) { | |||
| first = false; | |||
| if (ParseSchema(header["schema"]) != SUCCESS) { | |||
| return FAILED; | |||
| } | |||
| if (ParseIndexFields(header["index_fields"]) != SUCCESS) { | |||
| return FAILED; | |||
| } | |||
| if (ParseStatistics(header["statistics"]) != SUCCESS) { | |||
| return FAILED; | |||
| } | |||
| RETURN_IF_NOT_OK(ParseSchema(header["schema"])); | |||
| RETURN_IF_NOT_OK(ParseIndexFields(header["index_fields"])); | |||
| RETURN_IF_NOT_OK(ParseStatistics(header["statistics"])); | |||
| ParseShardAddress(header["shard_addresses"]); | |||
| header_size_ = header["header_size"].get<uint64_t>(); | |||
| page_size_ = header["page_size"].get<uint64_t>(); | |||
| compression_size_ = header.contains("compression_size") ? header["compression_size"].get<uint64_t>() : 0; | |||
| } | |||
| if (SUCCESS != ParsePage(header["page"], shard_index, load_dataset)) { | |||
| return FAILED; | |||
| } | |||
| RETURN_IF_NOT_OK(ParsePage(header["page"], shard_index, load_dataset)); | |||
| shard_index++; | |||
| } | |||
| return SUCCESS; | |||
| return Status::OK(); | |||
| } | |||
| MSRStatus ShardHeader::CheckFileStatus(const std::string &path) { | |||
| Status ShardHeader::CheckFileStatus(const std::string &path) { | |||
| auto realpath = Common::GetRealPath(path); | |||
| if (!realpath.has_value()) { | |||
| MS_LOG(ERROR) << "Get real path failed, path=" << path; | |||
| return FAILED; | |||
| } | |||
| CHECK_FAIL_RETURN_UNEXPECTED(realpath.has_value(), "Get real path failed, path: " + path); | |||
| std::ifstream fin(realpath.value(), std::ios::in | std::ios::binary); | |||
| if (!fin) { | |||
| MS_LOG(ERROR) << "File does not exist or permission denied. path: " << path; | |||
| return FAILED; | |||
| } | |||
| if (fin.fail()) { | |||
| MS_LOG(ERROR) << "Failed to open file. path: " << path; | |||
| return FAILED; | |||
| } | |||
| CHECK_FAIL_RETURN_UNEXPECTED(fin, "Failed to open file. path: " + path); | |||
| // fetch file size | |||
| auto &io_seekg = fin.seekg(0, std::ios::end); | |||
| if (!io_seekg.good() || io_seekg.fail() || io_seekg.bad()) { | |||
| fin.close(); | |||
| MS_LOG(ERROR) << "File seekg failed. path: " << path; | |||
| return FAILED; | |||
| RETURN_STATUS_UNEXPECTED("File seekg failed. path: " + path); | |||
| } | |||
| size_t file_size = fin.tellg(); | |||
| if (file_size < kMinFileSize) { | |||
| fin.close(); | |||
| MS_LOG(ERROR) << "Invalid file. path: " << path; | |||
| return FAILED; | |||
| RETURN_STATUS_UNEXPECTED("Invalid file. path: " + path); | |||
| } | |||
| fin.close(); | |||
| return SUCCESS; | |||
| return Status::OK(); | |||
| } | |||
| std::pair<MSRStatus, json> ShardHeader::ValidateHeader(const std::string &path) { | |||
| if (CheckFileStatus(path) != SUCCESS) { | |||
| return {FAILED, {}}; | |||
| } | |||
| Status ShardHeader::ValidateHeader(const std::string &path, std::shared_ptr<json> *header_ptr) { | |||
| RETURN_UNEXPECTED_IF_NULL(header_ptr); | |||
| RETURN_IF_NOT_OK(CheckFileStatus(path)); | |||
| // read header size | |||
| json json_header; | |||
| std::ifstream fin(common::SafeCStr(path), std::ios::in | std::ios::binary); | |||
| if (!fin.is_open()) { | |||
| MS_LOG(ERROR) << "File seekg failed. path: " << path; | |||
| return {FAILED, json_header}; | |||
| } | |||
| CHECK_FAIL_RETURN_UNEXPECTED(fin.is_open(), "Failed to open file. path: " + path); | |||
| uint64_t header_size = 0; | |||
| auto &io_read = fin.read(reinterpret_cast<char *>(&header_size), kInt64Len); | |||
| if (!io_read.good() || io_read.fail() || io_read.bad()) { | |||
| MS_LOG(ERROR) << "File read failed"; | |||
| fin.close(); | |||
| return {FAILED, json_header}; | |||
| RETURN_STATUS_UNEXPECTED("File read failed"); | |||
| } | |||
| if (header_size > kMaxHeaderSize) { | |||
| fin.close(); | |||
| MS_LOG(ERROR) << "Invalid file content. path: " << path; | |||
| return {FAILED, json_header}; | |||
| RETURN_STATUS_UNEXPECTED("Invalid file content. path: " + path); | |||
| } | |||
| // read header content | |||
| std::vector<uint8_t> header_content(header_size); | |||
| auto &io_read_content = fin.read(reinterpret_cast<char *>(&header_content[0]), header_size); | |||
| if (!io_read_content.good() || io_read_content.fail() || io_read_content.bad()) { | |||
| MS_LOG(ERROR) << "File read failed. path: " << path; | |||
| fin.close(); | |||
| return {FAILED, json_header}; | |||
| RETURN_STATUS_UNEXPECTED("File read failed. path: " + path); | |||
| } | |||
| fin.close(); | |||
| @@ -144,34 +114,35 @@ std::pair<MSRStatus, json> ShardHeader::ValidateHeader(const std::string &path) | |||
| try { | |||
| json_header = json::parse(raw_header_content); | |||
| } catch (json::parse_error &e) { | |||
| MS_LOG(ERROR) << "Json parse error: " << e.what(); | |||
| return {FAILED, json_header}; | |||
| RETURN_STATUS_UNEXPECTED("Json parse error: " + std::string(e.what())); | |||
| } | |||
| return {SUCCESS, json_header}; | |||
| *header_ptr = std::make_shared<json>(json_header); | |||
| return Status::OK(); | |||
| } | |||
| std::pair<MSRStatus, json> ShardHeader::BuildSingleHeader(const std::string &file_path) { | |||
| auto ret = ValidateHeader(file_path); | |||
| if (SUCCESS != ret.first) { | |||
| return {FAILED, json()}; | |||
| } | |||
| json raw_header = ret.second; | |||
| Status ShardHeader::BuildSingleHeader(const std::string &file_path, std::shared_ptr<json> *header_ptr) { | |||
| RETURN_UNEXPECTED_IF_NULL(header_ptr); | |||
| std::shared_ptr<json> raw_header; | |||
| RETURN_IF_NOT_OK(ValidateHeader(file_path, &raw_header)); | |||
| uint64_t compression_size = | |||
| raw_header.contains("compression_size") ? raw_header["compression_size"].get<uint64_t>() : 0; | |||
| json header = {{"shard_addresses", raw_header["shard_addresses"]}, | |||
| {"header_size", raw_header["header_size"]}, | |||
| {"page_size", raw_header["page_size"]}, | |||
| raw_header->contains("compression_size") ? (*raw_header)["compression_size"].get<uint64_t>() : 0; | |||
| json header = {{"shard_addresses", (*raw_header)["shard_addresses"]}, | |||
| {"header_size", (*raw_header)["header_size"]}, | |||
| {"page_size", (*raw_header)["page_size"]}, | |||
| {"compression_size", compression_size}, | |||
| {"index_fields", raw_header["index_fields"]}, | |||
| {"blob_fields", raw_header["schema"][0]["blob_fields"]}, | |||
| {"schema", raw_header["schema"][0]["schema"]}, | |||
| {"version", raw_header["version"]}}; | |||
| return {SUCCESS, header}; | |||
| {"index_fields", (*raw_header)["index_fields"]}, | |||
| {"blob_fields", (*raw_header)["schema"][0]["blob_fields"]}, | |||
| {"schema", (*raw_header)["schema"][0]["schema"]}, | |||
| {"version", (*raw_header)["version"]}}; | |||
| *header_ptr = std::make_shared<json>(header); | |||
| return Status::OK(); | |||
| } | |||
| MSRStatus ShardHeader::BuildDataset(const std::vector<std::string> &file_paths, bool load_dataset) { | |||
| Status ShardHeader::BuildDataset(const std::vector<std::string> &file_paths, bool load_dataset) { | |||
| uint32_t thread_num = std::thread::hardware_concurrency(); | |||
| if (thread_num == 0) thread_num = kThreadNumber; | |||
| if (thread_num == 0) { | |||
| thread_num = kThreadNumber; | |||
| } | |||
| uint32_t work_thread_num = 0; | |||
| uint32_t shard_count = file_paths.size(); | |||
| int group_num = ceil(shard_count * 1.0 / thread_num); | |||
| @@ -194,12 +165,10 @@ MSRStatus ShardHeader::BuildDataset(const std::vector<std::string> &file_paths, | |||
| } | |||
| if (thread_status) { | |||
| thread_status = false; | |||
| return FAILED; | |||
| RETURN_STATUS_UNEXPECTED("Error occurred in GetHeadersOneTask thread."); | |||
| } | |||
| if (SUCCESS != InitializeHeader(headers, load_dataset)) { | |||
| return FAILED; | |||
| } | |||
| return SUCCESS; | |||
| RETURN_IF_NOT_OK(InitializeHeader(headers, load_dataset)); | |||
| return Status::OK(); | |||
| } | |||
| void ShardHeader::GetHeadersOneTask(int start, int end, std::vector<json> &headers, | |||
| @@ -208,48 +177,39 @@ void ShardHeader::GetHeadersOneTask(int start, int end, std::vector<json> &heade | |||
| return; | |||
| } | |||
| for (int x = start; x < end; ++x) { | |||
| auto ret = ValidateHeader(realAddresses[x]); | |||
| if (SUCCESS != ret.first) { | |||
| std::shared_ptr<json> header; | |||
| if (ValidateHeader(realAddresses[x], &header).IsError()) { | |||
| thread_status = true; | |||
| return; | |||
| } | |||
| json header; | |||
| header = ret.second; | |||
| header["shard_addresses"] = realAddresses; | |||
| if (std::find(kSupportedVersion.begin(), kSupportedVersion.end(), header["version"]) == kSupportedVersion.end()) { | |||
| MS_LOG(ERROR) << "Version wrong, file version is: " << header["version"].dump() | |||
| (*header)["shard_addresses"] = realAddresses; | |||
| if (std::find(kSupportedVersion.begin(), kSupportedVersion.end(), (*header)["version"]) == | |||
| kSupportedVersion.end()) { | |||
| MS_LOG(ERROR) << "Version wrong, file version is: " << (*header)["version"].dump() | |||
| << ", lib version is: " << kVersion; | |||
| thread_status = true; | |||
| return; | |||
| } | |||
| headers[x] = header; | |||
| headers[x] = *header; | |||
| } | |||
| } | |||
| MSRStatus ShardHeader::InitByFiles(const std::vector<std::string> &file_paths) { | |||
| Status ShardHeader::InitByFiles(const std::vector<std::string> &file_paths) { | |||
| std::vector<std::string> file_names(file_paths.size()); | |||
| std::transform(file_paths.begin(), file_paths.end(), file_names.begin(), [](std::string fp) -> std::string { | |||
| if (GetFileName(fp).first == SUCCESS) { | |||
| return GetFileName(fp).second; | |||
| } | |||
| std::shared_ptr<std::string> fn; | |||
| return GetFileName(fp, &fn).IsOk() ? *fn : ""; | |||
| }); | |||
| shard_addresses_ = std::move(file_names); | |||
| shard_count_ = file_paths.size(); | |||
| if (shard_count_ == 0) { | |||
| return FAILED; | |||
| } | |||
| if (shard_count_ <= kMaxShardCount) { | |||
| pages_.resize(shard_count_); | |||
| } else { | |||
| return FAILED; | |||
| } | |||
| return SUCCESS; | |||
| CHECK_FAIL_RETURN_UNEXPECTED(shard_count_ != 0 && (shard_count_ <= kMaxShardCount), | |||
| "shard count is invalid. shard count: " + std::to_string(shard_count_)); | |||
| pages_.resize(shard_count_); | |||
| return Status::OK(); | |||
| } | |||
| void ShardHeader::ParseHeader(const json &header) {} | |||
| MSRStatus ShardHeader::ParseIndexFields(const json &index_fields) { | |||
| Status ShardHeader::ParseIndexFields(const json &index_fields) { | |||
| std::vector<std::pair<uint64_t, std::string>> parsed_index_fields; | |||
| for (auto &index_field : index_fields) { | |||
| auto schema_id = index_field["schema_id"].get<uint64_t>(); | |||
| @@ -257,18 +217,15 @@ MSRStatus ShardHeader::ParseIndexFields(const json &index_fields) { | |||
| std::pair<uint64_t, std::string> parsed_index_field(schema_id, field_name); | |||
| parsed_index_fields.push_back(parsed_index_field); | |||
| } | |||
| if (!parsed_index_fields.empty() && AddIndexFields(parsed_index_fields) != SUCCESS) { | |||
| return FAILED; | |||
| } | |||
| return SUCCESS; | |||
| RETURN_IF_NOT_OK(AddIndexFields(parsed_index_fields)); | |||
| return Status::OK(); | |||
| } | |||
| MSRStatus ShardHeader::ParsePage(const json &pages, int shard_index, bool load_dataset) { | |||
| Status ShardHeader::ParsePage(const json &pages, int shard_index, bool load_dataset) { | |||
| // set shard_index when load_dataset is false | |||
| if (shard_count_ > kMaxFileCount) { | |||
| MS_LOG(ERROR) << "The number of mindrecord files is greater than max value: " << kMaxFileCount; | |||
| return FAILED; | |||
| } | |||
| CHECK_FAIL_RETURN_UNEXPECTED( | |||
| shard_count_ <= kMaxFileCount, | |||
| "The number of mindrecord files is greater than max value: " + std::to_string(kMaxFileCount)); | |||
| if (pages_.empty() && shard_count_ <= kMaxFileCount) { | |||
| pages_.resize(shard_count_); | |||
| } | |||
| @@ -295,44 +252,37 @@ MSRStatus ShardHeader::ParsePage(const json &pages, int shard_index, bool load_d | |||
| pages_[shard_index].push_back(std::move(parsed_page)); | |||
| } | |||
| } | |||
| return SUCCESS; | |||
| return Status::OK(); | |||
| } | |||
| MSRStatus ShardHeader::ParseStatistics(const json &statistics) { | |||
| Status ShardHeader::ParseStatistics(const json &statistics) { | |||
| for (auto &statistic : statistics) { | |||
| if (statistic.find("desc") == statistic.end() || statistic.find("statistics") == statistic.end()) { | |||
| MS_LOG(ERROR) << "Deserialize statistics failed, statistic: " << statistics.dump(); | |||
| return FAILED; | |||
| } | |||
| CHECK_FAIL_RETURN_UNEXPECTED( | |||
| statistic.find("desc") != statistic.end() && statistic.find("statistics") != statistic.end(), | |||
| "Deserialize statistics failed, statistic: " + statistics.dump()); | |||
| std::string statistic_description = statistic["desc"].get<std::string>(); | |||
| json statistic_body = statistic["statistics"]; | |||
| std::shared_ptr<Statistics> parsed_statistic = Statistics::Build(statistic_description, statistic_body); | |||
| if (!parsed_statistic) { | |||
| return FAILED; | |||
| } | |||
| RETURN_UNEXPECTED_IF_NULL(parsed_statistic); | |||
| AddStatistic(parsed_statistic); | |||
| } | |||
| return SUCCESS; | |||
| return Status::OK(); | |||
| } | |||
| MSRStatus ShardHeader::ParseSchema(const json &schemas) { | |||
| Status ShardHeader::ParseSchema(const json &schemas) { | |||
| for (auto &schema : schemas) { | |||
| // change how we get schemaBody once design is finalized | |||
| if (schema.find("desc") == schema.end() || schema.find("blob_fields") == schema.end() || | |||
| schema.find("schema") == schema.end()) { | |||
| MS_LOG(ERROR) << "Deserialize schema failed. schema: " << schema.dump(); | |||
| return FAILED; | |||
| } | |||
| CHECK_FAIL_RETURN_UNEXPECTED(schema.find("desc") != schema.end() && schema.find("blob_fields") != schema.end() && | |||
| schema.find("schema") != schema.end(), | |||
| "Deserialize schema failed. schema: " + schema.dump()); | |||
| std::string schema_description = schema["desc"].get<std::string>(); | |||
| std::vector<std::string> blob_fields = schema["blob_fields"].get<std::vector<std::string>>(); | |||
| json schema_body = schema["schema"]; | |||
| std::shared_ptr<Schema> parsed_schema = Schema::Build(schema_description, schema_body); | |||
| if (!parsed_schema) { | |||
| return FAILED; | |||
| } | |||
| RETURN_UNEXPECTED_IF_NULL(parsed_schema); | |||
| AddSchema(parsed_schema); | |||
| } | |||
| return SUCCESS; | |||
| return Status::OK(); | |||
| } | |||
| void ShardHeader::ParseShardAddress(const json &address) { | |||
| @@ -340,7 +290,7 @@ void ShardHeader::ParseShardAddress(const json &address) { | |||
| } | |||
| std::vector<std::string> ShardHeader::SerializeHeader() { | |||
| std::vector<string> header; | |||
| std::vector<std::string> header; | |||
| auto index = SerializeIndexFields(); | |||
| auto stats = SerializeStatistics(); | |||
| auto schema = SerializeSchema(); | |||
| @@ -406,45 +356,42 @@ std::string ShardHeader::SerializeSchema() { | |||
| std::string ShardHeader::SerializeShardAddress() { | |||
| json j; | |||
| (void)std::transform(shard_addresses_.begin(), shard_addresses_.end(), std::back_inserter(j), | |||
| [](const std::string &addr) { return GetFileName(addr).second; }); | |||
| std::shared_ptr<std::string> fn_ptr; | |||
| for (const auto &addr : shard_addresses_) { | |||
| (void)GetFileName(addr, &fn_ptr); | |||
| j.emplace_back(*fn_ptr); | |||
| } | |||
| return j.dump(); | |||
| } | |||
| std::pair<std::shared_ptr<Page>, MSRStatus> ShardHeader::GetPage(const int &shard_id, const int &page_id) { | |||
| Status ShardHeader::GetPage(const int &shard_id, const int &page_id, std::shared_ptr<Page> *page_ptr) { | |||
| RETURN_UNEXPECTED_IF_NULL(page_ptr); | |||
| if (shard_id < static_cast<int>(pages_.size()) && page_id < static_cast<int>(pages_[shard_id].size())) { | |||
| return std::make_pair(pages_[shard_id][page_id], SUCCESS); | |||
| } else { | |||
| return std::make_pair(nullptr, FAILED); | |||
| *page_ptr = pages_[shard_id][page_id]; | |||
| return Status::OK(); | |||
| } | |||
| page_ptr = nullptr; | |||
| RETURN_STATUS_UNEXPECTED("Failed to Get Page."); | |||
| } | |||
| MSRStatus ShardHeader::SetPage(const std::shared_ptr<Page> &new_page) { | |||
| if (new_page == nullptr) { | |||
| return FAILED; | |||
| } | |||
| Status ShardHeader::SetPage(const std::shared_ptr<Page> &new_page) { | |||
| int shard_id = new_page->GetShardID(); | |||
| int page_id = new_page->GetPageID(); | |||
| if (shard_id < static_cast<int>(pages_.size()) && page_id < static_cast<int>(pages_[shard_id].size())) { | |||
| pages_[shard_id][page_id] = new_page; | |||
| return SUCCESS; | |||
| } else { | |||
| return FAILED; | |||
| return Status::OK(); | |||
| } | |||
| RETURN_STATUS_UNEXPECTED("Failed to Set Page."); | |||
| } | |||
| MSRStatus ShardHeader::AddPage(const std::shared_ptr<Page> &new_page) { | |||
| if (new_page == nullptr) { | |||
| return FAILED; | |||
| } | |||
| Status ShardHeader::AddPage(const std::shared_ptr<Page> &new_page) { | |||
| int shard_id = new_page->GetShardID(); | |||
| int page_id = new_page->GetPageID(); | |||
| if (shard_id < static_cast<int>(pages_.size()) && page_id == static_cast<int>(pages_[shard_id].size())) { | |||
| pages_[shard_id].push_back(new_page); | |||
| return SUCCESS; | |||
| } else { | |||
| return FAILED; | |||
| return Status::OK(); | |||
| } | |||
| RETURN_STATUS_UNEXPECTED("Failed to Add Page."); | |||
| } | |||
| int64_t ShardHeader::GetLastPageId(const int &shard_id) { | |||
| @@ -468,20 +415,18 @@ int ShardHeader::GetLastPageIdByType(const int &shard_id, const std::string &pag | |||
| return last_page_id; | |||
| } | |||
| const std::pair<MSRStatus, std::shared_ptr<Page>> ShardHeader::GetPageByGroupId(const int &group_id, | |||
| const int &shard_id) { | |||
| if (shard_id >= static_cast<int>(pages_.size())) { | |||
| MS_LOG(ERROR) << "Shard id is more than sum of shards."; | |||
| return {FAILED, nullptr}; | |||
| } | |||
| Status ShardHeader::GetPageByGroupId(const int &group_id, const int &shard_id, std::shared_ptr<Page> *page_ptr) { | |||
| RETURN_UNEXPECTED_IF_NULL(page_ptr); | |||
| CHECK_FAIL_RETURN_UNEXPECTED(shard_id < static_cast<int>(pages_.size()), "Shard id is more than sum of shards."); | |||
| for (uint64_t i = pages_[shard_id].size(); i >= 1; i--) { | |||
| auto page = pages_[shard_id][i - 1]; | |||
| if (page->GetPageType() == kPageTypeBlob && page->GetPageTypeID() == group_id) { | |||
| return {SUCCESS, page}; | |||
| *page_ptr = std::make_shared<Page>(*page); | |||
| return Status::OK(); | |||
| } | |||
| } | |||
| MS_LOG(ERROR) << "Could not get page by group id " << group_id; | |||
| return {FAILED, nullptr}; | |||
| page_ptr = nullptr; | |||
| RETURN_STATUS_UNEXPECTED("Failed to get page by group id: " + group_id); | |||
| } | |||
| int ShardHeader::AddSchema(std::shared_ptr<Schema> schema) { | |||
| @@ -524,151 +469,88 @@ std::shared_ptr<Index> ShardHeader::InitIndexPtr() { | |||
| return index; | |||
| } | |||
| MSRStatus ShardHeader::CheckIndexField(const std::string &field, const json &schema) { | |||
| Status ShardHeader::CheckIndexField(const std::string &field, const json &schema) { | |||
| // check field name is or is not valid | |||
| if (schema.find(field) == schema.end()) { | |||
| MS_LOG(ERROR) << "Schema do not contain the field: " << field << "."; | |||
| return FAILED; | |||
| } | |||
| if (schema[field]["type"] == "bytes") { | |||
| MS_LOG(ERROR) << field << " is bytes type, can not be schema index field."; | |||
| return FAILED; | |||
| } | |||
| if (schema.find(field) != schema.end() && schema[field].find("shape") != schema[field].end()) { | |||
| MS_LOG(ERROR) << field << " array can not be schema index field."; | |||
| return FAILED; | |||
| } | |||
| return SUCCESS; | |||
| CHECK_FAIL_RETURN_UNEXPECTED(schema.find(field) != schema.end(), "Filed can not found in schema."); | |||
| CHECK_FAIL_RETURN_UNEXPECTED(schema[field]["type"] != "Bytes", "bytes can not be as index field."); | |||
| CHECK_FAIL_RETURN_UNEXPECTED(schema.find(field) == schema.end() || schema[field].find("shape") == schema[field].end(), | |||
| "array can not be as index field."); | |||
| return Status::OK(); | |||
| } | |||
| MSRStatus ShardHeader::AddIndexFields(const std::vector<std::string> &fields) { | |||
| Status ShardHeader::AddIndexFields(const std::vector<std::string> &fields) { | |||
| if (fields.empty()) { | |||
| return Status::OK(); | |||
| } | |||
| CHECK_FAIL_RETURN_UNEXPECTED(!GetSchemas().empty(), "Schema is empty."); | |||
| // create index Object | |||
| std::shared_ptr<Index> index = InitIndexPtr(); | |||
| if (fields.size() == kInt0) { | |||
| MS_LOG(ERROR) << "There are no index fields"; | |||
| return FAILED; | |||
| } | |||
| if (GetSchemas().empty()) { | |||
| MS_LOG(ERROR) << "No schema is set"; | |||
| return FAILED; | |||
| } | |||
| for (const auto &schemaPtr : schema_) { | |||
| auto result = GetSchemaByID(schemaPtr->GetSchemaID()); | |||
| if (result.second != SUCCESS) { | |||
| MS_LOG(ERROR) << "Could not get schema by id."; | |||
| return FAILED; | |||
| } | |||
| if (result.first == nullptr) { | |||
| MS_LOG(ERROR) << "Could not get schema by id."; | |||
| return FAILED; | |||
| } | |||
| json schema = result.first->GetSchema().at("schema"); | |||
| std::shared_ptr<Schema> schema_ptr; | |||
| RETURN_IF_NOT_OK(GetSchemaByID(schemaPtr->GetSchemaID(), &schema_ptr)); | |||
| json schema = schema_ptr->GetSchema().at("schema"); | |||
| // checkout and add fields for each schema | |||
| std::set<std::string> field_set; | |||
| for (const auto &item : index->GetFields()) { | |||
| field_set.insert(item.second); | |||
| } | |||
| for (const auto &field : fields) { | |||
| if (field_set.find(field) != field_set.end()) { | |||
| MS_LOG(ERROR) << "Add same index field twice"; | |||
| return FAILED; | |||
| } | |||
| CHECK_FAIL_RETURN_UNEXPECTED(field_set.find(field) == field_set.end(), "Add same index field twice."); | |||
| // check field name is or is not valid | |||
| if (CheckIndexField(field, schema) == FAILED) { | |||
| return FAILED; | |||
| } | |||
| RETURN_IF_NOT_OK(CheckIndexField(field, schema)); | |||
| field_set.insert(field); | |||
| // add field into index | |||
| index.get()->AddIndexField(schemaPtr->GetSchemaID(), field); | |||
| } | |||
| } | |||
| index_ = index; | |||
| return SUCCESS; | |||
| return Status::OK(); | |||
| } | |||
| MSRStatus ShardHeader::GetAllSchemaID(std::set<uint64_t> &bucket_count) { | |||
| Status ShardHeader::GetAllSchemaID(std::set<uint64_t> &bucket_count) { | |||
| // get all schema id | |||
| for (const auto &schema : schema_) { | |||
| auto bucket_it = bucket_count.find(schema->GetSchemaID()); | |||
| if (bucket_it != bucket_count.end()) { | |||
| MS_LOG(ERROR) << "Schema duplication"; | |||
| return FAILED; | |||
| } else { | |||
| bucket_count.insert(schema->GetSchemaID()); | |||
| } | |||
| CHECK_FAIL_RETURN_UNEXPECTED(bucket_it == bucket_count.end(), "Schema duplication."); | |||
| bucket_count.insert(schema->GetSchemaID()); | |||
| } | |||
| return SUCCESS; | |||
| return Status::OK(); | |||
| } | |||
| MSRStatus ShardHeader::AddIndexFields(std::vector<std::pair<uint64_t, std::string>> fields) { | |||
| Status ShardHeader::AddIndexFields(std::vector<std::pair<uint64_t, std::string>> fields) { | |||
| if (fields.empty()) { | |||
| return Status::OK(); | |||
| } | |||
| // create index Object | |||
| std::shared_ptr<Index> index = InitIndexPtr(); | |||
| if (fields.size() == kInt0) { | |||
| MS_LOG(ERROR) << "There are no index fields"; | |||
| return FAILED; | |||
| } | |||
| // get all schema id | |||
| std::set<uint64_t> bucket_count; | |||
| if (GetAllSchemaID(bucket_count) != SUCCESS) { | |||
| return FAILED; | |||
| } | |||
| RETURN_IF_NOT_OK(GetAllSchemaID(bucket_count)); | |||
| // check and add fields for each schema | |||
| std::set<std::pair<uint64_t, std::string>> field_set; | |||
| for (const auto &item : index->GetFields()) { | |||
| field_set.insert(item); | |||
| } | |||
| for (const auto &field : fields) { | |||
| if (field_set.find(field) != field_set.end()) { | |||
| MS_LOG(ERROR) << "Add same index field twice"; | |||
| return FAILED; | |||
| } | |||
| CHECK_FAIL_RETURN_UNEXPECTED(field_set.find(field) == field_set.end(), "Add same index field twice."); | |||
| uint64_t schema_id = field.first; | |||
| std::string field_name = field.second; | |||
| // check schemaId is or is not valid | |||
| if (bucket_count.find(schema_id) == bucket_count.end()) { | |||
| MS_LOG(ERROR) << "Illegal schema id: " << schema_id; | |||
| return FAILED; | |||
| } | |||
| CHECK_FAIL_RETURN_UNEXPECTED(bucket_count.find(schema_id) != bucket_count.end(), "Invalid schema id: " + schema_id); | |||
| // check field name is or is not valid | |||
| auto result = GetSchemaByID(schema_id); | |||
| if (result.second != SUCCESS) { | |||
| MS_LOG(ERROR) << "Could not get schema by id."; | |||
| return FAILED; | |||
| } | |||
| json schema = result.first->GetSchema().at("schema"); | |||
| if (schema.find(field_name) == schema.end()) { | |||
| MS_LOG(ERROR) << "Schema " << schema_id << " do not contain the field: " << field_name; | |||
| return FAILED; | |||
| } | |||
| if (CheckIndexField(field_name, schema) == FAILED) { | |||
| return FAILED; | |||
| } | |||
| std::shared_ptr<Schema> schema_ptr; | |||
| RETURN_IF_NOT_OK(GetSchemaByID(schema_id, &schema_ptr)); | |||
| json schema = schema_ptr->GetSchema().at("schema"); | |||
| CHECK_FAIL_RETURN_UNEXPECTED(schema.find(field_name) != schema.end(), | |||
| "Schema " + std::to_string(schema_id) + " do not contain the field: " + field_name); | |||
| RETURN_IF_NOT_OK(CheckIndexField(field_name, schema)); | |||
| field_set.insert(field); | |||
| // add field into index | |||
| index.get()->AddIndexField(schema_id, field_name); | |||
| index->AddIndexField(schema_id, field_name); | |||
| } | |||
| index_ = index; | |||
| return SUCCESS; | |||
| return Status::OK(); | |||
| } | |||
| std::string ShardHeader::GetShardAddressByID(int64_t shard_id) { | |||
| @@ -686,103 +568,71 @@ std::vector<std::pair<uint64_t, std::string>> ShardHeader::GetFields() { return | |||
| std::shared_ptr<Index> ShardHeader::GetIndex() { return index_; } | |||
| std::pair<std::shared_ptr<Schema>, MSRStatus> ShardHeader::GetSchemaByID(int64_t schema_id) { | |||
| Status ShardHeader::GetSchemaByID(int64_t schema_id, std::shared_ptr<Schema> *schema_ptr) { | |||
| RETURN_UNEXPECTED_IF_NULL(schema_ptr); | |||
| int64_t schemaSize = schema_.size(); | |||
| if (schema_id < 0 || schema_id >= schemaSize) { | |||
| MS_LOG(ERROR) << "Illegal schema id"; | |||
| return std::make_pair(nullptr, FAILED); | |||
| } | |||
| return std::make_pair(schema_.at(schema_id), SUCCESS); | |||
| CHECK_FAIL_RETURN_UNEXPECTED(schema_id >= 0 && schema_id < schemaSize, "schema id is invalid."); | |||
| *schema_ptr = schema_.at(schema_id); | |||
| return Status::OK(); | |||
| } | |||
| std::pair<std::shared_ptr<Statistics>, MSRStatus> ShardHeader::GetStatisticByID(int64_t statistic_id) { | |||
| Status ShardHeader::GetStatisticByID(int64_t statistic_id, std::shared_ptr<Statistics> *statistics_ptr) { | |||
| RETURN_UNEXPECTED_IF_NULL(statistics_ptr); | |||
| int64_t statistics_size = statistics_.size(); | |||
| if (statistic_id < 0 || statistic_id >= statistics_size) { | |||
| return std::make_pair(nullptr, FAILED); | |||
| } | |||
| return std::make_pair(statistics_.at(statistic_id), SUCCESS); | |||
| CHECK_FAIL_RETURN_UNEXPECTED(statistic_id >= 0 && statistic_id < statistics_size, "statistic id is invalid."); | |||
| *statistics_ptr = statistics_.at(statistic_id); | |||
| return Status::OK(); | |||
| } | |||
| MSRStatus ShardHeader::PagesToFile(const std::string dump_file_name) { | |||
| Status ShardHeader::PagesToFile(const std::string dump_file_name) { | |||
| auto realpath = Common::GetRealPath(dump_file_name); | |||
| if (!realpath.has_value()) { | |||
| MS_LOG(ERROR) << "Get real path failed, path=" << dump_file_name; | |||
| return FAILED; | |||
| } | |||
| CHECK_FAIL_RETURN_UNEXPECTED(realpath.has_value(), "Get real path failed, path=" + dump_file_name); | |||
| // write header content to file, dump whatever is in the file before | |||
| std::ofstream page_out_handle(realpath.value(), std::ios_base::trunc | std::ios_base::out); | |||
| if (page_out_handle.fail()) { | |||
| MS_LOG(ERROR) << "Failed in opening page file"; | |||
| return FAILED; | |||
| } | |||
| CHECK_FAIL_RETURN_UNEXPECTED(page_out_handle.good(), "Failed to open page file."); | |||
| auto pages = SerializePage(); | |||
| for (const auto &shard_pages : pages) { | |||
| page_out_handle << shard_pages << "\n"; | |||
| } | |||
| page_out_handle.close(); | |||
| return SUCCESS; | |||
| return Status::OK(); | |||
| } | |||
| MSRStatus ShardHeader::FileToPages(const std::string dump_file_name) { | |||
| Status ShardHeader::FileToPages(const std::string dump_file_name) { | |||
| for (auto &v : pages_) { // clean pages | |||
| v.clear(); | |||
| } | |||
| auto realpath = Common::GetRealPath(dump_file_name); | |||
| if (!realpath.has_value()) { | |||
| MS_LOG(ERROR) << "Get real path failed, path=" << dump_file_name; | |||
| return FAILED; | |||
| } | |||
| CHECK_FAIL_RETURN_UNEXPECTED(realpath.has_value(), "Get real path failed, path=" + dump_file_name); | |||
| // attempt to open the file contains the page in json | |||
| std::ifstream page_in_handle(realpath.value()); | |||
| if (!page_in_handle.good()) { | |||
| MS_LOG(INFO) << "No page file exists."; | |||
| return SUCCESS; | |||
| } | |||
| CHECK_FAIL_RETURN_UNEXPECTED(page_in_handle.good(), "No page file exists."); | |||
| std::string line; | |||
| while (std::getline(page_in_handle, line)) { | |||
| if (SUCCESS != ParsePage(json::parse(line), -1, true)) { | |||
| return FAILED; | |||
| } | |||
| RETURN_IF_NOT_OK(ParsePage(json::parse(line), -1, true)); | |||
| } | |||
| page_in_handle.close(); | |||
| return SUCCESS; | |||
| return Status::OK(); | |||
| } | |||
| MSRStatus ShardHeader::Initialize(const std::shared_ptr<ShardHeader> *header_ptr, const json &schema, | |||
| const std::vector<std::string> &index_fields, std::vector<std::string> &blob_fields, | |||
| uint64_t &schema_id) { | |||
| if (header_ptr == nullptr) { | |||
| MS_LOG(ERROR) << "ShardHeader pointer is NULL."; | |||
| return FAILED; | |||
| } | |||
| Status ShardHeader::Initialize(const std::shared_ptr<ShardHeader> *header_ptr, const json &schema, | |||
| const std::vector<std::string> &index_fields, std::vector<std::string> &blob_fields, | |||
| uint64_t &schema_id) { | |||
| RETURN_UNEXPECTED_IF_NULL(header_ptr); | |||
| auto schema_ptr = Schema::Build("mindrecord", schema); | |||
| if (schema_ptr == nullptr) { | |||
| MS_LOG(ERROR) << "Got unexpected error when building mindrecord schema."; | |||
| return FAILED; | |||
| } | |||
| RETURN_UNEXPECTED_IF_NULL(schema_ptr); | |||
| schema_id = (*header_ptr)->AddSchema(schema_ptr); | |||
| // create index | |||
| std::vector<std::pair<uint64_t, std::string>> id_index_fields; | |||
| if (!index_fields.empty()) { | |||
| (void)std::transform(index_fields.begin(), index_fields.end(), std::back_inserter(id_index_fields), | |||
| [schema_id](const std::string &el) { return std::make_pair(schema_id, el); }); | |||
| if (SUCCESS != (*header_ptr)->AddIndexFields(id_index_fields)) { | |||
| MS_LOG(ERROR) << "Got unexpected error when adding mindrecord index."; | |||
| return FAILED; | |||
| } | |||
| (void)transform(index_fields.begin(), index_fields.end(), std::back_inserter(id_index_fields), | |||
| [schema_id](const std::string &el) { return std::make_pair(schema_id, el); }); | |||
| RETURN_IF_NOT_OK((*header_ptr)->AddIndexFields(id_index_fields)); | |||
| } | |||
| auto build_schema_ptr = (*header_ptr)->GetSchemas()[0]; | |||
| blob_fields = build_schema_ptr->GetBlobFields(); | |||
| return SUCCESS; | |||
| return Status::OK(); | |||
| } | |||
| } // namespace mindrecord | |||
| } // namespace mindspore | |||
| @@ -37,13 +37,11 @@ ShardPkSample::ShardPkSample(const std::string &category_field, int64_t num_elem | |||
| shuffle_op_ = std::make_shared<ShardShuffle>(seed, kShuffleSample); // do shuffle and replacement | |||
| } | |||
| MSRStatus ShardPkSample::SufExecute(ShardTaskList &tasks) { | |||
| Status ShardPkSample::SufExecute(ShardTaskList &tasks) { | |||
| if (shuffle_ == true) { | |||
| if (SUCCESS != (*shuffle_op_)(tasks)) { | |||
| return FAILED; | |||
| } | |||
| RETURN_IF_NOT_OK((*shuffle_op_)(tasks)); | |||
| } | |||
| return SUCCESS; | |||
| return Status::OK(); | |||
| } | |||
| } // namespace mindrecord | |||
| } // namespace mindspore | |||
| @@ -80,7 +80,7 @@ int64_t ShardSample::GetNumSamples(int64_t dataset_size, int64_t num_classes) { | |||
| return 0; | |||
| } | |||
| MSRStatus ShardSample::UpdateTasks(ShardTaskList &tasks, int taking) { | |||
| Status ShardSample::UpdateTasks(ShardTaskList &tasks, int taking) { | |||
| if (tasks.permutation_.empty()) { | |||
| ShardTaskList new_tasks; | |||
| int total_no = static_cast<int>(tasks.sample_ids_.size()); | |||
| @@ -110,9 +110,7 @@ MSRStatus ShardSample::UpdateTasks(ShardTaskList &tasks, int taking) { | |||
| ShardTaskList::TaskListSwap(tasks, new_tasks); | |||
| } else { | |||
| ShardTaskList new_tasks; | |||
| if (taking > static_cast<int>(tasks.sample_ids_.size())) { | |||
| return FAILED; | |||
| } | |||
| CHECK_FAIL_RETURN_UNEXPECTED(taking <= static_cast<int>(tasks.sample_ids_.size()), "taking is out of range."); | |||
| int total_no = static_cast<int>(tasks.permutation_.size()); | |||
| int cnt = 0; | |||
| for (size_t i = partition_id_ * taking; i < (partition_id_ + 1) * taking; i++) { | |||
| @@ -122,10 +120,10 @@ MSRStatus ShardSample::UpdateTasks(ShardTaskList &tasks, int taking) { | |||
| } | |||
| ShardTaskList::TaskListSwap(tasks, new_tasks); | |||
| } | |||
| return SUCCESS; | |||
| return Status::OK(); | |||
| } | |||
| MSRStatus ShardSample::Execute(ShardTaskList &tasks) { | |||
| Status ShardSample::Execute(ShardTaskList &tasks) { | |||
| if (offset_ != -1) { | |||
| int64_t old_v = 0; | |||
| int num_rows_ = static_cast<int>(tasks.sample_ids_.size()); | |||
| @@ -146,10 +144,8 @@ MSRStatus ShardSample::Execute(ShardTaskList &tasks) { | |||
| no_of_samples_ = std::min(no_of_samples_, total_no); | |||
| taking = no_of_samples_ - no_of_samples_ % no_of_categories; | |||
| } else if (sampler_type_ == kSubsetRandomSampler || sampler_type_ == kSubsetSampler) { | |||
| if (indices_.size() > static_cast<size_t>(total_no)) { | |||
| MS_LOG(ERROR) << "parameter indices's size is greater than dataset size."; | |||
| return FAILED; | |||
| } | |||
| CHECK_FAIL_RETURN_UNEXPECTED(indices_.size() <= static_cast<size_t>(total_no), | |||
| "Parameter indices's size is greater than dataset size."); | |||
| } else { // constructor TopPercent | |||
| if (numerator_ > 0 && denominator_ > 0 && numerator_ <= denominator_) { | |||
| if (numerator_ == 1 && denominator_ > 1) { // sharding | |||
| @@ -159,20 +155,17 @@ MSRStatus ShardSample::Execute(ShardTaskList &tasks) { | |||
| taking -= (taking % no_of_categories); | |||
| } | |||
| } else { | |||
| MS_LOG(ERROR) << "parameter numerator or denominator is illegal"; | |||
| return FAILED; | |||
| RETURN_STATUS_UNEXPECTED("Parameter numerator or denominator is invalid."); | |||
| } | |||
| } | |||
| return UpdateTasks(tasks, taking); | |||
| } | |||
| MSRStatus ShardSample::SufExecute(ShardTaskList &tasks) { | |||
| Status ShardSample::SufExecute(ShardTaskList &tasks) { | |||
| if (sampler_type_ == kSubsetRandomSampler) { | |||
| if (SUCCESS != (*shuffle_op_)(tasks)) { | |||
| return FAILED; | |||
| } | |||
| RETURN_IF_NOT_OK((*shuffle_op_)(tasks)); | |||
| } | |||
| return SUCCESS; | |||
| return Status::OK(); | |||
| } | |||
| } // namespace mindrecord | |||
| } // namespace mindspore | |||
| @@ -38,12 +38,6 @@ std::shared_ptr<Schema> Schema::Build(std::string desc, const json &schema) { | |||
| return std::make_shared<Schema>(object_schema); | |||
| } | |||
| std::shared_ptr<Schema> Schema::Build(std::string desc, pybind11::handle schema) { | |||
| // validate check | |||
| json schema_json = nlohmann::detail::ToJsonImpl(schema); | |||
| return Build(std::move(desc), schema_json); | |||
| } | |||
| std::string Schema::GetDesc() const { return desc_; } | |||
| json Schema::GetSchema() const { | |||
| @@ -54,12 +48,6 @@ json Schema::GetSchema() const { | |||
| return str_schema; | |||
| } | |||
| pybind11::object Schema::GetSchemaForPython() const { | |||
| json schema_json = GetSchema(); | |||
| pybind11::object schema_py = nlohmann::detail::FromJsonImpl(schema_json); | |||
| return schema_py; | |||
| } | |||
| void Schema::SetSchemaID(int64_t id) { schema_id_ = id; } | |||
| int64_t Schema::GetSchemaID() const { return schema_id_; } | |||
| @@ -38,7 +38,7 @@ int64_t ShardSequentialSample::GetNumSamples(int64_t dataset_size, int64_t num_c | |||
| return std::min(static_cast<int64_t>(no_of_samples_), dataset_size); | |||
| } | |||
| MSRStatus ShardSequentialSample::Execute(ShardTaskList &tasks) { | |||
| Status ShardSequentialSample::Execute(ShardTaskList &tasks) { | |||
| int64_t taking; | |||
| int64_t total_no = static_cast<int64_t>(tasks.sample_ids_.size()); | |||
| if (no_of_samples_ == 0 && (per_ >= -kEpsilon && per_ <= kEpsilon)) { | |||
| @@ -58,16 +58,15 @@ MSRStatus ShardSequentialSample::Execute(ShardTaskList &tasks) { | |||
| ShardTaskList::TaskListSwap(tasks, new_tasks); | |||
| } else { // shuffled | |||
| ShardTaskList new_tasks; | |||
| if (taking > static_cast<int64_t>(tasks.permutation_.size())) { | |||
| return FAILED; | |||
| } | |||
| CHECK_FAIL_RETURN_UNEXPECTED(taking <= static_cast<int64_t>(tasks.permutation_.size()), | |||
| "Taking is out of task range."); | |||
| total_no = static_cast<int64_t>(tasks.permutation_.size()); | |||
| for (size_t i = offset_; i < taking + offset_; ++i) { | |||
| new_tasks.AssignTask(tasks, tasks.permutation_[i % total_no]); | |||
| } | |||
| ShardTaskList::TaskListSwap(tasks, new_tasks); | |||
| } | |||
| return SUCCESS; | |||
| return Status::OK(); | |||
| } | |||
| } // namespace mindrecord | |||
| @@ -42,7 +42,7 @@ int64_t ShardShuffle::GetNumSamples(int64_t dataset_size, int64_t num_classes) { | |||
| return no_of_samples_ == 0 ? dataset_size : std::min(dataset_size, no_of_samples_); | |||
| } | |||
| MSRStatus ShardShuffle::CategoryShuffle(ShardTaskList &tasks) { | |||
| Status ShardShuffle::CategoryShuffle(ShardTaskList &tasks) { | |||
| uint32_t individual_size = tasks.sample_ids_.size() / tasks.categories; | |||
| std::vector<std::vector<int>> new_permutations(tasks.categories, std::vector<int>(individual_size)); | |||
| for (uint32_t i = 0; i < tasks.categories; i++) { | |||
| @@ -62,17 +62,14 @@ MSRStatus ShardShuffle::CategoryShuffle(ShardTaskList &tasks) { | |||
| } | |||
| ShardTaskList::TaskListSwap(tasks, new_tasks); | |||
| return SUCCESS; | |||
| return Status::OK(); | |||
| } | |||
| MSRStatus ShardShuffle::ShuffleFiles(ShardTaskList &tasks) { | |||
| Status ShardShuffle::ShuffleFiles(ShardTaskList &tasks) { | |||
| if (no_of_samples_ == 0) { | |||
| no_of_samples_ = static_cast<int>(tasks.Size()); | |||
| } | |||
| if (no_of_samples_ <= 0) { | |||
| MS_LOG(ERROR) << "no_of_samples need to be positive."; | |||
| return FAILED; | |||
| } | |||
| CHECK_FAIL_RETURN_UNEXPECTED(no_of_samples_ > 0, "Parameter no_of_samples need to be positive."); | |||
| auto shard_sample_cout = GetShardSampleCount(); | |||
| // shuffle the files index | |||
| @@ -118,16 +115,14 @@ MSRStatus ShardShuffle::ShuffleFiles(ShardTaskList &tasks) { | |||
| new_tasks.AssignTask(tasks, tasks.permutation_[i]); | |||
| } | |||
| ShardTaskList::TaskListSwap(tasks, new_tasks); | |||
| return Status::OK(); | |||
| } | |||
| MSRStatus ShardShuffle::ShuffleInfile(ShardTaskList &tasks) { | |||
| Status ShardShuffle::ShuffleInfile(ShardTaskList &tasks) { | |||
| if (no_of_samples_ == 0) { | |||
| no_of_samples_ = static_cast<int>(tasks.Size()); | |||
| } | |||
| if (no_of_samples_ <= 0) { | |||
| MS_LOG(ERROR) << "no_of_samples need to be positive."; | |||
| return FAILED; | |||
| } | |||
| CHECK_FAIL_RETURN_UNEXPECTED(no_of_samples_ > 0, "Parameter no_of_samples need to be positive."); | |||
| // reconstruct the permutation in file | |||
| // -- before -- | |||
| // file1: [0, 1, 2] | |||
| @@ -154,13 +149,12 @@ MSRStatus ShardShuffle::ShuffleInfile(ShardTaskList &tasks) { | |||
| new_tasks.AssignTask(tasks, tasks.permutation_[i]); | |||
| } | |||
| ShardTaskList::TaskListSwap(tasks, new_tasks); | |||
| return Status::OK(); | |||
| } | |||
| MSRStatus ShardShuffle::Execute(ShardTaskList &tasks) { | |||
| Status ShardShuffle::Execute(ShardTaskList &tasks) { | |||
| if (reshuffle_each_epoch_) shuffle_seed_++; | |||
| if (tasks.categories < 1) { | |||
| return FAILED; | |||
| } | |||
| CHECK_FAIL_RETURN_UNEXPECTED(tasks.categories >= 1, "Task category is invalid."); | |||
| if (shuffle_type_ == kShuffleSample) { // shuffle each sample | |||
| if (tasks.permutation_.empty() == true) { | |||
| tasks.MakePerm(); | |||
| @@ -169,10 +163,7 @@ MSRStatus ShardShuffle::Execute(ShardTaskList &tasks) { | |||
| if (replacement_ == true) { | |||
| ShardTaskList new_tasks; | |||
| if (no_of_samples_ == 0) no_of_samples_ = static_cast<int>(tasks.sample_ids_.size()); | |||
| if (no_of_samples_ <= 0) { | |||
| MS_LOG(ERROR) << "no_of_samples need to be positive."; | |||
| return FAILED; | |||
| } | |||
| CHECK_FAIL_RETURN_UNEXPECTED(no_of_samples_ > 0, "Parameter no_of_samples need to be positive."); | |||
| for (uint32_t i = 0; i < no_of_samples_; ++i) { | |||
| new_tasks.AssignTask(tasks, tasks.GetRandomTaskID()); | |||
| } | |||
| @@ -190,20 +181,14 @@ MSRStatus ShardShuffle::Execute(ShardTaskList &tasks) { | |||
| ShardTaskList::TaskListSwap(tasks, new_tasks); | |||
| } | |||
| } else if (GetShuffleMode() == dataset::ShuffleMode::kInfile) { | |||
| auto ret = ShuffleInfile(tasks); | |||
| if (ret != SUCCESS) { | |||
| return ret; | |||
| } | |||
| RETURN_IF_NOT_OK(ShuffleInfile(tasks)); | |||
| } else if (GetShuffleMode() == dataset::ShuffleMode::kFiles) { | |||
| auto ret = ShuffleFiles(tasks); | |||
| if (ret != SUCCESS) { | |||
| return ret; | |||
| } | |||
| RETURN_IF_NOT_OK(ShuffleFiles(tasks)); | |||
| } | |||
| } else { // shuffle unit like: (a1, b1, c1),(a2, b2, c2),..., (an, bn, cn) | |||
| return this->CategoryShuffle(tasks); | |||
| } | |||
| return SUCCESS; | |||
| return Status::OK(); | |||
| } | |||
| } // namespace mindrecord | |||
| } // namespace mindspore | |||
| @@ -35,19 +35,6 @@ std::shared_ptr<Statistics> Statistics::Build(std::string desc, const json &stat | |||
| return std::make_shared<Statistics>(object_statistics); | |||
| } | |||
| std::shared_ptr<Statistics> Statistics::Build(std::string desc, pybind11::handle statistics) { | |||
| // validate check | |||
| json statistics_json = nlohmann::detail::ToJsonImpl(statistics); | |||
| if (!Validate(statistics_json)) { | |||
| return nullptr; | |||
| } | |||
| Statistics object_statistics; | |||
| object_statistics.desc_ = std::move(desc); | |||
| object_statistics.statistics_ = statistics_json; | |||
| object_statistics.statistics_id_ = -1; | |||
| return std::make_shared<Statistics>(object_statistics); | |||
| } | |||
| std::string Statistics::GetDesc() const { return desc_; } | |||
| json Statistics::GetStatistics() const { | |||
| @@ -57,11 +44,6 @@ json Statistics::GetStatistics() const { | |||
| return str_statistics; | |||
| } | |||
| pybind11::object Statistics::GetStatisticsForPython() const { | |||
| json str_statistics = Statistics::GetStatistics(); | |||
| return nlohmann::detail::FromJsonImpl(str_statistics); | |||
| } | |||
| void Statistics::SetStatisticsID(int64_t id) { statistics_id_ = id; } | |||
| int64_t Statistics::GetStatisticsID() const { return statistics_id_; } | |||
| @@ -85,15 +85,9 @@ uint32_t ShardTaskList::SizeOfRows() const { | |||
| return nRows; | |||
| } | |||
| ShardTask &ShardTaskList::GetTaskByID(size_t id) { | |||
| MS_ASSERT(id < task_list_.size()); | |||
| return task_list_[id]; | |||
| } | |||
| ShardTask &ShardTaskList::GetTaskByID(size_t id) { return task_list_[id]; } | |||
| int ShardTaskList::GetTaskSampleByID(size_t id) { | |||
| MS_ASSERT(id < sample_ids_.size()); | |||
| return sample_ids_[id]; | |||
| } | |||
| int ShardTaskList::GetTaskSampleByID(size_t id) { return sample_ids_[id]; } | |||
| int ShardTaskList::GetRandomTaskID() { | |||
| std::mt19937 gen = mindspore::dataset::GetRandomDevice(); | |||
| @@ -70,7 +70,7 @@ class ShardReader: | |||
| Raises: | |||
| MRMLaunchError: If failed to launch worker threads. | |||
| """ | |||
| ret = self._reader.launch(False) | |||
| ret = self._reader.launch() | |||
| if ret != ms.MSRStatus.SUCCESS: | |||
| logger.error("Failed to launch worker threads.") | |||
| raise MRMLaunchError | |||
| @@ -19,7 +19,6 @@ import mindspore._c_mindrecord as ms | |||
| from mindspore import log as logger | |||
| from .shardutils import populate_data, SUCCESS | |||
| from .shardheader import ShardHeader | |||
| from .common.exceptions import MRMOpenError, MRMFetchCandidateFieldsError, MRMReadCategoryInfoError, MRMFetchDataError | |||
| __all__ = ['ShardSegment'] | |||
| @@ -73,15 +72,8 @@ class ShardSegment: | |||
| Returns: | |||
| list[str], by which data could be grouped. | |||
| Raises: | |||
| MRMFetchCandidateFieldsError: If failed to get candidate category fields. | |||
| """ | |||
| ret, fields = self._segment.get_category_fields() | |||
| if ret != SUCCESS: | |||
| logger.error("Failed to get candidate category fields.") | |||
| raise MRMFetchCandidateFieldsError | |||
| return fields | |||
| return self._segment.get_category_fields() | |||
| def set_category_field(self, category_field): | |||
| """Select one category field to use.""" | |||
| @@ -94,14 +86,8 @@ class ShardSegment: | |||
| Returns: | |||
| str, description fo group information. | |||
| Raises: | |||
| MRMReadCategoryInfoError: If failed to read category information. | |||
| """ | |||
| ret, category_info = self._segment.read_category_info() | |||
| if ret != SUCCESS: | |||
| logger.error("Failed to read category information.") | |||
| raise MRMReadCategoryInfoError | |||
| return category_info | |||
| return self._segment.read_category_info() | |||
| def read_at_page_by_id(self, category_id, page, num_row): | |||
| """ | |||
| @@ -116,13 +102,9 @@ class ShardSegment: | |||
| list[dict] | |||
| Raises: | |||
| MRMFetchDataError: If failed to read by category id. | |||
| MRMUnsupportedSchemaError: If schema is invalid. | |||
| """ | |||
| ret, data = self._segment.read_at_page_by_id(category_id, page, num_row) | |||
| if ret != SUCCESS: | |||
| logger.error("Failed to read by category id.") | |||
| raise MRMFetchDataError | |||
| data = self._segment.read_at_page_by_id(category_id, page, num_row) | |||
| return [populate_data(raw, blob, self._columns, self._header.blob_fields, | |||
| self._header.schema) for blob, raw in data] | |||
| @@ -139,12 +121,8 @@ class ShardSegment: | |||
| list[dict] | |||
| Raises: | |||
| MRMFetchDataError: If failed to read by category name. | |||
| MRMUnsupportedSchemaError: If schema is invalid. | |||
| """ | |||
| ret, data = self._segment.read_at_page_by_name(category_name, page, num_row) | |||
| if ret != SUCCESS: | |||
| logger.error("Failed to read by category name.") | |||
| raise MRMFetchDataError | |||
| data = self._segment.read_at_page_by_name(category_name, page, num_row) | |||
| return [populate_data(raw, blob, self._columns, self._header.blob_fields, | |||
| self._header.schema) for blob, raw in data] | |||
| @@ -384,8 +384,8 @@ void ShardWriterImageNetOpenForAppend(string filename) { | |||
| { | |||
| MS_LOG(INFO) << "=============== images " << bin_data.size() << " ============================"; | |||
| mindrecord::ShardWriter fw; | |||
| auto ret = fw.OpenForAppend(filename); | |||
| if (ret == FAILED) { | |||
| auto status = fw.OpenForAppend(filename); | |||
| if (status.IsError()) { | |||
| return; | |||
| } | |||
| @@ -121,14 +121,19 @@ TEST_F(TestShard, TestShardHeaderPart) { | |||
| re_statistics.push_back(*statistic); | |||
| } | |||
| ASSERT_EQ(re_statistics, validate_statistics); | |||
| ASSERT_EQ(header_data.GetStatisticByID(-1).second, FAILED); | |||
| ASSERT_EQ(header_data.GetStatisticByID(10).second, FAILED); | |||
| std::shared_ptr<Statistics> statistics_ptr; | |||
| auto status = header_data.GetStatisticByID(-1, &statistics_ptr); | |||
| EXPECT_FALSE(status.IsOk()); | |||
| status = header_data.GetStatisticByID(10, &statistics_ptr); | |||
| EXPECT_FALSE(status.IsOk()); | |||
| // test add index fields | |||
| std::vector<std::pair<uint64_t, std::string>> fields; | |||
| std::pair<uint64_t, std::string> pair1(0, "name"); | |||
| fields.push_back(pair1); | |||
| ASSERT_TRUE(header_data.AddIndexFields(fields) == SUCCESS); | |||
| status = header_data.AddIndexFields(fields); | |||
| EXPECT_TRUE(status.IsOk()); | |||
| std::vector<std::pair<uint64_t, std::string>> resFields = header_data.GetFields(); | |||
| ASSERT_EQ(resFields, fields); | |||
| } | |||
| @@ -79,36 +79,37 @@ TEST_F(TestShardHeader, AddIndexFields) { | |||
| std::pair<uint64_t, std::string> index_field2(schema_id1, "box"); | |||
| fields.push_back(index_field1); | |||
| fields.push_back(index_field2); | |||
| MSRStatus res = header_data.AddIndexFields(fields); | |||
| ASSERT_EQ(res, SUCCESS); | |||
| Status status = header_data.AddIndexFields(fields); | |||
| EXPECT_TRUE(status.IsOk()); | |||
| ASSERT_EQ(header_data.GetFields().size(), 2); | |||
| fields.clear(); | |||
| std::pair<uint64_t, std::string> index_field3(schema_id1, "name"); | |||
| fields.push_back(index_field3); | |||
| res = header_data.AddIndexFields(fields); | |||
| ASSERT_EQ(res, FAILED); | |||
| status = header_data.AddIndexFields(fields); | |||
| EXPECT_FALSE(status.IsOk()); | |||
| ASSERT_EQ(header_data.GetFields().size(), 2); | |||
| fields.clear(); | |||
| std::pair<uint64_t, std::string> index_field4(schema_id1, "names"); | |||
| fields.push_back(index_field4); | |||
| res = header_data.AddIndexFields(fields); | |||
| ASSERT_EQ(res, FAILED); | |||
| status = header_data.AddIndexFields(fields); | |||
| EXPECT_FALSE(status.IsOk()); | |||
| ASSERT_EQ(header_data.GetFields().size(), 2); | |||
| fields.clear(); | |||
| std::pair<uint64_t, std::string> index_field5(schema_id1 + 1, "name"); | |||
| fields.push_back(index_field5); | |||
| res = header_data.AddIndexFields(fields); | |||
| ASSERT_EQ(res, FAILED); | |||
| status = header_data.AddIndexFields(fields); | |||
| EXPECT_FALSE(status.IsOk()); | |||
| ASSERT_EQ(header_data.GetFields().size(), 2); | |||
| fields.clear(); | |||
| std::pair<uint64_t, std::string> index_field6(schema_id1, "label"); | |||
| fields.push_back(index_field6); | |||
| res = header_data.AddIndexFields(fields); | |||
| ASSERT_EQ(res, FAILED); | |||
| status = header_data.AddIndexFields(fields); | |||
| EXPECT_FALSE(status.IsOk()); | |||
| ASSERT_EQ(header_data.GetFields().size(), 2); | |||
| std::string desc_new = "this is a test1"; | |||
| @@ -129,26 +130,26 @@ TEST_F(TestShardHeader, AddIndexFields) { | |||
| single_fields.push_back("name"); | |||
| single_fields.push_back("name"); | |||
| single_fields.push_back("box"); | |||
| res = header_data_new.AddIndexFields(single_fields); | |||
| ASSERT_EQ(res, FAILED); | |||
| status = header_data_new.AddIndexFields(single_fields); | |||
| EXPECT_FALSE(status.IsOk()); | |||
| ASSERT_EQ(header_data_new.GetFields().size(), 1); | |||
| single_fields.push_back("name"); | |||
| single_fields.push_back("box"); | |||
| res = header_data_new.AddIndexFields(single_fields); | |||
| ASSERT_EQ(res, FAILED); | |||
| status = header_data_new.AddIndexFields(single_fields); | |||
| EXPECT_FALSE(status.IsOk()); | |||
| ASSERT_EQ(header_data_new.GetFields().size(), 1); | |||
| single_fields.clear(); | |||
| single_fields.push_back("names"); | |||
| res = header_data_new.AddIndexFields(single_fields); | |||
| ASSERT_EQ(res, FAILED); | |||
| status = header_data_new.AddIndexFields(single_fields); | |||
| EXPECT_FALSE(status.IsOk()); | |||
| ASSERT_EQ(header_data_new.GetFields().size(), 1); | |||
| single_fields.clear(); | |||
| single_fields.push_back("box"); | |||
| res = header_data_new.AddIndexFields(single_fields); | |||
| ASSERT_EQ(res, SUCCESS); | |||
| status = header_data_new.AddIndexFields(single_fields); | |||
| EXPECT_TRUE(status.IsOk()); | |||
| ASSERT_EQ(header_data_new.GetFields().size(), 2); | |||
| } | |||
| } // namespace mindrecord | |||
| @@ -167,8 +167,8 @@ TEST_F(TestShardReader, TestShardReaderColumnNotInIndex) { | |||
| std::string file_name = "./imagenet.shard01"; | |||
| auto column_list = std::vector<std::string>{"label"}; | |||
| ShardReader dataset; | |||
| MSRStatus ret = dataset.Open({file_name}, true, 4, column_list); | |||
| ASSERT_EQ(ret, SUCCESS); | |||
| auto status = dataset.Open({file_name}, true, 4, column_list); | |||
| EXPECT_TRUE(status.IsOk()); | |||
| dataset.Launch(); | |||
| while (true) { | |||
| @@ -188,16 +188,16 @@ TEST_F(TestShardReader, TestShardReaderColumnNotInSchema) { | |||
| std::string file_name = "./imagenet.shard01"; | |||
| auto column_list = std::vector<std::string>{"file_namex"}; | |||
| ShardReader dataset; | |||
| MSRStatus ret = dataset.Open({file_name}, true, 4, column_list); | |||
| ASSERT_EQ(ret, ILLEGAL_COLUMN_LIST); | |||
| auto status= dataset.Open({file_name}, true, 4, column_list); | |||
| EXPECT_FALSE(status.IsOk()); | |||
| } | |||
| TEST_F(TestShardReader, TestShardVersion) { | |||
| MS_LOG(INFO) << FormatInfo("Test shard version"); | |||
| std::string file_name = "./imagenet.shard01"; | |||
| ShardReader dataset; | |||
| MSRStatus ret = dataset.Open({file_name}, true, 4); | |||
| ASSERT_EQ(ret, SUCCESS); | |||
| auto status = dataset.Open({file_name}, true, 4); | |||
| EXPECT_TRUE(status.IsOk()); | |||
| dataset.Launch(); | |||
| while (true) { | |||
| @@ -219,8 +219,8 @@ TEST_F(TestShardReader, TestShardReaderDir) { | |||
| auto column_list = std::vector<std::string>{"file_name"}; | |||
| ShardReader dataset; | |||
| MSRStatus ret = dataset.Open({file_name}, true, 4, column_list); | |||
| ASSERT_EQ(ret, FAILED); | |||
| auto status = dataset.Open({file_name}, true, 4, column_list); | |||
| EXPECT_FALSE(status.IsOk()); | |||
| } | |||
| TEST_F(TestShardReader, TestShardReaderConsumer) { | |||
| @@ -61,35 +61,44 @@ TEST_F(TestShardSegment, TestShardSegment) { | |||
| ShardSegment dataset; | |||
| dataset.Open({file_name}, true, 4); | |||
| auto x = dataset.GetCategoryFields(); | |||
| for (const auto &fields : x.second) { | |||
| auto fields_ptr = std::make_shared<vector<std::string>>(); | |||
| auto status = dataset.GetCategoryFields(&fields_ptr); | |||
| for (const auto &fields : *fields_ptr) { | |||
| MS_LOG(INFO) << "Get category field: " << fields; | |||
| } | |||
| ASSERT_TRUE(dataset.SetCategoryField("label") == SUCCESS); | |||
| ASSERT_TRUE(dataset.SetCategoryField("laabel_0") == FAILED); | |||
| status = dataset.SetCategoryField("label"); | |||
| EXPECT_TRUE(status.IsOk()); | |||
| status = dataset.SetCategoryField("laabel_0"); | |||
| EXPECT_FALSE(status.IsOk()); | |||
| MS_LOG(INFO) << "Read category info: " << dataset.ReadCategoryInfo().second; | |||
| auto ret = dataset.ReadAtPageByName("822", 0, 10); | |||
| auto images = ret.second; | |||
| MS_LOG(INFO) << "category field: 822, images count: " << images.size() << ", image[0] size: " << images[0].size(); | |||
| std::shared_ptr<std::string> category_ptr; | |||
| status = dataset.ReadCategoryInfo(&category_ptr); | |||
| EXPECT_TRUE(status.IsOk()); | |||
| MS_LOG(INFO) << "Read category info: " << *category_ptr; | |||
| auto ret1 = dataset.ReadAtPageByName("823", 0, 10); | |||
| auto images2 = ret1.second; | |||
| MS_LOG(INFO) << "category field: 823, images count: " << images2.size(); | |||
| auto pages_ptr = std::make_shared<std::vector<std::vector<uint8_t>>>(); | |||
| status = dataset.ReadAtPageByName("822", 0, 10, &pages_ptr); | |||
| EXPECT_TRUE(status.IsOk()); | |||
| MS_LOG(INFO) << "category field: 822, images count: " << pages_ptr->size() << ", image[0] size: " << ((*pages_ptr)[0]).size(); | |||
| auto ret2 = dataset.ReadAtPageById(1, 0, 10); | |||
| auto images3 = ret2.second; | |||
| MS_LOG(INFO) << "category id: 1, images count: " << images3.size() << ", image[0] size: " << images3[0].size(); | |||
| auto pages_ptr_1 = std::make_shared<std::vector<std::vector<uint8_t>>>(); | |||
| status = dataset.ReadAtPageByName("823", 0, 10, &pages_ptr_1); | |||
| MS_LOG(INFO) << "category field: 823, images count: " << pages_ptr_1->size(); | |||
| auto ret3 = dataset.ReadAllAtPageByName("822", 0, 10); | |||
| auto images4 = ret3.second; | |||
| MS_LOG(INFO) << "category field: 822, images count: " << images4.size(); | |||
| auto pages_ptr_2 = std::make_shared<std::vector<std::vector<uint8_t>>>(); | |||
| status = dataset.ReadAtPageById(1, 0, 10, &pages_ptr_2); | |||
| EXPECT_TRUE(status.IsOk()); | |||
| MS_LOG(INFO) << "category id: 1, images count: " << pages_ptr_2->size() << ", image[0] size: " << ((*pages_ptr_2)[0]).size(); | |||
| auto ret4 = dataset.ReadAllAtPageById(1, 0, 10); | |||
| auto images5 = ret4.second; | |||
| MS_LOG(INFO) << "category id: 1, images count: " << images5.size(); | |||
| auto pages_ptr_3 = std::make_shared<PAGES>(); | |||
| status = dataset.ReadAllAtPageByName("822", 0, 10, &pages_ptr_3); | |||
| MS_LOG(INFO) << "category field: 822, images count: " << pages_ptr_3->size(); | |||
| auto pages_ptr_4 = std::make_shared<PAGES>(); | |||
| status = dataset.ReadAllAtPageById(1, 0, 10, &pages_ptr_4); | |||
| MS_LOG(INFO) << "category id: 1, images count: " << pages_ptr_4->size(); | |||
| } | |||
| TEST_F(TestShardSegment, TestReadAtPageByNameOfCategoryName) { | |||
| @@ -99,21 +108,28 @@ TEST_F(TestShardSegment, TestReadAtPageByNameOfCategoryName) { | |||
| ShardSegment dataset; | |||
| dataset.Open({file_name}, true, 4); | |||
| auto x = dataset.GetCategoryFields(); | |||
| for (const auto &fields : x.second) { | |||
| auto fields_ptr = std::make_shared<vector<std::string>>(); | |||
| auto status = dataset.GetCategoryFields(&fields_ptr); | |||
| for (const auto &fields : *fields_ptr) { | |||
| MS_LOG(INFO) << "Get category field: " << fields; | |||
| } | |||
| string category_name = "82Cus"; | |||
| string category_field = "laabel_0"; | |||
| ASSERT_TRUE(dataset.SetCategoryField("label") == SUCCESS); | |||
| ASSERT_TRUE(dataset.SetCategoryField(category_field) == FAILED); | |||
| status = dataset.SetCategoryField("label"); | |||
| EXPECT_TRUE(status.IsOk()); | |||
| status = dataset.SetCategoryField(category_field); | |||
| EXPECT_FALSE(status.IsOk()); | |||
| MS_LOG(INFO) << "Read category info: " << dataset.ReadCategoryInfo().second; | |||
| std::shared_ptr<std::string> category_ptr; | |||
| status = dataset.ReadCategoryInfo(&category_ptr); | |||
| EXPECT_TRUE(status.IsOk()); | |||
| MS_LOG(INFO) << "Read category info: " << *category_ptr; | |||
| auto ret = dataset.ReadAtPageByName(category_name, 0, 10); | |||
| EXPECT_TRUE(ret.first == FAILED); | |||
| auto pages_ptr = std::make_shared<std::vector<std::vector<uint8_t>>>(); | |||
| status = dataset.ReadAtPageByName(category_name, 0, 10, &pages_ptr); | |||
| EXPECT_FALSE(status.IsOk()); | |||
| } | |||
| TEST_F(TestShardSegment, TestReadAtPageByIdOfCategoryId) { | |||
| @@ -123,19 +139,25 @@ TEST_F(TestShardSegment, TestReadAtPageByIdOfCategoryId) { | |||
| ShardSegment dataset; | |||
| dataset.Open({file_name}, true, 4); | |||
| auto x = dataset.GetCategoryFields(); | |||
| for (const auto &fields : x.second) { | |||
| auto fields_ptr = std::make_shared<vector<std::string>>(); | |||
| auto status = dataset.GetCategoryFields(&fields_ptr); | |||
| for (const auto &fields : *fields_ptr) { | |||
| MS_LOG(INFO) << "Get category field: " << fields; | |||
| } | |||
| int64_t categoryId = 2251799813685247; | |||
| MS_LOG(INFO) << "Input category id: " << categoryId; | |||
| ASSERT_TRUE(dataset.SetCategoryField("label") == SUCCESS); | |||
| MS_LOG(INFO) << "Read category info: " << dataset.ReadCategoryInfo().second; | |||
| status = dataset.SetCategoryField("label"); | |||
| EXPECT_TRUE(status.IsOk()); | |||
| std::shared_ptr<std::string> category_ptr; | |||
| status = dataset.ReadCategoryInfo(&category_ptr); | |||
| EXPECT_TRUE(status.IsOk()); | |||
| MS_LOG(INFO) << "Read category info: " << *category_ptr; | |||
| auto ret2 = dataset.ReadAtPageById(categoryId, 0, 10); | |||
| EXPECT_TRUE(ret2.first == FAILED); | |||
| auto pages_ptr = std::make_shared<std::vector<std::vector<uint8_t>>>(); | |||
| status = dataset.ReadAtPageById(categoryId, 0, 10, &pages_ptr); | |||
| EXPECT_FALSE(status.IsOk()); | |||
| } | |||
| TEST_F(TestShardSegment, TestReadAtPageByIdOfPageNo) { | |||
| @@ -145,19 +167,27 @@ TEST_F(TestShardSegment, TestReadAtPageByIdOfPageNo) { | |||
| ShardSegment dataset; | |||
| dataset.Open({file_name}, true, 4); | |||
| auto x = dataset.GetCategoryFields(); | |||
| for (const auto &fields : x.second) { | |||
| auto fields_ptr = std::make_shared<vector<std::string>>(); | |||
| auto status = dataset.GetCategoryFields(&fields_ptr); | |||
| for (const auto &fields : *fields_ptr) { | |||
| MS_LOG(INFO) << "Get category field: " << fields; | |||
| } | |||
| int64_t page_no = 2251799813685247; | |||
| MS_LOG(INFO) << "Input page no: " << page_no; | |||
| ASSERT_TRUE(dataset.SetCategoryField("label") == SUCCESS); | |||
| MS_LOG(INFO) << "Read category info: " << dataset.ReadCategoryInfo().second; | |||
| status = dataset.SetCategoryField("label"); | |||
| EXPECT_TRUE(status.IsOk()); | |||
| auto ret2 = dataset.ReadAtPageById(1, page_no, 10); | |||
| EXPECT_TRUE(ret2.first == FAILED); | |||
| std::shared_ptr<std::string> category_ptr; | |||
| status = dataset.ReadCategoryInfo(&category_ptr); | |||
| EXPECT_TRUE(status.IsOk()); | |||
| MS_LOG(INFO) << "Read category info: " << *category_ptr; | |||
| auto pages_ptr = std::make_shared<std::vector<std::vector<uint8_t>>>(); | |||
| status = dataset.ReadAtPageById(1, page_no, 10, &pages_ptr); | |||
| EXPECT_FALSE(status.IsOk()); | |||
| } | |||
| TEST_F(TestShardSegment, TestReadAtPageByIdOfPageRows) { | |||
| @@ -167,19 +197,26 @@ TEST_F(TestShardSegment, TestReadAtPageByIdOfPageRows) { | |||
| ShardSegment dataset; | |||
| dataset.Open({file_name}, true, 4); | |||
| auto x = dataset.GetCategoryFields(); | |||
| for (const auto &fields : x.second) { | |||
| auto fields_ptr = std::make_shared<vector<std::string>>(); | |||
| auto status = dataset.GetCategoryFields(&fields_ptr); | |||
| for (const auto &fields : *fields_ptr) { | |||
| MS_LOG(INFO) << "Get category field: " << fields; | |||
| } | |||
| int64_t pageRows = 0; | |||
| MS_LOG(INFO) << "Input page rows: " << pageRows; | |||
| ASSERT_TRUE(dataset.SetCategoryField("label") == SUCCESS); | |||
| MS_LOG(INFO) << "Read category info: " << dataset.ReadCategoryInfo().second; | |||
| status = dataset.SetCategoryField("label"); | |||
| EXPECT_TRUE(status.IsOk()); | |||
| std::shared_ptr<std::string> category_ptr; | |||
| status = dataset.ReadCategoryInfo(&category_ptr); | |||
| EXPECT_TRUE(status.IsOk()); | |||
| MS_LOG(INFO) << "Read category info: " << *category_ptr; | |||
| auto ret2 = dataset.ReadAtPageById(1, 0, pageRows); | |||
| EXPECT_TRUE(ret2.first == FAILED); | |||
| auto pages_ptr = std::make_shared<std::vector<std::vector<uint8_t>>>(); | |||
| status = dataset.ReadAtPageById(1, 0, pageRows, &pages_ptr); | |||
| EXPECT_FALSE(status.IsOk()); | |||
| } | |||
| } // namespace mindrecord | |||
| @@ -60,8 +60,8 @@ TEST_F(TestShardWriter, TestShardWriterOneSample) { | |||
| std::string filename = "./OneSample.shard01"; | |||
| ShardReader dataset; | |||
| MSRStatus ret = dataset.Open({filename}, true, 4); | |||
| ASSERT_EQ(ret, SUCCESS); | |||
| auto status = dataset.Open({filename}, true, 4); | |||
| EXPECT_TRUE(status.IsOk()); | |||
| dataset.Launch(); | |||
| while (true) { | |||
| @@ -675,8 +675,8 @@ TEST_F(TestShardWriter, AllRawDataWrong) { | |||
| fw.SetShardHeader(std::make_shared<mindrecord::ShardHeader>(header_data)); | |||
| // write rawdata | |||
| MSRStatus res = fw.WriteRawData(rawdatas, bin_data); | |||
| ASSERT_EQ(res, SUCCESS); | |||
| auto status = fw.WriteRawData(rawdatas, bin_data); | |||
| EXPECT_TRUE(status.IsOk()); | |||
| for (const auto &filename : file_names) { | |||
| auto filename_db = filename + ".db"; | |||
| remove(common::SafeCStr(filename_db)); | |||
| @@ -716,7 +716,8 @@ TEST_F(TestShardWriter, TestShardReaderStringAndNumberColumnInIndex) { | |||
| fields.push_back(index_field2); | |||
| // add index to shardHeader | |||
| ASSERT_EQ(header_data.AddIndexFields(fields), SUCCESS); | |||
| auto status = header_data.AddIndexFields(fields); | |||
| EXPECT_TRUE(status.IsOk()); | |||
| MS_LOG(INFO) << "Init Index Fields Already."; | |||
| // load meta data | |||
| @@ -736,28 +737,34 @@ TEST_F(TestShardWriter, TestShardReaderStringAndNumberColumnInIndex) { | |||
| } | |||
| mindrecord::ShardWriter fw_init; | |||
| ASSERT_TRUE(fw_init.Open(file_names) == SUCCESS); | |||
| status = fw_init.Open(file_names); | |||
| EXPECT_TRUE(status.IsOk()); | |||
| // set shardHeader | |||
| ASSERT_TRUE(fw_init.SetShardHeader(std::make_shared<mindrecord::ShardHeader>(header_data)) == SUCCESS); | |||
| status = fw_init.SetShardHeader(std::make_shared<mindrecord::ShardHeader>(header_data)); | |||
| EXPECT_TRUE(status.IsOk()); | |||
| // write raw data | |||
| ASSERT_TRUE(fw_init.WriteRawData(rawdatas, bin_data) == SUCCESS); | |||
| ASSERT_TRUE(fw_init.Commit() == SUCCESS); | |||
| status = fw_init.WriteRawData(rawdatas, bin_data); | |||
| EXPECT_TRUE(status.IsOk()); | |||
| status = fw_init.Commit(); | |||
| EXPECT_TRUE(status.IsOk()); | |||
| // create the index file | |||
| std::string filename = "./imagenet.shard01"; | |||
| mindrecord::ShardIndexGenerator sg{filename}; | |||
| sg.Build(); | |||
| ASSERT_TRUE(sg.WriteToDatabase() == SUCCESS); | |||
| status = sg.WriteToDatabase(); | |||
| EXPECT_TRUE(status.IsOk()); | |||
| MS_LOG(INFO) << "Done create index"; | |||
| // read the mindrecord file | |||
| filename = "./imagenet.shard01"; | |||
| auto column_list = std::vector<std::string>{"label", "file_name", "data"}; | |||
| ShardReader dataset; | |||
| MSRStatus ret = dataset.Open({filename}, true, 4, column_list); | |||
| ASSERT_EQ(ret, SUCCESS); | |||
| status = dataset.Open({filename}, true, 4, column_list); | |||
| EXPECT_TRUE(status.IsOk()); | |||
| dataset.Launch(); | |||
| int count = 0; | |||
| @@ -822,28 +829,34 @@ TEST_F(TestShardWriter, TestShardNoBlob) { | |||
| } | |||
| mindrecord::ShardWriter fw_init; | |||
| ASSERT_TRUE(fw_init.Open(file_names) == SUCCESS); | |||
| auto status = fw_init.Open(file_names); | |||
| EXPECT_TRUE(status.IsOk()); | |||
| // set shardHeader | |||
| ASSERT_TRUE(fw_init.SetShardHeader(std::make_shared<mindrecord::ShardHeader>(header_data)) == SUCCESS); | |||
| status = fw_init.SetShardHeader(std::make_shared<mindrecord::ShardHeader>(header_data)); | |||
| EXPECT_TRUE(status.IsOk()); | |||
| // write raw data | |||
| ASSERT_TRUE(fw_init.WriteRawData(rawdatas, bin_data) == SUCCESS); | |||
| ASSERT_TRUE(fw_init.Commit() == SUCCESS); | |||
| status = fw_init.WriteRawData(rawdatas, bin_data); | |||
| EXPECT_TRUE(status.IsOk()); | |||
| status = fw_init.Commit(); | |||
| EXPECT_TRUE(status.IsOk()); | |||
| // create the index file | |||
| std::string filename = "./imagenet.shard01"; | |||
| mindrecord::ShardIndexGenerator sg{filename}; | |||
| sg.Build(); | |||
| ASSERT_TRUE(sg.WriteToDatabase() == SUCCESS); | |||
| status = sg.WriteToDatabase(); | |||
| EXPECT_TRUE(status.IsOk()); | |||
| MS_LOG(INFO) << "Done create index"; | |||
| // read the mindrecord file | |||
| filename = "./imagenet.shard01"; | |||
| auto column_list = std::vector<std::string>{"label", "file_name"}; | |||
| ShardReader dataset; | |||
| MSRStatus ret = dataset.Open({filename}, true, 4, column_list); | |||
| ASSERT_EQ(ret, SUCCESS); | |||
| status = dataset.Open({filename}, true, 4, column_list); | |||
| EXPECT_TRUE(status.IsOk()); | |||
| dataset.Launch(); | |||
| int count = 0; | |||
| @@ -896,7 +909,8 @@ TEST_F(TestShardWriter, TestShardReaderStringAndNumberNotColumnInIndex) { | |||
| fields.push_back(index_field1); | |||
| // add index to shardHeader | |||
| ASSERT_EQ(header_data.AddIndexFields(fields), SUCCESS); | |||
| auto status = header_data.AddIndexFields(fields); | |||
| EXPECT_TRUE(status.IsOk()); | |||
| MS_LOG(INFO) << "Init Index Fields Already."; | |||
| // load meta data | |||
| @@ -916,28 +930,34 @@ TEST_F(TestShardWriter, TestShardReaderStringAndNumberNotColumnInIndex) { | |||
| } | |||
| mindrecord::ShardWriter fw_init; | |||
| ASSERT_TRUE(fw_init.Open(file_names) == SUCCESS); | |||
| status = fw_init.Open(file_names); | |||
| EXPECT_TRUE(status.IsOk()); | |||
| // set shardHeader | |||
| ASSERT_TRUE(fw_init.SetShardHeader(std::make_shared<mindrecord::ShardHeader>(header_data)) == SUCCESS); | |||
| status = fw_init.SetShardHeader(std::make_shared<mindrecord::ShardHeader>(header_data)); | |||
| EXPECT_TRUE(status.IsOk()); | |||
| // write raw data | |||
| ASSERT_TRUE(fw_init.WriteRawData(rawdatas, bin_data) == SUCCESS); | |||
| ASSERT_TRUE(fw_init.Commit() == SUCCESS); | |||
| status = fw_init.WriteRawData(rawdatas, bin_data); | |||
| EXPECT_TRUE(status.IsOk()); | |||
| status = fw_init.Commit(); | |||
| EXPECT_TRUE(status.IsOk()); | |||
| // create the index file | |||
| std::string filename = "./imagenet.shard01"; | |||
| mindrecord::ShardIndexGenerator sg{filename}; | |||
| sg.Build(); | |||
| ASSERT_TRUE(sg.WriteToDatabase() == SUCCESS); | |||
| status = sg.WriteToDatabase(); | |||
| EXPECT_TRUE(status.IsOk()); | |||
| MS_LOG(INFO) << "Done create index"; | |||
| // read the mindrecord file | |||
| filename = "./imagenet.shard01"; | |||
| auto column_list = std::vector<std::string>{"label", "data"}; | |||
| ShardReader dataset; | |||
| MSRStatus ret = dataset.Open({filename}, true, 4, column_list); | |||
| ASSERT_EQ(ret, SUCCESS); | |||
| status = dataset.Open({filename}, true, 4, column_list); | |||
| EXPECT_TRUE(status.IsOk()); | |||
| dataset.Launch(); | |||
| int count = 0; | |||
| @@ -1043,8 +1063,8 @@ TEST_F(TestShardWriter, TestShardWriter10Sample40Shard) { | |||
| filename = "./TenSampleFortyShard.shard01"; | |||
| ShardReader dataset; | |||
| MSRStatus ret = dataset.Open({filename}, true, 4); | |||
| ASSERT_EQ(ret, SUCCESS); | |||
| auto status = dataset.Open({filename}, true, 4); | |||
| EXPECT_TRUE(status.IsOk()); | |||
| dataset.Launch(); | |||
| int count = 0; | |||
| @@ -95,7 +95,7 @@ def test_invalid_mindrecord(): | |||
| f.write('just for test') | |||
| columns_list = ["data", "file_name", "label"] | |||
| num_readers = 4 | |||
| with pytest.raises(Exception, match="MindRecordOp init failed"): | |||
| with pytest.raises(RuntimeError, match="Unexpected error. Invalid file content. path:"): | |||
| data_set = ds.MindDataset('dummy.mindrecord', columns_list, num_readers) | |||
| num_iter = 0 | |||
| for _ in data_set.create_dict_iterator(num_epochs=1, output_numpy=True): | |||
| @@ -114,7 +114,7 @@ def test_minddataset_lack_db(): | |||
| os.remove("{}.db".format(CV_FILE_NAME)) | |||
| columns_list = ["data", "file_name", "label"] | |||
| num_readers = 4 | |||
| with pytest.raises(Exception, match="MindRecordOp init failed"): | |||
| with pytest.raises(RuntimeError, match="Unexpected error. Invalid database file:"): | |||
| data_set = ds.MindDataset(CV_FILE_NAME, columns_list, num_readers) | |||
| num_iter = 0 | |||
| for _ in data_set.create_dict_iterator(num_epochs=1, output_numpy=True): | |||
| @@ -133,7 +133,7 @@ def test_cv_minddataset_pk_sample_error_class_column(): | |||
| columns_list = ["data", "file_name", "label"] | |||
| num_readers = 4 | |||
| sampler = ds.PKSampler(5, None, True, 'no_exist_column') | |||
| with pytest.raises(Exception, match="MindRecordOp launch failed"): | |||
| with pytest.raises(RuntimeError, match="Unexpected error. Failed to launch read threads."): | |||
| data_set = ds.MindDataset(CV_FILE_NAME, columns_list, num_readers, sampler=sampler) | |||
| num_iter = 0 | |||
| for _ in data_set.create_dict_iterator(num_epochs=1, output_numpy=True): | |||
| @@ -162,7 +162,7 @@ def test_cv_minddataset_reader_different_schema(): | |||
| create_diff_schema_cv_mindrecord(1) | |||
| columns_list = ["data", "label"] | |||
| num_readers = 4 | |||
| with pytest.raises(Exception, match="MindRecordOp init failed"): | |||
| with pytest.raises(RuntimeError, match="Mindrecord files meta information is different"): | |||
| data_set = ds.MindDataset([CV_FILE_NAME, CV1_FILE_NAME], columns_list, | |||
| num_readers) | |||
| num_iter = 0 | |||
| @@ -179,7 +179,7 @@ def test_cv_minddataset_reader_different_page_size(): | |||
| create_diff_page_size_cv_mindrecord(1) | |||
| columns_list = ["data", "label"] | |||
| num_readers = 4 | |||
| with pytest.raises(Exception, match="MindRecordOp init failed"): | |||
| with pytest.raises(RuntimeError, match="Mindrecord files meta information is different"): | |||
| data_set = ds.MindDataset([CV_FILE_NAME, CV1_FILE_NAME], columns_list, | |||
| num_readers) | |||
| num_iter = 0 | |||
| @@ -19,7 +19,6 @@ import pytest | |||
| from mindspore import log as logger | |||
| from mindspore.mindrecord import Cifar100ToMR | |||
| from mindspore.mindrecord import FileReader | |||
| from mindspore.mindrecord import MRMOpenError | |||
| from mindspore.mindrecord import SUCCESS | |||
| CIFAR100_DIR = "../data/mindrecord/testCifar100Data" | |||
| @@ -119,8 +118,8 @@ def test_cifar100_to_mindrecord_directory(fixture_file): | |||
| test transform cifar10 dataset to mindrecord | |||
| when destination path is directory. | |||
| """ | |||
| with pytest.raises(MRMOpenError, | |||
| match="MindRecord File could not open successfully"): | |||
| with pytest.raises(RuntimeError, | |||
| match="MindRecord file already existed, please delete file:"): | |||
| cifar100_transformer = Cifar100ToMR(CIFAR100_DIR, | |||
| CIFAR100_DIR) | |||
| cifar100_transformer.transform() | |||
| @@ -130,8 +129,8 @@ def test_cifar100_to_mindrecord_filename_equals_cifar100(fixture_file): | |||
| test transform cifar10 dataset to mindrecord | |||
| when destination path equals source path. | |||
| """ | |||
| with pytest.raises(MRMOpenError, | |||
| match="MindRecord File could not open successfully"): | |||
| with pytest.raises(RuntimeError, | |||
| match="indRecord file already existed, please delete file:"): | |||
| cifar100_transformer = Cifar100ToMR(CIFAR100_DIR, | |||
| CIFAR100_DIR + "/train") | |||
| cifar100_transformer.transform() | |||
| @@ -19,7 +19,7 @@ import pytest | |||
| from mindspore import log as logger | |||
| from mindspore.mindrecord import Cifar10ToMR | |||
| from mindspore.mindrecord import FileReader | |||
| from mindspore.mindrecord import MRMOpenError, SUCCESS | |||
| from mindspore.mindrecord import SUCCESS | |||
| CIFAR10_DIR = "../data/mindrecord/testCifar10Data" | |||
| MINDRECORD_FILE = "./cifar10.mindrecord" | |||
| @@ -146,8 +146,8 @@ def test_cifar10_to_mindrecord_directory(fixture_file): | |||
| test transform cifar10 dataset to mindrecord | |||
| when destination path is directory. | |||
| """ | |||
| with pytest.raises(MRMOpenError, | |||
| match="MindRecord File could not open successfully"): | |||
| with pytest.raises(RuntimeError, | |||
| match="MindRecord file already existed, please delete file:"): | |||
| cifar10_transformer = Cifar10ToMR(CIFAR10_DIR, CIFAR10_DIR) | |||
| cifar10_transformer.transform() | |||
| @@ -157,8 +157,8 @@ def test_cifar10_to_mindrecord_filename_equals_cifar10(): | |||
| test transform cifar10 dataset to mindrecord | |||
| when destination path equals source path. | |||
| """ | |||
| with pytest.raises(MRMOpenError, | |||
| match="MindRecord File could not open successfully"): | |||
| with pytest.raises(RuntimeError, | |||
| match="MindRecord file already existed, please delete file:"): | |||
| cifar10_transformer = Cifar10ToMR(CIFAR10_DIR, | |||
| CIFAR10_DIR + "/data_batch_0") | |||
| cifar10_transformer.transform() | |||
| @@ -21,8 +21,7 @@ from utils import get_data | |||
| from mindspore import log as logger | |||
| from mindspore.mindrecord import FileWriter, FileReader, MindPage, SUCCESS | |||
| from mindspore.mindrecord import MRMOpenError, MRMGenerateIndexError, ParamValueError, MRMGetMetaError, \ | |||
| MRMFetchDataError | |||
| from mindspore.mindrecord import ParamValueError, MRMGetMetaError | |||
| CV_FILE_NAME = "./imagenet.mindrecord" | |||
| NLP_FILE_NAME = "./aclImdb.mindrecord" | |||
| @@ -106,21 +105,19 @@ def create_cv_mindrecord(files_num): | |||
| def test_lack_partition_and_db(): | |||
| """test file reader when mindrecord file does not exist.""" | |||
| with pytest.raises(MRMOpenError) as err: | |||
| with pytest.raises(RuntimeError) as err: | |||
| reader = FileReader('dummy.mindrecord') | |||
| reader.close() | |||
| assert '[MRMOpenError]: MindRecord File could not open successfully.' \ | |||
| in str(err.value) | |||
| assert 'Unexpected error. Invalid file path:' in str(err.value) | |||
| def test_lack_db(fixture_cv_file): | |||
| """test file reader when db file does not exist.""" | |||
| create_cv_mindrecord(1) | |||
| os.remove("{}.db".format(CV_FILE_NAME)) | |||
| with pytest.raises(MRMOpenError) as err: | |||
| with pytest.raises(RuntimeError) as err: | |||
| reader = FileReader(CV_FILE_NAME) | |||
| reader.close() | |||
| assert '[MRMOpenError]: MindRecord File could not open successfully.' \ | |||
| in str(err.value) | |||
| assert 'Unexpected error. Invalid database file:' in str(err.value) | |||
| def test_lack_some_partition_and_db(fixture_cv_file): | |||
| """test file reader when some partition and db do not exist.""" | |||
| @@ -129,11 +126,10 @@ def test_lack_some_partition_and_db(fixture_cv_file): | |||
| for x in range(FILES_NUM)] | |||
| os.remove("{}".format(paths[3])) | |||
| os.remove("{}.db".format(paths[3])) | |||
| with pytest.raises(MRMOpenError) as err: | |||
| with pytest.raises(RuntimeError) as err: | |||
| reader = FileReader(CV_FILE_NAME + "0") | |||
| reader.close() | |||
| assert '[MRMOpenError]: MindRecord File could not open successfully.' \ | |||
| in str(err.value) | |||
| assert 'Unexpected error. Invalid file path:' in str(err.value) | |||
| def test_lack_some_partition_first(fixture_cv_file): | |||
| """test file reader when first partition does not exist.""" | |||
| @@ -141,11 +137,10 @@ def test_lack_some_partition_first(fixture_cv_file): | |||
| paths = ["{}{}".format(CV_FILE_NAME, str(x).rjust(1, '0')) | |||
| for x in range(FILES_NUM)] | |||
| os.remove("{}".format(paths[0])) | |||
| with pytest.raises(MRMOpenError) as err: | |||
| with pytest.raises(RuntimeError) as err: | |||
| reader = FileReader(CV_FILE_NAME + "0") | |||
| reader.close() | |||
| assert '[MRMOpenError]: MindRecord File could not open successfully.' \ | |||
| in str(err.value) | |||
| assert 'Unexpected error. Invalid file path:' in str(err.value) | |||
| def test_lack_some_partition_middle(fixture_cv_file): | |||
| """test file reader when some partition does not exist.""" | |||
| @@ -153,11 +148,10 @@ def test_lack_some_partition_middle(fixture_cv_file): | |||
| paths = ["{}{}".format(CV_FILE_NAME, str(x).rjust(1, '0')) | |||
| for x in range(FILES_NUM)] | |||
| os.remove("{}".format(paths[1])) | |||
| with pytest.raises(MRMOpenError) as err: | |||
| with pytest.raises(RuntimeError) as err: | |||
| reader = FileReader(CV_FILE_NAME + "0") | |||
| reader.close() | |||
| assert '[MRMOpenError]: MindRecord File could not open successfully.' \ | |||
| in str(err.value) | |||
| assert 'Unexpected error. Invalid file path:' in str(err.value) | |||
| def test_lack_some_partition_last(fixture_cv_file): | |||
| """test file reader when last partition does not exist.""" | |||
| @@ -165,11 +159,10 @@ def test_lack_some_partition_last(fixture_cv_file): | |||
| paths = ["{}{}".format(CV_FILE_NAME, str(x).rjust(1, '0')) | |||
| for x in range(FILES_NUM)] | |||
| os.remove("{}".format(paths[3])) | |||
| with pytest.raises(MRMOpenError) as err: | |||
| with pytest.raises(RuntimeError) as err: | |||
| reader = FileReader(CV_FILE_NAME + "0") | |||
| reader.close() | |||
| assert '[MRMOpenError]: MindRecord File could not open successfully.' \ | |||
| in str(err.value) | |||
| assert 'Unexpected error. Invalid file path:' in str(err.value) | |||
| def test_mindpage_lack_some_partition(fixture_cv_file): | |||
| """test page reader when some partition does not exist.""" | |||
| @@ -177,10 +170,9 @@ def test_mindpage_lack_some_partition(fixture_cv_file): | |||
| paths = ["{}{}".format(CV_FILE_NAME, str(x).rjust(1, '0')) | |||
| for x in range(FILES_NUM)] | |||
| os.remove("{}".format(paths[0])) | |||
| with pytest.raises(MRMOpenError) as err: | |||
| with pytest.raises(RuntimeError) as err: | |||
| MindPage(CV_FILE_NAME + "0") | |||
| assert '[MRMOpenError]: MindRecord File could not open successfully.' \ | |||
| in str(err.value) | |||
| assert 'Unexpected error. Invalid file path:' in str(err.value) | |||
| def test_lack_some_db(fixture_cv_file): | |||
| """test file reader when some db does not exist.""" | |||
| @@ -188,11 +180,10 @@ def test_lack_some_db(fixture_cv_file): | |||
| paths = ["{}{}".format(CV_FILE_NAME, str(x).rjust(1, '0')) | |||
| for x in range(FILES_NUM)] | |||
| os.remove("{}.db".format(paths[3])) | |||
| with pytest.raises(MRMOpenError) as err: | |||
| with pytest.raises(RuntimeError) as err: | |||
| reader = FileReader(CV_FILE_NAME + "0") | |||
| reader.close() | |||
| assert '[MRMOpenError]: MindRecord File could not open successfully.' \ | |||
| in str(err.value) | |||
| assert 'Unexpected error. Invalid database file:' in str(err.value) | |||
| def test_invalid_mindrecord(): | |||
| @@ -200,10 +191,9 @@ def test_invalid_mindrecord(): | |||
| with open(CV_FILE_NAME, 'w') as f: | |||
| dummy = 's' * 100 | |||
| f.write(dummy) | |||
| with pytest.raises(MRMOpenError) as err: | |||
| with pytest.raises(RuntimeError) as err: | |||
| FileReader(CV_FILE_NAME) | |||
| assert '[MRMOpenError]: MindRecord File could not open successfully.' \ | |||
| in str(err.value) | |||
| assert 'Unexpected error. Invalid file content. path:' in str(err.value) | |||
| os.remove(CV_FILE_NAME) | |||
| def test_invalid_db(fixture_cv_file): | |||
| @@ -212,27 +202,26 @@ def test_invalid_db(fixture_cv_file): | |||
| os.remove("imagenet.mindrecord.db") | |||
| with open('imagenet.mindrecord.db', 'w') as f: | |||
| f.write('just for test') | |||
| with pytest.raises(MRMOpenError) as err: | |||
| with pytest.raises(RuntimeError) as err: | |||
| FileReader('imagenet.mindrecord') | |||
| assert '[MRMOpenError]: MindRecord File could not open successfully.' \ | |||
| in str(err.value) | |||
| assert 'Unexpected error. Error in execute sql:' in str(err.value) | |||
| def test_overwrite_invalid_mindrecord(fixture_cv_file): | |||
| """test file writer when overwrite invalid mindreocrd file.""" | |||
| with open(CV_FILE_NAME, 'w') as f: | |||
| f.write('just for test') | |||
| with pytest.raises(MRMOpenError) as err: | |||
| with pytest.raises(RuntimeError) as err: | |||
| create_cv_mindrecord(1) | |||
| assert '[MRMOpenError]: MindRecord File could not open successfully.' \ | |||
| assert 'Unexpected error. MindRecord file already existed, please delete file:' \ | |||
| in str(err.value) | |||
| def test_overwrite_invalid_db(fixture_cv_file): | |||
| """test file writer when overwrite invalid db file.""" | |||
| with open('imagenet.mindrecord.db', 'w') as f: | |||
| f.write('just for test') | |||
| with pytest.raises(MRMGenerateIndexError) as err: | |||
| with pytest.raises(RuntimeError) as err: | |||
| create_cv_mindrecord(1) | |||
| assert '[MRMGenerateIndexError]: Failed to generate index.' in str(err.value) | |||
| assert 'Unexpected error. Failed to write data to db.' in str(err.value) | |||
| def test_read_after_close(fixture_cv_file): | |||
| """test file reader when close read.""" | |||
| @@ -302,7 +291,7 @@ def test_mindpage_pageno_pagesize_not_int(fixture_cv_file): | |||
| with pytest.raises(ParamValueError): | |||
| reader.read_at_page_by_name("822", 0, "qwer") | |||
| with pytest.raises(MRMFetchDataError, match="Failed to fetch data by category."): | |||
| with pytest.raises(RuntimeError, match="Unexpected error. Invalid category id:"): | |||
| reader.read_at_page_by_id(99999, 0, 1) | |||
| @@ -320,10 +309,10 @@ def test_mindpage_filename_not_exist(fixture_cv_file): | |||
| info = reader.read_category_info() | |||
| logger.info("category info: {}".format(info)) | |||
| with pytest.raises(MRMFetchDataError): | |||
| with pytest.raises(RuntimeError, match="Unexpected error. Invalid category id:"): | |||
| reader.read_at_page_by_id(9999, 0, 1) | |||
| with pytest.raises(MRMFetchDataError): | |||
| with pytest.raises(RuntimeError, match="Unexpected error. Invalid category name."): | |||
| reader.read_at_page_by_name("abc.jpg", 0, 1) | |||
| with pytest.raises(ParamValueError): | |||
| @@ -475,7 +464,7 @@ def test_write_with_invalid_data(): | |||
| mindrecord_file_name = "test.mindrecord" | |||
| # field: file_name => filename | |||
| with pytest.raises(Exception, match="Failed to write dataset"): | |||
| with pytest.raises(RuntimeError, match="Unexpected error. Data size is not positive."): | |||
| remove_one_file(mindrecord_file_name) | |||
| remove_one_file(mindrecord_file_name + ".db") | |||
| @@ -510,7 +499,7 @@ def test_write_with_invalid_data(): | |||
| writer.commit() | |||
| # field: mask => masks | |||
| with pytest.raises(Exception, match="Failed to write dataset"): | |||
| with pytest.raises(RuntimeError, match="Unexpected error. Data size is not positive."): | |||
| remove_one_file(mindrecord_file_name) | |||
| remove_one_file(mindrecord_file_name + ".db") | |||
| @@ -545,7 +534,7 @@ def test_write_with_invalid_data(): | |||
| writer.commit() | |||
| # field: data => image | |||
| with pytest.raises(Exception, match="Failed to write dataset"): | |||
| with pytest.raises(RuntimeError, match="Unexpected error. Data size is not positive."): | |||
| remove_one_file(mindrecord_file_name) | |||
| remove_one_file(mindrecord_file_name + ".db") | |||
| @@ -580,7 +569,7 @@ def test_write_with_invalid_data(): | |||
| writer.commit() | |||
| # field: label => labels | |||
| with pytest.raises(Exception, match="Failed to write dataset"): | |||
| with pytest.raises(RuntimeError, match="Unexpected error. Data size is not positive."): | |||
| remove_one_file(mindrecord_file_name) | |||
| remove_one_file(mindrecord_file_name + ".db") | |||
| @@ -615,7 +604,7 @@ def test_write_with_invalid_data(): | |||
| writer.commit() | |||
| # field: score => scores | |||
| with pytest.raises(Exception, match="Failed to write dataset"): | |||
| with pytest.raises(RuntimeError, match="Unexpected error. Data size is not positive."): | |||
| remove_one_file(mindrecord_file_name) | |||
| remove_one_file(mindrecord_file_name + ".db") | |||
| @@ -650,7 +639,7 @@ def test_write_with_invalid_data(): | |||
| writer.commit() | |||
| # string type with int value | |||
| with pytest.raises(Exception, match="Failed to write dataset"): | |||
| with pytest.raises(RuntimeError, match="Unexpected error. Data size is not positive."): | |||
| remove_one_file(mindrecord_file_name) | |||
| remove_one_file(mindrecord_file_name + ".db") | |||
| @@ -685,7 +674,7 @@ def test_write_with_invalid_data(): | |||
| writer.commit() | |||
| # field with int64 type, but the real data is string | |||
| with pytest.raises(Exception, match="Failed to write dataset"): | |||
| with pytest.raises(RuntimeError, match="Unexpected error. Data size is not positive."): | |||
| remove_one_file(mindrecord_file_name) | |||
| remove_one_file(mindrecord_file_name + ".db") | |||
| @@ -720,7 +709,7 @@ def test_write_with_invalid_data(): | |||
| writer.commit() | |||
| # bytes field is string | |||
| with pytest.raises(Exception, match="Failed to write dataset"): | |||
| with pytest.raises(RuntimeError, match="Unexpected error. Data size is not positive."): | |||
| remove_one_file(mindrecord_file_name) | |||
| remove_one_file(mindrecord_file_name + ".db") | |||
| @@ -755,7 +744,7 @@ def test_write_with_invalid_data(): | |||
| writer.commit() | |||
| # field is not numpy type | |||
| with pytest.raises(Exception, match="Failed to write dataset"): | |||
| with pytest.raises(RuntimeError, match="Unexpected error. Data size is not positive."): | |||
| remove_one_file(mindrecord_file_name) | |||
| remove_one_file(mindrecord_file_name + ".db") | |||
| @@ -790,7 +779,7 @@ def test_write_with_invalid_data(): | |||
| writer.commit() | |||
| # not enough field | |||
| with pytest.raises(Exception, match="Failed to write dataset"): | |||
| with pytest.raises(RuntimeError, match="Unexpected error. Data size is not positive."): | |||
| remove_one_file(mindrecord_file_name) | |||
| remove_one_file(mindrecord_file_name + ".db") | |||