| @@ -256,9 +256,7 @@ Status SaveToDisk::Save() { | |||||
| auto mr_header = std::make_shared<mindrecord::ShardHeader>(); | auto mr_header = std::make_shared<mindrecord::ShardHeader>(); | ||||
| auto mr_writer = std::make_unique<mindrecord::ShardWriter>(); | auto mr_writer = std::make_unique<mindrecord::ShardWriter>(); | ||||
| std::vector<std::string> blob_fields; | std::vector<std::string> blob_fields; | ||||
| if (mindrecord::SUCCESS != mindrecord::ShardWriter::Initialize(&mr_writer, file_names)) { | |||||
| RETURN_STATUS_UNEXPECTED("Error: failed to initialize ShardWriter, please check above `ERROR` level message."); | |||||
| } | |||||
| RETURN_IF_NOT_OK(mindrecord::ShardWriter::Initialize(&mr_writer, file_names)); | |||||
| std::unordered_map<std::string, int32_t> column_name_id_map; | std::unordered_map<std::string, int32_t> column_name_id_map; | ||||
| for (auto el : tree_adapter_->GetColumnNameMap()) { | for (auto el : tree_adapter_->GetColumnNameMap()) { | ||||
| @@ -286,22 +284,16 @@ Status SaveToDisk::Save() { | |||||
| std::vector<std::string> index_fields; | std::vector<std::string> index_fields; | ||||
| RETURN_IF_NOT_OK(FetchMetaFromTensorRow(column_name_id_map, row, &mr_json, &index_fields)); | RETURN_IF_NOT_OK(FetchMetaFromTensorRow(column_name_id_map, row, &mr_json, &index_fields)); | ||||
| MS_LOG(INFO) << "Schema of saved mindrecord: " << mr_json.dump(); | MS_LOG(INFO) << "Schema of saved mindrecord: " << mr_json.dump(); | ||||
| if (mindrecord::SUCCESS != | |||||
| mindrecord::ShardHeader::Initialize(&mr_header, mr_json, index_fields, blob_fields, mr_schema_id)) { | |||||
| RETURN_STATUS_UNEXPECTED("Error: failed to initialize ShardHeader."); | |||||
| } | |||||
| if (mindrecord::SUCCESS != mr_writer->SetShardHeader(mr_header)) { | |||||
| RETURN_STATUS_UNEXPECTED("Error: failed to set header of ShardWriter."); | |||||
| } | |||||
| RETURN_IF_NOT_OK( | |||||
| mindrecord::ShardHeader::Initialize(&mr_header, mr_json, index_fields, blob_fields, mr_schema_id)); | |||||
| RETURN_IF_NOT_OK(mr_writer->SetShardHeader(mr_header)); | |||||
| first_loop = false; | first_loop = false; | ||||
| } | } | ||||
| // construct data | // construct data | ||||
| if (!row.empty()) { // write data | if (!row.empty()) { // write data | ||||
| RETURN_IF_NOT_OK(FetchDataFromTensorRow(row, column_name_id_map, &row_raw_data, &row_bin_data)); | RETURN_IF_NOT_OK(FetchDataFromTensorRow(row, column_name_id_map, &row_raw_data, &row_bin_data)); | ||||
| std::shared_ptr<std::vector<uint8_t>> output_bin_data; | std::shared_ptr<std::vector<uint8_t>> output_bin_data; | ||||
| if (mindrecord::SUCCESS != mr_writer->MergeBlobData(blob_fields, row_bin_data, &output_bin_data)) { | |||||
| RETURN_STATUS_UNEXPECTED("Error: failed to merge blob data of ShardWriter."); | |||||
| } | |||||
| RETURN_IF_NOT_OK(mr_writer->MergeBlobData(blob_fields, row_bin_data, &output_bin_data)); | |||||
| std::map<std::uint64_t, std::vector<nlohmann::json>> raw_data; | std::map<std::uint64_t, std::vector<nlohmann::json>> raw_data; | ||||
| raw_data.insert( | raw_data.insert( | ||||
| std::pair<uint64_t, std::vector<nlohmann::json>>(mr_schema_id, std::vector<nlohmann::json>{row_raw_data})); | std::pair<uint64_t, std::vector<nlohmann::json>>(mr_schema_id, std::vector<nlohmann::json>{row_raw_data})); | ||||
| @@ -309,18 +301,12 @@ Status SaveToDisk::Save() { | |||||
| if (output_bin_data != nullptr) { | if (output_bin_data != nullptr) { | ||||
| bin_data.emplace_back(*output_bin_data); | bin_data.emplace_back(*output_bin_data); | ||||
| } | } | ||||
| if (mindrecord::SUCCESS != mr_writer->WriteRawData(raw_data, bin_data)) { | |||||
| RETURN_STATUS_UNEXPECTED("Error: failed to write raw data to ShardWriter."); | |||||
| } | |||||
| RETURN_IF_NOT_OK(mr_writer->WriteRawData(raw_data, bin_data)); | |||||
| } | } | ||||
| } while (!row.empty()); | } while (!row.empty()); | ||||
| if (mindrecord::SUCCESS != mr_writer->Commit()) { | |||||
| RETURN_STATUS_UNEXPECTED("Error: failed to commit ShardWriter."); | |||||
| } | |||||
| if (mindrecord::SUCCESS != mindrecord::ShardIndexGenerator::Finalize(file_names)) { | |||||
| RETURN_STATUS_UNEXPECTED("Error: failed to finalize ShardIndexGenerator."); | |||||
| } | |||||
| RETURN_IF_NOT_OK(mr_writer->Commit()); | |||||
| RETURN_IF_NOT_OK(mindrecord::ShardIndexGenerator::Finalize(file_names)); | |||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| @@ -23,6 +23,7 @@ | |||||
| #include "minddata/dataset/core/config_manager.h" | #include "minddata/dataset/core/config_manager.h" | ||||
| #include "minddata/dataset/core/global_context.h" | #include "minddata/dataset/core/global_context.h" | ||||
| #include "minddata/dataset/engine/datasetops/source/sampler/mind_record_sampler.h" | #include "minddata/dataset/engine/datasetops/source/sampler/mind_record_sampler.h" | ||||
| #include "minddata/mindrecord/include/shard_column.h" | |||||
| #include "minddata/dataset/engine/db_connector.h" | #include "minddata/dataset/engine/db_connector.h" | ||||
| #include "minddata/dataset/engine/execution_tree.h" | #include "minddata/dataset/engine/execution_tree.h" | ||||
| #include "minddata/dataset/include/dataset/constants.h" | #include "minddata/dataset/include/dataset/constants.h" | ||||
| @@ -63,10 +64,8 @@ MindRecordOp::MindRecordOp(int32_t num_mind_record_workers, std::vector<std::str | |||||
| // Private helper method to encapsulate some common construction/reset tasks | // Private helper method to encapsulate some common construction/reset tasks | ||||
| Status MindRecordOp::Init() { | Status MindRecordOp::Init() { | ||||
| auto rc = shard_reader_->Open(dataset_file_, load_dataset_, num_mind_record_workers_, columns_to_load_, operators_, | |||||
| num_padded_); | |||||
| CHECK_FAIL_RETURN_UNEXPECTED(rc == MSRStatus::SUCCESS, "MindRecordOp init failed, " + ErrnoToMessage(rc)); | |||||
| RETURN_IF_NOT_OK(shard_reader_->Open(dataset_file_, load_dataset_, num_mind_record_workers_, columns_to_load_, | |||||
| operators_, num_padded_)); | |||||
| data_schema_ = std::make_unique<DataSchema>(); | data_schema_ = std::make_unique<DataSchema>(); | ||||
| @@ -206,7 +205,9 @@ Status MindRecordOp::GetRowFromReader(TensorRow *fetched_row, uint64_t row_id, i | |||||
| fetched_row->setPath(file_path); | fetched_row->setPath(file_path); | ||||
| fetched_row->setId(row_id); | fetched_row->setId(row_id); | ||||
| } | } | ||||
| if (tupled_buffer.empty()) return Status::OK(); | |||||
| if (tupled_buffer.empty()) { | |||||
| return Status::OK(); | |||||
| } | |||||
| if (task_type == mindrecord::TaskType::kCommonTask) { | if (task_type == mindrecord::TaskType::kCommonTask) { | ||||
| for (const auto &tupled_row : tupled_buffer) { | for (const auto &tupled_row : tupled_buffer) { | ||||
| std::vector<uint8_t> columns_blob = std::get<0>(tupled_row); | std::vector<uint8_t> columns_blob = std::get<0>(tupled_row); | ||||
| @@ -237,20 +238,15 @@ Status MindRecordOp::LoadTensorRow(TensorRow *tensor_row, const std::vector<uint | |||||
| // Get column data | // Get column data | ||||
| auto shard_column = shard_reader_->GetShardColumn(); | auto shard_column = shard_reader_->GetShardColumn(); | ||||
| if (num_padded_ > 0 && task_type == mindrecord::TaskType::kPaddedTask) { | if (num_padded_ > 0 && task_type == mindrecord::TaskType::kPaddedTask) { | ||||
| auto rc = | |||||
| shard_column->GetColumnTypeByName(column_name, &column_data_type, &column_data_type_size, &column_shape); | |||||
| if (rc.first != MSRStatus::SUCCESS) { | |||||
| RETURN_STATUS_UNEXPECTED("Invalid parameter, column_name: " + column_name + "does not exist in dataset."); | |||||
| } | |||||
| if (rc.second == mindrecord::ColumnInRaw) { | |||||
| auto column_in_raw = shard_column->GetColumnFromJson(column_name, sample_json_, &data_ptr, &n_bytes); | |||||
| if (column_in_raw == MSRStatus::FAILED) { | |||||
| RETURN_STATUS_UNEXPECTED("Invalid data, failed to retrieve raw data from padding sample."); | |||||
| } | |||||
| } else if (rc.second == mindrecord::ColumnInBlob) { | |||||
| if (sample_bytes_.find(column_name) == sample_bytes_.end()) { | |||||
| RETURN_STATUS_UNEXPECTED("Invalid data, failed to retrieve blob data from padding sample."); | |||||
| } | |||||
| mindrecord::ColumnCategory category; | |||||
| RETURN_IF_NOT_OK(shard_column->GetColumnTypeByName(column_name, &column_data_type, &column_data_type_size, | |||||
| &column_shape, &category)); | |||||
| if (category == mindrecord::ColumnInRaw) { | |||||
| RETURN_IF_NOT_OK(shard_column->GetColumnFromJson(column_name, sample_json_, &data_ptr, &n_bytes)); | |||||
| } else if (category == mindrecord::ColumnInBlob) { | |||||
| CHECK_FAIL_RETURN_UNEXPECTED(sample_bytes_.find(column_name) != sample_bytes_.end(), | |||||
| "Invalid data, failed to retrieve blob data from padding sample."); | |||||
| std::string ss(sample_bytes_[column_name]); | std::string ss(sample_bytes_[column_name]); | ||||
| n_bytes = ss.size(); | n_bytes = ss.size(); | ||||
| data_ptr = std::make_unique<unsigned char[]>(n_bytes); | data_ptr = std::make_unique<unsigned char[]>(n_bytes); | ||||
| @@ -262,12 +258,9 @@ Status MindRecordOp::LoadTensorRow(TensorRow *tensor_row, const std::vector<uint | |||||
| data = reinterpret_cast<const unsigned char *>(data_ptr.get()); | data = reinterpret_cast<const unsigned char *>(data_ptr.get()); | ||||
| } | } | ||||
| } else { | } else { | ||||
| auto has_column = | |||||
| shard_column->GetColumnValueByName(column_name, columns_blob, columns_json, &data, &data_ptr, &n_bytes, | |||||
| &column_data_type, &column_data_type_size, &column_shape); | |||||
| if (has_column == MSRStatus::FAILED) { | |||||
| RETURN_STATUS_UNEXPECTED("Invalid data, failed to retrieve data from mindrecord reader."); | |||||
| } | |||||
| RETURN_IF_NOT_OK(shard_column->GetColumnValueByName(column_name, columns_blob, columns_json, &data, &data_ptr, | |||||
| &n_bytes, &column_data_type, &column_data_type_size, | |||||
| &column_shape)); | |||||
| } | } | ||||
| std::shared_ptr<Tensor> tensor; | std::shared_ptr<Tensor> tensor; | ||||
| @@ -309,15 +302,10 @@ Status MindRecordOp::Reset() { | |||||
| } | } | ||||
| Status MindRecordOp::LaunchThreadsAndInitOp() { | Status MindRecordOp::LaunchThreadsAndInitOp() { | ||||
| if (tree_ == nullptr) { | |||||
| RETURN_STATUS_UNEXPECTED("Pipeline init failed, Execution tree not set."); | |||||
| } | |||||
| RETURN_UNEXPECTED_IF_NULL(tree_); | |||||
| RETURN_IF_NOT_OK(io_block_queues_.Register(tree_->AllTasks())); | RETURN_IF_NOT_OK(io_block_queues_.Register(tree_->AllTasks())); | ||||
| RETURN_IF_NOT_OK(wait_for_workers_post_.Register(tree_->AllTasks())); | RETURN_IF_NOT_OK(wait_for_workers_post_.Register(tree_->AllTasks())); | ||||
| if (shard_reader_->Launch(true) == MSRStatus::FAILED) { | |||||
| RETURN_STATUS_UNEXPECTED("MindRecordOp launch failed."); | |||||
| } | |||||
| RETURN_IF_NOT_OK(shard_reader_->Launch(true)); | |||||
| // Launch main workers that load TensorRows by reading all images | // Launch main workers that load TensorRows by reading all images | ||||
| RETURN_IF_NOT_OK( | RETURN_IF_NOT_OK( | ||||
| tree_->LaunchWorkers(num_workers_, std::bind(&MindRecordOp::WorkerEntry, this, std::placeholders::_1), "", id())); | tree_->LaunchWorkers(num_workers_, std::bind(&MindRecordOp::WorkerEntry, this, std::placeholders::_1), "", id())); | ||||
| @@ -330,12 +318,7 @@ Status MindRecordOp::LaunchThreadsAndInitOp() { | |||||
| Status MindRecordOp::CountTotalRows(const std::vector<std::string> dataset_path, bool load_dataset, | Status MindRecordOp::CountTotalRows(const std::vector<std::string> dataset_path, bool load_dataset, | ||||
| const std::shared_ptr<ShardOperator> &op, int64_t *count, int64_t num_padded) { | const std::shared_ptr<ShardOperator> &op, int64_t *count, int64_t num_padded) { | ||||
| std::unique_ptr<ShardReader> shard_reader = std::make_unique<ShardReader>(); | std::unique_ptr<ShardReader> shard_reader = std::make_unique<ShardReader>(); | ||||
| MSRStatus rc = shard_reader->CountTotalRows(dataset_path, load_dataset, op, count, num_padded); | |||||
| if (rc == MSRStatus::FAILED) { | |||||
| RETURN_STATUS_UNEXPECTED( | |||||
| "Invalid data, MindRecordOp failed to count total rows. Check whether there are corresponding .db files " | |||||
| "and the value of dataset_file parameter is given correctly."); | |||||
| } | |||||
| RETURN_IF_NOT_OK(shard_reader->CountTotalRows(dataset_path, load_dataset, op, count, num_padded)); | |||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| @@ -38,9 +38,8 @@ Status GraphFeatureParser::LoadFeatureTensor(const std::string &key, const std:: | |||||
| uint64_t n_bytes = 0, col_type_size = 1; | uint64_t n_bytes = 0, col_type_size = 1; | ||||
| mindrecord::ColumnDataType col_type = mindrecord::ColumnNoDataType; | mindrecord::ColumnDataType col_type = mindrecord::ColumnNoDataType; | ||||
| std::vector<int64_t> column_shape; | std::vector<int64_t> column_shape; | ||||
| MSRStatus rs = shard_column_->GetColumnValueByName(key, col_blob, {}, &data, &data_ptr, &n_bytes, &col_type, | |||||
| &col_type_size, &column_shape); | |||||
| CHECK_FAIL_RETURN_UNEXPECTED(rs == mindrecord::SUCCESS, "fail to load column" + key); | |||||
| RETURN_IF_NOT_OK(shard_column_->GetColumnValueByName(key, col_blob, {}, &data, &data_ptr, &n_bytes, &col_type, | |||||
| &col_type_size, &column_shape)); | |||||
| if (data == nullptr) data = reinterpret_cast<const unsigned char *>(&data_ptr[0]); | if (data == nullptr) data = reinterpret_cast<const unsigned char *>(&data_ptr[0]); | ||||
| RETURN_IF_NOT_OK(Tensor::CreateFromMemory(std::move(TensorShape({static_cast<dsize_t>(n_bytes / col_type_size)})), | RETURN_IF_NOT_OK(Tensor::CreateFromMemory(std::move(TensorShape({static_cast<dsize_t>(n_bytes / col_type_size)})), | ||||
| std::move(DataType(mindrecord::ColumnDataTypeNameNormalized[col_type])), | std::move(DataType(mindrecord::ColumnDataTypeNameNormalized[col_type])), | ||||
| @@ -57,9 +56,8 @@ Status GraphFeatureParser::LoadFeatureToSharedMemory(const std::string &key, con | |||||
| uint64_t n_bytes = 0, col_type_size = 1; | uint64_t n_bytes = 0, col_type_size = 1; | ||||
| mindrecord::ColumnDataType col_type = mindrecord::ColumnNoDataType; | mindrecord::ColumnDataType col_type = mindrecord::ColumnNoDataType; | ||||
| std::vector<int64_t> column_shape; | std::vector<int64_t> column_shape; | ||||
| MSRStatus rs = shard_column_->GetColumnValueByName(key, col_blob, {}, &data, &data_ptr, &n_bytes, &col_type, | |||||
| &col_type_size, &column_shape); | |||||
| CHECK_FAIL_RETURN_UNEXPECTED(rs == mindrecord::SUCCESS, "fail to load column" + key); | |||||
| RETURN_IF_NOT_OK(shard_column_->GetColumnValueByName(key, col_blob, {}, &data, &data_ptr, &n_bytes, &col_type, | |||||
| &col_type_size, &column_shape)); | |||||
| if (data == nullptr) data = reinterpret_cast<const unsigned char *>(&data_ptr[0]); | if (data == nullptr) data = reinterpret_cast<const unsigned char *>(&data_ptr[0]); | ||||
| std::shared_ptr<Tensor> tensor; | std::shared_ptr<Tensor> tensor; | ||||
| RETURN_IF_NOT_OK(Tensor::CreateEmpty(std::move(TensorShape({2})), std::move(DataType(DataType::DE_INT64)), &tensor)); | RETURN_IF_NOT_OK(Tensor::CreateEmpty(std::move(TensorShape({2})), std::move(DataType(DataType::DE_INT64)), &tensor)); | ||||
| @@ -81,9 +79,8 @@ Status GraphFeatureParser::LoadFeatureIndex(const std::string &key, const std::v | |||||
| uint64_t n_bytes = 0, col_type_size = 1; | uint64_t n_bytes = 0, col_type_size = 1; | ||||
| mindrecord::ColumnDataType col_type = mindrecord::ColumnNoDataType; | mindrecord::ColumnDataType col_type = mindrecord::ColumnNoDataType; | ||||
| std::vector<int64_t> column_shape; | std::vector<int64_t> column_shape; | ||||
| MSRStatus rs = shard_column_->GetColumnValueByName(key, col_blob, {}, &data, &data_ptr, &n_bytes, &col_type, | |||||
| &col_type_size, &column_shape); | |||||
| CHECK_FAIL_RETURN_UNEXPECTED(rs == mindrecord::SUCCESS, "fail to load column:" + key); | |||||
| RETURN_IF_NOT_OK(shard_column_->GetColumnValueByName(key, col_blob, {}, &data, &data_ptr, &n_bytes, &col_type, | |||||
| &col_type_size, &column_shape)); | |||||
| if (data == nullptr) data = reinterpret_cast<const unsigned char *>(&data_ptr[0]); | if (data == nullptr) data = reinterpret_cast<const unsigned char *>(&data_ptr[0]); | ||||
| @@ -94,10 +94,9 @@ Status GraphLoader::InitAndLoad() { | |||||
| TaskGroup vg; | TaskGroup vg; | ||||
| shard_reader_ = std::make_unique<ShardReader>(); | shard_reader_ = std::make_unique<ShardReader>(); | ||||
| CHECK_FAIL_RETURN_UNEXPECTED(shard_reader_->Open({mr_path_}, true, num_workers_) == MSRStatus::SUCCESS, | |||||
| "Fail to open" + mr_path_); | |||||
| RETURN_IF_NOT_OK(shard_reader_->Open({mr_path_}, true, num_workers_)); | |||||
| CHECK_FAIL_RETURN_UNEXPECTED(shard_reader_->GetShardHeader()->GetSchemaCount() > 0, "No schema found!"); | CHECK_FAIL_RETURN_UNEXPECTED(shard_reader_->GetShardHeader()->GetSchemaCount() > 0, "No schema found!"); | ||||
| CHECK_FAIL_RETURN_UNEXPECTED(shard_reader_->Launch(true) == MSRStatus::SUCCESS, "fail to launch mr"); | |||||
| RETURN_IF_NOT_OK(shard_reader_->Launch(true)); | |||||
| graph_impl_->data_schema_ = (shard_reader_->GetShardHeader()->GetSchemas()[0]->GetSchema()); | graph_impl_->data_schema_ = (shard_reader_->GetShardHeader()->GetSchemas()[0]->GetSchema()); | ||||
| mindrecord::json schema = graph_impl_->data_schema_["schema"]; | mindrecord::json schema = graph_impl_->data_schema_["schema"]; | ||||
| @@ -116,8 +115,7 @@ Status GraphLoader::InitAndLoad() { | |||||
| if (graph_impl_->server_mode_) { | if (graph_impl_->server_mode_) { | ||||
| #if !defined(_WIN32) && !defined(_WIN64) | #if !defined(_WIN32) && !defined(_WIN64) | ||||
| int64_t total_blob_size = 0; | int64_t total_blob_size = 0; | ||||
| CHECK_FAIL_RETURN_UNEXPECTED(shard_reader_->GetTotalBlobSize(&total_blob_size) == MSRStatus::SUCCESS, | |||||
| "failed to get total blob size"); | |||||
| RETURN_IF_NOT_OK(shard_reader_->GetTotalBlobSize(&total_blob_size)); | |||||
| graph_impl_->graph_shared_memory_ = std::make_unique<GraphSharedMemory>(total_blob_size, mr_path_); | graph_impl_->graph_shared_memory_ = std::make_unique<GraphSharedMemory>(total_blob_size, mr_path_); | ||||
| RETURN_IF_NOT_OK(graph_impl_->graph_shared_memory_->CreateSharedMemory()); | RETURN_IF_NOT_OK(graph_impl_->graph_shared_memory_->CreateSharedMemory()); | ||||
| #endif | #endif | ||||
| @@ -1,83 +0,0 @@ | |||||
| /** | |||||
| * Copyright 2019 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "minddata/mindrecord/include/shard_error.h" | |||||
| namespace mindspore { | |||||
| namespace mindrecord { | |||||
| static const std::map<MSRStatus, std::string> kErrnoToMessage = { | |||||
| {FAILED, "operator failed"}, | |||||
| {SUCCESS, "operator success"}, | |||||
| {OPEN_FILE_FAILED, "open file failed"}, | |||||
| {CLOSE_FILE_FAILED, "close file failed"}, | |||||
| {WRITE_METADATA_FAILED, "write metadata failed"}, | |||||
| {WRITE_RAWDATA_FAILED, "write rawdata failed"}, | |||||
| {GET_SCHEMA_FAILED, "get schema failed"}, | |||||
| {ILLEGAL_RAWDATA, "illegal raw data"}, | |||||
| {PYTHON_TO_JSON_FAILED, "pybind: python object to json failed"}, | |||||
| {DIR_CREATE_FAILED, "directory create failed"}, | |||||
| {OPEN_DIR_FAILED, "open directory failed"}, | |||||
| {INVALID_STATISTICS, "invalid statistics object"}, | |||||
| {OPEN_DATABASE_FAILED, "open database failed"}, | |||||
| {CLOSE_DATABASE_FAILED, "close database failed"}, | |||||
| {DATABASE_OPERATE_FAILED, "database operate failed"}, | |||||
| {BUILD_SCHEMA_FAILED, "build schema failed"}, | |||||
| {DIVISOR_IS_ILLEGAL, "divisor is illegal"}, | |||||
| {INVALID_FILE_PATH, "file path is invalid"}, | |||||
| {SECURE_FUNC_FAILED, "secure function failed"}, | |||||
| {ALLOCATE_MEM_FAILED, "allocate memory failed"}, | |||||
| {ILLEGAL_FIELD_NAME, "illegal field name"}, | |||||
| {ILLEGAL_FIELD_TYPE, "illegal field type"}, | |||||
| {SET_METADATA_FAILED, "set metadata failed"}, | |||||
| {ILLEGAL_SCHEMA_DEFINITION, "illegal schema definition"}, | |||||
| {ILLEGAL_COLUMN_LIST, "illegal column list"}, | |||||
| {SQL_ERROR, "sql error"}, | |||||
| {ILLEGAL_SHARD_COUNT, "illegal shard count"}, | |||||
| {ILLEGAL_SCHEMA_COUNT, "illegal schema count"}, | |||||
| {VERSION_ERROR, "data version is not matched"}, | |||||
| {ADD_SCHEMA_FAILED, "add schema failed"}, | |||||
| {ILLEGAL_Header_SIZE, "illegal header size"}, | |||||
| {ILLEGAL_Page_SIZE, "illegal page size"}, | |||||
| {ILLEGAL_SIZE_VALUE, "illegal size value"}, | |||||
| {INDEX_FIELD_ERROR, "add index fields failed"}, | |||||
| {GET_CANDIDATE_CATEGORYFIELDS_FAILED, "get candidate category fields failed"}, | |||||
| {GET_CATEGORY_INFO_FAILED, "get category information failed"}, | |||||
| {ILLEGAL_CATEGORY_ID, "illegal category id"}, | |||||
| {ILLEGAL_ROWNUMBER_OF_PAGE, "illegal row number of page"}, | |||||
| {ILLEGAL_SCHEMA_ID, "illegal schema id"}, | |||||
| {DESERIALIZE_SCHEMA_FAILED, "deserialize schema failed"}, | |||||
| {DESERIALIZE_STATISTICS_FAILED, "deserialize statistics failed"}, | |||||
| {ILLEGAL_DB_FILE, "illegal db file"}, | |||||
| {OVERWRITE_DB_FILE, "overwrite db file"}, | |||||
| {OVERWRITE_MINDRECORD_FILE, "overwrite mindrecord file"}, | |||||
| {ILLEGAL_MINDRECORD_FILE, "illegal mindrecord file"}, | |||||
| {PARSE_JSON_FAILED, "parse json failed"}, | |||||
| {ILLEGAL_PARAMETERS, "illegal parameters"}, | |||||
| {GET_PAGE_BY_GROUP_ID_FAILED, "get page by group id failed"}, | |||||
| {GET_SYSTEM_STATE_FAILED, "get system state failed"}, | |||||
| {IO_FAILED, "io operate failed"}, | |||||
| {MATCH_HEADER_FAILED, "match header failed"}}; | |||||
| std::string ErrnoToMessage(MSRStatus status) { | |||||
| auto iter = kErrnoToMessage.find(status); | |||||
| if (iter != kErrnoToMessage.end()) { | |||||
| return kErrnoToMessage.at(status); | |||||
| } else { | |||||
| return "invalid error no"; | |||||
| } | |||||
| } | |||||
| } // namespace mindrecord | |||||
| } // namespace mindspore | |||||
| @@ -36,20 +36,42 @@ using mindspore::MsLogLevel::ERROR; | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace mindrecord { | namespace mindrecord { | ||||
| #define THROW_IF_ERROR(s) \ | |||||
| do { \ | |||||
| Status rc = std::move(s); \ | |||||
| if (rc.IsError()) throw std::runtime_error(rc.ToString()); \ | |||||
| } while (false) | |||||
| void BindSchema(py::module *m) { | void BindSchema(py::module *m) { | ||||
| (void)py::class_<Schema, std::shared_ptr<Schema>>(*m, "Schema", py::module_local()) | (void)py::class_<Schema, std::shared_ptr<Schema>>(*m, "Schema", py::module_local()) | ||||
| .def_static("build", (std::shared_ptr<Schema>(*)(std::string, py::handle)) & Schema::Build) | |||||
| .def_static("build", | |||||
| [](const std::string &desc, const pybind11::handle &schema) { | |||||
| json schema_json = nlohmann::detail::ToJsonImpl(schema); | |||||
| return Schema::Build(std::move(desc), schema_json); | |||||
| }) | |||||
| .def("get_desc", &Schema::GetDesc) | .def("get_desc", &Schema::GetDesc) | ||||
| .def("get_schema_content", (py::object(Schema::*)()) & Schema::GetSchemaForPython) | |||||
| .def("get_schema_content", | |||||
| [](Schema &s) { | |||||
| json schema_json = s.GetSchema(); | |||||
| return nlohmann::detail::FromJsonImpl(schema_json); | |||||
| }) | |||||
| .def("get_blob_fields", &Schema::GetBlobFields) | .def("get_blob_fields", &Schema::GetBlobFields) | ||||
| .def("get_schema_id", &Schema::GetSchemaID); | .def("get_schema_id", &Schema::GetSchemaID); | ||||
| } | } | ||||
| void BindStatistics(const py::module *m) { | void BindStatistics(const py::module *m) { | ||||
| (void)py::class_<Statistics, std::shared_ptr<Statistics>>(*m, "Statistics", py::module_local()) | (void)py::class_<Statistics, std::shared_ptr<Statistics>>(*m, "Statistics", py::module_local()) | ||||
| .def_static("build", (std::shared_ptr<Statistics>(*)(std::string, py::handle)) & Statistics::Build) | |||||
| .def_static("build", | |||||
| [](const std::string desc, const pybind11::handle &statistics) { | |||||
| json statistics_json = nlohmann::detail::ToJsonImpl(statistics); | |||||
| return Statistics::Build(std::move(desc), statistics_json); | |||||
| }) | |||||
| .def("get_desc", &Statistics::GetDesc) | .def("get_desc", &Statistics::GetDesc) | ||||
| .def("get_statistics", (py::object(Statistics::*)()) & Statistics::GetStatisticsForPython) | |||||
| .def("get_statistics", | |||||
| [](Statistics &s) { | |||||
| json statistics_json = s.GetStatistics(); | |||||
| return nlohmann::detail::FromJsonImpl(statistics_json); | |||||
| }) | |||||
| .def("get_statistics_id", &Statistics::GetStatisticsID); | .def("get_statistics_id", &Statistics::GetStatisticsID); | ||||
| } | } | ||||
| @@ -59,70 +81,179 @@ void BindShardHeader(const py::module *m) { | |||||
| .def("add_schema", &ShardHeader::AddSchema) | .def("add_schema", &ShardHeader::AddSchema) | ||||
| .def("add_statistics", &ShardHeader::AddStatistic) | .def("add_statistics", &ShardHeader::AddStatistic) | ||||
| .def("add_index_fields", | .def("add_index_fields", | ||||
| (MSRStatus(ShardHeader::*)(const std::vector<std::string> &)) & ShardHeader::AddIndexFields) | |||||
| [](ShardHeader &s, const std::vector<std::string> &fields) { | |||||
| THROW_IF_ERROR(s.AddIndexFields(fields)); | |||||
| return SUCCESS; | |||||
| }) | |||||
| .def("get_meta", &ShardHeader::GetSchemas) | .def("get_meta", &ShardHeader::GetSchemas) | ||||
| .def("get_statistics", &ShardHeader::GetStatistics) | .def("get_statistics", &ShardHeader::GetStatistics) | ||||
| .def("get_fields", &ShardHeader::GetFields) | .def("get_fields", &ShardHeader::GetFields) | ||||
| .def("get_schema_by_id", &ShardHeader::GetSchemaByID) | |||||
| .def("get_statistic_by_id", &ShardHeader::GetStatisticByID); | |||||
| .def("get_schema_by_id", | |||||
| [](ShardHeader &s, int64_t schema_id) { | |||||
| std::shared_ptr<Schema> schema_ptr; | |||||
| THROW_IF_ERROR(s.GetSchemaByID(schema_id, &schema_ptr)); | |||||
| return schema_ptr; | |||||
| }) | |||||
| .def("get_statistic_by_id", [](ShardHeader &s, int64_t statistic_id) { | |||||
| std::shared_ptr<Statistics> statistics_ptr; | |||||
| THROW_IF_ERROR(s.GetStatisticByID(statistic_id, &statistics_ptr)); | |||||
| return statistics_ptr; | |||||
| }); | |||||
| } | } | ||||
| void BindShardWriter(py::module *m) { | void BindShardWriter(py::module *m) { | ||||
| (void)py::class_<ShardWriter>(*m, "ShardWriter", py::module_local()) | (void)py::class_<ShardWriter>(*m, "ShardWriter", py::module_local()) | ||||
| .def(py::init<>()) | .def(py::init<>()) | ||||
| .def("open", &ShardWriter::Open) | |||||
| .def("open_for_append", &ShardWriter::OpenForAppend) | |||||
| .def("set_header_size", &ShardWriter::SetHeaderSize) | |||||
| .def("set_page_size", &ShardWriter::SetPageSize) | |||||
| .def("set_shard_header", &ShardWriter::SetShardHeader) | |||||
| .def("write_raw_data", (MSRStatus(ShardWriter::*)(std::map<uint64_t, std::vector<py::handle>> &, | |||||
| vector<vector<uint8_t>> &, bool, bool)) & | |||||
| ShardWriter::WriteRawData) | |||||
| .def("commit", &ShardWriter::Commit); | |||||
| .def("open", | |||||
| [](ShardWriter &s, const std::vector<std::string> &paths, bool append) { | |||||
| THROW_IF_ERROR(s.Open(paths, append)); | |||||
| return SUCCESS; | |||||
| }) | |||||
| .def("open_for_append", | |||||
| [](ShardWriter &s, const std::string &path) { | |||||
| THROW_IF_ERROR(s.OpenForAppend(path)); | |||||
| return SUCCESS; | |||||
| }) | |||||
| .def("set_header_size", | |||||
| [](ShardWriter &s, const uint64_t &header_size) { | |||||
| THROW_IF_ERROR(s.SetHeaderSize(header_size)); | |||||
| return SUCCESS; | |||||
| }) | |||||
| .def("set_page_size", | |||||
| [](ShardWriter &s, const uint64_t &page_size) { | |||||
| THROW_IF_ERROR(s.SetPageSize(page_size)); | |||||
| return SUCCESS; | |||||
| }) | |||||
| .def("set_shard_header", | |||||
| [](ShardWriter &s, std::shared_ptr<ShardHeader> header_data) { | |||||
| THROW_IF_ERROR(s.SetShardHeader(header_data)); | |||||
| return SUCCESS; | |||||
| }) | |||||
| .def("write_raw_data", | |||||
| [](ShardWriter &s, std::map<uint64_t, std::vector<py::handle>> &raw_data, vector<vector<uint8_t>> &blob_data, | |||||
| bool sign, bool parallel_writer) { | |||||
| std::map<uint64_t, std::vector<json>> raw_data_json; | |||||
| (void)std::transform(raw_data.begin(), raw_data.end(), std::inserter(raw_data_json, raw_data_json.end()), | |||||
| [](const std::pair<uint64_t, std::vector<py::handle>> &p) { | |||||
| auto &py_raw_data = p.second; | |||||
| std::vector<json> json_raw_data; | |||||
| (void)std::transform( | |||||
| py_raw_data.begin(), py_raw_data.end(), std::back_inserter(json_raw_data), | |||||
| [](const py::handle &obj) { return nlohmann::detail::ToJsonImpl(obj); }); | |||||
| return std::make_pair(p.first, std::move(json_raw_data)); | |||||
| }); | |||||
| THROW_IF_ERROR(s.WriteRawData(raw_data_json, blob_data, sign, parallel_writer)); | |||||
| return SUCCESS; | |||||
| }) | |||||
| .def("commit", [](ShardWriter &s) { | |||||
| THROW_IF_ERROR(s.Commit()); | |||||
| return SUCCESS; | |||||
| }); | |||||
| } | } | ||||
| void BindShardReader(const py::module *m) { | void BindShardReader(const py::module *m) { | ||||
| (void)py::class_<ShardReader, std::shared_ptr<ShardReader>>(*m, "ShardReader", py::module_local()) | (void)py::class_<ShardReader, std::shared_ptr<ShardReader>>(*m, "ShardReader", py::module_local()) | ||||
| .def(py::init<>()) | .def(py::init<>()) | ||||
| .def("open", (MSRStatus(ShardReader::*)(const std::vector<std::string> &, bool, const int &, | |||||
| const std::vector<std::string> &, | |||||
| const std::vector<std::shared_ptr<ShardOperator>> &)) & | |||||
| ShardReader::OpenPy) | |||||
| .def("launch", &ShardReader::Launch) | |||||
| .def("open", | |||||
| [](ShardReader &s, const std::vector<std::string> &file_paths, bool load_dataset, const int &n_consumer, | |||||
| const std::vector<std::string> &selected_columns, | |||||
| const std::vector<std::shared_ptr<ShardOperator>> &operators) { | |||||
| THROW_IF_ERROR(s.Open(file_paths, load_dataset, n_consumer, selected_columns, operators)); | |||||
| return SUCCESS; | |||||
| }) | |||||
| .def("launch", | |||||
| [](ShardReader &s) { | |||||
| THROW_IF_ERROR(s.Launch(false)); | |||||
| return SUCCESS; | |||||
| }) | |||||
| .def("get_header", &ShardReader::GetShardHeader) | .def("get_header", &ShardReader::GetShardHeader) | ||||
| .def("get_blob_fields", &ShardReader::GetBlobFields) | .def("get_blob_fields", &ShardReader::GetBlobFields) | ||||
| .def("get_next", (std::vector<std::tuple<std::vector<std::vector<uint8_t>>, pybind11::object>>(ShardReader::*)()) & | |||||
| ShardReader::GetNextPy) | |||||
| .def("get_next", | |||||
| [](ShardReader &s) { | |||||
| auto data = s.GetNext(); | |||||
| vector<std::tuple<std::vector<std::vector<uint8_t>>, pybind11::object>> res; | |||||
| std::transform(data.begin(), data.end(), std::back_inserter(res), | |||||
| [&s](const std::tuple<std::vector<uint8_t>, json> &item) { | |||||
| auto &j = std::get<1>(item); | |||||
| pybind11::object obj = nlohmann::detail::FromJsonImpl(j); | |||||
| auto blob_data_ptr = std::make_shared<std::vector<std::vector<uint8_t>>>(); | |||||
| (void)s.UnCompressBlob(std::get<0>(item), &blob_data_ptr); | |||||
| return std::make_tuple(*blob_data_ptr, std::move(obj)); | |||||
| }); | |||||
| return res; | |||||
| }) | |||||
| .def("close", &ShardReader::Close); | .def("close", &ShardReader::Close); | ||||
| } | } | ||||
| void BindShardIndexGenerator(const py::module *m) { | void BindShardIndexGenerator(const py::module *m) { | ||||
| (void)py::class_<ShardIndexGenerator>(*m, "ShardIndexGenerator", py::module_local()) | (void)py::class_<ShardIndexGenerator>(*m, "ShardIndexGenerator", py::module_local()) | ||||
| .def(py::init<const std::string &, bool>()) | .def(py::init<const std::string &, bool>()) | ||||
| .def("build", &ShardIndexGenerator::Build) | |||||
| .def("write_to_db", &ShardIndexGenerator::WriteToDatabase); | |||||
| .def("build", | |||||
| [](ShardIndexGenerator &s) { | |||||
| THROW_IF_ERROR(s.Build()); | |||||
| return SUCCESS; | |||||
| }) | |||||
| .def("write_to_db", [](ShardIndexGenerator &s) { | |||||
| THROW_IF_ERROR(s.WriteToDatabase()); | |||||
| return SUCCESS; | |||||
| }); | |||||
| } | } | ||||
| void BindShardSegment(py::module *m) { | void BindShardSegment(py::module *m) { | ||||
| (void)py::class_<ShardSegment>(*m, "ShardSegment", py::module_local()) | (void)py::class_<ShardSegment>(*m, "ShardSegment", py::module_local()) | ||||
| .def(py::init<>()) | .def(py::init<>()) | ||||
| .def("open", (MSRStatus(ShardSegment::*)(const std::vector<std::string> &, bool, const int &, | |||||
| const std::vector<std::string> &, | |||||
| const std::vector<std::shared_ptr<ShardOperator>> &)) & | |||||
| ShardSegment::OpenPy) | |||||
| .def("open", | |||||
| [](ShardSegment &s, const std::vector<std::string> &file_paths, bool load_dataset, const int &n_consumer, | |||||
| const std::vector<std::string> &selected_columns, | |||||
| const std::vector<std::shared_ptr<ShardOperator>> &operators) { | |||||
| THROW_IF_ERROR(s.Open(file_paths, load_dataset, n_consumer, selected_columns, operators)); | |||||
| return SUCCESS; | |||||
| }) | |||||
| .def("get_category_fields", | .def("get_category_fields", | ||||
| (std::pair<MSRStatus, vector<std::string>>(ShardSegment::*)()) & ShardSegment::GetCategoryFields) | |||||
| .def("set_category_field", (MSRStatus(ShardSegment::*)(std::string)) & ShardSegment::SetCategoryField) | |||||
| .def("read_category_info", (std::pair<MSRStatus, std::string>(ShardSegment::*)()) & ShardSegment::ReadCategoryInfo) | |||||
| .def("read_at_page_by_id", (std::pair<MSRStatus, std::vector<std::tuple<std::vector<uint8_t>, pybind11::object>>>( | |||||
| ShardSegment::*)(int64_t, int64_t, int64_t)) & | |||||
| ShardSegment::ReadAtPageByIdPy) | |||||
| .def("read_at_page_by_name", (std::pair<MSRStatus, std::vector<std::tuple<std::vector<uint8_t>, pybind11::object>>>( | |||||
| ShardSegment::*)(std::string, int64_t, int64_t)) & | |||||
| ShardSegment::ReadAtPageByNamePy) | |||||
| [](ShardSegment &s) { | |||||
| auto fields_ptr = std::make_shared<vector<std::string>>(); | |||||
| THROW_IF_ERROR(s.GetCategoryFields(&fields_ptr)); | |||||
| return *fields_ptr; | |||||
| }) | |||||
| .def("set_category_field", | |||||
| [](ShardSegment &s, const std::string &category_field) { | |||||
| THROW_IF_ERROR(s.SetCategoryField(category_field)); | |||||
| return SUCCESS; | |||||
| }) | |||||
| .def("read_category_info", | |||||
| [](ShardSegment &s) { | |||||
| std::shared_ptr<std::string> category_ptr; | |||||
| THROW_IF_ERROR(s.ReadCategoryInfo(&category_ptr)); | |||||
| return *category_ptr; | |||||
| }) | |||||
| .def("read_at_page_by_id", | |||||
| [](ShardSegment &s, int64_t category_id, int64_t page_no, int64_t n_rows_of_page) { | |||||
| auto pages_load_ptr = std::make_shared<PAGES_LOAD>(); | |||||
| auto pages_ptr = std::make_shared<PAGES>(); | |||||
| THROW_IF_ERROR(s.ReadAllAtPageById(category_id, page_no, n_rows_of_page, &pages_ptr)); | |||||
| (void)std::transform(pages_ptr->begin(), pages_ptr->end(), std::back_inserter(*pages_load_ptr), | |||||
| [](const std::tuple<std::vector<uint8_t>, json> &item) { | |||||
| auto &j = std::get<1>(item); | |||||
| pybind11::object obj = nlohmann::detail::FromJsonImpl(j); | |||||
| return std::make_tuple(std::get<0>(item), std::move(obj)); | |||||
| }); | |||||
| return *pages_load_ptr; | |||||
| }) | |||||
| .def("read_at_page_by_name", | |||||
| [](ShardSegment &s, std::string category_name, int64_t page_no, int64_t n_rows_of_page) { | |||||
| auto pages_load_ptr = std::make_shared<PAGES_LOAD>(); | |||||
| auto pages_ptr = std::make_shared<PAGES>(); | |||||
| THROW_IF_ERROR(s.ReadAllAtPageByName(category_name, page_no, n_rows_of_page, &pages_ptr)); | |||||
| (void)std::transform(pages_ptr->begin(), pages_ptr->end(), std::back_inserter(*pages_load_ptr), | |||||
| [](const std::tuple<std::vector<uint8_t>, json> &item) { | |||||
| auto &j = std::get<1>(item); | |||||
| pybind11::object obj = nlohmann::detail::FromJsonImpl(j); | |||||
| return std::make_tuple(std::get<0>(item), std::move(obj)); | |||||
| }); | |||||
| return *pages_load_ptr; | |||||
| }) | |||||
| .def("get_header", &ShardSegment::GetShardHeader) | .def("get_header", &ShardSegment::GetShardHeader) | ||||
| .def("get_blob_fields", | |||||
| (std::pair<ShardType, std::vector<std::string>>(ShardSegment::*)()) & ShardSegment::GetBlobFields); | |||||
| .def("get_blob_fields", [](ShardSegment &s) { return s.GetBlobFields(); }); | |||||
| } | } | ||||
| void BindGlobalParams(py::module *m) { | void BindGlobalParams(py::module *m) { | ||||
| @@ -57,26 +57,24 @@ bool ValidateFieldName(const std::string &str) { | |||||
| return true; | return true; | ||||
| } | } | ||||
| std::pair<MSRStatus, std::string> GetFileName(const std::string &path) { | |||||
| Status GetFileName(const std::string &path, std::shared_ptr<std::string> *fn_ptr) { | |||||
| RETURN_UNEXPECTED_IF_NULL(fn_ptr); | |||||
| char real_path[PATH_MAX] = {0}; | char real_path[PATH_MAX] = {0}; | ||||
| char buf[PATH_MAX] = {0}; | char buf[PATH_MAX] = {0}; | ||||
| if (strncpy_s(buf, PATH_MAX, common::SafeCStr(path), path.length()) != EOK) { | if (strncpy_s(buf, PATH_MAX, common::SafeCStr(path), path.length()) != EOK) { | ||||
| MS_LOG(ERROR) << "Securec func [strncpy_s] failed, path: " << path; | |||||
| return {FAILED, ""}; | |||||
| RETURN_STATUS_UNEXPECTED("Securec func [strncpy_s] failed, path: " + path); | |||||
| } | } | ||||
| char tmp[PATH_MAX] = {0}; | char tmp[PATH_MAX] = {0}; | ||||
| #if defined(_WIN32) || defined(_WIN64) | #if defined(_WIN32) || defined(_WIN64) | ||||
| if (_fullpath(tmp, dirname(&(buf[0])), PATH_MAX) == nullptr) { | if (_fullpath(tmp, dirname(&(buf[0])), PATH_MAX) == nullptr) { | ||||
| MS_LOG(ERROR) << "Invalid file path, path: " << buf; | |||||
| return {FAILED, ""}; | |||||
| RETURN_STATUS_UNEXPECTED("Invalid file path, path: " + std::string(buf)); | |||||
| } | } | ||||
| if (_fullpath(real_path, common::SafeCStr(path), PATH_MAX) == nullptr) { | if (_fullpath(real_path, common::SafeCStr(path), PATH_MAX) == nullptr) { | ||||
| MS_LOG(DEBUG) << "Path: " << common::SafeCStr(path) << "check successfully"; | MS_LOG(DEBUG) << "Path: " << common::SafeCStr(path) << "check successfully"; | ||||
| } | } | ||||
| #else | #else | ||||
| if (realpath(dirname(&(buf[0])), tmp) == nullptr) { | if (realpath(dirname(&(buf[0])), tmp) == nullptr) { | ||||
| MS_LOG(ERROR) << "Invalid file path, path: " << buf; | |||||
| return {FAILED, ""}; | |||||
| RETURN_STATUS_UNEXPECTED(std::string("Invalid file path, path: ") + buf); | |||||
| } | } | ||||
| if (realpath(common::SafeCStr(path), real_path) == nullptr) { | if (realpath(common::SafeCStr(path), real_path) == nullptr) { | ||||
| MS_LOG(DEBUG) << "Path: " << path << "check successfully"; | MS_LOG(DEBUG) << "Path: " << path << "check successfully"; | ||||
| @@ -87,32 +85,32 @@ std::pair<MSRStatus, std::string> GetFileName(const std::string &path) { | |||||
| size_t i = s.rfind(sep, s.length()); | size_t i = s.rfind(sep, s.length()); | ||||
| if (i != std::string::npos) { | if (i != std::string::npos) { | ||||
| if (i + 1 < s.size()) { | if (i + 1 < s.size()) { | ||||
| return {SUCCESS, s.substr(i + 1)}; | |||||
| *fn_ptr = std::make_shared<std::string>(s.substr(i + 1)); | |||||
| return Status::OK(); | |||||
| } | } | ||||
| } | } | ||||
| return {SUCCESS, s}; | |||||
| *fn_ptr = std::make_shared<std::string>(s); | |||||
| return Status::OK(); | |||||
| } | } | ||||
| std::pair<MSRStatus, std::string> GetParentDir(const std::string &path) { | |||||
| Status GetParentDir(const std::string &path, std::shared_ptr<std::string> *pd_ptr) { | |||||
| RETURN_UNEXPECTED_IF_NULL(pd_ptr); | |||||
| char real_path[PATH_MAX] = {0}; | char real_path[PATH_MAX] = {0}; | ||||
| char buf[PATH_MAX] = {0}; | char buf[PATH_MAX] = {0}; | ||||
| if (strncpy_s(buf, PATH_MAX, common::SafeCStr(path), path.length()) != EOK) { | if (strncpy_s(buf, PATH_MAX, common::SafeCStr(path), path.length()) != EOK) { | ||||
| MS_LOG(ERROR) << "Securec func [strncpy_s] failed, path: " << path; | |||||
| return {FAILED, ""}; | |||||
| RETURN_STATUS_UNEXPECTED("Securec func [strncpy_s] failed, path: " + path); | |||||
| } | } | ||||
| char tmp[PATH_MAX] = {0}; | char tmp[PATH_MAX] = {0}; | ||||
| #if defined(_WIN32) || defined(_WIN64) | #if defined(_WIN32) || defined(_WIN64) | ||||
| if (_fullpath(tmp, dirname(&(buf[0])), PATH_MAX) == nullptr) { | if (_fullpath(tmp, dirname(&(buf[0])), PATH_MAX) == nullptr) { | ||||
| MS_LOG(ERROR) << "Invalid file path, path: " << buf; | |||||
| return {FAILED, ""}; | |||||
| RETURN_STATUS_UNEXPECTED("Invalid file path, path: " + std::string(buf)); | |||||
| } | } | ||||
| if (_fullpath(real_path, common::SafeCStr(path), PATH_MAX) == nullptr) { | if (_fullpath(real_path, common::SafeCStr(path), PATH_MAX) == nullptr) { | ||||
| MS_LOG(DEBUG) << "Path: " << common::SafeCStr(path) << "check successfully"; | MS_LOG(DEBUG) << "Path: " << common::SafeCStr(path) << "check successfully"; | ||||
| } | } | ||||
| #else | #else | ||||
| if (realpath(dirname(&(buf[0])), tmp) == nullptr) { | if (realpath(dirname(&(buf[0])), tmp) == nullptr) { | ||||
| MS_LOG(ERROR) << "Invalid file path, path: " << buf; | |||||
| return {FAILED, ""}; | |||||
| RETURN_STATUS_UNEXPECTED(std::string("Invalid file path, path: ") + buf); | |||||
| } | } | ||||
| if (realpath(common::SafeCStr(path), real_path) == nullptr) { | if (realpath(common::SafeCStr(path), real_path) == nullptr) { | ||||
| MS_LOG(DEBUG) << "Path: " << path << "check successfully"; | MS_LOG(DEBUG) << "Path: " << path << "check successfully"; | ||||
| @@ -120,9 +118,11 @@ std::pair<MSRStatus, std::string> GetParentDir(const std::string &path) { | |||||
| #endif | #endif | ||||
| std::string s = real_path; | std::string s = real_path; | ||||
| if (s.rfind('/') + 1 <= s.size()) { | if (s.rfind('/') + 1 <= s.size()) { | ||||
| return {SUCCESS, s.substr(0, s.rfind('/') + 1)}; | |||||
| *pd_ptr = std::make_shared<std::string>(s.substr(0, s.rfind('/') + 1)); | |||||
| return Status::OK(); | |||||
| } | } | ||||
| return {SUCCESS, "/"}; | |||||
| *pd_ptr = std::make_shared<std::string>("/"); | |||||
| return Status::OK(); | |||||
| } | } | ||||
| bool CheckIsValidUtf8(const std::string &str) { | bool CheckIsValidUtf8(const std::string &str) { | ||||
| @@ -163,15 +163,16 @@ bool IsLegalFile(const std::string &path) { | |||||
| return false; | return false; | ||||
| } | } | ||||
| std::pair<MSRStatus, uint64_t> GetDiskSize(const std::string &str_dir, const DiskSizeType &disk_type) { | |||||
| Status GetDiskSize(const std::string &str_dir, const DiskSizeType &disk_type, std::shared_ptr<uint64_t> *size_ptr) { | |||||
| RETURN_UNEXPECTED_IF_NULL(size_ptr); | |||||
| #if defined(_WIN32) || defined(_WIN64) || defined(__APPLE__) | #if defined(_WIN32) || defined(_WIN64) || defined(__APPLE__) | ||||
| return {SUCCESS, 100}; | |||||
| *size_ptr = std::make_shared<uint64_t>(100); | |||||
| return Status::OK(); | |||||
| #else | #else | ||||
| uint64_t ll_count = 0; | uint64_t ll_count = 0; | ||||
| struct statfs disk_info; | struct statfs disk_info; | ||||
| if (statfs(common::SafeCStr(str_dir), &disk_info) == -1) { | if (statfs(common::SafeCStr(str_dir), &disk_info) == -1) { | ||||
| MS_LOG(ERROR) << "Get disk size error"; | |||||
| return {FAILED, 0}; | |||||
| RETURN_STATUS_UNEXPECTED("Get disk size error."); | |||||
| } | } | ||||
| switch (disk_type) { | switch (disk_type) { | ||||
| @@ -187,8 +188,8 @@ std::pair<MSRStatus, uint64_t> GetDiskSize(const std::string &str_dir, const Dis | |||||
| ll_count = 0; | ll_count = 0; | ||||
| break; | break; | ||||
| } | } | ||||
| return {SUCCESS, ll_count}; | |||||
| *size_ptr = std::make_shared<uint64_t>(ll_count); | |||||
| return Status::OK(); | |||||
| #endif | #endif | ||||
| } | } | ||||
| @@ -201,17 +202,15 @@ uint32_t GetMaxThreadNum() { | |||||
| return thread_num; | return thread_num; | ||||
| } | } | ||||
| std::pair<MSRStatus, std::vector<std::string>> GetDatasetFiles(const std::string &path, const json &addresses) { | |||||
| auto ret = GetParentDir(path); | |||||
| if (SUCCESS != ret.first) { | |||||
| return {FAILED, {}}; | |||||
| } | |||||
| std::vector<std::string> abs_addresses; | |||||
| Status GetDatasetFiles(const std::string &path, const json &addresses, std::shared_ptr<std::vector<std::string>> *ds) { | |||||
| RETURN_UNEXPECTED_IF_NULL(ds); | |||||
| std::shared_ptr<std::string> parent_dir; | |||||
| RETURN_IF_NOT_OK(GetParentDir(path, &parent_dir)); | |||||
| for (const auto &p : addresses) { | for (const auto &p : addresses) { | ||||
| std::string abs_path = ret.second + std::string(p); | |||||
| abs_addresses.emplace_back(abs_path); | |||||
| std::string abs_path = *parent_dir + std::string(p); | |||||
| (*ds)->emplace_back(abs_path); | |||||
| } | } | ||||
| return {SUCCESS, abs_addresses}; | |||||
| return Status::OK(); | |||||
| } | } | ||||
| } // namespace mindrecord | } // namespace mindrecord | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -33,6 +33,7 @@ | |||||
| #include <future> | #include <future> | ||||
| #include <iostream> | #include <iostream> | ||||
| #include <map> | #include <map> | ||||
| #include <memory> | |||||
| #include <random> | #include <random> | ||||
| #include <set> | #include <set> | ||||
| #include <sstream> | #include <sstream> | ||||
| @@ -159,13 +160,15 @@ bool ValidateFieldName(const std::string &str); | |||||
| /// \brief get the filename by the path | /// \brief get the filename by the path | ||||
| /// \param s file path | /// \param s file path | ||||
| /// \return | |||||
| std::pair<MSRStatus, std::string> GetFileName(const std::string &s); | |||||
| /// \param fn_ptr shared ptr of file name | |||||
| /// \return Status | |||||
| Status GetFileName(const std::string &path, std::shared_ptr<std::string> *fn_ptr); | |||||
| /// \brief get parent dir | /// \brief get parent dir | ||||
| /// \param path file path | /// \param path file path | ||||
| /// \return parent path | |||||
| std::pair<MSRStatus, std::string> GetParentDir(const std::string &path); | |||||
| /// \param pd_ptr shared ptr of parent path | |||||
| /// \return Status | |||||
| Status GetParentDir(const std::string &path, std::shared_ptr<std::string> *pd_ptr); | |||||
| bool CheckIsValidUtf8(const std::string &str); | bool CheckIsValidUtf8(const std::string &str); | ||||
| @@ -179,8 +182,9 @@ enum DiskSizeType { kTotalSize = 0, kFreeSize }; | |||||
| /// \brief get the free space about the disk | /// \brief get the free space about the disk | ||||
| /// \param str_dir file path | /// \param str_dir file path | ||||
| /// \param disk_type: kTotalSize / kFreeSize | /// \param disk_type: kTotalSize / kFreeSize | ||||
| /// \return size in Megabytes | |||||
| std::pair<MSRStatus, uint64_t> GetDiskSize(const std::string &str_dir, const DiskSizeType &disk_type); | |||||
| /// \param size: shared ptr of size in Megabytes | |||||
| /// \return Status | |||||
| Status GetDiskSize(const std::string &str_dir, const DiskSizeType &disk_type, std::shared_ptr<uint64_t> *size); | |||||
| /// \brief get the max hardware concurrency | /// \brief get the max hardware concurrency | ||||
| /// \return max concurrency | /// \return max concurrency | ||||
| @@ -189,8 +193,9 @@ uint32_t GetMaxThreadNum(); | |||||
| /// \brief get absolute path of all mindrecord files | /// \brief get absolute path of all mindrecord files | ||||
| /// \param path path to one fo mindrecord files | /// \param path path to one fo mindrecord files | ||||
| /// \param addresses relative path of all mindrecord files | /// \param addresses relative path of all mindrecord files | ||||
| /// \return vector of absolute path | |||||
| std::pair<MSRStatus, std::vector<std::string>> GetDatasetFiles(const std::string &path, const json &addresses); | |||||
| /// \param ds shared ptr of vector of absolute path | |||||
| /// \return Status | |||||
| Status GetDatasetFiles(const std::string &path, const json &addresses, std::shared_ptr<std::vector<std::string>> *ds); | |||||
| } // namespace mindrecord | } // namespace mindrecord | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -46,7 +46,7 @@ class __attribute__((visibility("default"))) ShardCategory : public ShardOperato | |||||
| bool GetReplacement() const { return replacement_; } | bool GetReplacement() const { return replacement_; } | ||||
| MSRStatus Execute(ShardTaskList &tasks) override; | |||||
| Status Execute(ShardTaskList &tasks) override; | |||||
| int64_t GetNumSamples(int64_t dataset_size, int64_t num_classes) override; | int64_t GetNumSamples(int64_t dataset_size, int64_t num_classes) override; | ||||
| @@ -65,11 +65,11 @@ class __attribute__((visibility("default"))) ShardColumn { | |||||
| ~ShardColumn() = default; | ~ShardColumn() = default; | ||||
| /// \brief get column value by column name | /// \brief get column value by column name | ||||
| MSRStatus GetColumnValueByName(const std::string &column_name, const std::vector<uint8_t> &columns_blob, | |||||
| const json &columns_json, const unsigned char **data, | |||||
| std::unique_ptr<unsigned char[]> *data_ptr, uint64_t *const n_bytes, | |||||
| ColumnDataType *column_data_type, uint64_t *column_data_type_size, | |||||
| std::vector<int64_t> *column_shape); | |||||
| Status GetColumnValueByName(const std::string &column_name, const std::vector<uint8_t> &columns_blob, | |||||
| const json &columns_json, const unsigned char **data, | |||||
| std::unique_ptr<unsigned char[]> *data_ptr, uint64_t *const n_bytes, | |||||
| ColumnDataType *column_data_type, uint64_t *column_data_type_size, | |||||
| std::vector<int64_t> *column_shape); | |||||
| /// \brief compress blob | /// \brief compress blob | ||||
| std::vector<uint8_t> CompressBlob(const std::vector<uint8_t> &blob, int64_t *compression_size); | std::vector<uint8_t> CompressBlob(const std::vector<uint8_t> &blob, int64_t *compression_size); | ||||
| @@ -90,19 +90,18 @@ class __attribute__((visibility("default"))) ShardColumn { | |||||
| std::vector<std::vector<int64_t>> GetColumnShape() { return column_shape_; } | std::vector<std::vector<int64_t>> GetColumnShape() { return column_shape_; } | ||||
| /// \brief get column value from blob | /// \brief get column value from blob | ||||
| MSRStatus GetColumnFromBlob(const std::string &column_name, const std::vector<uint8_t> &columns_blob, | |||||
| const unsigned char **data, std::unique_ptr<unsigned char[]> *data_ptr, | |||||
| uint64_t *const n_bytes); | |||||
| Status GetColumnFromBlob(const std::string &column_name, const std::vector<uint8_t> &columns_blob, | |||||
| const unsigned char **data, std::unique_ptr<unsigned char[]> *data_ptr, | |||||
| uint64_t *const n_bytes); | |||||
| /// \brief get column type | /// \brief get column type | ||||
| std::pair<MSRStatus, ColumnCategory> GetColumnTypeByName(const std::string &column_name, | |||||
| ColumnDataType *column_data_type, | |||||
| uint64_t *column_data_type_size, | |||||
| std::vector<int64_t> *column_shape); | |||||
| Status GetColumnTypeByName(const std::string &column_name, ColumnDataType *column_data_type, | |||||
| uint64_t *column_data_type_size, std::vector<int64_t> *column_shape, | |||||
| ColumnCategory *column_category); | |||||
| /// \brief get column value from json | /// \brief get column value from json | ||||
| MSRStatus GetColumnFromJson(const std::string &column_name, const json &columns_json, | |||||
| std::unique_ptr<unsigned char[]> *data_ptr, uint64_t *n_bytes); | |||||
| Status GetColumnFromJson(const std::string &column_name, const json &columns_json, | |||||
| std::unique_ptr<unsigned char[]> *data_ptr, uint64_t *n_bytes); | |||||
| private: | private: | ||||
| /// \brief initialization | /// \brief initialization | ||||
| @@ -110,15 +109,15 @@ class __attribute__((visibility("default"))) ShardColumn { | |||||
| /// \brief get float value from json | /// \brief get float value from json | ||||
| template <typename T> | template <typename T> | ||||
| MSRStatus GetFloat(std::unique_ptr<unsigned char[]> *data_ptr, const json &json_column_value, bool use_double); | |||||
| Status GetFloat(std::unique_ptr<unsigned char[]> *data_ptr, const json &json_column_value, bool use_double); | |||||
| /// \brief get integer value from json | /// \brief get integer value from json | ||||
| template <typename T> | template <typename T> | ||||
| MSRStatus GetInt(std::unique_ptr<unsigned char[]> *data_ptr, const json &json_column_value); | |||||
| Status GetInt(std::unique_ptr<unsigned char[]> *data_ptr, const json &json_column_value); | |||||
| /// \brief get column offset address and size from blob | /// \brief get column offset address and size from blob | ||||
| MSRStatus GetColumnAddressInBlock(const uint64_t &column_id, const std::vector<uint8_t> &columns_blob, | |||||
| uint64_t *num_bytes, uint64_t *shift_idx); | |||||
| Status GetColumnAddressInBlock(const uint64_t &column_id, const std::vector<uint8_t> &columns_blob, | |||||
| uint64_t *num_bytes, uint64_t *shift_idx); | |||||
| /// \brief check if column name is available | /// \brief check if column name is available | ||||
| ColumnCategory CheckColumnName(const std::string &column_name); | ColumnCategory CheckColumnName(const std::string &column_name); | ||||
| @@ -128,8 +127,8 @@ class __attribute__((visibility("default"))) ShardColumn { | |||||
| /// \brief uncompress integer array column | /// \brief uncompress integer array column | ||||
| template <typename T> | template <typename T> | ||||
| static MSRStatus UncompressInt(const uint64_t &column_id, std::unique_ptr<unsigned char[]> *const data_ptr, | |||||
| const std::vector<uint8_t> &columns_blob, uint64_t *num_bytes, uint64_t shift_idx); | |||||
| static Status UncompressInt(const uint64_t &column_id, std::unique_ptr<unsigned char[]> *const data_ptr, | |||||
| const std::vector<uint8_t> &columns_blob, uint64_t *num_bytes, uint64_t shift_idx); | |||||
| /// \brief convert big-endian bytes to unsigned int | /// \brief convert big-endian bytes to unsigned int | ||||
| /// \param bytes_array bytes array | /// \param bytes_array bytes array | ||||
| @@ -39,7 +39,7 @@ class __attribute__((visibility("default"))) ShardDistributedSample : public Sha | |||||
| ~ShardDistributedSample() override{}; | ~ShardDistributedSample() override{}; | ||||
| MSRStatus PreExecute(ShardTaskList &tasks) override; | |||||
| Status PreExecute(ShardTaskList &tasks) override; | |||||
| int64_t GetNumSamples(int64_t dataset_size, int64_t num_classes) override; | int64_t GetNumSamples(int64_t dataset_size, int64_t num_classes) override; | ||||
| @@ -19,65 +19,55 @@ | |||||
| #include <map> | #include <map> | ||||
| #include <string> | #include <string> | ||||
| #include "include/api/status.h" | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace mindrecord { | namespace mindrecord { | ||||
| #define RETURN_IF_NOT_OK(_s) \ | |||||
| do { \ | |||||
| Status __rc = (_s); \ | |||||
| if (__rc.IsError()) { \ | |||||
| return __rc; \ | |||||
| } \ | |||||
| } while (false) | |||||
| #define RELEASE_AND_RETURN_IF_NOT_OK(_s, _db, _in) \ | |||||
| do { \ | |||||
| Status __rc = (_s); \ | |||||
| if (__rc.IsError()) { \ | |||||
| if ((_db) != nullptr) { \ | |||||
| sqlite3_close(_db); \ | |||||
| } \ | |||||
| (_in).close(); \ | |||||
| return __rc; \ | |||||
| } \ | |||||
| } while (false) | |||||
| #define CHECK_FAIL_RETURN_UNEXPECTED(_condition, _e) \ | |||||
| do { \ | |||||
| if (!(_condition)) { \ | |||||
| return Status(StatusCode::kMDUnexpectedError, __LINE__, __FILE__, _e); \ | |||||
| } \ | |||||
| } while (false) | |||||
| #define RETURN_UNEXPECTED_IF_NULL(_ptr) \ | |||||
| do { \ | |||||
| if ((_ptr) == nullptr) { \ | |||||
| std::string err_msg = "The pointer[" + std::string(#_ptr) + "] is null."; \ | |||||
| RETURN_STATUS_UNEXPECTED(err_msg); \ | |||||
| } \ | |||||
| } while (false) | |||||
| #define RETURN_STATUS_UNEXPECTED(_e) \ | |||||
| do { \ | |||||
| return Status(StatusCode::kMDUnexpectedError, __LINE__, __FILE__, _e); \ | |||||
| } while (false) | |||||
| enum MSRStatus { | enum MSRStatus { | ||||
| SUCCESS = 0, | SUCCESS = 0, | ||||
| FAILED = 1, | FAILED = 1, | ||||
| OPEN_FILE_FAILED, | |||||
| CLOSE_FILE_FAILED, | |||||
| WRITE_METADATA_FAILED, | |||||
| WRITE_RAWDATA_FAILED, | |||||
| GET_SCHEMA_FAILED, | |||||
| ILLEGAL_RAWDATA, | |||||
| PYTHON_TO_JSON_FAILED, | |||||
| DIR_CREATE_FAILED, | |||||
| OPEN_DIR_FAILED, | |||||
| INVALID_STATISTICS, | |||||
| OPEN_DATABASE_FAILED, | |||||
| CLOSE_DATABASE_FAILED, | |||||
| DATABASE_OPERATE_FAILED, | |||||
| BUILD_SCHEMA_FAILED, | |||||
| DIVISOR_IS_ILLEGAL, | |||||
| INVALID_FILE_PATH, | |||||
| SECURE_FUNC_FAILED, | |||||
| ALLOCATE_MEM_FAILED, | |||||
| ILLEGAL_FIELD_NAME, | |||||
| ILLEGAL_FIELD_TYPE, | |||||
| SET_METADATA_FAILED, | |||||
| ILLEGAL_SCHEMA_DEFINITION, | |||||
| ILLEGAL_COLUMN_LIST, | |||||
| SQL_ERROR, | |||||
| ILLEGAL_SHARD_COUNT, | |||||
| ILLEGAL_SCHEMA_COUNT, | |||||
| VERSION_ERROR, | |||||
| ADD_SCHEMA_FAILED, | |||||
| ILLEGAL_Header_SIZE, | |||||
| ILLEGAL_Page_SIZE, | |||||
| ILLEGAL_SIZE_VALUE, | |||||
| INDEX_FIELD_ERROR, | |||||
| GET_CANDIDATE_CATEGORYFIELDS_FAILED, | |||||
| GET_CATEGORY_INFO_FAILED, | |||||
| ILLEGAL_CATEGORY_ID, | |||||
| ILLEGAL_ROWNUMBER_OF_PAGE, | |||||
| ILLEGAL_SCHEMA_ID, | |||||
| DESERIALIZE_SCHEMA_FAILED, | |||||
| DESERIALIZE_STATISTICS_FAILED, | |||||
| ILLEGAL_DB_FILE, | |||||
| OVERWRITE_DB_FILE, | |||||
| OVERWRITE_MINDRECORD_FILE, | |||||
| ILLEGAL_MINDRECORD_FILE, | |||||
| PARSE_JSON_FAILED, | |||||
| ILLEGAL_PARAMETERS, | |||||
| GET_PAGE_BY_GROUP_ID_FAILED, | |||||
| GET_SYSTEM_STATE_FAILED, | |||||
| IO_FAILED, | |||||
| MATCH_HEADER_FAILED | |||||
| }; | }; | ||||
| // convert error no to string message | |||||
| std::string __attribute__((visibility("default"))) ErrnoToMessage(MSRStatus status); | |||||
| } // namespace mindrecord | } // namespace mindrecord | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -37,9 +37,9 @@ class __attribute__((visibility("default"))) ShardHeader { | |||||
| ~ShardHeader() = default; | ~ShardHeader() = default; | ||||
| MSRStatus BuildDataset(const std::vector<std::string> &file_paths, bool load_dataset = true); | |||||
| Status BuildDataset(const std::vector<std::string> &file_paths, bool load_dataset = true); | |||||
| static std::pair<MSRStatus, json> BuildSingleHeader(const std::string &file_path); | |||||
| static Status BuildSingleHeader(const std::string &file_path, std::shared_ptr<json> *header_ptr); | |||||
| /// \brief add the schema and save it | /// \brief add the schema and save it | ||||
| /// \param[in] schema the schema needs to be added | /// \param[in] schema the schema needs to be added | ||||
| /// \return the last schema's id | /// \return the last schema's id | ||||
| @@ -53,9 +53,9 @@ class __attribute__((visibility("default"))) ShardHeader { | |||||
| /// \brief create index and add fields which from schema for each schema | /// \brief create index and add fields which from schema for each schema | ||||
| /// \param[in] fields the index fields needs to be added | /// \param[in] fields the index fields needs to be added | ||||
| /// \return SUCCESS if add successfully, FAILED if not | /// \return SUCCESS if add successfully, FAILED if not | ||||
| MSRStatus AddIndexFields(std::vector<std::pair<uint64_t, std::string>> fields); | |||||
| Status AddIndexFields(std::vector<std::pair<uint64_t, std::string>> fields); | |||||
| MSRStatus AddIndexFields(const std::vector<std::string> &fields); | |||||
| Status AddIndexFields(const std::vector<std::string> &fields); | |||||
| /// \brief get the schema | /// \brief get the schema | ||||
| /// \return the schema | /// \return the schema | ||||
| @@ -79,9 +79,10 @@ class __attribute__((visibility("default"))) ShardHeader { | |||||
| std::shared_ptr<Index> GetIndex(); | std::shared_ptr<Index> GetIndex(); | ||||
| /// \brief get the schema by schemaid | /// \brief get the schema by schemaid | ||||
| /// \param[in] schemaId the id of schema needs to be got | |||||
| /// \return the schema obtained by schemaId | |||||
| std::pair<std::shared_ptr<Schema>, MSRStatus> GetSchemaByID(int64_t schema_id); | |||||
| /// \param[in] schema_id the id of schema needs to be got | |||||
| /// \param[in] schema_ptr the schema obtained by schemaId | |||||
| /// \return Status | |||||
| Status GetSchemaByID(int64_t schema_id, std::shared_ptr<Schema> *schema_ptr); | |||||
| /// \brief get the filepath to shard by shardID | /// \brief get the filepath to shard by shardID | ||||
| /// \param[in] shardID the id of shard which filepath needs to be obtained | /// \param[in] shardID the id of shard which filepath needs to be obtained | ||||
| @@ -89,25 +90,26 @@ class __attribute__((visibility("default"))) ShardHeader { | |||||
| std::string GetShardAddressByID(int64_t shard_id); | std::string GetShardAddressByID(int64_t shard_id); | ||||
| /// \brief get the statistic by statistic id | /// \brief get the statistic by statistic id | ||||
| /// \param[in] statisticId the id of statistic needs to be get | |||||
| /// \return the statistics obtained by statistic id | |||||
| std::pair<std::shared_ptr<Statistics>, MSRStatus> GetStatisticByID(int64_t statistic_id); | |||||
| /// \param[in] statistic_id the id of statistic needs to be get | |||||
| /// \param[in] statistics_ptr the statistics obtained by statistic id | |||||
| /// \return Status | |||||
| Status GetStatisticByID(int64_t statistic_id, std::shared_ptr<Statistics> *statistics_ptr); | |||||
| MSRStatus InitByFiles(const std::vector<std::string> &file_paths); | |||||
| Status InitByFiles(const std::vector<std::string> &file_paths); | |||||
| void SetIndex(Index index) { index_ = std::make_shared<Index>(index); } | void SetIndex(Index index) { index_ = std::make_shared<Index>(index); } | ||||
| std::pair<std::shared_ptr<Page>, MSRStatus> GetPage(const int &shard_id, const int &page_id); | |||||
| Status GetPage(const int &shard_id, const int &page_id, std::shared_ptr<Page> *page_ptr); | |||||
| MSRStatus SetPage(const std::shared_ptr<Page> &new_page); | |||||
| Status SetPage(const std::shared_ptr<Page> &new_page); | |||||
| MSRStatus AddPage(const std::shared_ptr<Page> &new_page); | |||||
| Status AddPage(const std::shared_ptr<Page> &new_page); | |||||
| int64_t GetLastPageId(const int &shard_id); | int64_t GetLastPageId(const int &shard_id); | ||||
| int GetLastPageIdByType(const int &shard_id, const std::string &page_type); | int GetLastPageIdByType(const int &shard_id, const std::string &page_type); | ||||
| const std::pair<MSRStatus, std::shared_ptr<Page>> GetPageByGroupId(const int &group_id, const int &shard_id); | |||||
| Status GetPageByGroupId(const int &group_id, const int &shard_id, std::shared_ptr<Page> *page_ptr); | |||||
| std::vector<std::string> GetShardAddresses() const { return shard_addresses_; } | std::vector<std::string> GetShardAddresses() const { return shard_addresses_; } | ||||
| @@ -129,43 +131,41 @@ class __attribute__((visibility("default"))) ShardHeader { | |||||
| std::vector<std::string> SerializeHeader(); | std::vector<std::string> SerializeHeader(); | ||||
| MSRStatus PagesToFile(const std::string dump_file_name); | |||||
| Status PagesToFile(const std::string dump_file_name); | |||||
| MSRStatus FileToPages(const std::string dump_file_name); | |||||
| Status FileToPages(const std::string dump_file_name); | |||||
| static MSRStatus Initialize(const std::shared_ptr<ShardHeader> *header_ptr, const json &schema, | |||||
| const std::vector<std::string> &index_fields, std::vector<std::string> &blob_fields, | |||||
| uint64_t &schema_id); | |||||
| static Status Initialize(const std::shared_ptr<ShardHeader> *header_ptr, const json &schema, | |||||
| const std::vector<std::string> &index_fields, std::vector<std::string> &blob_fields, | |||||
| uint64_t &schema_id); | |||||
| private: | private: | ||||
| MSRStatus InitializeHeader(const std::vector<json> &headers, bool load_dataset); | |||||
| Status InitializeHeader(const std::vector<json> &headers, bool load_dataset); | |||||
| /// \brief get the headers from all the shard data | /// \brief get the headers from all the shard data | ||||
| /// \param[in] the shard data real path | /// \param[in] the shard data real path | ||||
| /// \param[in] the headers which read from the shard data | /// \param[in] the headers which read from the shard data | ||||
| /// \return SUCCESS/FAILED | /// \return SUCCESS/FAILED | ||||
| MSRStatus GetHeaders(const vector<string> &real_addresses, std::vector<json> &headers); | |||||
| Status GetHeaders(const vector<string> &real_addresses, std::vector<json> &headers); | |||||
| MSRStatus ValidateField(const std::vector<std::string> &field_name, json schema, const uint64_t &schema_id); | |||||
| Status ValidateField(const std::vector<std::string> &field_name, json schema, const uint64_t &schema_id); | |||||
| /// \brief check the binary file status | /// \brief check the binary file status | ||||
| static MSRStatus CheckFileStatus(const std::string &path); | |||||
| static Status CheckFileStatus(const std::string &path); | |||||
| static std::pair<MSRStatus, json> ValidateHeader(const std::string &path); | |||||
| void ParseHeader(const json &header); | |||||
| static Status ValidateHeader(const std::string &path, std::shared_ptr<json> *header_ptr); | |||||
| void GetHeadersOneTask(int start, int end, std::vector<json> &headers, const vector<string> &realAddresses); | void GetHeadersOneTask(int start, int end, std::vector<json> &headers, const vector<string> &realAddresses); | ||||
| MSRStatus ParseIndexFields(const json &index_fields); | |||||
| Status ParseIndexFields(const json &index_fields); | |||||
| MSRStatus CheckIndexField(const std::string &field, const json &schema); | |||||
| Status CheckIndexField(const std::string &field, const json &schema); | |||||
| MSRStatus ParsePage(const json &page, int shard_index, bool load_dataset); | |||||
| Status ParsePage(const json &page, int shard_index, bool load_dataset); | |||||
| MSRStatus ParseStatistics(const json &statistics); | |||||
| Status ParseStatistics(const json &statistics); | |||||
| MSRStatus ParseSchema(const json &schema); | |||||
| Status ParseSchema(const json &schema); | |||||
| void ParseShardAddress(const json &address); | void ParseShardAddress(const json &address); | ||||
| @@ -181,7 +181,7 @@ class __attribute__((visibility("default"))) ShardHeader { | |||||
| std::shared_ptr<Index> InitIndexPtr(); | std::shared_ptr<Index> InitIndexPtr(); | ||||
| MSRStatus GetAllSchemaID(std::set<uint64_t> &bucket_count); | |||||
| Status GetAllSchemaID(std::set<uint64_t> &bucket_count); | |||||
| uint32_t shard_count_; | uint32_t shard_count_; | ||||
| uint64_t header_size_; | uint64_t header_size_; | ||||
| @@ -30,23 +30,24 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace mindrecord { | namespace mindrecord { | ||||
| using INDEX_FIELDS = std::pair<MSRStatus, std::vector<std::tuple<std::string, std::string, std::string>>>; | |||||
| using ROW_DATA = std::pair<MSRStatus, std::vector<std::vector<std::tuple<std::string, std::string, std::string>>>>; | |||||
| using INDEX_FIELDS = std::vector<std::tuple<std::string, std::string, std::string>>; | |||||
| using ROW_DATA = std::vector<std::vector<std::tuple<std::string, std::string, std::string>>>; | |||||
| class __attribute__((visibility("default"))) ShardIndexGenerator { | class __attribute__((visibility("default"))) ShardIndexGenerator { | ||||
| public: | public: | ||||
| explicit ShardIndexGenerator(const std::string &file_path, bool append = false); | explicit ShardIndexGenerator(const std::string &file_path, bool append = false); | ||||
| MSRStatus Build(); | |||||
| Status Build(); | |||||
| static std::pair<MSRStatus, std::string> GenerateFieldName(const std::pair<uint64_t, std::string> &field); | |||||
| static Status GenerateFieldName(const std::pair<uint64_t, std::string> &field, std::shared_ptr<std::string> *fn_ptr); | |||||
| ~ShardIndexGenerator() {} | ~ShardIndexGenerator() {} | ||||
| /// \brief fetch value in json by field name | /// \brief fetch value in json by field name | ||||
| /// \param[in] field | /// \param[in] field | ||||
| /// \param[in] input | /// \param[in] input | ||||
| /// \return pair<MSRStatus, value> | |||||
| std::pair<MSRStatus, std::string> GetValueByField(const string &field, json input); | |||||
| /// \param[in] value | |||||
| /// \return Status | |||||
| Status GetValueByField(const string &field, json input, std::shared_ptr<std::string> *value); | |||||
| /// \brief fetch field type in schema n by field path | /// \brief fetch field type in schema n by field path | ||||
| /// \param[in] field_path | /// \param[in] field_path | ||||
| @@ -55,55 +56,54 @@ class __attribute__((visibility("default"))) ShardIndexGenerator { | |||||
| static std::string TakeFieldType(const std::string &field_path, json schema); | static std::string TakeFieldType(const std::string &field_path, json schema); | ||||
| /// \brief create databases for indexes | /// \brief create databases for indexes | ||||
| MSRStatus WriteToDatabase(); | |||||
| Status WriteToDatabase(); | |||||
| static MSRStatus Finalize(const std::vector<std::string> file_names); | |||||
| static Status Finalize(const std::vector<std::string> file_names); | |||||
| private: | private: | ||||
| static int Callback(void *not_used, int argc, char **argv, char **az_col_name); | static int Callback(void *not_used, int argc, char **argv, char **az_col_name); | ||||
| static MSRStatus ExecuteSQL(const std::string &statement, sqlite3 *db, const string &success_msg = ""); | |||||
| static Status ExecuteSQL(const std::string &statement, sqlite3 *db, const string &success_msg = ""); | |||||
| static std::string ConvertJsonToSQL(const std::string &json); | static std::string ConvertJsonToSQL(const std::string &json); | ||||
| std::pair<MSRStatus, sqlite3 *> CreateDatabase(int shard_no); | |||||
| Status CreateDatabase(int shard_no, sqlite3 **db); | |||||
| std::pair<MSRStatus, std::vector<json>> GetSchemaDetails(const std::vector<uint64_t> &schema_lens, std::fstream &in); | |||||
| Status GetSchemaDetails(const std::vector<uint64_t> &schema_lens, std::fstream &in, | |||||
| std::shared_ptr<std::vector<json>> *detail_ptr); | |||||
| static std::pair<MSRStatus, std::string> GenerateRawSQL(const std::vector<std::pair<uint64_t, std::string>> &fields); | |||||
| static Status GenerateRawSQL(const std::vector<std::pair<uint64_t, std::string>> &fields, | |||||
| std::shared_ptr<std::string> *sql_ptr); | |||||
| std::pair<MSRStatus, sqlite3 *> CheckDatabase(const std::string &shard_address); | |||||
| Status CheckDatabase(const std::string &shard_address, sqlite3 **db); | |||||
| /// | /// | ||||
| /// \param shard_no | /// \param shard_no | ||||
| /// \param blob_id_to_page_id | /// \param blob_id_to_page_id | ||||
| /// \param raw_page_id | /// \param raw_page_id | ||||
| /// \param in | /// \param in | ||||
| /// \return field name, db type, field value | |||||
| ROW_DATA GenerateRowData(int shard_no, const std::map<int, int> &blob_id_to_page_id, int raw_page_id, | |||||
| std::fstream &in); | |||||
| /// \return Status | |||||
| Status GenerateRowData(int shard_no, const std::map<int, int> &blob_id_to_page_id, int raw_page_id, std::fstream &in, | |||||
| std::shared_ptr<ROW_DATA> *row_data_ptr); | |||||
| /// | /// | ||||
| /// \param db | /// \param db | ||||
| /// \param sql | /// \param sql | ||||
| /// \param data | /// \param data | ||||
| /// \return | /// \return | ||||
| MSRStatus BindParameterExecuteSQL( | |||||
| sqlite3 *db, const std::string &sql, | |||||
| const std::vector<std::vector<std::tuple<std::string, std::string, std::string>>> &data); | |||||
| Status BindParameterExecuteSQL(sqlite3 *db, const std::string &sql, const ROW_DATA &data); | |||||
| INDEX_FIELDS GenerateIndexFields(const std::vector<json> &schema_detail); | |||||
| Status GenerateIndexFields(const std::vector<json> &schema_detail, std::shared_ptr<INDEX_FIELDS> *index_fields_ptr); | |||||
| MSRStatus ExecuteTransaction(const int &shard_no, std::pair<MSRStatus, sqlite3 *> &db, | |||||
| const std::vector<int> &raw_page_ids, const std::map<int, int> &blob_id_to_page_id); | |||||
| Status ExecuteTransaction(const int &shard_no, sqlite3 *db, const std::vector<int> &raw_page_ids, | |||||
| const std::map<int, int> &blob_id_to_page_id); | |||||
| MSRStatus CreateShardNameTable(sqlite3 *db, const std::string &shard_name); | |||||
| Status CreateShardNameTable(sqlite3 *db, const std::string &shard_name); | |||||
| MSRStatus AddBlobPageInfo(std::vector<std::tuple<std::string, std::string, std::string>> &row_data, | |||||
| const std::shared_ptr<Page> cur_blob_page, uint64_t &cur_blob_page_offset, | |||||
| std::fstream &in); | |||||
| Status AddBlobPageInfo(std::vector<std::tuple<std::string, std::string, std::string>> &row_data, | |||||
| const std::shared_ptr<Page> cur_blob_page, uint64_t &cur_blob_page_offset, std::fstream &in); | |||||
| void AddIndexFieldByRawData(const std::vector<json> &schema_detail, | |||||
| std::vector<std::tuple<std::string, std::string, std::string>> &row_data); | |||||
| Status AddIndexFieldByRawData(const std::vector<json> &schema_detail, | |||||
| std::vector<std::tuple<std::string, std::string, std::string>> &row_data); | |||||
| void DatabaseWriter(); // worker thread | void DatabaseWriter(); // worker thread | ||||
| @@ -13,7 +13,6 @@ | |||||
| * See the License for the specific language governing permissions and | * See the License for the specific language governing permissions and | ||||
| * limitations under the License. | * limitations under the License. | ||||
| */ | */ | ||||
| #ifndef MINDSPORE_CCSRC_MINDDATA_MINDRECORD_INCLUDE_SHARD_OPERATOR_H_ | #ifndef MINDSPORE_CCSRC_MINDDATA_MINDRECORD_INCLUDE_SHARD_OPERATOR_H_ | ||||
| #define MINDSPORE_CCSRC_MINDDATA_MINDRECORD_INCLUDE_SHARD_OPERATOR_H_ | #define MINDSPORE_CCSRC_MINDDATA_MINDRECORD_INCLUDE_SHARD_OPERATOR_H_ | ||||
| @@ -28,33 +27,29 @@ class __attribute__((visibility("default"))) ShardOperator { | |||||
| public: | public: | ||||
| virtual ~ShardOperator() = default; | virtual ~ShardOperator() = default; | ||||
| MSRStatus operator()(ShardTaskList &tasks) { | |||||
| if (SUCCESS != this->PreExecute(tasks)) { | |||||
| return FAILED; | |||||
| } | |||||
| if (SUCCESS != this->Execute(tasks)) { | |||||
| return FAILED; | |||||
| } | |||||
| if (SUCCESS != this->SufExecute(tasks)) { | |||||
| return FAILED; | |||||
| } | |||||
| return SUCCESS; | |||||
| Status operator()(ShardTaskList &tasks) { | |||||
| RETURN_IF_NOT_OK(this->PreExecute(tasks)); | |||||
| RETURN_IF_NOT_OK(this->Execute(tasks)); | |||||
| RETURN_IF_NOT_OK(this->SufExecute(tasks)); | |||||
| return Status::OK(); | |||||
| } | } | ||||
| virtual bool HasChildOp() { return child_op_ != nullptr; } | virtual bool HasChildOp() { return child_op_ != nullptr; } | ||||
| virtual MSRStatus SetChildOp(std::shared_ptr<ShardOperator> child_op) { | |||||
| if (child_op != nullptr) child_op_ = child_op; | |||||
| return SUCCESS; | |||||
| virtual Status SetChildOp(std::shared_ptr<ShardOperator> child_op) { | |||||
| if (child_op != nullptr) { | |||||
| child_op_ = child_op; | |||||
| } | |||||
| return Status::OK(); | |||||
| } | } | ||||
| virtual std::shared_ptr<ShardOperator> GetChildOp() { return child_op_; } | virtual std::shared_ptr<ShardOperator> GetChildOp() { return child_op_; } | ||||
| virtual MSRStatus PreExecute(ShardTaskList &tasks) { return SUCCESS; } | |||||
| virtual Status PreExecute(ShardTaskList &tasks) { return Status::OK(); } | |||||
| virtual MSRStatus Execute(ShardTaskList &tasks) = 0; | |||||
| virtual Status Execute(ShardTaskList &tasks) = 0; | |||||
| virtual MSRStatus SufExecute(ShardTaskList &tasks) { return SUCCESS; } | |||||
| virtual Status SufExecute(ShardTaskList &tasks) { return Status::OK(); } | |||||
| virtual int64_t GetNumSamples(int64_t dataset_size, int64_t num_classes) { return 0; } | virtual int64_t GetNumSamples(int64_t dataset_size, int64_t num_classes) { return 0; } | ||||
| @@ -72,9 +67,9 @@ class __attribute__((visibility("default"))) ShardOperator { | |||||
| std::shared_ptr<ShardOperator> child_op_ = nullptr; | std::shared_ptr<ShardOperator> child_op_ = nullptr; | ||||
| // indicate shard_id : inc_count | // indicate shard_id : inc_count | ||||
| // 0 : 15 - shard0 has 15 samples | |||||
| // 1 : 41 - shard1 has 26 samples | |||||
| // 2 : 58 - shard2 has 17 samples | |||||
| // // 0 : 15 - shard0 has 15 samples | |||||
| // // 1 : 41 - shard1 has 26 samples | |||||
| // // 2 : 58 - shard2 has 17 samples | |||||
| std::vector<uint32_t> shard_sample_count_; | std::vector<uint32_t> shard_sample_count_; | ||||
| dataset::ShuffleMode shuffle_mode_ = dataset::ShuffleMode::kGlobal; | dataset::ShuffleMode shuffle_mode_ = dataset::ShuffleMode::kGlobal; | ||||
| @@ -38,7 +38,7 @@ class __attribute__((visibility("default"))) ShardPkSample : public ShardCategor | |||||
| ~ShardPkSample() override{}; | ~ShardPkSample() override{}; | ||||
| MSRStatus SufExecute(ShardTaskList &tasks) override; | |||||
| Status SufExecute(ShardTaskList &tasks) override; | |||||
| int64_t GetNumSamples() const { return num_samples_; } | int64_t GetNumSamples() const { return num_samples_; } | ||||
| @@ -59,12 +59,9 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace mindrecord { | namespace mindrecord { | ||||
| using ROW_GROUPS = | |||||
| std::tuple<MSRStatus, std::vector<std::vector<std::vector<uint64_t>>>, std::vector<std::vector<json>>>; | |||||
| using ROW_GROUP_BRIEF = | |||||
| std::tuple<MSRStatus, std::string, int, uint64_t, std::vector<std::vector<uint64_t>>, std::vector<json>>; | |||||
| using TASK_RETURN_CONTENT = | |||||
| std::pair<MSRStatus, std::pair<TaskType, std::vector<std::tuple<std::vector<uint8_t>, json>>>>; | |||||
| using ROW_GROUPS = std::pair<std::vector<std::vector<std::vector<uint64_t>>>, std::vector<std::vector<json>>>; | |||||
| using ROW_GROUP_BRIEF = std::tuple<std::string, int, uint64_t, std::vector<std::vector<uint64_t>>, std::vector<json>>; | |||||
| using TASK_CONTENT = std::pair<TaskType, std::vector<std::tuple<std::vector<uint8_t>, json>>>; | |||||
| const int kNumBatchInMap = 1000; // iterator buffer size in row-reader mode | const int kNumBatchInMap = 1000; // iterator buffer size in row-reader mode | ||||
| class API_PUBLIC ShardReader { | class API_PUBLIC ShardReader { | ||||
| @@ -82,21 +79,10 @@ class API_PUBLIC ShardReader { | |||||
| /// \param[in] num_padded the number of padded samples | /// \param[in] num_padded the number of padded samples | ||||
| /// \param[in] lazy_load if the mindrecord dataset is too large, enable lazy load mode to speed up initialization | /// \param[in] lazy_load if the mindrecord dataset is too large, enable lazy load mode to speed up initialization | ||||
| /// \return MSRStatus the status of MSRStatus | /// \return MSRStatus the status of MSRStatus | ||||
| MSRStatus Open(const std::vector<std::string> &file_paths, bool load_dataset, int n_consumer = 4, | |||||
| const std::vector<std::string> &selected_columns = {}, | |||||
| const std::vector<std::shared_ptr<ShardOperator>> &operators = {}, const int num_padded = 0, | |||||
| bool lazy_load = false); | |||||
| /// \brief open files and initialize reader, python API | |||||
| /// \param[in] file_paths the path of ONE file, any file in dataset is fine or file list | |||||
| /// \param[in] load_dataset load dataset from single file or not | |||||
| /// \param[in] n_consumer number of threads when reading | |||||
| /// \param[in] selected_columns column list to be populated | |||||
| /// \param[in] operators operators applied to data, operator type is shuffle, sample or category | |||||
| /// \return MSRStatus the status of MSRStatus | |||||
| MSRStatus OpenPy(const std::vector<std::string> &file_paths, bool load_dataset, const int &n_consumer = 4, | |||||
| const std::vector<std::string> &selected_columns = {}, | |||||
| const std::vector<std::shared_ptr<ShardOperator>> &operators = {}); | |||||
| Status Open(const std::vector<std::string> &file_paths, bool load_dataset, int n_consumer = 4, | |||||
| const std::vector<std::string> &selected_columns = {}, | |||||
| const std::vector<std::shared_ptr<ShardOperator>> &operators = {}, const int num_padded = 0, | |||||
| bool lazy_load = false); | |||||
| /// \brief close reader | /// \brief close reader | ||||
| /// \return null | /// \return null | ||||
| @@ -104,16 +90,16 @@ class API_PUBLIC ShardReader { | |||||
| /// \brief read the file, get schema meta,statistics and index, single-thread mode | /// \brief read the file, get schema meta,statistics and index, single-thread mode | ||||
| /// \return MSRStatus the status of MSRStatus | /// \return MSRStatus the status of MSRStatus | ||||
| MSRStatus Open(); | |||||
| Status Open(); | |||||
| /// \brief read the file, get schema meta,statistics and index, multiple-thread mode | /// \brief read the file, get schema meta,statistics and index, multiple-thread mode | ||||
| /// \return MSRStatus the status of MSRStatus | /// \return MSRStatus the status of MSRStatus | ||||
| MSRStatus Open(int n_consumer); | |||||
| Status Open(int n_consumer); | |||||
| /// \brief launch threads to get batches | /// \brief launch threads to get batches | ||||
| /// \param[in] is_simple_reader trigger threads if false; do nothing if true | /// \param[in] is_simple_reader trigger threads if false; do nothing if true | ||||
| /// \return MSRStatus the status of MSRStatus | /// \return MSRStatus the status of MSRStatus | ||||
| MSRStatus Launch(bool is_simple_reader = false); | |||||
| Status Launch(bool is_simple_reader = false); | |||||
| /// \brief aim to get the meta data | /// \brief aim to get the meta data | ||||
| /// \return the metadata | /// \return the metadata | ||||
| @@ -133,8 +119,8 @@ class API_PUBLIC ShardReader { | |||||
| /// \param[in] op smart pointer refer to ShardCategory or ShardSample object | /// \param[in] op smart pointer refer to ShardCategory or ShardSample object | ||||
| /// \param[out] count # of rows | /// \param[out] count # of rows | ||||
| /// \return MSRStatus the status of MSRStatus | /// \return MSRStatus the status of MSRStatus | ||||
| MSRStatus CountTotalRows(const std::vector<std::string> &file_paths, bool load_dataset, | |||||
| const std::shared_ptr<ShardOperator> &op, int64_t *count, const int num_padded); | |||||
| Status CountTotalRows(const std::vector<std::string> &file_paths, bool load_dataset, | |||||
| const std::shared_ptr<ShardOperator> &op, int64_t *count, const int num_padded); | |||||
| /// \brief shuffle task with incremental seed | /// \brief shuffle task with incremental seed | ||||
| /// \return void | /// \return void | ||||
| @@ -162,8 +148,8 @@ class API_PUBLIC ShardReader { | |||||
| /// 3. Offset address of row group in file | /// 3. Offset address of row group in file | ||||
| /// 4. The list of image offset in page [startOffset, endOffset) | /// 4. The list of image offset in page [startOffset, endOffset) | ||||
| /// 5. The list of columns data | /// 5. The list of columns data | ||||
| ROW_GROUP_BRIEF ReadRowGroupBrief(int group_id, int shard_id, | |||||
| const std::vector<std::string> &columns = std::vector<std::string>()); | |||||
| Status ReadRowGroupBrief(int group_id, int shard_id, const std::vector<std::string> &columns, | |||||
| std::shared_ptr<ROW_GROUP_BRIEF> *row_group_brief_ptr); | |||||
| /// \brief Read 1 row group data, excluding images, following an index field criteria | /// \brief Read 1 row group data, excluding images, following an index field criteria | ||||
| /// \param[in] groupID row group ID | /// \param[in] groupID row group ID | ||||
| @@ -176,8 +162,9 @@ class API_PUBLIC ShardReader { | |||||
| /// 3. Offset address of row group in file | /// 3. Offset address of row group in file | ||||
| /// 4. The list of image offset in page [startOffset, endOffset) | /// 4. The list of image offset in page [startOffset, endOffset) | ||||
| /// 5. The list of columns data | /// 5. The list of columns data | ||||
| ROW_GROUP_BRIEF ReadRowGroupCriteria(int group_id, int shard_id, const std::pair<std::string, std::string> &criteria, | |||||
| const std::vector<std::string> &columns = std::vector<std::string>()); | |||||
| Status ReadRowGroupCriteria(int group_id, int shard_id, const std::pair<std::string, std::string> &criteria, | |||||
| const std::vector<std::string> &columns, | |||||
| std::shared_ptr<ROW_GROUP_BRIEF> *row_group_brief_ptr); | |||||
| /// \brief return a batch, given that one is ready | /// \brief return a batch, given that one is ready | ||||
| /// \return a batch of images and image data | /// \return a batch of images and image data | ||||
| @@ -185,13 +172,7 @@ class API_PUBLIC ShardReader { | |||||
| /// \brief return a row by id | /// \brief return a row by id | ||||
| /// \return a batch of images and image data | /// \return a batch of images and image data | ||||
| std::pair<TaskType, std::vector<std::tuple<std::vector<uint8_t>, json>>> GetNextById(const int64_t &task_id, | |||||
| const int32_t &consumer_id); | |||||
| /// \brief return a batch, given that one is ready, python API | |||||
| /// \return a batch of images and image data | |||||
| std::vector<std::tuple<std::vector<std::vector<uint8_t>>, pybind11::object>> GetNextPy(); | |||||
| TASK_CONTENT GetNextById(const int64_t &task_id, const int32_t &consumer_id); | |||||
| /// \brief get blob filed list | /// \brief get blob filed list | ||||
| /// \return blob field list | /// \return blob field list | ||||
| std::pair<ShardType, std::vector<std::string>> GetBlobFields(); | std::pair<ShardType, std::vector<std::string>> GetBlobFields(); | ||||
| @@ -205,83 +186,86 @@ class API_PUBLIC ShardReader { | |||||
| void SetAllInIndex(bool all_in_index) { all_in_index_ = all_in_index; } | void SetAllInIndex(bool all_in_index) { all_in_index_ = all_in_index; } | ||||
| /// \brief get all classes | /// \brief get all classes | ||||
| MSRStatus GetAllClasses(const std::string &category_field, std::shared_ptr<std::set<std::string>> category_ptr); | |||||
| /// \brief get the size of blob data | |||||
| MSRStatus GetTotalBlobSize(int64_t *total_blob_size); | |||||
| Status GetAllClasses(const std::string &category_field, std::shared_ptr<std::set<std::string>> category_ptr); | |||||
| /// \brief get a read-only ptr to the sampled ids for this epoch | /// \brief get a read-only ptr to the sampled ids for this epoch | ||||
| const std::vector<int> *GetSampleIds(); | const std::vector<int> *GetSampleIds(); | ||||
| /// \brief get the size of blob data | |||||
| Status GetTotalBlobSize(int64_t *total_blob_size); | |||||
| /// \brief extract uncompressed data based on column list | |||||
| Status UnCompressBlob(const std::vector<uint8_t> &raw_blob_data, | |||||
| std::shared_ptr<std::vector<std::vector<uint8_t>>> *blob_data_ptr); | |||||
| protected: | protected: | ||||
| /// \brief sqlite call back function | /// \brief sqlite call back function | ||||
| static int SelectCallback(void *p_data, int num_fields, char **p_fields, char **p_col_names); | static int SelectCallback(void *p_data, int num_fields, char **p_fields, char **p_col_names); | ||||
| private: | private: | ||||
| /// \brief wrap up labels to json format | /// \brief wrap up labels to json format | ||||
| MSRStatus ConvertLabelToJson(const std::vector<std::vector<std::string>> &labels, std::shared_ptr<std::fstream> fs, | |||||
| std::shared_ptr<std::vector<std::vector<std::vector<uint64_t>>>> offset_ptr, | |||||
| int shard_id, const std::vector<std::string> &columns, | |||||
| std::shared_ptr<std::vector<std::vector<json>>> col_val_ptr); | |||||
| Status ConvertLabelToJson(const std::vector<std::vector<std::string>> &labels, std::shared_ptr<std::fstream> fs, | |||||
| std::shared_ptr<std::vector<std::vector<std::vector<uint64_t>>>> offset_ptr, int shard_id, | |||||
| const std::vector<std::string> &columns, | |||||
| std::shared_ptr<std::vector<std::vector<json>>> col_val_ptr); | |||||
| /// \brief read all rows for specified columns | /// \brief read all rows for specified columns | ||||
| ROW_GROUPS ReadAllRowGroup(const std::vector<std::string> &columns); | |||||
| Status ReadAllRowGroup(const std::vector<std::string> &columns, std::shared_ptr<ROW_GROUPS> *row_group_ptr); | |||||
| /// \brief read row meta by shard_id and sample_id | /// \brief read row meta by shard_id and sample_id | ||||
| ROW_GROUPS ReadRowGroupByShardIDAndSampleID(const std::vector<std::string> &columns, const uint32_t &shard_id, | |||||
| const uint32_t &sample_id); | |||||
| Status ReadRowGroupByShardIDAndSampleID(const std::vector<std::string> &columns, const uint32_t &shard_id, | |||||
| const uint32_t &sample_id, std::shared_ptr<ROW_GROUPS> *row_group_ptr); | |||||
| /// \brief read all rows in one shard | /// \brief read all rows in one shard | ||||
| MSRStatus ReadAllRowsInShard(int shard_id, const std::string &sql, const std::vector<std::string> &columns, | |||||
| std::shared_ptr<std::vector<std::vector<std::vector<uint64_t>>>> offset_ptr, | |||||
| std::shared_ptr<std::vector<std::vector<json>>> col_val_ptr); | |||||
| Status ReadAllRowsInShard(int shard_id, const std::string &sql, const std::vector<std::string> &columns, | |||||
| std::shared_ptr<std::vector<std::vector<std::vector<uint64_t>>>> offset_ptr, | |||||
| std::shared_ptr<std::vector<std::vector<json>>> col_val_ptr); | |||||
| /// \brief initialize reader | /// \brief initialize reader | ||||
| MSRStatus Init(const std::vector<std::string> &file_paths, bool load_dataset); | |||||
| Status Init(const std::vector<std::string> &file_paths, bool load_dataset); | |||||
| /// \brief validate column list | /// \brief validate column list | ||||
| MSRStatus CheckColumnList(const std::vector<std::string> &selected_columns); | |||||
| Status CheckColumnList(const std::vector<std::string> &selected_columns); | |||||
| /// \brief populate one row by task list in row-reader mode | /// \brief populate one row by task list in row-reader mode | ||||
| MSRStatus ConsumerByRow(int consumer_id); | |||||
| void ConsumerByRow(int consumer_id); | |||||
| /// \brief get offset address of images within page | /// \brief get offset address of images within page | ||||
| std::vector<std::vector<uint64_t>> GetImageOffset(int group_id, int shard_id, | std::vector<std::vector<uint64_t>> GetImageOffset(int group_id, int shard_id, | ||||
| const std::pair<std::string, std::string> &criteria = {"", ""}); | const std::pair<std::string, std::string> &criteria = {"", ""}); | ||||
| /// \brief get page id by category | /// \brief get page id by category | ||||
| std::pair<MSRStatus, std::vector<uint64_t>> GetPagesByCategory(int shard_id, | |||||
| const std::pair<std::string, std::string> &criteria); | |||||
| Status GetPagesByCategory(int shard_id, const std::pair<std::string, std::string> &criteria, | |||||
| std::shared_ptr<std::vector<uint64_t>> *pages_ptr); | |||||
| /// \brief execute sqlite query with prepare statement | /// \brief execute sqlite query with prepare statement | ||||
| MSRStatus QueryWithCriteria(sqlite3 *db, const string &sql, const string &criteria, | |||||
| std::shared_ptr<std::vector<std::vector<std::string>>> labels_ptr); | |||||
| Status QueryWithCriteria(sqlite3 *db, const string &sql, const string &criteria, | |||||
| std::shared_ptr<std::vector<std::vector<std::string>>> labels_ptr); | |||||
| /// \brief verify the validity of dataset | /// \brief verify the validity of dataset | ||||
| MSRStatus VerifyDataset(sqlite3 **db, const string &file); | |||||
| Status VerifyDataset(sqlite3 **db, const string &file); | |||||
| /// \brief get column values | /// \brief get column values | ||||
| std::pair<MSRStatus, std::vector<json>> GetLabels(int group_id, int shard_id, const std::vector<std::string> &columns, | |||||
| const std::pair<std::string, std::string> &criteria = {"", ""}); | |||||
| Status GetLabels(int page_id, int shard_id, const std::vector<std::string> &columns, | |||||
| const std::pair<std::string, std::string> &criteria, std::shared_ptr<std::vector<json>> *labels_ptr); | |||||
| /// \brief get column values from raw data page | /// \brief get column values from raw data page | ||||
| std::pair<MSRStatus, std::vector<json>> GetLabelsFromPage(int group_id, int shard_id, | |||||
| const std::vector<std::string> &columns, | |||||
| const std::pair<std::string, std::string> &criteria = {"", | |||||
| ""}); | |||||
| Status GetLabelsFromPage(int page_id, int shard_id, const std::vector<std::string> &columns, | |||||
| const std::pair<std::string, std::string> &criteria, | |||||
| std::shared_ptr<std::vector<json>> *labels_ptr); | |||||
| /// \brief create category-applied task list | /// \brief create category-applied task list | ||||
| MSRStatus CreateTasksByCategory(const std::shared_ptr<ShardOperator> &op); | |||||
| Status CreateTasksByCategory(const std::shared_ptr<ShardOperator> &op); | |||||
| /// \brief create task list in row-reader mode | /// \brief create task list in row-reader mode | ||||
| MSRStatus CreateTasksByRow(const std::vector<std::tuple<int, int, int, uint64_t>> &row_group_summary, | |||||
| const std::vector<std::shared_ptr<ShardOperator>> &operators); | |||||
| Status CreateTasksByRow(const std::vector<std::tuple<int, int, int, uint64_t>> &row_group_summary, | |||||
| const std::vector<std::shared_ptr<ShardOperator>> &operators); | |||||
| /// \brief create task list in row-reader mode and lazy mode | /// \brief create task list in row-reader mode and lazy mode | ||||
| MSRStatus CreateLazyTasksByRow(const std::vector<std::tuple<int, int, int, uint64_t>> &row_group_summary, | |||||
| const std::vector<std::shared_ptr<ShardOperator>> &operators); | |||||
| Status CreateLazyTasksByRow(const std::vector<std::tuple<int, int, int, uint64_t>> &row_group_summary, | |||||
| const std::vector<std::shared_ptr<ShardOperator>> &operators); | |||||
| /// \brief crate task list | /// \brief crate task list | ||||
| MSRStatus CreateTasks(const std::vector<std::tuple<int, int, int, uint64_t>> &row_group_summary, | |||||
| const std::vector<std::shared_ptr<ShardOperator>> &operators); | |||||
| Status CreateTasks(const std::vector<std::tuple<int, int, int, uint64_t>> &row_group_summary, | |||||
| const std::vector<std::shared_ptr<ShardOperator>> &operators); | |||||
| /// \brief check if all specified columns are in index table | /// \brief check if all specified columns are in index table | ||||
| void CheckIfColumnInIndex(const std::vector<std::string> &columns); | void CheckIfColumnInIndex(const std::vector<std::string> &columns); | ||||
| @@ -290,11 +274,12 @@ class API_PUBLIC ShardReader { | |||||
| void FileStreamsOperator(); | void FileStreamsOperator(); | ||||
| /// \brief read one row by one task | /// \brief read one row by one task | ||||
| TASK_RETURN_CONTENT ConsumerOneTask(int task_id, uint32_t consumer_id); | |||||
| Status ConsumerOneTask(int task_id, uint32_t consumer_id, std::shared_ptr<TASK_CONTENT> *task_content_pt); | |||||
| /// \brief get labels from binary file | /// \brief get labels from binary file | ||||
| std::pair<MSRStatus, std::vector<json>> GetLabelsFromBinaryFile( | |||||
| int shard_id, const std::vector<std::string> &columns, const std::vector<std::vector<std::string>> &label_offsets); | |||||
| Status GetLabelsFromBinaryFile(int shard_id, const std::vector<std::string> &columns, | |||||
| const std::vector<std::vector<std::string>> &label_offsets, | |||||
| std::shared_ptr<std::vector<json>> *labels_ptr); | |||||
| /// \brief get classes in one shard | /// \brief get classes in one shard | ||||
| void GetClassesInShard(sqlite3 *db, int shard_id, const std::string &sql, | void GetClassesInShard(sqlite3 *db, int shard_id, const std::string &sql, | ||||
| @@ -304,11 +289,8 @@ class API_PUBLIC ShardReader { | |||||
| int64_t GetNumClasses(const std::string &category_field); | int64_t GetNumClasses(const std::string &category_field); | ||||
| /// \brief get meta of header | /// \brief get meta of header | ||||
| std::pair<MSRStatus, std::vector<std::string>> GetMeta(const std::string &file_path, | |||||
| std::shared_ptr<json> meta_data_ptr); | |||||
| /// \brief extract uncompressed data based on column list | |||||
| std::pair<MSRStatus, std::vector<std::vector<uint8_t>>> UnCompressBlob(const std::vector<uint8_t> &raw_blob_data); | |||||
| Status GetMeta(const std::string &file_path, std::shared_ptr<json> meta_data_ptr, | |||||
| std::shared_ptr<std::vector<std::string>> *addresses_ptr); | |||||
| protected: | protected: | ||||
| uint64_t header_size_; // header size | uint64_t header_size_; // header size | ||||
| @@ -40,11 +40,11 @@ class __attribute__((visibility("default"))) ShardSample : public ShardOperator | |||||
| ~ShardSample() override{}; | ~ShardSample() override{}; | ||||
| MSRStatus Execute(ShardTaskList &tasks) override; | |||||
| Status Execute(ShardTaskList &tasks) override; | |||||
| MSRStatus UpdateTasks(ShardTaskList &tasks, int taking); | |||||
| Status UpdateTasks(ShardTaskList &tasks, int taking); | |||||
| MSRStatus SufExecute(ShardTaskList &tasks) override; | |||||
| Status SufExecute(ShardTaskList &tasks) override; | |||||
| int64_t GetNumSamples(int64_t dataset_size, int64_t num_classes) override; | int64_t GetNumSamples(int64_t dataset_size, int64_t num_classes) override; | ||||
| @@ -39,11 +39,6 @@ class __attribute__((visibility("default"))) Schema { | |||||
| /// \param[in] schema the schema's json | /// \param[in] schema the schema's json | ||||
| static std::shared_ptr<Schema> Build(std::string desc, const json &schema); | static std::shared_ptr<Schema> Build(std::string desc, const json &schema); | ||||
| /// \brief obtain the json schema and its description for python | |||||
| /// \param[in] desc the description of the schema | |||||
| /// \param[in] schema the schema's json | |||||
| static std::shared_ptr<Schema> Build(std::string desc, pybind11::handle schema); | |||||
| /// \brief compare two schema to judge if they are equal | /// \brief compare two schema to judge if they are equal | ||||
| /// \param b another schema to be judged | /// \param b another schema to be judged | ||||
| /// \return true if they are equal,false if not | /// \return true if they are equal,false if not | ||||
| @@ -57,10 +52,6 @@ class __attribute__((visibility("default"))) Schema { | |||||
| /// \return the json format of the schema and its description | /// \return the json format of the schema and its description | ||||
| json GetSchema() const; | json GetSchema() const; | ||||
| /// \brief get the schema and its description for python method | |||||
| /// \return the python object of the schema and its description | |||||
| pybind11::object GetSchemaForPython() const; | |||||
| /// set the schema id | /// set the schema id | ||||
| /// \param[in] id the id need to be set | /// \param[in] id the id need to be set | ||||
| void SetSchemaID(int64_t id); | void SetSchemaID(int64_t id); | ||||
| @@ -17,6 +17,7 @@ | |||||
| #ifndef MINDSPORE_CCSRC_MINDDATA_MINDRECORD_INCLUDE_SHARD_SEGMENT_H_ | #ifndef MINDSPORE_CCSRC_MINDDATA_MINDRECORD_INCLUDE_SHARD_SEGMENT_H_ | ||||
| #define MINDSPORE_CCSRC_MINDDATA_MINDRECORD_INCLUDE_SHARD_SEGMENT_H_ | #define MINDSPORE_CCSRC_MINDDATA_MINDRECORD_INCLUDE_SHARD_SEGMENT_H_ | ||||
| #include <memory> | |||||
| #include <string> | #include <string> | ||||
| #include <tuple> | #include <tuple> | ||||
| #include <utility> | #include <utility> | ||||
| @@ -25,6 +26,10 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace mindrecord { | namespace mindrecord { | ||||
| using CATEGORY_INFO = std::vector<std::tuple<int, std::string, int>>; | |||||
| using PAGES = std::vector<std::tuple<std::vector<uint8_t>, json>>; | |||||
| using PAGES_LOAD = std::vector<std::tuple<std::vector<uint8_t>, pybind11::object>>; | |||||
| class __attribute__((visibility("default"))) ShardSegment : public ShardReader { | class __attribute__((visibility("default"))) ShardSegment : public ShardReader { | ||||
| public: | public: | ||||
| ShardSegment(); | ShardSegment(); | ||||
| @@ -33,12 +38,12 @@ class __attribute__((visibility("default"))) ShardSegment : public ShardReader { | |||||
| /// \brief Get candidate category fields | /// \brief Get candidate category fields | ||||
| /// \return a list of fields names which are the candidates of category | /// \return a list of fields names which are the candidates of category | ||||
| std::pair<MSRStatus, vector<std::string>> GetCategoryFields(); | |||||
| Status GetCategoryFields(std::shared_ptr<vector<std::string>> *fields_ptr); | |||||
| /// \brief Set category field | /// \brief Set category field | ||||
| /// \param[in] category_field category name | /// \param[in] category_field category name | ||||
| /// \return true if category name is existed | /// \return true if category name is existed | ||||
| MSRStatus SetCategoryField(std::string category_field); | |||||
| Status SetCategoryField(std::string category_field); | |||||
| /// \brief Thread-safe implementation of ReadCategoryInfo | /// \brief Thread-safe implementation of ReadCategoryInfo | ||||
| /// \return statistics data in json format with 2 field: "key" and "categories". | /// \return statistics data in json format with 2 field: "key" and "categories". | ||||
| @@ -50,47 +55,41 @@ class __attribute__((visibility("default"))) ShardSegment : public ShardReader { | |||||
| /// { "key": "label", | /// { "key": "label", | ||||
| /// "categories": [ { "count": 3, "id": 0, "name": "sport", }, | /// "categories": [ { "count": 3, "id": 0, "name": "sport", }, | ||||
| /// { "count": 3, "id": 1, "name": "finance", } ] } | /// { "count": 3, "id": 1, "name": "finance", } ] } | ||||
| std::pair<MSRStatus, std::string> ReadCategoryInfo(); | |||||
| Status ReadCategoryInfo(std::shared_ptr<std::string> *category_ptr); | |||||
| /// \brief Thread-safe implementation of ReadAtPageById | /// \brief Thread-safe implementation of ReadAtPageById | ||||
| /// \param[in] category_id category ID | /// \param[in] category_id category ID | ||||
| /// \param[in] page_no page number | /// \param[in] page_no page number | ||||
| /// \param[in] n_rows_of_page rows number in one page | /// \param[in] n_rows_of_page rows number in one page | ||||
| /// \return images array, image is a vector of uint8_t | /// \return images array, image is a vector of uint8_t | ||||
| std::pair<MSRStatus, std::vector<std::vector<uint8_t>>> ReadAtPageById(int64_t category_id, int64_t page_no, | |||||
| int64_t n_rows_of_page); | |||||
| Status ReadAtPageById(int64_t category_id, int64_t page_no, int64_t n_rows_of_page, | |||||
| std::shared_ptr<std::vector<std::vector<uint8_t>>> *page_ptr); | |||||
| /// \brief Thread-safe implementation of ReadAtPageByName | /// \brief Thread-safe implementation of ReadAtPageByName | ||||
| /// \param[in] category_name category Name | /// \param[in] category_name category Name | ||||
| /// \param[in] page_no page number | /// \param[in] page_no page number | ||||
| /// \param[in] n_rows_of_page rows number in one page | /// \param[in] n_rows_of_page rows number in one page | ||||
| /// \return images array, image is a vector of uint8_t | /// \return images array, image is a vector of uint8_t | ||||
| std::pair<MSRStatus, std::vector<std::vector<uint8_t>>> ReadAtPageByName(std::string category_name, int64_t page_no, | |||||
| int64_t n_rows_of_page); | |||||
| std::pair<MSRStatus, std::vector<std::tuple<std::vector<uint8_t>, json>>> ReadAllAtPageById(int64_t category_id, | |||||
| int64_t page_no, | |||||
| int64_t n_rows_of_page); | |||||
| std::pair<MSRStatus, std::vector<std::tuple<std::vector<uint8_t>, json>>> ReadAllAtPageByName( | |||||
| std::string category_name, int64_t page_no, int64_t n_rows_of_page); | |||||
| Status ReadAtPageByName(std::string category_name, int64_t page_no, int64_t n_rows_of_page, | |||||
| std::shared_ptr<std::vector<std::vector<uint8_t>>> *pages_ptr); | |||||
| std::pair<MSRStatus, std::vector<std::tuple<std::vector<uint8_t>, pybind11::object>>> ReadAtPageByIdPy( | |||||
| int64_t category_id, int64_t page_no, int64_t n_rows_of_page); | |||||
| Status ReadAllAtPageById(int64_t category_id, int64_t page_no, int64_t n_rows_of_page, | |||||
| std::shared_ptr<PAGES> *pages_ptr); | |||||
| std::pair<MSRStatus, std::vector<std::tuple<std::vector<uint8_t>, pybind11::object>>> ReadAtPageByNamePy( | |||||
| std::string category_name, int64_t page_no, int64_t n_rows_of_page); | |||||
| Status ReadAllAtPageByName(std::string category_name, int64_t page_no, int64_t n_rows_of_page, | |||||
| std::shared_ptr<PAGES> *pages_ptr); | |||||
| std::pair<ShardType, std::vector<std::string>> GetBlobFields(); | std::pair<ShardType, std::vector<std::string>> GetBlobFields(); | ||||
| private: | private: | ||||
| std::pair<MSRStatus, std::vector<std::tuple<int, std::string, int>>> WrapCategoryInfo(); | |||||
| Status WrapCategoryInfo(std::shared_ptr<CATEGORY_INFO> *category_info_ptr); | |||||
| std::string ToJsonForCategory(const std::vector<std::tuple<int, std::string, int>> &tri_vec); | std::string ToJsonForCategory(const std::vector<std::tuple<int, std::string, int>> &tri_vec); | ||||
| std::string CleanUp(std::string fieldName); | std::string CleanUp(std::string fieldName); | ||||
| std::pair<MSRStatus, std::vector<uint8_t>> PackImages(int group_id, int shard_id, std::vector<uint64_t> offset); | |||||
| Status PackImages(int group_id, int shard_id, std::vector<uint64_t> offset, | |||||
| std::shared_ptr<std::vector<uint8_t>> *images_ptr); | |||||
| std::vector<std::string> candidate_category_fields_; | std::vector<std::string> candidate_category_fields_; | ||||
| std::string current_category_field_; | std::string current_category_field_; | ||||
| @@ -33,7 +33,7 @@ class __attribute__((visibility("default"))) ShardSequentialSample : public Shar | |||||
| ~ShardSequentialSample() override{}; | ~ShardSequentialSample() override{}; | ||||
| MSRStatus Execute(ShardTaskList &tasks) override; | |||||
| Status Execute(ShardTaskList &tasks) override; | |||||
| int64_t GetNumSamples(int64_t dataset_size, int64_t num_classes) override; | int64_t GetNumSamples(int64_t dataset_size, int64_t num_classes) override; | ||||
| @@ -31,19 +31,19 @@ class __attribute__((visibility("default"))) ShardShuffle : public ShardOperator | |||||
| ~ShardShuffle() override{}; | ~ShardShuffle() override{}; | ||||
| MSRStatus Execute(ShardTaskList &tasks) override; | |||||
| Status Execute(ShardTaskList &tasks) override; | |||||
| int64_t GetNumSamples(int64_t dataset_size, int64_t num_classes) override; | int64_t GetNumSamples(int64_t dataset_size, int64_t num_classes) override; | ||||
| private: | private: | ||||
| // Private helper function | // Private helper function | ||||
| MSRStatus CategoryShuffle(ShardTaskList &tasks); | |||||
| Status CategoryShuffle(ShardTaskList &tasks); | |||||
| // Keep the file sequence the same but shuffle the data within each file | // Keep the file sequence the same but shuffle the data within each file | ||||
| MSRStatus ShuffleInfile(ShardTaskList &tasks); | |||||
| Status ShuffleInfile(ShardTaskList &tasks); | |||||
| // Shuffle the file sequence but keep the order of data within each file | // Shuffle the file sequence but keep the order of data within each file | ||||
| MSRStatus ShuffleFiles(ShardTaskList &tasks); | |||||
| Status ShuffleFiles(ShardTaskList &tasks); | |||||
| uint32_t shuffle_seed_; | uint32_t shuffle_seed_; | ||||
| int64_t no_of_samples_; | int64_t no_of_samples_; | ||||
| @@ -39,11 +39,6 @@ class __attribute__((visibility("default"))) Statistics { | |||||
| /// \param[in] statistics the statistic needs to be saved | /// \param[in] statistics the statistic needs to be saved | ||||
| static std::shared_ptr<Statistics> Build(std::string desc, const json &statistics); | static std::shared_ptr<Statistics> Build(std::string desc, const json &statistics); | ||||
| /// \brief save the statistic from python and its description | |||||
| /// \param[in] desc the statistic's description | |||||
| /// \param[in] statistics the statistic needs to be saved | |||||
| static std::shared_ptr<Statistics> Build(std::string desc, pybind11::handle statistics); | |||||
| ~Statistics() = default; | ~Statistics() = default; | ||||
| /// \brief compare two statistics to judge if they are equal | /// \brief compare two statistics to judge if they are equal | ||||
| @@ -59,10 +54,6 @@ class __attribute__((visibility("default"))) Statistics { | |||||
| /// \return json format of the statistic | /// \return json format of the statistic | ||||
| json GetStatistics() const; | json GetStatistics() const; | ||||
| /// \brief get the statistic for python | |||||
| /// \return the python object of statistics | |||||
| pybind11::object GetStatisticsForPython() const; | |||||
| /// \brief decode the bson statistics to json | /// \brief decode the bson statistics to json | ||||
| /// \param[in] encodedStatistics the bson type of statistics | /// \param[in] encodedStatistics the bson type of statistics | ||||
| /// \return json type of statistic | /// \return json type of statistic | ||||
| @@ -55,69 +55,60 @@ class __attribute__((visibility("default"))) ShardWriter { | |||||
| /// \brief Open file at the beginning | /// \brief Open file at the beginning | ||||
| /// \param[in] paths the file names list | /// \param[in] paths the file names list | ||||
| /// \param[in] append new data at the end of file if true, otherwise overwrite file | /// \param[in] append new data at the end of file if true, otherwise overwrite file | ||||
| /// \return MSRStatus the status of MSRStatus | |||||
| MSRStatus Open(const std::vector<std::string> &paths, bool append = false); | |||||
| /// \return Status | |||||
| Status Open(const std::vector<std::string> &paths, bool append = false); | |||||
| /// \brief Open file at the ending | /// \brief Open file at the ending | ||||
| /// \param[in] paths the file names list | /// \param[in] paths the file names list | ||||
| /// \return MSRStatus the status of MSRStatus | /// \return MSRStatus the status of MSRStatus | ||||
| MSRStatus OpenForAppend(const std::string &path); | |||||
| Status OpenForAppend(const std::string &path); | |||||
| /// \brief Write header to disk | /// \brief Write header to disk | ||||
| /// \return MSRStatus the status of MSRStatus | /// \return MSRStatus the status of MSRStatus | ||||
| MSRStatus Commit(); | |||||
| Status Commit(); | |||||
| /// \brief Set file size | /// \brief Set file size | ||||
| /// \param[in] header_size the size of header, only (1<<N) is accepted | /// \param[in] header_size the size of header, only (1<<N) is accepted | ||||
| /// \return MSRStatus the status of MSRStatus | /// \return MSRStatus the status of MSRStatus | ||||
| MSRStatus SetHeaderSize(const uint64_t &header_size); | |||||
| Status SetHeaderSize(const uint64_t &header_size); | |||||
| /// \brief Set page size | /// \brief Set page size | ||||
| /// \param[in] page_size the size of page, only (1<<N) is accepted | /// \param[in] page_size the size of page, only (1<<N) is accepted | ||||
| /// \return MSRStatus the status of MSRStatus | /// \return MSRStatus the status of MSRStatus | ||||
| MSRStatus SetPageSize(const uint64_t &page_size); | |||||
| Status SetPageSize(const uint64_t &page_size); | |||||
| /// \brief Set shard header | /// \brief Set shard header | ||||
| /// \param[in] header_data the info of header | /// \param[in] header_data the info of header | ||||
| /// WARNING, only called when file is empty | /// WARNING, only called when file is empty | ||||
| /// \return MSRStatus the status of MSRStatus | /// \return MSRStatus the status of MSRStatus | ||||
| MSRStatus SetShardHeader(std::shared_ptr<ShardHeader> header_data); | |||||
| Status SetShardHeader(std::shared_ptr<ShardHeader> header_data); | |||||
| /// \brief write raw data by group size | /// \brief write raw data by group size | ||||
| /// \param[in] raw_data the vector of raw json data, vector format | /// \param[in] raw_data the vector of raw json data, vector format | ||||
| /// \param[in] blob_data the vector of image data | /// \param[in] blob_data the vector of image data | ||||
| /// \param[in] sign validate data or not | /// \param[in] sign validate data or not | ||||
| /// \return MSRStatus the status of MSRStatus to judge if write successfully | /// \return MSRStatus the status of MSRStatus to judge if write successfully | ||||
| MSRStatus WriteRawData(std::map<uint64_t, std::vector<json>> &raw_data, vector<vector<uint8_t>> &blob_data, | |||||
| bool sign = true, bool parallel_writer = false); | |||||
| /// \brief write raw data by group size for call from python | |||||
| /// \param[in] raw_data the vector of raw json data, python-handle format | |||||
| /// \param[in] blob_data the vector of image data | |||||
| /// \param[in] sign validate data or not | |||||
| /// \return MSRStatus the status of MSRStatus to judge if write successfully | |||||
| MSRStatus WriteRawData(std::map<uint64_t, std::vector<py::handle>> &raw_data, vector<vector<uint8_t>> &blob_data, | |||||
| bool sign = true, bool parallel_writer = false); | |||||
| Status WriteRawData(std::map<uint64_t, std::vector<json>> &raw_data, vector<vector<uint8_t>> &blob_data, | |||||
| bool sign = true, bool parallel_writer = false); | |||||
| /// \brief write raw data by group size for call from python | /// \brief write raw data by group size for call from python | ||||
| /// \param[in] raw_data the vector of raw json data, python-handle format | /// \param[in] raw_data the vector of raw json data, python-handle format | ||||
| /// \param[in] blob_data the vector of blob json data, python-handle format | /// \param[in] blob_data the vector of blob json data, python-handle format | ||||
| /// \param[in] sign validate data or not | /// \param[in] sign validate data or not | ||||
| /// \return MSRStatus the status of MSRStatus to judge if write successfully | /// \return MSRStatus the status of MSRStatus to judge if write successfully | ||||
| MSRStatus WriteRawData(std::map<uint64_t, std::vector<py::handle>> &raw_data, | |||||
| std::map<uint64_t, std::vector<py::handle>> &blob_data, bool sign = true, | |||||
| bool parallel_writer = false); | |||||
| Status WriteRawData(std::map<uint64_t, std::vector<py::handle>> &raw_data, | |||||
| std::map<uint64_t, std::vector<py::handle>> &blob_data, bool sign = true, | |||||
| bool parallel_writer = false); | |||||
| MSRStatus MergeBlobData(const std::vector<string> &blob_fields, | |||||
| const std::map<std::string, std::unique_ptr<std::vector<uint8_t>>> &row_bin_data, | |||||
| std::shared_ptr<std::vector<uint8_t>> *output); | |||||
| Status MergeBlobData(const std::vector<string> &blob_fields, | |||||
| const std::map<std::string, std::unique_ptr<std::vector<uint8_t>>> &row_bin_data, | |||||
| std::shared_ptr<std::vector<uint8_t>> *output); | |||||
| static MSRStatus Initialize(const std::unique_ptr<ShardWriter> *writer_ptr, | |||||
| const std::vector<std::string> &file_names); | |||||
| static Status Initialize(const std::unique_ptr<ShardWriter> *writer_ptr, const std::vector<std::string> &file_names); | |||||
| private: | private: | ||||
| /// \brief write shard header data to disk | /// \brief write shard header data to disk | ||||
| MSRStatus WriteShardHeader(); | |||||
| Status WriteShardHeader(); | |||||
| /// \brief erase error data | /// \brief erase error data | ||||
| void DeleteErrorData(std::map<uint64_t, std::vector<json>> &raw_data, std::vector<std::vector<uint8_t>> &blob_data); | void DeleteErrorData(std::map<uint64_t, std::vector<json>> &raw_data, std::vector<std::vector<uint8_t>> &blob_data); | ||||
| @@ -130,108 +121,107 @@ class __attribute__((visibility("default"))) ShardWriter { | |||||
| std::map<int, std::string> &err_raw_data); | std::map<int, std::string> &err_raw_data); | ||||
| /// \brief write shard header data to disk | /// \brief write shard header data to disk | ||||
| std::tuple<MSRStatus, int, int> ValidateRawData(std::map<uint64_t, std::vector<json>> &raw_data, | |||||
| std::vector<std::vector<uint8_t>> &blob_data, bool sign); | |||||
| Status ValidateRawData(std::map<uint64_t, std::vector<json>> &raw_data, std::vector<std::vector<uint8_t>> &blob_data, | |||||
| bool sign, std::shared_ptr<std::pair<int, int>> *count_ptr); | |||||
| /// \brief fill data array in multiple thread run | /// \brief fill data array in multiple thread run | ||||
| void FillArray(int start, int end, std::map<uint64_t, vector<json>> &raw_data, | void FillArray(int start, int end, std::map<uint64_t, vector<json>> &raw_data, | ||||
| std::vector<std::vector<uint8_t>> &bin_data); | std::vector<std::vector<uint8_t>> &bin_data); | ||||
| /// \brief serialized raw data | /// \brief serialized raw data | ||||
| MSRStatus SerializeRawData(std::map<uint64_t, std::vector<json>> &raw_data, | |||||
| std::vector<std::vector<uint8_t>> &bin_data, uint32_t row_count); | |||||
| Status SerializeRawData(std::map<uint64_t, std::vector<json>> &raw_data, std::vector<std::vector<uint8_t>> &bin_data, | |||||
| uint32_t row_count); | |||||
| /// \brief write all data parallel | /// \brief write all data parallel | ||||
| MSRStatus ParallelWriteData(const std::vector<std::vector<uint8_t>> &blob_data, | |||||
| const std::vector<std::vector<uint8_t>> &bin_raw_data); | |||||
| Status ParallelWriteData(const std::vector<std::vector<uint8_t>> &blob_data, | |||||
| const std::vector<std::vector<uint8_t>> &bin_raw_data); | |||||
| /// \brief write data shard by shard | /// \brief write data shard by shard | ||||
| MSRStatus WriteByShard(int shard_id, int start_row, int end_row, const std::vector<std::vector<uint8_t>> &blob_data, | |||||
| const std::vector<std::vector<uint8_t>> &bin_raw_data); | |||||
| Status WriteByShard(int shard_id, int start_row, int end_row, const std::vector<std::vector<uint8_t>> &blob_data, | |||||
| const std::vector<std::vector<uint8_t>> &bin_raw_data); | |||||
| /// \brief break image data up into multiple row groups | /// \brief break image data up into multiple row groups | ||||
| MSRStatus CutRowGroup(int start_row, int end_row, const std::vector<std::vector<uint8_t>> &blob_data, | |||||
| std::vector<std::pair<int, int>> &rows_in_group, const std::shared_ptr<Page> &last_raw_page, | |||||
| const std::shared_ptr<Page> &last_blob_page); | |||||
| Status CutRowGroup(int start_row, int end_row, const std::vector<std::vector<uint8_t>> &blob_data, | |||||
| std::vector<std::pair<int, int>> &rows_in_group, const std::shared_ptr<Page> &last_raw_page, | |||||
| const std::shared_ptr<Page> &last_blob_page); | |||||
| /// \brief append partial blob data to previous page | /// \brief append partial blob data to previous page | ||||
| MSRStatus AppendBlobPage(const int &shard_id, const std::vector<std::vector<uint8_t>> &blob_data, | |||||
| const std::vector<std::pair<int, int>> &rows_in_group, | |||||
| const std::shared_ptr<Page> &last_blob_page); | |||||
| /// \brief write new blob data page to disk | |||||
| MSRStatus NewBlobPage(const int &shard_id, const std::vector<std::vector<uint8_t>> &blob_data, | |||||
| Status AppendBlobPage(const int &shard_id, const std::vector<std::vector<uint8_t>> &blob_data, | |||||
| const std::vector<std::pair<int, int>> &rows_in_group, | const std::vector<std::pair<int, int>> &rows_in_group, | ||||
| const std::shared_ptr<Page> &last_blob_page); | const std::shared_ptr<Page> &last_blob_page); | ||||
| /// \brief write new blob data page to disk | |||||
| Status NewBlobPage(const int &shard_id, const std::vector<std::vector<uint8_t>> &blob_data, | |||||
| const std::vector<std::pair<int, int>> &rows_in_group, | |||||
| const std::shared_ptr<Page> &last_blob_page); | |||||
| /// \brief shift last row group to next raw page for new appending | /// \brief shift last row group to next raw page for new appending | ||||
| MSRStatus ShiftRawPage(const int &shard_id, const std::vector<std::pair<int, int>> &rows_in_group, | |||||
| std::shared_ptr<Page> &last_raw_page); | |||||
| Status ShiftRawPage(const int &shard_id, const std::vector<std::pair<int, int>> &rows_in_group, | |||||
| std::shared_ptr<Page> &last_raw_page); | |||||
| /// \brief write raw data page to disk | /// \brief write raw data page to disk | ||||
| MSRStatus WriteRawPage(const int &shard_id, const std::vector<std::pair<int, int>> &rows_in_group, | |||||
| std::shared_ptr<Page> &last_raw_page, const std::vector<std::vector<uint8_t>> &bin_raw_data); | |||||
| Status WriteRawPage(const int &shard_id, const std::vector<std::pair<int, int>> &rows_in_group, | |||||
| std::shared_ptr<Page> &last_raw_page, const std::vector<std::vector<uint8_t>> &bin_raw_data); | |||||
| /// \brief generate empty raw data page | /// \brief generate empty raw data page | ||||
| void EmptyRawPage(const int &shard_id, std::shared_ptr<Page> &last_raw_page); | |||||
| Status EmptyRawPage(const int &shard_id, std::shared_ptr<Page> &last_raw_page); | |||||
| /// \brief append a row group at the end of raw page | /// \brief append a row group at the end of raw page | ||||
| MSRStatus AppendRawPage(const int &shard_id, const std::vector<std::pair<int, int>> &rows_in_group, | |||||
| const int &chunk_id, int &last_row_groupId, std::shared_ptr<Page> last_raw_page, | |||||
| const std::vector<std::vector<uint8_t>> &bin_raw_data); | |||||
| Status AppendRawPage(const int &shard_id, const std::vector<std::pair<int, int>> &rows_in_group, const int &chunk_id, | |||||
| int &last_row_groupId, std::shared_ptr<Page> last_raw_page, | |||||
| const std::vector<std::vector<uint8_t>> &bin_raw_data); | |||||
| /// \brief write blob chunk to disk | /// \brief write blob chunk to disk | ||||
| MSRStatus FlushBlobChunk(const std::shared_ptr<std::fstream> &out, const std::vector<std::vector<uint8_t>> &blob_data, | |||||
| const std::pair<int, int> &blob_row); | |||||
| Status FlushBlobChunk(const std::shared_ptr<std::fstream> &out, const std::vector<std::vector<uint8_t>> &blob_data, | |||||
| const std::pair<int, int> &blob_row); | |||||
| /// \brief write raw chunk to disk | /// \brief write raw chunk to disk | ||||
| MSRStatus FlushRawChunk(const std::shared_ptr<std::fstream> &out, | |||||
| const std::vector<std::pair<int, int>> &rows_in_group, const int &chunk_id, | |||||
| const std::vector<std::vector<uint8_t>> &bin_raw_data); | |||||
| Status FlushRawChunk(const std::shared_ptr<std::fstream> &out, const std::vector<std::pair<int, int>> &rows_in_group, | |||||
| const int &chunk_id, const std::vector<std::vector<uint8_t>> &bin_raw_data); | |||||
| /// \brief break up into tasks by shard | /// \brief break up into tasks by shard | ||||
| std::vector<std::pair<int, int>> BreakIntoShards(); | std::vector<std::pair<int, int>> BreakIntoShards(); | ||||
| /// \brief calculate raw data size row by row | /// \brief calculate raw data size row by row | ||||
| MSRStatus SetRawDataSize(const std::vector<std::vector<uint8_t>> &bin_raw_data); | |||||
| Status SetRawDataSize(const std::vector<std::vector<uint8_t>> &bin_raw_data); | |||||
| /// \brief calculate blob data size row by row | /// \brief calculate blob data size row by row | ||||
| MSRStatus SetBlobDataSize(const std::vector<std::vector<uint8_t>> &blob_data); | |||||
| Status SetBlobDataSize(const std::vector<std::vector<uint8_t>> &blob_data); | |||||
| /// \brief populate last raw page pointer | /// \brief populate last raw page pointer | ||||
| void SetLastRawPage(const int &shard_id, std::shared_ptr<Page> &last_raw_page); | |||||
| Status SetLastRawPage(const int &shard_id, std::shared_ptr<Page> &last_raw_page); | |||||
| /// \brief populate last blob page pointer | /// \brief populate last blob page pointer | ||||
| void SetLastBlobPage(const int &shard_id, std::shared_ptr<Page> &last_blob_page); | |||||
| Status SetLastBlobPage(const int &shard_id, std::shared_ptr<Page> &last_blob_page); | |||||
| /// \brief check the data by schema | /// \brief check the data by schema | ||||
| MSRStatus CheckData(const std::map<uint64_t, std::vector<json>> &raw_data); | |||||
| Status CheckData(const std::map<uint64_t, std::vector<json>> &raw_data); | |||||
| /// \brief check the data and type | /// \brief check the data and type | ||||
| MSRStatus CheckDataTypeAndValue(const std::string &key, const json &value, const json &data, const int &i, | |||||
| std::map<int, std::string> &err_raw_data); | |||||
| Status CheckDataTypeAndValue(const std::string &key, const json &value, const json &data, const int &i, | |||||
| std::map<int, std::string> &err_raw_data); | |||||
| /// \brief Lock writer and save pages info | /// \brief Lock writer and save pages info | ||||
| int LockWriter(bool parallel_writer = false); | |||||
| Status LockWriter(bool parallel_writer, std::unique_ptr<int> *fd_ptr); | |||||
| /// \brief Unlock writer and save pages info | /// \brief Unlock writer and save pages info | ||||
| MSRStatus UnlockWriter(int fd, bool parallel_writer = false); | |||||
| Status UnlockWriter(int fd, bool parallel_writer = false); | |||||
| /// \brief Check raw data before writing | /// \brief Check raw data before writing | ||||
| MSRStatus WriteRawDataPreCheck(std::map<uint64_t, std::vector<json>> &raw_data, vector<vector<uint8_t>> &blob_data, | |||||
| bool sign, int *schema_count, int *row_count); | |||||
| Status WriteRawDataPreCheck(std::map<uint64_t, std::vector<json>> &raw_data, vector<vector<uint8_t>> &blob_data, | |||||
| bool sign, int *schema_count, int *row_count); | |||||
| /// \brief Get full path from file name | /// \brief Get full path from file name | ||||
| MSRStatus GetFullPathFromFileName(const std::vector<std::string> &paths); | |||||
| Status GetFullPathFromFileName(const std::vector<std::string> &paths); | |||||
| /// \brief Open files | /// \brief Open files | ||||
| MSRStatus OpenDataFiles(bool append); | |||||
| Status OpenDataFiles(bool append); | |||||
| /// \brief Remove lock file | /// \brief Remove lock file | ||||
| MSRStatus RemoveLockFile(); | |||||
| Status RemoveLockFile(); | |||||
| /// \brief Remove lock file | /// \brief Remove lock file | ||||
| MSRStatus InitLockFile(); | |||||
| Status InitLockFile(); | |||||
| private: | private: | ||||
| const std::string kLockFileSuffix = "_Locker"; | const std::string kLockFileSuffix = "_Locker"; | ||||
| @@ -37,70 +37,48 @@ ShardIndexGenerator::ShardIndexGenerator(const std::string &file_path, bool appe | |||||
| task_(0), | task_(0), | ||||
| write_success_(true) {} | write_success_(true) {} | ||||
| MSRStatus ShardIndexGenerator::Build() { | |||||
| auto ret = ShardHeader::BuildSingleHeader(file_path_); | |||||
| if (ret.first != SUCCESS) { | |||||
| return FAILED; | |||||
| } | |||||
| auto json_header = ret.second; | |||||
| auto ret2 = GetDatasetFiles(file_path_, json_header["shard_addresses"]); | |||||
| if (SUCCESS != ret2.first) { | |||||
| return FAILED; | |||||
| } | |||||
| Status ShardIndexGenerator::Build() { | |||||
| std::shared_ptr<json> header_ptr; | |||||
| RETURN_IF_NOT_OK(ShardHeader::BuildSingleHeader(file_path_, &header_ptr)); | |||||
| auto ds = std::make_shared<std::vector<std::string>>(); | |||||
| RETURN_IF_NOT_OK(GetDatasetFiles(file_path_, (*header_ptr)["shard_addresses"], &ds)); | |||||
| ShardHeader header = ShardHeader(); | ShardHeader header = ShardHeader(); | ||||
| auto addresses = ret2.second; | |||||
| if (header.BuildDataset(addresses) == FAILED) { | |||||
| return FAILED; | |||||
| } | |||||
| RETURN_IF_NOT_OK(header.BuildDataset(*ds)); | |||||
| shard_header_ = header; | shard_header_ = header; | ||||
| MS_LOG(INFO) << "Init header from mindrecord file for index successfully."; | MS_LOG(INFO) << "Init header from mindrecord file for index successfully."; | ||||
| return SUCCESS; | |||||
| return Status::OK(); | |||||
| } | } | ||||
| std::pair<MSRStatus, std::string> ShardIndexGenerator::GetValueByField(const string &field, json input) { | |||||
| if (field.empty()) { | |||||
| MS_LOG(ERROR) << "The input field is None."; | |||||
| return {FAILED, ""}; | |||||
| } | |||||
| if (input.empty()) { | |||||
| MS_LOG(ERROR) << "The input json is None."; | |||||
| return {FAILED, ""}; | |||||
| } | |||||
| Status ShardIndexGenerator::GetValueByField(const string &field, json input, std::shared_ptr<std::string> *value) { | |||||
| RETURN_UNEXPECTED_IF_NULL(value); | |||||
| CHECK_FAIL_RETURN_UNEXPECTED(!field.empty(), "The input field is empty."); | |||||
| CHECK_FAIL_RETURN_UNEXPECTED(!input.empty(), "The input json is empty."); | |||||
| // parameter input does not contain the field | // parameter input does not contain the field | ||||
| if (input.find(field) == input.end()) { | |||||
| MS_LOG(ERROR) << "The field " << field << " is not found in parameter " << input; | |||||
| return {FAILED, ""}; | |||||
| } | |||||
| CHECK_FAIL_RETURN_UNEXPECTED(input.find(field) != input.end(), | |||||
| "The field " + field + " is not found in json " + input.dump()); | |||||
| // schema does not contain the field | // schema does not contain the field | ||||
| auto schema = shard_header_.GetSchemas()[0]->GetSchema()["schema"]; | auto schema = shard_header_.GetSchemas()[0]->GetSchema()["schema"]; | ||||
| if (schema.find(field) == schema.end()) { | |||||
| MS_LOG(ERROR) << "The field " << field << " is not found in schema " << schema; | |||||
| return {FAILED, ""}; | |||||
| } | |||||
| CHECK_FAIL_RETURN_UNEXPECTED(schema.find(field) != schema.end(), | |||||
| "The field " + field + " is not found in schema " + schema.dump()); | |||||
| // field should be scalar type | // field should be scalar type | ||||
| if (kScalarFieldTypeSet.find(schema[field]["type"]) == kScalarFieldTypeSet.end()) { | |||||
| MS_LOG(ERROR) << "The field " << field << " type is " << schema[field]["type"] << ", it is not retrievable"; | |||||
| return {FAILED, ""}; | |||||
| } | |||||
| CHECK_FAIL_RETURN_UNEXPECTED( | |||||
| kScalarFieldTypeSet.find(schema[field]["type"]) != kScalarFieldTypeSet.end(), | |||||
| "The field " + field + " type is " + schema[field]["type"].dump() + " which is not retrievable."); | |||||
| if (kNumberFieldTypeSet.find(schema[field]["type"]) != kNumberFieldTypeSet.end()) { | if (kNumberFieldTypeSet.find(schema[field]["type"]) != kNumberFieldTypeSet.end()) { | ||||
| auto schema_field_options = schema[field]; | auto schema_field_options = schema[field]; | ||||
| if (schema_field_options.find("shape") == schema_field_options.end()) { | |||||
| return {SUCCESS, input[field].dump()}; | |||||
| } else { | |||||
| // field with shape option | |||||
| MS_LOG(ERROR) << "The field " << field << " shape is " << schema[field]["shape"] << " which is not retrievable"; | |||||
| return {FAILED, ""}; | |||||
| } | |||||
| CHECK_FAIL_RETURN_UNEXPECTED( | |||||
| schema_field_options.find("shape") == schema_field_options.end(), | |||||
| "The field " + field + " shape is " + schema[field]["shape"].dump() + " which is not retrievable."); | |||||
| *value = std::make_shared<std::string>(input[field].dump()); | |||||
| } else { | |||||
| // the field type is string in here | |||||
| *value = std::make_shared<std::string>(input[field].get<std::string>()); | |||||
| } | } | ||||
| // the field type is string in here | |||||
| return {SUCCESS, input[field].get<std::string>()}; | |||||
| return Status::OK(); | |||||
| } | } | ||||
| std::string ShardIndexGenerator::TakeFieldType(const string &field_path, json schema) { | std::string ShardIndexGenerator::TakeFieldType(const string &field_path, json schema) { | ||||
| @@ -150,24 +128,28 @@ int ShardIndexGenerator::Callback(void *not_used, int argc, char **argv, char ** | |||||
| return 0; | return 0; | ||||
| } | } | ||||
| MSRStatus ShardIndexGenerator::ExecuteSQL(const std::string &sql, sqlite3 *db, const std::string &success_msg) { | |||||
| Status ShardIndexGenerator::ExecuteSQL(const std::string &sql, sqlite3 *db, const std::string &success_msg) { | |||||
| char *z_err_msg = nullptr; | char *z_err_msg = nullptr; | ||||
| int rc = sqlite3_exec(db, common::SafeCStr(sql), Callback, nullptr, &z_err_msg); | int rc = sqlite3_exec(db, common::SafeCStr(sql), Callback, nullptr, &z_err_msg); | ||||
| if (rc != SQLITE_OK) { | if (rc != SQLITE_OK) { | ||||
| MS_LOG(ERROR) << "Sql error: " << z_err_msg; | |||||
| std::ostringstream oss; | |||||
| oss << "Failed to exec sqlite3_exec, msg is: " << z_err_msg; | |||||
| MS_LOG(DEBUG) << oss.str(); | |||||
| sqlite3_free(z_err_msg); | sqlite3_free(z_err_msg); | ||||
| return FAILED; | |||||
| sqlite3_close(db); | |||||
| RETURN_STATUS_UNEXPECTED(oss.str()); | |||||
| } else { | } else { | ||||
| if (!success_msg.empty()) { | if (!success_msg.empty()) { | ||||
| MS_LOG(DEBUG) << "Sqlite3_exec exec success, msg is: " << success_msg; | |||||
| MS_LOG(DEBUG) << "Suceess to exec sqlite3_exec, msg is: " << success_msg; | |||||
| } | } | ||||
| sqlite3_free(z_err_msg); | sqlite3_free(z_err_msg); | ||||
| return SUCCESS; | |||||
| return Status::OK(); | |||||
| } | } | ||||
| } | } | ||||
| std::pair<MSRStatus, std::string> ShardIndexGenerator::GenerateFieldName( | |||||
| const std::pair<uint64_t, std::string> &field) { | |||||
| Status ShardIndexGenerator::GenerateFieldName(const std::pair<uint64_t, std::string> &field, | |||||
| std::shared_ptr<std::string> *fn_ptr) { | |||||
| RETURN_UNEXPECTED_IF_NULL(fn_ptr); | |||||
| // Replaces dots and dashes with underscores for SQL use | // Replaces dots and dashes with underscores for SQL use | ||||
| std::string field_name = field.second; | std::string field_name = field.second; | ||||
| // white list to avoid sql injection | // white list to avoid sql injection | ||||
| @@ -176,95 +158,71 @@ std::pair<MSRStatus, std::string> ShardIndexGenerator::GenerateFieldName( | |||||
| auto pos = std::find_if_not(field_name.begin(), field_name.end(), [](char x) { | auto pos = std::find_if_not(field_name.begin(), field_name.end(), [](char x) { | ||||
| return (x >= 'A' && x <= 'Z') || (x >= 'a' && x <= 'z') || x == '_' || (x >= '0' && x <= '9'); | return (x >= 'A' && x <= 'Z') || (x >= 'a' && x <= 'z') || x == '_' || (x >= '0' && x <= '9'); | ||||
| }); | }); | ||||
| if (pos != field_name.end()) { | |||||
| MS_LOG(ERROR) << "Field name must be composed of '0-9' or 'a-z' or 'A-Z' or '_', field_name: " << field_name; | |||||
| return {FAILED, ""}; | |||||
| } | |||||
| return {SUCCESS, field_name + "_" + std::to_string(field.first)}; | |||||
| CHECK_FAIL_RETURN_UNEXPECTED( | |||||
| pos == field_name.end(), | |||||
| "Field name must be composed of '0-9' or 'a-z' or 'A-Z' or '_', field_name: " + field_name); | |||||
| *fn_ptr = std::make_shared<std::string>(field_name + "_" + std::to_string(field.first)); | |||||
| return Status::OK(); | |||||
| } | } | ||||
| std::pair<MSRStatus, sqlite3 *> ShardIndexGenerator::CheckDatabase(const std::string &shard_address) { | |||||
| Status ShardIndexGenerator::CheckDatabase(const std::string &shard_address, sqlite3 **db) { | |||||
| auto realpath = Common::GetRealPath(shard_address); | auto realpath = Common::GetRealPath(shard_address); | ||||
| if (!realpath.has_value()) { | |||||
| MS_LOG(ERROR) << "Get real path failed, path=" << shard_address; | |||||
| return {FAILED, nullptr}; | |||||
| } | |||||
| sqlite3 *db = nullptr; | |||||
| CHECK_FAIL_RETURN_UNEXPECTED(realpath.has_value(), "Get real path failed, path=" + shard_address); | |||||
| std::ifstream fin(realpath.value()); | std::ifstream fin(realpath.value()); | ||||
| if (!append_ && fin.good()) { | if (!append_ && fin.good()) { | ||||
| MS_LOG(ERROR) << "Invalid file, DB file already exist: " << shard_address; | |||||
| fin.close(); | fin.close(); | ||||
| return {FAILED, nullptr}; | |||||
| RETURN_STATUS_UNEXPECTED("Invalid file, DB file already exist: " + shard_address); | |||||
| } | } | ||||
| fin.close(); | fin.close(); | ||||
| int rc = sqlite3_open_v2(common::SafeCStr(shard_address), &db, SQLITE_OPEN_READWRITE | SQLITE_OPEN_CREATE, nullptr); | |||||
| if (rc) { | |||||
| MS_LOG(ERROR) << "Invalid file, failed to open database: " << shard_address << ", error" << sqlite3_errmsg(db); | |||||
| return {FAILED, nullptr}; | |||||
| } else { | |||||
| MS_LOG(DEBUG) << "Opened database successfully"; | |||||
| return {SUCCESS, db}; | |||||
| if (sqlite3_open_v2(common::SafeCStr(shard_address), db, SQLITE_OPEN_READWRITE | SQLITE_OPEN_CREATE, nullptr)) { | |||||
| RETURN_STATUS_UNEXPECTED("Invalid file, failed to open database: " + shard_address + ", error" + | |||||
| std::string(sqlite3_errmsg(*db))); | |||||
| } | } | ||||
| MS_LOG(DEBUG) << "Opened database successfully"; | |||||
| return Status::OK(); | |||||
| } | } | ||||
| MSRStatus ShardIndexGenerator::CreateShardNameTable(sqlite3 *db, const std::string &shard_name) { | |||||
| Status ShardIndexGenerator::CreateShardNameTable(sqlite3 *db, const std::string &shard_name) { | |||||
| // create shard_name table | // create shard_name table | ||||
| std::string sql = "DROP TABLE IF EXISTS SHARD_NAME;"; | std::string sql = "DROP TABLE IF EXISTS SHARD_NAME;"; | ||||
| if (ExecuteSQL(sql, db, "drop table successfully.") != SUCCESS) { | |||||
| return FAILED; | |||||
| } | |||||
| RETURN_IF_NOT_OK(ExecuteSQL(sql, db, "drop table successfully.")); | |||||
| sql = "CREATE TABLE SHARD_NAME(NAME TEXT NOT NULL);"; | sql = "CREATE TABLE SHARD_NAME(NAME TEXT NOT NULL);"; | ||||
| if (ExecuteSQL(sql, db, "create table successfully.") != SUCCESS) { | |||||
| return FAILED; | |||||
| } | |||||
| RETURN_IF_NOT_OK(ExecuteSQL(sql, db, "create table successfully.")); | |||||
| sql = "INSERT INTO SHARD_NAME (NAME) VALUES (:SHARD_NAME);"; | sql = "INSERT INTO SHARD_NAME (NAME) VALUES (:SHARD_NAME);"; | ||||
| sqlite3_stmt *stmt = nullptr; | sqlite3_stmt *stmt = nullptr; | ||||
| if (sqlite3_prepare_v2(db, common::SafeCStr(sql), -1, &stmt, 0) != SQLITE_OK) { | if (sqlite3_prepare_v2(db, common::SafeCStr(sql), -1, &stmt, 0) != SQLITE_OK) { | ||||
| if (stmt != nullptr) { | if (stmt != nullptr) { | ||||
| (void)sqlite3_finalize(stmt); | (void)sqlite3_finalize(stmt); | ||||
| } | } | ||||
| MS_LOG(ERROR) << "SQL error: could not prepare statement, sql: " << sql; | |||||
| return FAILED; | |||||
| sqlite3_close(db); | |||||
| RETURN_STATUS_UNEXPECTED("SQL error: could not prepare statement, sql: " + sql); | |||||
| } | } | ||||
| int index = sqlite3_bind_parameter_index(stmt, ":SHARD_NAME"); | int index = sqlite3_bind_parameter_index(stmt, ":SHARD_NAME"); | ||||
| if (sqlite3_bind_text(stmt, index, shard_name.data(), -1, SQLITE_STATIC) != SQLITE_OK) { | if (sqlite3_bind_text(stmt, index, shard_name.data(), -1, SQLITE_STATIC) != SQLITE_OK) { | ||||
| (void)sqlite3_finalize(stmt); | (void)sqlite3_finalize(stmt); | ||||
| MS_LOG(ERROR) << "SQL error: could not bind parameter, index: " << index << ", field value: " << shard_name; | |||||
| return FAILED; | |||||
| sqlite3_close(db); | |||||
| RETURN_STATUS_UNEXPECTED("SQL error: could not bind parameter, index: " + std::to_string(index) + | |||||
| ", field value: " + std::string(shard_name)); | |||||
| } | } | ||||
| if (sqlite3_step(stmt) != SQLITE_DONE) { | if (sqlite3_step(stmt) != SQLITE_DONE) { | ||||
| (void)sqlite3_finalize(stmt); | (void)sqlite3_finalize(stmt); | ||||
| MS_LOG(ERROR) << "SQL error: Could not step (execute) stmt."; | |||||
| return FAILED; | |||||
| RETURN_STATUS_UNEXPECTED("SQL error: Could not step (execute) stmt."); | |||||
| } | } | ||||
| (void)sqlite3_finalize(stmt); | (void)sqlite3_finalize(stmt); | ||||
| return SUCCESS; | |||||
| return Status::OK(); | |||||
| } | } | ||||
| std::pair<MSRStatus, sqlite3 *> ShardIndexGenerator::CreateDatabase(int shard_no) { | |||||
| Status ShardIndexGenerator::CreateDatabase(int shard_no, sqlite3 **db) { | |||||
| std::string shard_address = shard_header_.GetShardAddressByID(shard_no); | std::string shard_address = shard_header_.GetShardAddressByID(shard_no); | ||||
| if (shard_address.empty()) { | |||||
| MS_LOG(ERROR) << "Shard address is null, shard no: " << shard_no; | |||||
| return {FAILED, nullptr}; | |||||
| } | |||||
| string shard_name = GetFileName(shard_address).second; | |||||
| CHECK_FAIL_RETURN_UNEXPECTED(!shard_address.empty(), "Shard address is empty, shard No: " + shard_no); | |||||
| std::shared_ptr<std::string> fn_ptr; | |||||
| RETURN_IF_NOT_OK(GetFileName(shard_address, &fn_ptr)); | |||||
| shard_address += ".db"; | shard_address += ".db"; | ||||
| auto ret1 = CheckDatabase(shard_address); | |||||
| if (ret1.first != SUCCESS) { | |||||
| return {FAILED, nullptr}; | |||||
| } | |||||
| sqlite3 *db = ret1.second; | |||||
| RETURN_IF_NOT_OK(CheckDatabase(shard_address, db)); | |||||
| std::string sql = "DROP TABLE IF EXISTS INDEXES;"; | std::string sql = "DROP TABLE IF EXISTS INDEXES;"; | ||||
| if (ExecuteSQL(sql, db, "drop table successfully.") != SUCCESS) { | |||||
| return {FAILED, nullptr}; | |||||
| } | |||||
| RETURN_IF_NOT_OK(ExecuteSQL(sql, *db, "drop table successfully.")); | |||||
| sql = | sql = | ||||
| "CREATE TABLE INDEXES(" | "CREATE TABLE INDEXES(" | ||||
| " ROW_ID INT NOT NULL, PAGE_ID_RAW INT NOT NULL" | " ROW_ID INT NOT NULL, PAGE_ID_RAW INT NOT NULL" | ||||
| @@ -273,95 +231,79 @@ std::pair<MSRStatus, sqlite3 *> ShardIndexGenerator::CreateDatabase(int shard_no | |||||
| ", PAGE_OFFSET_BLOB INT NOT NULL, PAGE_OFFSET_BLOB_END INT NOT NULL"; | ", PAGE_OFFSET_BLOB INT NOT NULL, PAGE_OFFSET_BLOB_END INT NOT NULL"; | ||||
| int field_no = 0; | int field_no = 0; | ||||
| std::shared_ptr<std::string> field_ptr; | |||||
| for (const auto &field : fields_) { | for (const auto &field : fields_) { | ||||
| uint64_t schema_id = field.first; | uint64_t schema_id = field.first; | ||||
| auto result = shard_header_.GetSchemaByID(schema_id); | |||||
| if (result.second != SUCCESS) { | |||||
| return {FAILED, nullptr}; | |||||
| } | |||||
| json json_schema = (result.first->GetSchema())["schema"]; | |||||
| std::shared_ptr<Schema> schema_ptr; | |||||
| RETURN_IF_NOT_OK(shard_header_.GetSchemaByID(schema_id, &schema_ptr)); | |||||
| json json_schema = (schema_ptr->GetSchema())["schema"]; | |||||
| std::string type = ConvertJsonToSQL(TakeFieldType(field.second, json_schema)); | std::string type = ConvertJsonToSQL(TakeFieldType(field.second, json_schema)); | ||||
| auto ret = GenerateFieldName(field); | |||||
| if (ret.first != SUCCESS) { | |||||
| return {FAILED, nullptr}; | |||||
| } | |||||
| sql += ",INC_" + std::to_string(field_no++) + " INT, " + ret.second + " " + type; | |||||
| RETURN_IF_NOT_OK(GenerateFieldName(field, &field_ptr)); | |||||
| sql += ",INC_" + std::to_string(field_no++) + " INT, " + *field_ptr + " " + type; | |||||
| } | } | ||||
| sql += ", PRIMARY KEY(ROW_ID"; | sql += ", PRIMARY KEY(ROW_ID"; | ||||
| for (uint64_t i = 0; i < fields_.size(); ++i) { | for (uint64_t i = 0; i < fields_.size(); ++i) { | ||||
| sql += ",INC_" + std::to_string(i); | sql += ",INC_" + std::to_string(i); | ||||
| } | } | ||||
| sql += "));"; | sql += "));"; | ||||
| if (ExecuteSQL(sql, db, "create table successfully.") != SUCCESS) { | |||||
| return {FAILED, nullptr}; | |||||
| } | |||||
| if (CreateShardNameTable(db, shard_name) != SUCCESS) { | |||||
| return {FAILED, nullptr}; | |||||
| } | |||||
| return {SUCCESS, db}; | |||||
| RETURN_IF_NOT_OK(ExecuteSQL(sql, *db, "create table successfully.")); | |||||
| RETURN_IF_NOT_OK(CreateShardNameTable(*db, *fn_ptr)); | |||||
| return Status::OK(); | |||||
| } | } | ||||
| std::pair<MSRStatus, std::vector<json>> ShardIndexGenerator::GetSchemaDetails(const std::vector<uint64_t> &schema_lens, | |||||
| std::fstream &in) { | |||||
| std::vector<json> schema_details; | |||||
| Status ShardIndexGenerator::GetSchemaDetails(const std::vector<uint64_t> &schema_lens, std::fstream &in, | |||||
| std::shared_ptr<std::vector<json>> *detail_ptr) { | |||||
| RETURN_UNEXPECTED_IF_NULL(detail_ptr); | |||||
| if (schema_count_ <= kMaxSchemaCount) { | if (schema_count_ <= kMaxSchemaCount) { | ||||
| for (int sc = 0; sc < schema_count_; ++sc) { | for (int sc = 0; sc < schema_count_; ++sc) { | ||||
| std::vector<char> schema_detail(schema_lens[sc]); | std::vector<char> schema_detail(schema_lens[sc]); | ||||
| auto &io_read = in.read(&schema_detail[0], schema_lens[sc]); | auto &io_read = in.read(&schema_detail[0], schema_lens[sc]); | ||||
| if (!io_read.good() || io_read.fail() || io_read.bad()) { | if (!io_read.good() || io_read.fail() || io_read.bad()) { | ||||
| MS_LOG(ERROR) << "File read failed"; | |||||
| in.close(); | in.close(); | ||||
| return {FAILED, {}}; | |||||
| RETURN_STATUS_UNEXPECTED("Failed to read file."); | |||||
| } | } | ||||
| schema_details.emplace_back(json::from_msgpack(std::string(schema_detail.begin(), schema_detail.end()))); | |||||
| auto j = json::from_msgpack(std::string(schema_detail.begin(), schema_detail.end())); | |||||
| (*detail_ptr)->emplace_back(j); | |||||
| } | } | ||||
| } | } | ||||
| return {SUCCESS, schema_details}; | |||||
| return Status::OK(); | |||||
| } | } | ||||
| std::pair<MSRStatus, std::string> ShardIndexGenerator::GenerateRawSQL( | |||||
| const std::vector<std::pair<uint64_t, std::string>> &fields) { | |||||
| Status ShardIndexGenerator::GenerateRawSQL(const std::vector<std::pair<uint64_t, std::string>> &fields, | |||||
| std::shared_ptr<std::string> *sql_ptr) { | |||||
| std::string sql = | std::string sql = | ||||
| "INSERT INTO INDEXES (ROW_ID,ROW_GROUP_ID,PAGE_ID_RAW,PAGE_OFFSET_RAW,PAGE_OFFSET_RAW_END," | "INSERT INTO INDEXES (ROW_ID,ROW_GROUP_ID,PAGE_ID_RAW,PAGE_OFFSET_RAW,PAGE_OFFSET_RAW_END," | ||||
| "PAGE_ID_BLOB,PAGE_OFFSET_BLOB,PAGE_OFFSET_BLOB_END"; | "PAGE_ID_BLOB,PAGE_OFFSET_BLOB,PAGE_OFFSET_BLOB_END"; | ||||
| int field_no = 0; | int field_no = 0; | ||||
| for (const auto &field : fields) { | for (const auto &field : fields) { | ||||
| auto ret = GenerateFieldName(field); | |||||
| if (ret.first != SUCCESS) { | |||||
| return {FAILED, ""}; | |||||
| } | |||||
| sql += ",INC_" + std::to_string(field_no++) + "," + ret.second; | |||||
| std::shared_ptr<std::string> fn_ptr; | |||||
| RETURN_IF_NOT_OK(GenerateFieldName(field, &fn_ptr)); | |||||
| sql += ",INC_" + std::to_string(field_no++) + "," + *fn_ptr; | |||||
| } | } | ||||
| sql += | sql += | ||||
| ") VALUES( :ROW_ID,:ROW_GROUP_ID,:PAGE_ID_RAW,:PAGE_OFFSET_RAW,:PAGE_OFFSET_RAW_END,:PAGE_ID_BLOB," | ") VALUES( :ROW_ID,:ROW_GROUP_ID,:PAGE_ID_RAW,:PAGE_OFFSET_RAW,:PAGE_OFFSET_RAW_END,:PAGE_ID_BLOB," | ||||
| ":PAGE_OFFSET_BLOB,:PAGE_OFFSET_BLOB_END"; | ":PAGE_OFFSET_BLOB,:PAGE_OFFSET_BLOB_END"; | ||||
| field_no = 0; | field_no = 0; | ||||
| for (const auto &field : fields) { | for (const auto &field : fields) { | ||||
| auto ret = GenerateFieldName(field); | |||||
| if (ret.first != SUCCESS) { | |||||
| return {FAILED, ""}; | |||||
| } | |||||
| sql += ",:INC_" + std::to_string(field_no++) + ",:" + ret.second; | |||||
| std::shared_ptr<std::string> fn_ptr; | |||||
| RETURN_IF_NOT_OK(GenerateFieldName(field, &fn_ptr)); | |||||
| sql += ",:INC_" + std::to_string(field_no++) + ",:" + *fn_ptr; | |||||
| } | } | ||||
| sql += " )"; | sql += " )"; | ||||
| return {SUCCESS, sql}; | |||||
| *sql_ptr = std::make_shared<std::string>(sql); | |||||
| return Status::OK(); | |||||
| } | } | ||||
| MSRStatus ShardIndexGenerator::BindParameterExecuteSQL( | |||||
| sqlite3 *db, const std::string &sql, | |||||
| const std::vector<std::vector<std::tuple<std::string, std::string, std::string>>> &data) { | |||||
| Status ShardIndexGenerator::BindParameterExecuteSQL(sqlite3 *db, const std::string &sql, const ROW_DATA &data) { | |||||
| sqlite3_stmt *stmt = nullptr; | sqlite3_stmt *stmt = nullptr; | ||||
| if (sqlite3_prepare_v2(db, common::SafeCStr(sql), -1, &stmt, 0) != SQLITE_OK) { | if (sqlite3_prepare_v2(db, common::SafeCStr(sql), -1, &stmt, 0) != SQLITE_OK) { | ||||
| if (stmt != nullptr) { | if (stmt != nullptr) { | ||||
| (void)sqlite3_finalize(stmt); | (void)sqlite3_finalize(stmt); | ||||
| } | } | ||||
| MS_LOG(ERROR) << "SQL error: could not prepare statement, sql: " << sql; | |||||
| return FAILED; | |||||
| sqlite3_close(db); | |||||
| RETURN_STATUS_UNEXPECTED("SQL error: could not prepare statement, sql: " + sql); | |||||
| } | } | ||||
| for (auto &row : data) { | for (auto &row : data) { | ||||
| for (auto &field : row) { | for (auto &field : row) { | ||||
| @@ -373,45 +315,47 @@ MSRStatus ShardIndexGenerator::BindParameterExecuteSQL( | |||||
| if (field_type == "INTEGER") { | if (field_type == "INTEGER") { | ||||
| if (sqlite3_bind_int64(stmt, index, std::stoll(field_value)) != SQLITE_OK) { | if (sqlite3_bind_int64(stmt, index, std::stoll(field_value)) != SQLITE_OK) { | ||||
| (void)sqlite3_finalize(stmt); | (void)sqlite3_finalize(stmt); | ||||
| MS_LOG(ERROR) << "SQL error: could not bind parameter, index: " << index | |||||
| << ", field value: " << std::stoll(field_value); | |||||
| return FAILED; | |||||
| sqlite3_close(db); | |||||
| RETURN_STATUS_UNEXPECTED("SQL error: could not bind parameter, index: " + std::to_string(index) + | |||||
| ", field value: " + std::string(field_value)); | |||||
| } | } | ||||
| } else if (field_type == "NUMERIC") { | } else if (field_type == "NUMERIC") { | ||||
| if (sqlite3_bind_double(stmt, index, std::stold(field_value)) != SQLITE_OK) { | if (sqlite3_bind_double(stmt, index, std::stold(field_value)) != SQLITE_OK) { | ||||
| (void)sqlite3_finalize(stmt); | (void)sqlite3_finalize(stmt); | ||||
| MS_LOG(ERROR) << "SQL error: could not bind parameter, index: " << index | |||||
| << ", field value: " << std::stold(field_value); | |||||
| return FAILED; | |||||
| sqlite3_close(db); | |||||
| RETURN_STATUS_UNEXPECTED("SQL error: could not bind parameter, index: " + std::to_string(index) + | |||||
| ", field value: " + std::string(field_value)); | |||||
| } | } | ||||
| } else if (field_type == "NULL") { | } else if (field_type == "NULL") { | ||||
| if (sqlite3_bind_null(stmt, index) != SQLITE_OK) { | if (sqlite3_bind_null(stmt, index) != SQLITE_OK) { | ||||
| (void)sqlite3_finalize(stmt); | (void)sqlite3_finalize(stmt); | ||||
| MS_LOG(ERROR) << "SQL error: could not bind parameter, index: " << index << ", field value: NULL"; | |||||
| return FAILED; | |||||
| sqlite3_close(db); | |||||
| RETURN_STATUS_UNEXPECTED("SQL error: could not bind parameter, index: " + std::to_string(index) + | |||||
| ", field value: NULL"); | |||||
| } | } | ||||
| } else { | } else { | ||||
| if (sqlite3_bind_text(stmt, index, common::SafeCStr(field_value), -1, SQLITE_STATIC) != SQLITE_OK) { | if (sqlite3_bind_text(stmt, index, common::SafeCStr(field_value), -1, SQLITE_STATIC) != SQLITE_OK) { | ||||
| (void)sqlite3_finalize(stmt); | (void)sqlite3_finalize(stmt); | ||||
| MS_LOG(ERROR) << "SQL error: could not bind parameter, index: " << index << ", field value: " << field_value; | |||||
| return FAILED; | |||||
| sqlite3_close(db); | |||||
| RETURN_STATUS_UNEXPECTED("SQL error: could not bind parameter, index: " + std::to_string(index) + | |||||
| ", field value: " + std::string(field_value)); | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| if (sqlite3_step(stmt) != SQLITE_DONE) { | if (sqlite3_step(stmt) != SQLITE_DONE) { | ||||
| (void)sqlite3_finalize(stmt); | (void)sqlite3_finalize(stmt); | ||||
| MS_LOG(ERROR) << "SQL error: Could not step (execute) stmt."; | |||||
| return FAILED; | |||||
| RETURN_STATUS_UNEXPECTED("SQL error: Could not step (execute) stmt."); | |||||
| } | } | ||||
| (void)sqlite3_reset(stmt); | (void)sqlite3_reset(stmt); | ||||
| } | } | ||||
| (void)sqlite3_finalize(stmt); | (void)sqlite3_finalize(stmt); | ||||
| return SUCCESS; | |||||
| return Status::OK(); | |||||
| } | } | ||||
| MSRStatus ShardIndexGenerator::AddBlobPageInfo(std::vector<std::tuple<std::string, std::string, std::string>> &row_data, | |||||
| const std::shared_ptr<Page> cur_blob_page, | |||||
| uint64_t &cur_blob_page_offset, std::fstream &in) { | |||||
| Status ShardIndexGenerator::AddBlobPageInfo(std::vector<std::tuple<std::string, std::string, std::string>> &row_data, | |||||
| const std::shared_ptr<Page> cur_blob_page, uint64_t &cur_blob_page_offset, | |||||
| std::fstream &in) { | |||||
| row_data.emplace_back(":PAGE_ID_BLOB", "INTEGER", std::to_string(cur_blob_page->GetPageID())); | row_data.emplace_back(":PAGE_ID_BLOB", "INTEGER", std::to_string(cur_blob_page->GetPageID())); | ||||
| // blob data start | // blob data start | ||||
| @@ -419,89 +363,71 @@ MSRStatus ShardIndexGenerator::AddBlobPageInfo(std::vector<std::tuple<std::strin | |||||
| auto &io_seekg_blob = | auto &io_seekg_blob = | ||||
| in.seekg(page_size_ * cur_blob_page->GetPageID() + header_size_ + cur_blob_page_offset, std::ios::beg); | in.seekg(page_size_ * cur_blob_page->GetPageID() + header_size_ + cur_blob_page_offset, std::ios::beg); | ||||
| if (!io_seekg_blob.good() || io_seekg_blob.fail() || io_seekg_blob.bad()) { | if (!io_seekg_blob.good() || io_seekg_blob.fail() || io_seekg_blob.bad()) { | ||||
| MS_LOG(ERROR) << "File seekg failed"; | |||||
| in.close(); | in.close(); | ||||
| return FAILED; | |||||
| RETURN_STATUS_UNEXPECTED("Failed to seekg file."); | |||||
| } | } | ||||
| uint64_t image_size = 0; | uint64_t image_size = 0; | ||||
| auto &io_read = in.read(reinterpret_cast<char *>(&image_size), kInt64Len); | auto &io_read = in.read(reinterpret_cast<char *>(&image_size), kInt64Len); | ||||
| if (!io_read.good() || io_read.fail() || io_read.bad()) { | if (!io_read.good() || io_read.fail() || io_read.bad()) { | ||||
| MS_LOG(ERROR) << "File read failed"; | MS_LOG(ERROR) << "File read failed"; | ||||
| in.close(); | in.close(); | ||||
| return FAILED; | |||||
| RETURN_STATUS_UNEXPECTED("Failed to read file."); | |||||
| } | } | ||||
| cur_blob_page_offset += (kInt64Len + image_size); | cur_blob_page_offset += (kInt64Len + image_size); | ||||
| row_data.emplace_back(":PAGE_OFFSET_BLOB_END", "INTEGER", std::to_string(cur_blob_page_offset)); | row_data.emplace_back(":PAGE_OFFSET_BLOB_END", "INTEGER", std::to_string(cur_blob_page_offset)); | ||||
| return SUCCESS; | |||||
| return Status::OK(); | |||||
| } | } | ||||
| void ShardIndexGenerator::AddIndexFieldByRawData( | |||||
| Status ShardIndexGenerator::AddIndexFieldByRawData( | |||||
| const std::vector<json> &schema_detail, std::vector<std::tuple<std::string, std::string, std::string>> &row_data) { | const std::vector<json> &schema_detail, std::vector<std::tuple<std::string, std::string, std::string>> &row_data) { | ||||
| auto result = GenerateIndexFields(schema_detail); | |||||
| if (result.first == SUCCESS) { | |||||
| int index = 0; | |||||
| for (const auto &field : result.second) { | |||||
| // assume simple field: string , number etc. | |||||
| row_data.emplace_back(":INC_" + std::to_string(index++), "INTEGER", "0"); | |||||
| row_data.emplace_back(":" + std::get<0>(field), std::get<1>(field), std::get<2>(field)); | |||||
| } | |||||
| } | |||||
| auto index_fields_ptr = std::make_shared<INDEX_FIELDS>(); | |||||
| RETURN_IF_NOT_OK(GenerateIndexFields(schema_detail, &index_fields_ptr)); | |||||
| int index = 0; | |||||
| for (const auto &field : *index_fields_ptr) { | |||||
| // assume simple field: string , number etc. | |||||
| row_data.emplace_back(":INC_" + std::to_string(index++), "INTEGER", "0"); | |||||
| row_data.emplace_back(":" + std::get<0>(field), std::get<1>(field), std::get<2>(field)); | |||||
| } | |||||
| return Status::OK(); | |||||
| } | } | ||||
| ROW_DATA ShardIndexGenerator::GenerateRowData(int shard_no, const std::map<int, int> &blob_id_to_page_id, | |||||
| int raw_page_id, std::fstream &in) { | |||||
| std::vector<std::vector<std::tuple<std::string, std::string, std::string>>> full_data; | |||||
| Status ShardIndexGenerator::GenerateRowData(int shard_no, const std::map<int, int> &blob_id_to_page_id, int raw_page_id, | |||||
| std::fstream &in, std::shared_ptr<ROW_DATA> *row_data_ptr) { | |||||
| RETURN_UNEXPECTED_IF_NULL(row_data_ptr); | |||||
| // current raw data page | // current raw data page | ||||
| auto ret1 = shard_header_.GetPage(shard_no, raw_page_id); | |||||
| if (ret1.second != SUCCESS) { | |||||
| MS_LOG(ERROR) << "Get page failed"; | |||||
| return {FAILED, {}}; | |||||
| } | |||||
| std::shared_ptr<Page> cur_raw_page = ret1.first; | |||||
| std::shared_ptr<Page> page_ptr; | |||||
| RETURN_IF_NOT_OK(shard_header_.GetPage(shard_no, raw_page_id, &page_ptr)); | |||||
| // related blob page | // related blob page | ||||
| vector<pair<int, uint64_t>> row_group_list = cur_raw_page->GetRowGroupIds(); | |||||
| vector<pair<int, uint64_t>> row_group_list = page_ptr->GetRowGroupIds(); | |||||
| // pair: row_group id, offset in raw data page | // pair: row_group id, offset in raw data page | ||||
| for (pair<int, int> blob_ids : row_group_list) { | for (pair<int, int> blob_ids : row_group_list) { | ||||
| // get blob data page according to row_group id | // get blob data page according to row_group id | ||||
| auto iter = blob_id_to_page_id.find(blob_ids.first); | auto iter = blob_id_to_page_id.find(blob_ids.first); | ||||
| if (iter == blob_id_to_page_id.end()) { | |||||
| MS_LOG(ERROR) << "Convert blob id failed"; | |||||
| return {FAILED, {}}; | |||||
| } | |||||
| auto ret2 = shard_header_.GetPage(shard_no, iter->second); | |||||
| if (ret2.second != SUCCESS) { | |||||
| MS_LOG(ERROR) << "Get page failed"; | |||||
| return {FAILED, {}}; | |||||
| } | |||||
| std::shared_ptr<Page> cur_blob_page = ret2.first; | |||||
| CHECK_FAIL_RETURN_UNEXPECTED(iter != blob_id_to_page_id.end(), "Failed to get page id from blob id."); | |||||
| std::shared_ptr<Page> blob_page_ptr; | |||||
| RETURN_IF_NOT_OK(shard_header_.GetPage(shard_no, iter->second, &blob_page_ptr)); | |||||
| // offset in current raw data page | // offset in current raw data page | ||||
| auto cur_raw_page_offset = static_cast<uint64_t>(blob_ids.second); | auto cur_raw_page_offset = static_cast<uint64_t>(blob_ids.second); | ||||
| uint64_t cur_blob_page_offset = 0; | uint64_t cur_blob_page_offset = 0; | ||||
| for (unsigned int i = cur_blob_page->GetStartRowID(); i < cur_blob_page->GetEndRowID(); ++i) { | |||||
| for (unsigned int i = blob_page_ptr->GetStartRowID(); i < blob_page_ptr->GetEndRowID(); ++i) { | |||||
| std::vector<std::tuple<std::string, std::string, std::string>> row_data; | std::vector<std::tuple<std::string, std::string, std::string>> row_data; | ||||
| row_data.emplace_back(":ROW_ID", "INTEGER", std::to_string(i)); | row_data.emplace_back(":ROW_ID", "INTEGER", std::to_string(i)); | ||||
| row_data.emplace_back(":ROW_GROUP_ID", "INTEGER", std::to_string(cur_blob_page->GetPageTypeID())); | |||||
| row_data.emplace_back(":PAGE_ID_RAW", "INTEGER", std::to_string(cur_raw_page->GetPageID())); | |||||
| row_data.emplace_back(":ROW_GROUP_ID", "INTEGER", std::to_string(blob_page_ptr->GetPageTypeID())); | |||||
| row_data.emplace_back(":PAGE_ID_RAW", "INTEGER", std::to_string(page_ptr->GetPageID())); | |||||
| // raw data start | // raw data start | ||||
| row_data.emplace_back(":PAGE_OFFSET_RAW", "INTEGER", std::to_string(cur_raw_page_offset)); | row_data.emplace_back(":PAGE_OFFSET_RAW", "INTEGER", std::to_string(cur_raw_page_offset)); | ||||
| // calculate raw data end | // calculate raw data end | ||||
| auto &io_seekg = | auto &io_seekg = | ||||
| in.seekg(page_size_ * (cur_raw_page->GetPageID()) + header_size_ + cur_raw_page_offset, std::ios::beg); | |||||
| in.seekg(page_size_ * (page_ptr->GetPageID()) + header_size_ + cur_raw_page_offset, std::ios::beg); | |||||
| if (!io_seekg.good() || io_seekg.fail() || io_seekg.bad()) { | if (!io_seekg.good() || io_seekg.fail() || io_seekg.bad()) { | ||||
| MS_LOG(ERROR) << "File seekg failed"; | |||||
| return {FAILED, {}}; | |||||
| in.close(); | |||||
| RETURN_STATUS_UNEXPECTED("Failed to seekg file."); | |||||
| } | } | ||||
| std::vector<uint64_t> schema_lens; | std::vector<uint64_t> schema_lens; | ||||
| if (schema_count_ <= kMaxSchemaCount) { | if (schema_count_ <= kMaxSchemaCount) { | ||||
| for (int sc = 0; sc < schema_count_; sc++) { | for (int sc = 0; sc < schema_count_; sc++) { | ||||
| @@ -509,8 +435,8 @@ ROW_DATA ShardIndexGenerator::GenerateRowData(int shard_no, const std::map<int, | |||||
| auto &io_read = in.read(reinterpret_cast<char *>(&schema_size), kInt64Len); | auto &io_read = in.read(reinterpret_cast<char *>(&schema_size), kInt64Len); | ||||
| if (!io_read.good() || io_read.fail() || io_read.bad()) { | if (!io_read.good() || io_read.fail() || io_read.bad()) { | ||||
| MS_LOG(ERROR) << "File read failed"; | |||||
| return {FAILED, {}}; | |||||
| in.close(); | |||||
| RETURN_STATUS_UNEXPECTED("Failed to read file."); | |||||
| } | } | ||||
| cur_raw_page_offset += (kInt64Len + schema_size); | cur_raw_page_offset += (kInt64Len + schema_size); | ||||
| @@ -520,122 +446,79 @@ ROW_DATA ShardIndexGenerator::GenerateRowData(int shard_no, const std::map<int, | |||||
| row_data.emplace_back(":PAGE_OFFSET_RAW_END", "INTEGER", std::to_string(cur_raw_page_offset)); | row_data.emplace_back(":PAGE_OFFSET_RAW_END", "INTEGER", std::to_string(cur_raw_page_offset)); | ||||
| // Getting schema for getting data for fields | // Getting schema for getting data for fields | ||||
| auto st_schema_detail = GetSchemaDetails(schema_lens, in); | |||||
| if (st_schema_detail.first != SUCCESS) { | |||||
| return {FAILED, {}}; | |||||
| } | |||||
| auto detail_ptr = std::make_shared<std::vector<json>>(); | |||||
| RETURN_IF_NOT_OK(GetSchemaDetails(schema_lens, in, &detail_ptr)); | |||||
| // start blob page info | // start blob page info | ||||
| if (AddBlobPageInfo(row_data, cur_blob_page, cur_blob_page_offset, in) != SUCCESS) { | |||||
| return {FAILED, {}}; | |||||
| } | |||||
| RETURN_IF_NOT_OK(AddBlobPageInfo(row_data, blob_page_ptr, cur_blob_page_offset, in)); | |||||
| // start index field | // start index field | ||||
| AddIndexFieldByRawData(st_schema_detail.second, row_data); | |||||
| full_data.push_back(std::move(row_data)); | |||||
| AddIndexFieldByRawData(*detail_ptr, row_data); | |||||
| (*row_data_ptr)->push_back(std::move(row_data)); | |||||
| } | } | ||||
| } | } | ||||
| return {SUCCESS, full_data}; | |||||
| return Status::OK(); | |||||
| } | } | ||||
| INDEX_FIELDS ShardIndexGenerator::GenerateIndexFields(const std::vector<json> &schema_detail) { | |||||
| std::vector<std::tuple<std::string, std::string, std::string>> fields; | |||||
| Status ShardIndexGenerator::GenerateIndexFields(const std::vector<json> &schema_detail, | |||||
| std::shared_ptr<INDEX_FIELDS> *index_fields_ptr) { | |||||
| RETURN_UNEXPECTED_IF_NULL(index_fields_ptr); | |||||
| // index fields | // index fields | ||||
| std::vector<std::pair<uint64_t, std::string>> index_fields = shard_header_.GetFields(); | std::vector<std::pair<uint64_t, std::string>> index_fields = shard_header_.GetFields(); | ||||
| for (const auto &field : index_fields) { | for (const auto &field : index_fields) { | ||||
| if (field.first >= schema_detail.size()) { | |||||
| return {FAILED, {}}; | |||||
| } | |||||
| auto field_value = GetValueByField(field.second, schema_detail[field.first]); | |||||
| if (field_value.first != SUCCESS) { | |||||
| MS_LOG(ERROR) << "Get value from json by field name failed"; | |||||
| return {FAILED, {}}; | |||||
| } | |||||
| auto result = shard_header_.GetSchemaByID(field.first); | |||||
| if (result.second != SUCCESS) { | |||||
| return {FAILED, {}}; | |||||
| } | |||||
| std::string field_type = ConvertJsonToSQL(TakeFieldType(field.second, result.first->GetSchema()["schema"])); | |||||
| auto ret = GenerateFieldName(field); | |||||
| if (ret.first != SUCCESS) { | |||||
| return {FAILED, {}}; | |||||
| } | |||||
| fields.emplace_back(ret.second, field_type, field_value.second); | |||||
| } | |||||
| return {SUCCESS, std::move(fields)}; | |||||
| CHECK_FAIL_RETURN_UNEXPECTED(field.first < schema_detail.size(), "Index field id is out of range."); | |||||
| std::shared_ptr<std::string> field_val_ptr; | |||||
| RETURN_IF_NOT_OK(GetValueByField(field.second, schema_detail[field.first], &field_val_ptr)); | |||||
| std::shared_ptr<Schema> schema_ptr; | |||||
| RETURN_IF_NOT_OK(shard_header_.GetSchemaByID(field.first, &schema_ptr)); | |||||
| std::string field_type = ConvertJsonToSQL(TakeFieldType(field.second, schema_ptr->GetSchema()["schema"])); | |||||
| std::shared_ptr<std::string> fn_ptr; | |||||
| RETURN_IF_NOT_OK(GenerateFieldName(field, &fn_ptr)); | |||||
| (*index_fields_ptr)->emplace_back(*fn_ptr, field_type, *field_val_ptr); | |||||
| } | |||||
| return Status::OK(); | |||||
| } | } | ||||
| MSRStatus ShardIndexGenerator::ExecuteTransaction(const int &shard_no, std::pair<MSRStatus, sqlite3 *> &db, | |||||
| const std::vector<int> &raw_page_ids, | |||||
| const std::map<int, int> &blob_id_to_page_id) { | |||||
| Status ShardIndexGenerator::ExecuteTransaction(const int &shard_no, sqlite3 *db, const std::vector<int> &raw_page_ids, | |||||
| const std::map<int, int> &blob_id_to_page_id) { | |||||
| // Add index data to database | // Add index data to database | ||||
| std::string shard_address = shard_header_.GetShardAddressByID(shard_no); | std::string shard_address = shard_header_.GetShardAddressByID(shard_no); | ||||
| if (shard_address.empty()) { | |||||
| MS_LOG(ERROR) << "Invalid data, shard address is null"; | |||||
| return FAILED; | |||||
| } | |||||
| CHECK_FAIL_RETURN_UNEXPECTED(!shard_address.empty(), "shard address is empty."); | |||||
| auto realpath = Common::GetRealPath(shard_address); | auto realpath = Common::GetRealPath(shard_address); | ||||
| if (!realpath.has_value()) { | |||||
| MS_LOG(ERROR) << "Get real path failed, path=" << shard_address; | |||||
| return FAILED; | |||||
| } | |||||
| CHECK_FAIL_RETURN_UNEXPECTED(realpath.has_value(), "Get real path failed, path=" + shard_address); | |||||
| std::fstream in; | std::fstream in; | ||||
| in.open(realpath.value(), std::ios::in | std::ios::binary); | in.open(realpath.value(), std::ios::in | std::ios::binary); | ||||
| if (!in.good()) { | if (!in.good()) { | ||||
| MS_LOG(ERROR) << "Invalid file, failed to open file: " << shard_address; | |||||
| in.close(); | in.close(); | ||||
| return FAILED; | |||||
| RETURN_STATUS_UNEXPECTED("Failed to open file: " + shard_address); | |||||
| } | } | ||||
| (void)sqlite3_exec(db.second, "BEGIN TRANSACTION;", nullptr, nullptr, nullptr); | |||||
| (void)sqlite3_exec(db, "BEGIN TRANSACTION;", nullptr, nullptr, nullptr); | |||||
| for (int raw_page_id : raw_page_ids) { | for (int raw_page_id : raw_page_ids) { | ||||
| auto sql = GenerateRawSQL(fields_); | |||||
| if (sql.first != SUCCESS) { | |||||
| MS_LOG(ERROR) << "Generate raw SQL failed"; | |||||
| in.close(); | |||||
| sqlite3_close(db.second); | |||||
| return FAILED; | |||||
| } | |||||
| auto data = GenerateRowData(shard_no, blob_id_to_page_id, raw_page_id, in); | |||||
| if (data.first != SUCCESS) { | |||||
| MS_LOG(ERROR) << "Generate raw data failed"; | |||||
| in.close(); | |||||
| sqlite3_close(db.second); | |||||
| return FAILED; | |||||
| } | |||||
| if (BindParameterExecuteSQL(db.second, sql.second, data.second) == FAILED) { | |||||
| MS_LOG(ERROR) << "Execute SQL failed"; | |||||
| in.close(); | |||||
| sqlite3_close(db.second); | |||||
| return FAILED; | |||||
| } | |||||
| MS_LOG(INFO) << "Insert " << data.second.size() << " rows to index db."; | |||||
| } | |||||
| (void)sqlite3_exec(db.second, "END TRANSACTION;", nullptr, nullptr, nullptr); | |||||
| std::shared_ptr<std::string> sql_ptr; | |||||
| RELEASE_AND_RETURN_IF_NOT_OK(GenerateRawSQL(fields_, &sql_ptr), db, in); | |||||
| auto row_data_ptr = std::make_shared<ROW_DATA>(); | |||||
| RELEASE_AND_RETURN_IF_NOT_OK(GenerateRowData(shard_no, blob_id_to_page_id, raw_page_id, in, &row_data_ptr), db, in); | |||||
| RELEASE_AND_RETURN_IF_NOT_OK(BindParameterExecuteSQL(db, *sql_ptr, *row_data_ptr), db, in); | |||||
| MS_LOG(INFO) << "Insert " << row_data_ptr->size() << " rows to index db."; | |||||
| } | |||||
| (void)sqlite3_exec(db, "END TRANSACTION;", nullptr, nullptr, nullptr); | |||||
| in.close(); | in.close(); | ||||
| // Close database | // Close database | ||||
| if (sqlite3_close(db.second) != SQLITE_OK) { | |||||
| MS_LOG(ERROR) << "Close database failed"; | |||||
| return FAILED; | |||||
| } | |||||
| db.second = nullptr; | |||||
| return SUCCESS; | |||||
| sqlite3_close(db); | |||||
| db = nullptr; | |||||
| return Status::OK(); | |||||
| } | } | ||||
| MSRStatus ShardIndexGenerator::WriteToDatabase() { | |||||
| Status ShardIndexGenerator::WriteToDatabase() { | |||||
| fields_ = shard_header_.GetFields(); | fields_ = shard_header_.GetFields(); | ||||
| page_size_ = shard_header_.GetPageSize(); | page_size_ = shard_header_.GetPageSize(); | ||||
| header_size_ = shard_header_.GetHeaderSize(); | header_size_ = shard_header_.GetHeaderSize(); | ||||
| schema_count_ = shard_header_.GetSchemaCount(); | schema_count_ = shard_header_.GetSchemaCount(); | ||||
| if (shard_header_.GetShardCount() > kMaxShardCount) { | |||||
| MS_LOG(ERROR) << "num shards: " << shard_header_.GetShardCount() << " exceeds max count:" << kMaxSchemaCount; | |||||
| return FAILED; | |||||
| } | |||||
| CHECK_FAIL_RETURN_UNEXPECTED(shard_header_.GetShardCount() <= kMaxShardCount, | |||||
| "num shards: " + std::to_string(shard_header_.GetShardCount()) + | |||||
| " exceeds max count:" + std::to_string(kMaxSchemaCount)); | |||||
| task_ = 0; // set two atomic vars to initial value | task_ = 0; // set two atomic vars to initial value | ||||
| write_success_ = true; | write_success_ = true; | ||||
| @@ -653,40 +536,41 @@ MSRStatus ShardIndexGenerator::WriteToDatabase() { | |||||
| for (size_t t = 0; t < threads.capacity(); t++) { | for (size_t t = 0; t < threads.capacity(); t++) { | ||||
| threads[t].join(); | threads[t].join(); | ||||
| } | } | ||||
| return write_success_ ? SUCCESS : FAILED; | |||||
| CHECK_FAIL_RETURN_UNEXPECTED(write_success_, "Failed to write data to db."); | |||||
| return Status::OK(); | |||||
| } | } | ||||
| void ShardIndexGenerator::DatabaseWriter() { | void ShardIndexGenerator::DatabaseWriter() { | ||||
| int shard_no = task_++; | int shard_no = task_++; | ||||
| while (shard_no < shard_header_.GetShardCount()) { | while (shard_no < shard_header_.GetShardCount()) { | ||||
| auto db = CreateDatabase(shard_no); | |||||
| if (db.first != SUCCESS || db.second == nullptr || write_success_ == false) { | |||||
| sqlite3 *db = nullptr; | |||||
| if (CreateDatabase(shard_no, &db).IsError()) { | |||||
| MS_LOG(ERROR) << "Failed to create Generate database."; | |||||
| write_success_ = false; | write_success_ = false; | ||||
| return; | return; | ||||
| } | } | ||||
| MS_LOG(INFO) << "Init index db for shard: " << shard_no << " successfully."; | MS_LOG(INFO) << "Init index db for shard: " << shard_no << " successfully."; | ||||
| // Pre-processing page information | // Pre-processing page information | ||||
| auto total_pages = shard_header_.GetLastPageId(shard_no) + 1; | auto total_pages = shard_header_.GetLastPageId(shard_no) + 1; | ||||
| std::map<int, int> blob_id_to_page_id; | std::map<int, int> blob_id_to_page_id; | ||||
| std::vector<int> raw_page_ids; | std::vector<int> raw_page_ids; | ||||
| for (uint64_t i = 0; i < total_pages; ++i) { | for (uint64_t i = 0; i < total_pages; ++i) { | ||||
| auto ret = shard_header_.GetPage(shard_no, i); | |||||
| if (ret.second != SUCCESS) { | |||||
| std::shared_ptr<Page> page_ptr; | |||||
| if (shard_header_.GetPage(shard_no, i, &page_ptr).IsError()) { | |||||
| MS_LOG(ERROR) << "Failed to get page."; | |||||
| write_success_ = false; | write_success_ = false; | ||||
| return; | return; | ||||
| } | } | ||||
| std::shared_ptr<Page> cur_page = ret.first; | |||||
| if (cur_page->GetPageType() == "RAW_DATA") { | |||||
| if (page_ptr->GetPageType() == "RAW_DATA") { | |||||
| raw_page_ids.push_back(i); | raw_page_ids.push_back(i); | ||||
| } else if (cur_page->GetPageType() == "BLOB_DATA") { | |||||
| blob_id_to_page_id[cur_page->GetPageTypeID()] = i; | |||||
| } else if (page_ptr->GetPageType() == "BLOB_DATA") { | |||||
| blob_id_to_page_id[page_ptr->GetPageTypeID()] = i; | |||||
| } | } | ||||
| } | } | ||||
| if (ExecuteTransaction(shard_no, db, raw_page_ids, blob_id_to_page_id) != SUCCESS) { | |||||
| if (ExecuteTransaction(shard_no, db, raw_page_ids, blob_id_to_page_id).IsError()) { | |||||
| MS_LOG(ERROR) << "Failed to execute transaction."; | |||||
| write_success_ = false; | write_success_ = false; | ||||
| return; | return; | ||||
| } | } | ||||
| @@ -694,21 +578,12 @@ void ShardIndexGenerator::DatabaseWriter() { | |||||
| shard_no = task_++; | shard_no = task_++; | ||||
| } | } | ||||
| } | } | ||||
| MSRStatus ShardIndexGenerator::Finalize(const std::vector<std::string> file_names) { | |||||
| if (file_names.empty()) { | |||||
| MS_LOG(ERROR) << "Mindrecord files is empty."; | |||||
| return FAILED; | |||||
| } | |||||
| Status ShardIndexGenerator::Finalize(const std::vector<std::string> file_names) { | |||||
| CHECK_FAIL_RETURN_UNEXPECTED(!file_names.empty(), "Mindrecord files is empty."); | |||||
| ShardIndexGenerator sg{file_names[0]}; | ShardIndexGenerator sg{file_names[0]}; | ||||
| if (SUCCESS != sg.Build()) { | |||||
| MS_LOG(ERROR) << "Failed to build index generator."; | |||||
| return FAILED; | |||||
| } | |||||
| if (SUCCESS != sg.WriteToDatabase()) { | |||||
| MS_LOG(ERROR) << "Failed to write to database."; | |||||
| return FAILED; | |||||
| } | |||||
| return SUCCESS; | |||||
| RETURN_IF_NOT_OK(sg.Build()); | |||||
| RETURN_IF_NOT_OK(sg.WriteToDatabase()); | |||||
| return Status::OK(); | |||||
| } | } | ||||
| } // namespace mindrecord | } // namespace mindrecord | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -30,9 +30,13 @@ namespace mindspore { | |||||
| namespace mindrecord { | namespace mindrecord { | ||||
| ShardSegment::ShardSegment() { SetAllInIndex(false); } | ShardSegment::ShardSegment() { SetAllInIndex(false); } | ||||
| std::pair<MSRStatus, vector<std::string>> ShardSegment::GetCategoryFields() { | |||||
| Status ShardSegment::GetCategoryFields(std::shared_ptr<vector<std::string>> *fields_ptr) { | |||||
| RETURN_UNEXPECTED_IF_NULL(fields_ptr); | |||||
| // Skip if already populated | // Skip if already populated | ||||
| if (!candidate_category_fields_.empty()) return {SUCCESS, candidate_category_fields_}; | |||||
| if (!candidate_category_fields_.empty()) { | |||||
| *fields_ptr = std::make_shared<vector<std::string>>(candidate_category_fields_); | |||||
| return Status::OK(); | |||||
| } | |||||
| std::string sql = "PRAGMA table_info(INDEXES);"; | std::string sql = "PRAGMA table_info(INDEXES);"; | ||||
| std::vector<std::vector<std::string>> field_names; | std::vector<std::vector<std::string>> field_names; | ||||
| @@ -40,11 +44,12 @@ std::pair<MSRStatus, vector<std::string>> ShardSegment::GetCategoryFields() { | |||||
| char *errmsg = nullptr; | char *errmsg = nullptr; | ||||
| int rc = sqlite3_exec(database_paths_[0], common::SafeCStr(sql), SelectCallback, &field_names, &errmsg); | int rc = sqlite3_exec(database_paths_[0], common::SafeCStr(sql), SelectCallback, &field_names, &errmsg); | ||||
| if (rc != SQLITE_OK) { | if (rc != SQLITE_OK) { | ||||
| MS_LOG(ERROR) << "Error in select statement, sql: " << sql << ", error: " << errmsg; | |||||
| std::ostringstream oss; | |||||
| oss << "Error in select statement, sql: " << sql + ", error: " << errmsg; | |||||
| sqlite3_free(errmsg); | sqlite3_free(errmsg); | ||||
| sqlite3_close(database_paths_[0]); | sqlite3_close(database_paths_[0]); | ||||
| database_paths_[0] = nullptr; | database_paths_[0] = nullptr; | ||||
| return {FAILED, vector<std::string>{}}; | |||||
| RETURN_STATUS_UNEXPECTED(oss.str()); | |||||
| } else { | } else { | ||||
| MS_LOG(INFO) << "Get " << static_cast<int>(field_names.size()) << " records from index."; | MS_LOG(INFO) << "Get " << static_cast<int>(field_names.size()) << " records from index."; | ||||
| } | } | ||||
| @@ -55,53 +60,46 @@ std::pair<MSRStatus, vector<std::string>> ShardSegment::GetCategoryFields() { | |||||
| sqlite3_free(errmsg); | sqlite3_free(errmsg); | ||||
| sqlite3_close(database_paths_[0]); | sqlite3_close(database_paths_[0]); | ||||
| database_paths_[0] = nullptr; | database_paths_[0] = nullptr; | ||||
| return {FAILED, vector<std::string>{}}; | |||||
| RETURN_STATUS_UNEXPECTED("idx is out of range."); | |||||
| } | } | ||||
| candidate_category_fields_.push_back(field_names[idx][1]); | candidate_category_fields_.push_back(field_names[idx][1]); | ||||
| idx += 2; | idx += 2; | ||||
| } | } | ||||
| sqlite3_free(errmsg); | sqlite3_free(errmsg); | ||||
| return {SUCCESS, candidate_category_fields_}; | |||||
| *fields_ptr = std::make_shared<vector<std::string>>(candidate_category_fields_); | |||||
| return Status::OK(); | |||||
| } | } | ||||
| MSRStatus ShardSegment::SetCategoryField(std::string category_field) { | |||||
| if (GetCategoryFields().first != SUCCESS) { | |||||
| MS_LOG(ERROR) << "Get candidate category field failed"; | |||||
| return FAILED; | |||||
| } | |||||
| Status ShardSegment::SetCategoryField(std::string category_field) { | |||||
| std::shared_ptr<vector<std::string>> fields_ptr; | |||||
| RETURN_IF_NOT_OK(GetCategoryFields(&fields_ptr)); | |||||
| category_field = category_field + "_0"; | category_field = category_field + "_0"; | ||||
| if (std::any_of(std::begin(candidate_category_fields_), std::end(candidate_category_fields_), | if (std::any_of(std::begin(candidate_category_fields_), std::end(candidate_category_fields_), | ||||
| [category_field](std::string x) { return x == category_field; })) { | [category_field](std::string x) { return x == category_field; })) { | ||||
| current_category_field_ = category_field; | current_category_field_ = category_field; | ||||
| return SUCCESS; | |||||
| return Status::OK(); | |||||
| } | } | ||||
| MS_LOG(ERROR) << "Field " << category_field << " is not a candidate category field."; | |||||
| return FAILED; | |||||
| RETURN_STATUS_UNEXPECTED("Field " + category_field + " is not a candidate category field."); | |||||
| } | } | ||||
| std::pair<MSRStatus, std::string> ShardSegment::ReadCategoryInfo() { | |||||
| Status ShardSegment::ReadCategoryInfo(std::shared_ptr<std::string> *category_ptr) { | |||||
| RETURN_UNEXPECTED_IF_NULL(category_ptr); | |||||
| MS_LOG(INFO) << "Read category begin"; | MS_LOG(INFO) << "Read category begin"; | ||||
| auto ret = WrapCategoryInfo(); | |||||
| if (ret.first != SUCCESS) { | |||||
| MS_LOG(ERROR) << "Get category info failed"; | |||||
| return {FAILED, ""}; | |||||
| } | |||||
| auto category_info_ptr = std::make_shared<CATEGORY_INFO>(); | |||||
| RETURN_IF_NOT_OK(WrapCategoryInfo(&category_info_ptr)); | |||||
| // Convert category info to json string | // Convert category info to json string | ||||
| auto category_json_string = ToJsonForCategory(ret.second); | |||||
| *category_ptr = std::make_shared<std::string>(ToJsonForCategory(*category_info_ptr)); | |||||
| MS_LOG(INFO) << "Read category end"; | MS_LOG(INFO) << "Read category end"; | ||||
| return {SUCCESS, category_json_string}; | |||||
| return Status::OK(); | |||||
| } | } | ||||
| std::pair<MSRStatus, std::vector<std::tuple<int, std::string, int>>> ShardSegment::WrapCategoryInfo() { | |||||
| Status ShardSegment::WrapCategoryInfo(std::shared_ptr<CATEGORY_INFO> *category_info_ptr) { | |||||
| RETURN_UNEXPECTED_IF_NULL(category_info_ptr); | |||||
| std::map<std::string, int> counter; | std::map<std::string, int> counter; | ||||
| if (!ValidateFieldName(current_category_field_)) { | |||||
| MS_LOG(ERROR) << "category field error from index, it is: " << current_category_field_; | |||||
| return {FAILED, std::vector<std::tuple<int, std::string, int>>()}; | |||||
| } | |||||
| CHECK_FAIL_RETURN_UNEXPECTED(ValidateFieldName(current_category_field_), | |||||
| "Category field error from index, it is: " + current_category_field_); | |||||
| std::string sql = "SELECT " + current_category_field_ + ", COUNT(" + current_category_field_ + | std::string sql = "SELECT " + current_category_field_ + ", COUNT(" + current_category_field_ + | ||||
| ") AS `value_occurrence` FROM indexes GROUP BY " + current_category_field_ + ";"; | ") AS `value_occurrence` FROM indexes GROUP BY " + current_category_field_ + ";"; | ||||
| @@ -109,13 +107,13 @@ std::pair<MSRStatus, std::vector<std::tuple<int, std::string, int>>> ShardSegmen | |||||
| std::vector<std::vector<std::string>> field_count; | std::vector<std::vector<std::string>> field_count; | ||||
| char *errmsg = nullptr; | char *errmsg = nullptr; | ||||
| int rc = sqlite3_exec(db, common::SafeCStr(sql), SelectCallback, &field_count, &errmsg); | |||||
| if (rc != SQLITE_OK) { | |||||
| MS_LOG(ERROR) << "Error in select statement, sql: " << sql << ", error: " << errmsg; | |||||
| if (sqlite3_exec(db, common::SafeCStr(sql), SelectCallback, &field_count, &errmsg) != SQLITE_OK) { | |||||
| std::ostringstream oss; | |||||
| oss << "Error in select statement, sql: " << sql + ", error: " << errmsg; | |||||
| sqlite3_free(errmsg); | sqlite3_free(errmsg); | ||||
| sqlite3_close(db); | sqlite3_close(db); | ||||
| db = nullptr; | db = nullptr; | ||||
| return {FAILED, std::vector<std::tuple<int, std::string, int>>()}; | |||||
| RETURN_STATUS_UNEXPECTED(oss.str()); | |||||
| } else { | } else { | ||||
| MS_LOG(INFO) << "Get " << static_cast<int>(field_count.size()) << " records from index."; | MS_LOG(INFO) << "Get " << static_cast<int>(field_count.size()) << " records from index."; | ||||
| } | } | ||||
| @@ -127,14 +125,14 @@ std::pair<MSRStatus, std::vector<std::tuple<int, std::string, int>>> ShardSegmen | |||||
| } | } | ||||
| int idx = 0; | int idx = 0; | ||||
| std::vector<std::tuple<int, std::string, int>> category_vec(counter.size()); | |||||
| (void)std::transform(counter.begin(), counter.end(), category_vec.begin(), [&idx](std::tuple<std::string, int> item) { | |||||
| return std::make_tuple(idx++, std::get<0>(item), std::get<1>(item)); | |||||
| }); | |||||
| return {SUCCESS, std::move(category_vec)}; | |||||
| (*category_info_ptr)->resize(counter.size()); | |||||
| (void)std::transform( | |||||
| counter.begin(), counter.end(), (*category_info_ptr)->begin(), | |||||
| [&idx](std::tuple<std::string, int> item) { return std::make_tuple(idx++, std::get<0>(item), std::get<1>(item)); }); | |||||
| return Status::OK(); | |||||
| } | } | ||||
| std::string ShardSegment::ToJsonForCategory(const std::vector<std::tuple<int, std::string, int>> &tri_vec) { | |||||
| std::string ShardSegment::ToJsonForCategory(const CATEGORY_INFO &tri_vec) { | |||||
| std::vector<json> category_json_vec; | std::vector<json> category_json_vec; | ||||
| for (auto q : tri_vec) { | for (auto q : tri_vec) { | ||||
| json j; | json j; | ||||
| @@ -152,27 +150,20 @@ std::string ShardSegment::ToJsonForCategory(const std::vector<std::tuple<int, st | |||||
| return category_info.dump(); | return category_info.dump(); | ||||
| } | } | ||||
| std::pair<MSRStatus, std::vector<std::vector<uint8_t>>> ShardSegment::ReadAtPageById(int64_t category_id, | |||||
| int64_t page_no, | |||||
| int64_t n_rows_of_page) { | |||||
| auto ret = WrapCategoryInfo(); | |||||
| if (ret.first != SUCCESS) { | |||||
| MS_LOG(ERROR) << "Get category info"; | |||||
| return {FAILED, std::vector<std::vector<uint8_t>>{}}; | |||||
| } | |||||
| if (category_id >= static_cast<int>(ret.second.size()) || category_id < 0) { | |||||
| MS_LOG(ERROR) << "Illegal category id, id: " << category_id; | |||||
| return {FAILED, std::vector<std::vector<uint8_t>>{}}; | |||||
| } | |||||
| int total_rows_in_category = std::get<2>(ret.second[category_id]); | |||||
| Status ShardSegment::ReadAtPageById(int64_t category_id, int64_t page_no, int64_t n_rows_of_page, | |||||
| std::shared_ptr<std::vector<std::vector<uint8_t>>> *page_ptr) { | |||||
| RETURN_UNEXPECTED_IF_NULL(page_ptr); | |||||
| auto category_info_ptr = std::make_shared<CATEGORY_INFO>(); | |||||
| RETURN_IF_NOT_OK(WrapCategoryInfo(&category_info_ptr)); | |||||
| CHECK_FAIL_RETURN_UNEXPECTED(category_id < static_cast<int>(category_info_ptr->size()) && category_id >= 0, | |||||
| "Invalid category id, id: " + std::to_string(category_id)); | |||||
| int total_rows_in_category = std::get<2>((*category_info_ptr)[category_id]); | |||||
| // Quit if category not found or page number is out of range | // Quit if category not found or page number is out of range | ||||
| if (total_rows_in_category <= 0 || page_no < 0 || n_rows_of_page <= 0 || | |||||
| page_no * n_rows_of_page >= total_rows_in_category) { | |||||
| MS_LOG(ERROR) << "Illegal page no / page size, page no: " << page_no << ", page size: " << n_rows_of_page; | |||||
| return {FAILED, std::vector<std::vector<uint8_t>>{}}; | |||||
| } | |||||
| CHECK_FAIL_RETURN_UNEXPECTED(total_rows_in_category > 0 && page_no >= 0 && n_rows_of_page > 0 && | |||||
| page_no * n_rows_of_page < total_rows_in_category, | |||||
| "Invalid page no / page size, page no: " + std::to_string(page_no) + | |||||
| ", page size: " + std::to_string(n_rows_of_page)); | |||||
| std::vector<std::vector<uint8_t>> page; | |||||
| auto row_group_summary = ReadRowGroupSummary(); | auto row_group_summary = ReadRowGroupSummary(); | ||||
| uint64_t i_start = page_no * n_rows_of_page; | uint64_t i_start = page_no * n_rows_of_page; | ||||
| @@ -183,12 +174,12 @@ std::pair<MSRStatus, std::vector<std::vector<uint8_t>>> ShardSegment::ReadAtPage | |||||
| auto shard_id = std::get<0>(rg); | auto shard_id = std::get<0>(rg); | ||||
| auto group_id = std::get<1>(rg); | auto group_id = std::get<1>(rg); | ||||
| auto details = ReadRowGroupCriteria( | |||||
| group_id, shard_id, std::make_pair(CleanUp(current_category_field_), std::get<1>(ret.second[category_id]))); | |||||
| if (SUCCESS != std::get<0>(details)) { | |||||
| return {FAILED, std::vector<std::vector<uint8_t>>{}}; | |||||
| } | |||||
| auto offsets = std::get<4>(details); | |||||
| std::shared_ptr<ROW_GROUP_BRIEF> row_group_brief_ptr; | |||||
| RETURN_IF_NOT_OK(ReadRowGroupCriteria( | |||||
| group_id, shard_id, | |||||
| std::make_pair(CleanUp(current_category_field_), std::get<1>((*category_info_ptr)[category_id])), {""}, | |||||
| &row_group_brief_ptr)); | |||||
| auto offsets = std::get<3>(*row_group_brief_ptr); | |||||
| uint64_t number_of_rows = offsets.size(); | uint64_t number_of_rows = offsets.size(); | ||||
| if (idx + number_of_rows < i_start) { | if (idx + number_of_rows < i_start) { | ||||
| idx += number_of_rows; | idx += number_of_rows; | ||||
| @@ -197,131 +188,116 @@ std::pair<MSRStatus, std::vector<std::vector<uint8_t>>> ShardSegment::ReadAtPage | |||||
| for (uint64_t i = 0; i < number_of_rows; ++i, ++idx) { | for (uint64_t i = 0; i < number_of_rows; ++i, ++idx) { | ||||
| if (idx >= i_start && idx < i_end) { | if (idx >= i_start && idx < i_end) { | ||||
| auto ret1 = PackImages(group_id, shard_id, offsets[i]); | |||||
| if (SUCCESS != ret1.first) { | |||||
| return {FAILED, std::vector<std::vector<uint8_t>>{}}; | |||||
| } | |||||
| page.push_back(std::move(ret1.second)); | |||||
| auto images_ptr = std::make_shared<std::vector<uint8_t>>(); | |||||
| RETURN_IF_NOT_OK(PackImages(group_id, shard_id, offsets[i], &images_ptr)); | |||||
| (*page_ptr)->push_back(std::move(*images_ptr)); | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| return {SUCCESS, std::move(page)}; | |||||
| return Status::OK(); | |||||
| } | } | ||||
| std::pair<MSRStatus, std::vector<uint8_t>> ShardSegment::PackImages(int group_id, int shard_id, | |||||
| std::vector<uint64_t> offset) { | |||||
| const auto &ret = shard_header_->GetPageByGroupId(group_id, shard_id); | |||||
| if (SUCCESS != ret.first) { | |||||
| return {FAILED, std::vector<uint8_t>()}; | |||||
| } | |||||
| const std::shared_ptr<Page> &blob_page = ret.second; | |||||
| Status ShardSegment::PackImages(int group_id, int shard_id, std::vector<uint64_t> offset, | |||||
| std::shared_ptr<std::vector<uint8_t>> *images_ptr) { | |||||
| RETURN_UNEXPECTED_IF_NULL(images_ptr); | |||||
| std::shared_ptr<Page> page_ptr; | |||||
| RETURN_IF_NOT_OK(shard_header_->GetPageByGroupId(group_id, shard_id, &page_ptr)); | |||||
| // Pack image list | // Pack image list | ||||
| std::vector<uint8_t> images(offset[1] - offset[0]); | |||||
| auto file_offset = header_size_ + page_size_ * (blob_page->GetPageID()) + offset[0]; | |||||
| (*images_ptr)->resize(offset[1] - offset[0]); | |||||
| auto file_offset = header_size_ + page_size_ * page_ptr->GetPageID() + offset[0]; | |||||
| auto &io_seekg = file_streams_random_[0][shard_id]->seekg(file_offset, std::ios::beg); | auto &io_seekg = file_streams_random_[0][shard_id]->seekg(file_offset, std::ios::beg); | ||||
| if (!io_seekg.good() || io_seekg.fail() || io_seekg.bad()) { | if (!io_seekg.good() || io_seekg.fail() || io_seekg.bad()) { | ||||
| MS_LOG(ERROR) << "File seekg failed"; | |||||
| file_streams_random_[0][shard_id]->close(); | file_streams_random_[0][shard_id]->close(); | ||||
| return {FAILED, {}}; | |||||
| RETURN_STATUS_UNEXPECTED("Failed to seekg file."); | |||||
| } | } | ||||
| auto &io_read = file_streams_random_[0][shard_id]->read(reinterpret_cast<char *>(&images[0]), offset[1] - offset[0]); | |||||
| auto &io_read = | |||||
| file_streams_random_[0][shard_id]->read(reinterpret_cast<char *>(&((*(*images_ptr))[0])), offset[1] - offset[0]); | |||||
| if (!io_read.good() || io_read.fail() || io_read.bad()) { | if (!io_read.good() || io_read.fail() || io_read.bad()) { | ||||
| MS_LOG(ERROR) << "File read failed"; | |||||
| file_streams_random_[0][shard_id]->close(); | file_streams_random_[0][shard_id]->close(); | ||||
| return {FAILED, {}}; | |||||
| RETURN_STATUS_UNEXPECTED("Failed to read file."); | |||||
| } | } | ||||
| return {SUCCESS, std::move(images)}; | |||||
| return Status::OK(); | |||||
| } | } | ||||
| std::pair<MSRStatus, std::vector<std::vector<uint8_t>>> ShardSegment::ReadAtPageByName(std::string category_name, | |||||
| int64_t page_no, | |||||
| int64_t n_rows_of_page) { | |||||
| auto ret = WrapCategoryInfo(); | |||||
| if (ret.first != SUCCESS) { | |||||
| MS_LOG(ERROR) << "Get category info"; | |||||
| return {FAILED, std::vector<std::vector<uint8_t>>{}}; | |||||
| } | |||||
| for (const auto &categories : ret.second) { | |||||
| Status ShardSegment::ReadAtPageByName(std::string category_name, int64_t page_no, int64_t n_rows_of_page, | |||||
| std::shared_ptr<std::vector<std::vector<uint8_t>>> *pages_ptr) { | |||||
| RETURN_UNEXPECTED_IF_NULL(pages_ptr); | |||||
| auto category_info_ptr = std::make_shared<CATEGORY_INFO>(); | |||||
| RETURN_IF_NOT_OK(WrapCategoryInfo(&category_info_ptr)); | |||||
| for (const auto &categories : *category_info_ptr) { | |||||
| if (std::get<1>(categories) == category_name) { | if (std::get<1>(categories) == category_name) { | ||||
| auto result = ReadAtPageById(std::get<0>(categories), page_no, n_rows_of_page); | |||||
| return result; | |||||
| RETURN_IF_NOT_OK(ReadAtPageById(std::get<0>(categories), page_no, n_rows_of_page, pages_ptr)); | |||||
| return Status::OK(); | |||||
| } | } | ||||
| } | } | ||||
| return {FAILED, std::vector<std::vector<uint8_t>>()}; | |||||
| RETURN_STATUS_UNEXPECTED("Category name can not match."); | |||||
| } | } | ||||
| std::pair<MSRStatus, std::vector<std::tuple<std::vector<uint8_t>, json>>> ShardSegment::ReadAllAtPageById( | |||||
| int64_t category_id, int64_t page_no, int64_t n_rows_of_page) { | |||||
| auto ret = WrapCategoryInfo(); | |||||
| if (ret.first != SUCCESS || category_id >= static_cast<int>(ret.second.size())) { | |||||
| MS_LOG(ERROR) << "Illegal category id, id: " << category_id; | |||||
| return {FAILED, std::vector<std::tuple<std::vector<uint8_t>, json>>{}}; | |||||
| } | |||||
| int total_rows_in_category = std::get<2>(ret.second[category_id]); | |||||
| // Quit if category not found or page number is out of range | |||||
| if (total_rows_in_category <= 0 || page_no < 0 || page_no * n_rows_of_page >= total_rows_in_category) { | |||||
| MS_LOG(ERROR) << "Illegal page no: " << page_no << ", page size: " << n_rows_of_page; | |||||
| return {FAILED, std::vector<std::tuple<std::vector<uint8_t>, json>>{}}; | |||||
| } | |||||
| Status ShardSegment::ReadAllAtPageById(int64_t category_id, int64_t page_no, int64_t n_rows_of_page, | |||||
| std::shared_ptr<PAGES> *pages_ptr) { | |||||
| RETURN_UNEXPECTED_IF_NULL(pages_ptr); | |||||
| auto category_info_ptr = std::make_shared<CATEGORY_INFO>(); | |||||
| RETURN_IF_NOT_OK(WrapCategoryInfo(&category_info_ptr)); | |||||
| CHECK_FAIL_RETURN_UNEXPECTED(category_id < static_cast<int64_t>(category_info_ptr->size()), | |||||
| "Invalid category id: " + std::to_string(category_id)); | |||||
| std::vector<std::tuple<std::vector<uint8_t>, json>> page; | |||||
| int total_rows_in_category = std::get<2>((*category_info_ptr)[category_id]); | |||||
| // Quit if category not found or page number is out of range | |||||
| CHECK_FAIL_RETURN_UNEXPECTED(total_rows_in_category > 0 && page_no >= 0 && n_rows_of_page > 0 && | |||||
| page_no * n_rows_of_page < total_rows_in_category, | |||||
| "Invalid page no / page size / total size, page no: " + std::to_string(page_no) + | |||||
| ", page size of page: " + std::to_string(n_rows_of_page) + | |||||
| ", total size: " + std::to_string(total_rows_in_category)); | |||||
| auto row_group_summary = ReadRowGroupSummary(); | auto row_group_summary = ReadRowGroupSummary(); | ||||
| int i_start = page_no * n_rows_of_page; | int i_start = page_no * n_rows_of_page; | ||||
| int i_end = std::min(static_cast<int64_t>(total_rows_in_category), (page_no + 1) * n_rows_of_page); | int i_end = std::min(static_cast<int64_t>(total_rows_in_category), (page_no + 1) * n_rows_of_page); | ||||
| int idx = 0; | int idx = 0; | ||||
| for (const auto &rg : row_group_summary) { | for (const auto &rg : row_group_summary) { | ||||
| if (idx >= i_end) break; | |||||
| if (idx >= i_end) { | |||||
| break; | |||||
| } | |||||
| auto shard_id = std::get<0>(rg); | auto shard_id = std::get<0>(rg); | ||||
| auto group_id = std::get<1>(rg); | auto group_id = std::get<1>(rg); | ||||
| auto details = ReadRowGroupCriteria( | |||||
| group_id, shard_id, std::make_pair(CleanUp(current_category_field_), std::get<1>(ret.second[category_id]))); | |||||
| if (SUCCESS != std::get<0>(details)) { | |||||
| return {FAILED, std::vector<std::tuple<std::vector<uint8_t>, json>>{}}; | |||||
| } | |||||
| auto offsets = std::get<4>(details); | |||||
| auto labels = std::get<5>(details); | |||||
| std::shared_ptr<ROW_GROUP_BRIEF> row_group_brief_ptr; | |||||
| RETURN_IF_NOT_OK(ReadRowGroupCriteria( | |||||
| group_id, shard_id, | |||||
| std::make_pair(CleanUp(current_category_field_), std::get<1>((*category_info_ptr)[category_id])), {""}, | |||||
| &row_group_brief_ptr)); | |||||
| auto offsets = std::get<3>(*row_group_brief_ptr); | |||||
| auto labels = std::get<4>(*row_group_brief_ptr); | |||||
| int number_of_rows = offsets.size(); | int number_of_rows = offsets.size(); | ||||
| if (idx + number_of_rows < i_start) { | if (idx + number_of_rows < i_start) { | ||||
| idx += number_of_rows; | idx += number_of_rows; | ||||
| continue; | continue; | ||||
| } | } | ||||
| if (number_of_rows > static_cast<int>(labels.size())) { | |||||
| MS_LOG(ERROR) << "Illegal row number of page: " << number_of_rows; | |||||
| return {FAILED, std::vector<std::tuple<std::vector<uint8_t>, json>>{}}; | |||||
| } | |||||
| CHECK_FAIL_RETURN_UNEXPECTED(number_of_rows <= static_cast<int>(labels.size()), | |||||
| "Invalid row number of page: " + number_of_rows); | |||||
| for (int i = 0; i < number_of_rows; ++i, ++idx) { | for (int i = 0; i < number_of_rows; ++i, ++idx) { | ||||
| if (idx >= i_start && idx < i_end) { | if (idx >= i_start && idx < i_end) { | ||||
| auto ret1 = PackImages(group_id, shard_id, offsets[i]); | |||||
| if (SUCCESS != ret1.first) { | |||||
| return {FAILED, std::vector<std::tuple<std::vector<uint8_t>, json>>{}}; | |||||
| } | |||||
| page.emplace_back(std::move(ret1.second), std::move(labels[i])); | |||||
| auto images_ptr = std::make_shared<std::vector<uint8_t>>(); | |||||
| RETURN_IF_NOT_OK(PackImages(group_id, shard_id, offsets[i], &images_ptr)); | |||||
| (*pages_ptr)->emplace_back(std::move(*images_ptr), std::move(labels[i])); | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| return {SUCCESS, std::move(page)}; | |||||
| return Status::OK(); | |||||
| } | } | ||||
| std::pair<MSRStatus, std::vector<std::tuple<std::vector<uint8_t>, json>>> ShardSegment::ReadAllAtPageByName( | |||||
| std::string category_name, int64_t page_no, int64_t n_rows_of_page) { | |||||
| auto ret = WrapCategoryInfo(); | |||||
| if (ret.first != SUCCESS) { | |||||
| MS_LOG(ERROR) << "Get category info"; | |||||
| return {FAILED, std::vector<std::tuple<std::vector<uint8_t>, json>>{}}; | |||||
| } | |||||
| Status ShardSegment::ReadAllAtPageByName(std::string category_name, int64_t page_no, int64_t n_rows_of_page, | |||||
| std::shared_ptr<PAGES> *pages_ptr) { | |||||
| RETURN_UNEXPECTED_IF_NULL(pages_ptr); | |||||
| auto category_info_ptr = std::make_shared<CATEGORY_INFO>(); | |||||
| RETURN_IF_NOT_OK(WrapCategoryInfo(&category_info_ptr)); | |||||
| // category_name to category_id | // category_name to category_id | ||||
| int64_t category_id = -1; | int64_t category_id = -1; | ||||
| for (const auto &categories : ret.second) { | |||||
| for (const auto &categories : *category_info_ptr) { | |||||
| std::string categories_name = std::get<1>(categories); | std::string categories_name = std::get<1>(categories); | ||||
| if (categories_name == category_name) { | if (categories_name == category_name) { | ||||
| @@ -329,45 +305,8 @@ std::pair<MSRStatus, std::vector<std::tuple<std::vector<uint8_t>, json>>> ShardS | |||||
| break; | break; | ||||
| } | } | ||||
| } | } | ||||
| if (category_id == -1) { | |||||
| return {FAILED, std::vector<std::tuple<std::vector<uint8_t>, json>>{}}; | |||||
| } | |||||
| return ReadAllAtPageById(category_id, page_no, n_rows_of_page); | |||||
| } | |||||
| std::pair<MSRStatus, std::vector<std::tuple<std::vector<uint8_t>, pybind11::object>>> ShardSegment::ReadAtPageByIdPy( | |||||
| int64_t category_id, int64_t page_no, int64_t n_rows_of_page) { | |||||
| auto res = ReadAllAtPageById(category_id, page_no, n_rows_of_page); | |||||
| if (res.first != SUCCESS) { | |||||
| return {FAILED, std::vector<std::tuple<std::vector<uint8_t>, pybind11::object>>{}}; | |||||
| } | |||||
| vector<std::tuple<std::vector<uint8_t>, pybind11::object>> json_data; | |||||
| std::transform(res.second.begin(), res.second.end(), std::back_inserter(json_data), | |||||
| [](const std::tuple<std::vector<uint8_t>, json> &item) { | |||||
| auto &j = std::get<1>(item); | |||||
| pybind11::object obj = nlohmann::detail::FromJsonImpl(j); | |||||
| return std::make_tuple(std::get<0>(item), std::move(obj)); | |||||
| }); | |||||
| return {SUCCESS, std::move(json_data)}; | |||||
| } | |||||
| std::pair<MSRStatus, std::vector<std::tuple<std::vector<uint8_t>, pybind11::object>>> ShardSegment::ReadAtPageByNamePy( | |||||
| std::string category_name, int64_t page_no, int64_t n_rows_of_page) { | |||||
| auto res = ReadAllAtPageByName(category_name, page_no, n_rows_of_page); | |||||
| if (res.first != SUCCESS) { | |||||
| return {FAILED, std::vector<std::tuple<std::vector<uint8_t>, pybind11::object>>{}}; | |||||
| } | |||||
| vector<std::tuple<std::vector<uint8_t>, pybind11::object>> json_data; | |||||
| std::transform(res.second.begin(), res.second.end(), std::back_inserter(json_data), | |||||
| [](const std::tuple<std::vector<uint8_t>, json> &item) { | |||||
| auto &j = std::get<1>(item); | |||||
| pybind11::object obj = nlohmann::detail::FromJsonImpl(j); | |||||
| return std::make_tuple(std::get<0>(item), std::move(obj)); | |||||
| }); | |||||
| return {SUCCESS, std::move(json_data)}; | |||||
| CHECK_FAIL_RETURN_UNEXPECTED(category_id != -1, "Invalid category name."); | |||||
| return ReadAllAtPageById(category_id, page_no, n_rows_of_page, pages_ptr); | |||||
| } | } | ||||
| std::pair<ShardType, std::vector<std::string>> ShardSegment::GetBlobFields() { | std::pair<ShardType, std::vector<std::string>> ShardSegment::GetBlobFields() { | ||||
| @@ -382,7 +321,9 @@ std::pair<ShardType, std::vector<std::string>> ShardSegment::GetBlobFields() { | |||||
| } | } | ||||
| std::string ShardSegment::CleanUp(std::string field_name) { | std::string ShardSegment::CleanUp(std::string field_name) { | ||||
| while (field_name.back() >= '0' && field_name.back() <= '9') field_name.pop_back(); | |||||
| while (field_name.back() >= '0' && field_name.back() <= '9') { | |||||
| field_name.pop_back(); | |||||
| } | |||||
| field_name.pop_back(); | field_name.pop_back(); | ||||
| return field_name; | return field_name; | ||||
| } | } | ||||
| @@ -34,7 +34,7 @@ ShardCategory::ShardCategory(const std::string &category_field, int64_t num_elem | |||||
| num_categories_(num_categories), | num_categories_(num_categories), | ||||
| replacement_(replacement) {} | replacement_(replacement) {} | ||||
| MSRStatus ShardCategory::Execute(ShardTaskList &tasks) { return SUCCESS; } | |||||
| Status ShardCategory::Execute(ShardTaskList &tasks) { return Status::OK(); } | |||||
| int64_t ShardCategory::GetNumSamples(int64_t dataset_size, int64_t num_classes) { | int64_t ShardCategory::GetNumSamples(int64_t dataset_size, int64_t num_classes) { | ||||
| if (dataset_size == 0) return dataset_size; | if (dataset_size == 0) return dataset_size; | ||||
| @@ -72,36 +72,36 @@ void ShardColumn::Init(const json &schema_json, bool compress_integer) { | |||||
| num_blob_column_ = blob_column_.size(); | num_blob_column_ = blob_column_.size(); | ||||
| } | } | ||||
| std::pair<MSRStatus, ColumnCategory> ShardColumn::GetColumnTypeByName(const std::string &column_name, | |||||
| ColumnDataType *column_data_type, | |||||
| uint64_t *column_data_type_size, | |||||
| std::vector<int64_t> *column_shape) { | |||||
| Status ShardColumn::GetColumnTypeByName(const std::string &column_name, ColumnDataType *column_data_type, | |||||
| uint64_t *column_data_type_size, std::vector<int64_t> *column_shape, | |||||
| ColumnCategory *column_category) { | |||||
| RETURN_UNEXPECTED_IF_NULL(column_data_type); | |||||
| RETURN_UNEXPECTED_IF_NULL(column_data_type_size); | |||||
| RETURN_UNEXPECTED_IF_NULL(column_shape); | |||||
| RETURN_UNEXPECTED_IF_NULL(column_category); | |||||
| // Skip if column not found | // Skip if column not found | ||||
| auto column_category = CheckColumnName(column_name); | |||||
| if (column_category == ColumnNotFound) { | |||||
| return {FAILED, ColumnNotFound}; | |||||
| } | |||||
| *column_category = CheckColumnName(column_name); | |||||
| CHECK_FAIL_RETURN_UNEXPECTED(*column_category != ColumnNotFound, "Invalid column category."); | |||||
| // Get data type and size | // Get data type and size | ||||
| auto column_id = column_name_id_[column_name]; | auto column_id = column_name_id_[column_name]; | ||||
| *column_data_type = column_data_type_[column_id]; | *column_data_type = column_data_type_[column_id]; | ||||
| *column_data_type_size = ColumnDataTypeSize[*column_data_type]; | *column_data_type_size = ColumnDataTypeSize[*column_data_type]; | ||||
| *column_shape = column_shape_[column_id]; | *column_shape = column_shape_[column_id]; | ||||
| return {SUCCESS, column_category}; | |||||
| return Status::OK(); | |||||
| } | } | ||||
| MSRStatus ShardColumn::GetColumnValueByName(const std::string &column_name, const std::vector<uint8_t> &columns_blob, | |||||
| const json &columns_json, const unsigned char **data, | |||||
| std::unique_ptr<unsigned char[]> *data_ptr, uint64_t *const n_bytes, | |||||
| ColumnDataType *column_data_type, uint64_t *column_data_type_size, | |||||
| std::vector<int64_t> *column_shape) { | |||||
| Status ShardColumn::GetColumnValueByName(const std::string &column_name, const std::vector<uint8_t> &columns_blob, | |||||
| const json &columns_json, const unsigned char **data, | |||||
| std::unique_ptr<unsigned char[]> *data_ptr, uint64_t *const n_bytes, | |||||
| ColumnDataType *column_data_type, uint64_t *column_data_type_size, | |||||
| std::vector<int64_t> *column_shape) { | |||||
| RETURN_UNEXPECTED_IF_NULL(column_data_type); | |||||
| RETURN_UNEXPECTED_IF_NULL(column_data_type_size); | |||||
| RETURN_UNEXPECTED_IF_NULL(column_shape); | |||||
| // Skip if column not found | // Skip if column not found | ||||
| auto column_category = CheckColumnName(column_name); | auto column_category = CheckColumnName(column_name); | ||||
| if (column_category == ColumnNotFound) { | |||||
| return FAILED; | |||||
| } | |||||
| CHECK_FAIL_RETURN_UNEXPECTED(column_category != ColumnNotFound, "Invalid column category."); | |||||
| // Get data type and size | // Get data type and size | ||||
| auto column_id = column_name_id_[column_name]; | auto column_id = column_name_id_[column_name]; | ||||
| *column_data_type = column_data_type_[column_id]; | *column_data_type = column_data_type_[column_id]; | ||||
| @@ -110,37 +110,31 @@ MSRStatus ShardColumn::GetColumnValueByName(const std::string &column_name, cons | |||||
| // Retrieve value from json | // Retrieve value from json | ||||
| if (column_category == ColumnInRaw) { | if (column_category == ColumnInRaw) { | ||||
| if (GetColumnFromJson(column_name, columns_json, data_ptr, n_bytes) == FAILED) { | |||||
| MS_LOG(ERROR) << "Error when get data from json, column name is " << column_name << "."; | |||||
| return FAILED; | |||||
| } | |||||
| RETURN_IF_NOT_OK(GetColumnFromJson(column_name, columns_json, data_ptr, n_bytes)); | |||||
| *data = reinterpret_cast<const unsigned char *>(data_ptr->get()); | *data = reinterpret_cast<const unsigned char *>(data_ptr->get()); | ||||
| return SUCCESS; | |||||
| return Status::OK(); | |||||
| } | } | ||||
| // Retrieve value from blob | // Retrieve value from blob | ||||
| if (GetColumnFromBlob(column_name, columns_blob, data, data_ptr, n_bytes) == FAILED) { | |||||
| MS_LOG(ERROR) << "Error when get data from blob, column name is " << column_name << "."; | |||||
| return FAILED; | |||||
| } | |||||
| RETURN_IF_NOT_OK(GetColumnFromBlob(column_name, columns_blob, data, data_ptr, n_bytes)); | |||||
| if (*data == nullptr) { | if (*data == nullptr) { | ||||
| *data = reinterpret_cast<const unsigned char *>(data_ptr->get()); | *data = reinterpret_cast<const unsigned char *>(data_ptr->get()); | ||||
| } | } | ||||
| return SUCCESS; | |||||
| return Status::OK(); | |||||
| } | } | ||||
| MSRStatus ShardColumn::GetColumnFromJson(const std::string &column_name, const json &columns_json, | |||||
| std::unique_ptr<unsigned char[]> *data_ptr, uint64_t *n_bytes) { | |||||
| Status ShardColumn::GetColumnFromJson(const std::string &column_name, const json &columns_json, | |||||
| std::unique_ptr<unsigned char[]> *data_ptr, uint64_t *n_bytes) { | |||||
| RETURN_UNEXPECTED_IF_NULL(n_bytes); | |||||
| RETURN_UNEXPECTED_IF_NULL(data_ptr); | |||||
| auto column_id = column_name_id_[column_name]; | auto column_id = column_name_id_[column_name]; | ||||
| auto column_data_type = column_data_type_[column_id]; | auto column_data_type = column_data_type_[column_id]; | ||||
| // Initialize num bytes | // Initialize num bytes | ||||
| *n_bytes = ColumnDataTypeSize[column_data_type]; | *n_bytes = ColumnDataTypeSize[column_data_type]; | ||||
| auto json_column_value = columns_json[column_name]; | auto json_column_value = columns_json[column_name]; | ||||
| if (!json_column_value.is_string() && !json_column_value.is_number()) { | |||||
| MS_LOG(ERROR) << "Conversion failed (" << json_column_value << ")."; | |||||
| return FAILED; | |||||
| } | |||||
| CHECK_FAIL_RETURN_UNEXPECTED(json_column_value.is_string() || json_column_value.is_number(), | |||||
| "Conversion to string or number failed (" + json_column_value.dump() + ")."); | |||||
| switch (column_data_type) { | switch (column_data_type) { | ||||
| case ColumnFloat32: { | case ColumnFloat32: { | ||||
| return GetFloat<float>(data_ptr, json_column_value, false); | return GetFloat<float>(data_ptr, json_column_value, false); | ||||
| @@ -171,12 +165,13 @@ MSRStatus ShardColumn::GetColumnFromJson(const std::string &column_name, const j | |||||
| break; | break; | ||||
| } | } | ||||
| } | } | ||||
| return SUCCESS; | |||||
| return Status::OK(); | |||||
| } | } | ||||
| template <typename T> | template <typename T> | ||||
| MSRStatus ShardColumn::GetFloat(std::unique_ptr<unsigned char[]> *data_ptr, const json &json_column_value, | |||||
| bool use_double) { | |||||
| Status ShardColumn::GetFloat(std::unique_ptr<unsigned char[]> *data_ptr, const json &json_column_value, | |||||
| bool use_double) { | |||||
| RETURN_UNEXPECTED_IF_NULL(data_ptr); | |||||
| std::unique_ptr<T[]> array_data = std::make_unique<T[]>(1); | std::unique_ptr<T[]> array_data = std::make_unique<T[]>(1); | ||||
| if (json_column_value.is_number()) { | if (json_column_value.is_number()) { | ||||
| array_data[0] = json_column_value; | array_data[0] = json_column_value; | ||||
| @@ -189,8 +184,7 @@ MSRStatus ShardColumn::GetFloat(std::unique_ptr<unsigned char[]> *data_ptr, cons | |||||
| array_data[0] = json_column_value.get<float>(); | array_data[0] = json_column_value.get<float>(); | ||||
| } | } | ||||
| } catch (json::exception &e) { | } catch (json::exception &e) { | ||||
| MS_LOG(ERROR) << "Conversion to float failed (" << json_column_value << ")."; | |||||
| return FAILED; | |||||
| RETURN_STATUS_UNEXPECTED("Conversion to float failed (" + json_column_value.dump() + ")."); | |||||
| } | } | ||||
| } | } | ||||
| @@ -199,54 +193,43 @@ MSRStatus ShardColumn::GetFloat(std::unique_ptr<unsigned char[]> *data_ptr, cons | |||||
| for (uint32_t i = 0; i < sizeof(T); i++) { | for (uint32_t i = 0; i < sizeof(T); i++) { | ||||
| (*data_ptr)[i] = *(data + i); | (*data_ptr)[i] = *(data + i); | ||||
| } | } | ||||
| return SUCCESS; | |||||
| return Status::OK(); | |||||
| } | } | ||||
| template <typename T> | template <typename T> | ||||
| MSRStatus ShardColumn::GetInt(std::unique_ptr<unsigned char[]> *data_ptr, const json &json_column_value) { | |||||
| Status ShardColumn::GetInt(std::unique_ptr<unsigned char[]> *data_ptr, const json &json_column_value) { | |||||
| RETURN_UNEXPECTED_IF_NULL(data_ptr); | |||||
| std::unique_ptr<T[]> array_data = std::make_unique<T[]>(1); | std::unique_ptr<T[]> array_data = std::make_unique<T[]>(1); | ||||
| int64_t temp_value; | int64_t temp_value; | ||||
| bool less_than_zero = false; | bool less_than_zero = false; | ||||
| if (json_column_value.is_number_integer()) { | if (json_column_value.is_number_integer()) { | ||||
| const json json_zero = 0; | const json json_zero = 0; | ||||
| if (json_column_value < json_zero) less_than_zero = true; | |||||
| if (json_column_value < json_zero) { | |||||
| less_than_zero = true; | |||||
| } | |||||
| temp_value = json_column_value; | temp_value = json_column_value; | ||||
| } else if (json_column_value.is_string()) { | } else if (json_column_value.is_string()) { | ||||
| std::string string_value = json_column_value; | std::string string_value = json_column_value; | ||||
| if (!string_value.empty() && string_value[0] == '-') { | |||||
| try { | |||||
| try { | |||||
| if (!string_value.empty() && string_value[0] == '-') { | |||||
| temp_value = std::stoll(string_value); | temp_value = std::stoll(string_value); | ||||
| less_than_zero = true; | less_than_zero = true; | ||||
| } catch (std::invalid_argument &e) { | |||||
| MS_LOG(ERROR) << "Conversion to int failed, invalid argument."; | |||||
| return FAILED; | |||||
| } catch (std::out_of_range &e) { | |||||
| MS_LOG(ERROR) << "Conversion to int failed, out of range."; | |||||
| return FAILED; | |||||
| } | |||||
| } else { | |||||
| try { | |||||
| } else { | |||||
| temp_value = static_cast<int64_t>(std::stoull(string_value)); | temp_value = static_cast<int64_t>(std::stoull(string_value)); | ||||
| } catch (std::invalid_argument &e) { | |||||
| MS_LOG(ERROR) << "Conversion to int failed, invalid argument."; | |||||
| return FAILED; | |||||
| } catch (std::out_of_range &e) { | |||||
| MS_LOG(ERROR) << "Conversion to int failed, out of range."; | |||||
| return FAILED; | |||||
| } | } | ||||
| } catch (std::invalid_argument &e) { | |||||
| RETURN_STATUS_UNEXPECTED("Conversion to int failed: " + std::string(e.what())); | |||||
| } catch (std::out_of_range &e) { | |||||
| RETURN_STATUS_UNEXPECTED("Conversion to int failed: " + std::string(e.what())); | |||||
| } | } | ||||
| } else { | } else { | ||||
| MS_LOG(ERROR) << "Conversion to int failed."; | |||||
| return FAILED; | |||||
| RETURN_STATUS_UNEXPECTED("Conversion to int failed."); | |||||
| } | } | ||||
| if ((less_than_zero && temp_value < static_cast<int64_t>(std::numeric_limits<T>::min())) || | if ((less_than_zero && temp_value < static_cast<int64_t>(std::numeric_limits<T>::min())) || | ||||
| (!less_than_zero && static_cast<uint64_t>(temp_value) > static_cast<uint64_t>(std::numeric_limits<T>::max()))) { | (!less_than_zero && static_cast<uint64_t>(temp_value) > static_cast<uint64_t>(std::numeric_limits<T>::max()))) { | ||||
| MS_LOG(ERROR) << "Conversion to int failed. Out of range"; | |||||
| return FAILED; | |||||
| RETURN_STATUS_UNEXPECTED("Conversion to int failed, out of range."); | |||||
| } | } | ||||
| array_data[0] = static_cast<T>(temp_value); | array_data[0] = static_cast<T>(temp_value); | ||||
| @@ -255,33 +238,26 @@ MSRStatus ShardColumn::GetInt(std::unique_ptr<unsigned char[]> *data_ptr, const | |||||
| for (uint32_t i = 0; i < sizeof(T); i++) { | for (uint32_t i = 0; i < sizeof(T); i++) { | ||||
| (*data_ptr)[i] = *(data + i); | (*data_ptr)[i] = *(data + i); | ||||
| } | } | ||||
| return SUCCESS; | |||||
| return Status::OK(); | |||||
| } | } | ||||
| MSRStatus ShardColumn::GetColumnFromBlob(const std::string &column_name, const std::vector<uint8_t> &columns_blob, | |||||
| const unsigned char **data, std::unique_ptr<unsigned char[]> *data_ptr, | |||||
| uint64_t *const n_bytes) { | |||||
| Status ShardColumn::GetColumnFromBlob(const std::string &column_name, const std::vector<uint8_t> &columns_blob, | |||||
| const unsigned char **data, std::unique_ptr<unsigned char[]> *data_ptr, | |||||
| uint64_t *const n_bytes) { | |||||
| RETURN_UNEXPECTED_IF_NULL(data); | |||||
| uint64_t offset_address = 0; | uint64_t offset_address = 0; | ||||
| auto column_id = column_name_id_[column_name]; | auto column_id = column_name_id_[column_name]; | ||||
| if (GetColumnAddressInBlock(column_id, columns_blob, n_bytes, &offset_address) == FAILED) { | |||||
| return FAILED; | |||||
| } | |||||
| RETURN_IF_NOT_OK(GetColumnAddressInBlock(column_id, columns_blob, n_bytes, &offset_address)); | |||||
| auto column_data_type = column_data_type_[column_id]; | auto column_data_type = column_data_type_[column_id]; | ||||
| if (has_compress_blob_ && column_data_type == ColumnInt32) { | if (has_compress_blob_ && column_data_type == ColumnInt32) { | ||||
| if (UncompressInt<int32_t>(column_id, data_ptr, columns_blob, n_bytes, offset_address) == FAILED) { | |||||
| return FAILED; | |||||
| } | |||||
| RETURN_IF_NOT_OK(UncompressInt<int32_t>(column_id, data_ptr, columns_blob, n_bytes, offset_address)); | |||||
| } else if (has_compress_blob_ && column_data_type == ColumnInt64) { | } else if (has_compress_blob_ && column_data_type == ColumnInt64) { | ||||
| if (UncompressInt<int64_t>(column_id, data_ptr, columns_blob, n_bytes, offset_address) == FAILED) { | |||||
| return FAILED; | |||||
| } | |||||
| RETURN_IF_NOT_OK(UncompressInt<int64_t>(column_id, data_ptr, columns_blob, n_bytes, offset_address)); | |||||
| } else { | } else { | ||||
| *data = reinterpret_cast<const unsigned char *>(&(columns_blob[offset_address])); | *data = reinterpret_cast<const unsigned char *>(&(columns_blob[offset_address])); | ||||
| } | } | ||||
| return SUCCESS; | |||||
| return Status::OK(); | |||||
| } | } | ||||
| ColumnCategory ShardColumn::CheckColumnName(const std::string &column_name) { | ColumnCategory ShardColumn::CheckColumnName(const std::string &column_name) { | ||||
| @@ -296,7 +272,9 @@ ColumnCategory ShardColumn::CheckColumnName(const std::string &column_name) { | |||||
| std::vector<uint8_t> ShardColumn::CompressBlob(const std::vector<uint8_t> &blob, int64_t *compression_size) { | std::vector<uint8_t> ShardColumn::CompressBlob(const std::vector<uint8_t> &blob, int64_t *compression_size) { | ||||
| // Skip if no compress columns | // Skip if no compress columns | ||||
| *compression_size = 0; | *compression_size = 0; | ||||
| if (!CheckCompressBlob()) return blob; | |||||
| if (!CheckCompressBlob()) { | |||||
| return blob; | |||||
| } | |||||
| std::vector<uint8_t> dst_blob; | std::vector<uint8_t> dst_blob; | ||||
| uint64_t i_src = 0; | uint64_t i_src = 0; | ||||
| @@ -380,12 +358,14 @@ vector<uint8_t> ShardColumn::CompressInt(const vector<uint8_t> &src_bytes, const | |||||
| return dst_bytes; | return dst_bytes; | ||||
| } | } | ||||
| MSRStatus ShardColumn::GetColumnAddressInBlock(const uint64_t &column_id, const std::vector<uint8_t> &columns_blob, | |||||
| uint64_t *num_bytes, uint64_t *shift_idx) { | |||||
| Status ShardColumn::GetColumnAddressInBlock(const uint64_t &column_id, const std::vector<uint8_t> &columns_blob, | |||||
| uint64_t *num_bytes, uint64_t *shift_idx) { | |||||
| RETURN_UNEXPECTED_IF_NULL(num_bytes); | |||||
| RETURN_UNEXPECTED_IF_NULL(shift_idx); | |||||
| if (num_blob_column_ == 1) { | if (num_blob_column_ == 1) { | ||||
| *num_bytes = columns_blob.size(); | *num_bytes = columns_blob.size(); | ||||
| *shift_idx = 0; | *shift_idx = 0; | ||||
| return SUCCESS; | |||||
| return Status::OK(); | |||||
| } | } | ||||
| auto blob_id = blob_column_id_[column_name_[column_id]]; | auto blob_id = blob_column_id_[column_name_[column_id]]; | ||||
| @@ -396,13 +376,14 @@ MSRStatus ShardColumn::GetColumnAddressInBlock(const uint64_t &column_id, const | |||||
| (*shift_idx) += kInt64Len; | (*shift_idx) += kInt64Len; | ||||
| return SUCCESS; | |||||
| return Status::OK(); | |||||
| } | } | ||||
| template <typename T> | template <typename T> | ||||
| MSRStatus ShardColumn::UncompressInt(const uint64_t &column_id, std::unique_ptr<unsigned char[]> *const data_ptr, | |||||
| const std::vector<uint8_t> &columns_blob, uint64_t *num_bytes, | |||||
| uint64_t shift_idx) { | |||||
| Status ShardColumn::UncompressInt(const uint64_t &column_id, std::unique_ptr<unsigned char[]> *const data_ptr, | |||||
| const std::vector<uint8_t> &columns_blob, uint64_t *num_bytes, uint64_t shift_idx) { | |||||
| RETURN_UNEXPECTED_IF_NULL(data_ptr); | |||||
| RETURN_UNEXPECTED_IF_NULL(num_bytes); | |||||
| auto num_elements = BytesBigToUInt64(columns_blob, shift_idx, kInt32Type); | auto num_elements = BytesBigToUInt64(columns_blob, shift_idx, kInt32Type); | ||||
| *num_bytes = sizeof(T) * num_elements; | *num_bytes = sizeof(T) * num_elements; | ||||
| @@ -421,19 +402,12 @@ MSRStatus ShardColumn::UncompressInt(const uint64_t &column_id, std::unique_ptr< | |||||
| auto data = reinterpret_cast<const unsigned char *>(array_data.get()); | auto data = reinterpret_cast<const unsigned char *>(array_data.get()); | ||||
| *data_ptr = std::make_unique<unsigned char[]>(*num_bytes); | *data_ptr = std::make_unique<unsigned char[]>(*num_bytes); | ||||
| // field is none. for example: numpy is null | // field is none. for example: numpy is null | ||||
| if (*num_bytes == 0) { | if (*num_bytes == 0) { | ||||
| return SUCCESS; | |||||
| return Status::OK(); | |||||
| } | } | ||||
| int ret_code = memcpy_s(data_ptr->get(), *num_bytes, data, *num_bytes); | |||||
| if (ret_code != 0) { | |||||
| MS_LOG(ERROR) << "Failed to copy data!"; | |||||
| return FAILED; | |||||
| } | |||||
| return SUCCESS; | |||||
| CHECK_FAIL_RETURN_UNEXPECTED(memcpy_s(data_ptr->get(), *num_bytes, data, *num_bytes) == 0, "Failed to copy data!"); | |||||
| return Status::OK(); | |||||
| } | } | ||||
| uint64_t ShardColumn::BytesBigToUInt64(const std::vector<uint8_t> &bytes_array, const uint64_t &pos, | uint64_t ShardColumn::BytesBigToUInt64(const std::vector<uint8_t> &bytes_array, const uint64_t &pos, | ||||
| @@ -55,15 +55,11 @@ int64_t ShardDistributedSample::GetNumSamples(int64_t dataset_size, int64_t num_ | |||||
| return 0; | return 0; | ||||
| } | } | ||||
| MSRStatus ShardDistributedSample::PreExecute(ShardTaskList &tasks) { | |||||
| Status ShardDistributedSample::PreExecute(ShardTaskList &tasks) { | |||||
| auto total_no = tasks.Size(); | auto total_no = tasks.Size(); | ||||
| if (no_of_padded_samples_ > 0 && first_epoch_) { | if (no_of_padded_samples_ > 0 && first_epoch_) { | ||||
| if (total_no % denominator_ != 0) { | |||||
| MS_LOG(ERROR) << "Dataset size plus number of padded samples is not divisible by number of shards. " | |||||
| << "task size: " << total_no << ", number padded: " << no_of_padded_samples_ | |||||
| << ", denominator: " << denominator_; | |||||
| return FAILED; | |||||
| } | |||||
| CHECK_FAIL_RETURN_UNEXPECTED(total_no % denominator_ == 0, | |||||
| "Dataset size plus number of padded samples is not divisible by number of shards."); | |||||
| } | } | ||||
| if (first_epoch_) { | if (first_epoch_) { | ||||
| first_epoch_ = false; | first_epoch_ = false; | ||||
| @@ -74,11 +70,9 @@ MSRStatus ShardDistributedSample::PreExecute(ShardTaskList &tasks) { | |||||
| if (shuffle_ == true) { | if (shuffle_ == true) { | ||||
| shuffle_op_->SetShardSampleCount(GetShardSampleCount()); | shuffle_op_->SetShardSampleCount(GetShardSampleCount()); | ||||
| shuffle_op_->UpdateShuffleMode(GetShuffleMode()); | shuffle_op_->UpdateShuffleMode(GetShuffleMode()); | ||||
| if (SUCCESS != (*shuffle_op_)(tasks)) { | |||||
| return FAILED; | |||||
| } | |||||
| RETURN_IF_NOT_OK((*shuffle_op_)(tasks)); | |||||
| } | } | ||||
| return SUCCESS; | |||||
| return Status::OK(); | |||||
| } | } | ||||
| } // namespace mindrecord | } // namespace mindrecord | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -38,104 +38,74 @@ ShardHeader::ShardHeader() : shard_count_(0), header_size_(0), page_size_(0), co | |||||
| index_ = std::make_shared<Index>(); | index_ = std::make_shared<Index>(); | ||||
| } | } | ||||
| MSRStatus ShardHeader::InitializeHeader(const std::vector<json> &headers, bool load_dataset) { | |||||
| Status ShardHeader::InitializeHeader(const std::vector<json> &headers, bool load_dataset) { | |||||
| shard_count_ = headers.size(); | shard_count_ = headers.size(); | ||||
| int shard_index = 0; | int shard_index = 0; | ||||
| bool first = true; | bool first = true; | ||||
| for (const auto &header : headers) { | for (const auto &header : headers) { | ||||
| if (first) { | if (first) { | ||||
| first = false; | first = false; | ||||
| if (ParseSchema(header["schema"]) != SUCCESS) { | |||||
| return FAILED; | |||||
| } | |||||
| if (ParseIndexFields(header["index_fields"]) != SUCCESS) { | |||||
| return FAILED; | |||||
| } | |||||
| if (ParseStatistics(header["statistics"]) != SUCCESS) { | |||||
| return FAILED; | |||||
| } | |||||
| RETURN_IF_NOT_OK(ParseSchema(header["schema"])); | |||||
| RETURN_IF_NOT_OK(ParseIndexFields(header["index_fields"])); | |||||
| RETURN_IF_NOT_OK(ParseStatistics(header["statistics"])); | |||||
| ParseShardAddress(header["shard_addresses"]); | ParseShardAddress(header["shard_addresses"]); | ||||
| header_size_ = header["header_size"].get<uint64_t>(); | header_size_ = header["header_size"].get<uint64_t>(); | ||||
| page_size_ = header["page_size"].get<uint64_t>(); | page_size_ = header["page_size"].get<uint64_t>(); | ||||
| compression_size_ = header.contains("compression_size") ? header["compression_size"].get<uint64_t>() : 0; | compression_size_ = header.contains("compression_size") ? header["compression_size"].get<uint64_t>() : 0; | ||||
| } | } | ||||
| if (SUCCESS != ParsePage(header["page"], shard_index, load_dataset)) { | |||||
| return FAILED; | |||||
| } | |||||
| RETURN_IF_NOT_OK(ParsePage(header["page"], shard_index, load_dataset)); | |||||
| shard_index++; | shard_index++; | ||||
| } | } | ||||
| return SUCCESS; | |||||
| return Status::OK(); | |||||
| } | } | ||||
| MSRStatus ShardHeader::CheckFileStatus(const std::string &path) { | |||||
| Status ShardHeader::CheckFileStatus(const std::string &path) { | |||||
| auto realpath = Common::GetRealPath(path); | auto realpath = Common::GetRealPath(path); | ||||
| if (!realpath.has_value()) { | |||||
| MS_LOG(ERROR) << "Get real path failed, path=" << path; | |||||
| return FAILED; | |||||
| } | |||||
| CHECK_FAIL_RETURN_UNEXPECTED(realpath.has_value(), "Get real path failed, path: " + path); | |||||
| std::ifstream fin(realpath.value(), std::ios::in | std::ios::binary); | std::ifstream fin(realpath.value(), std::ios::in | std::ios::binary); | ||||
| if (!fin) { | |||||
| MS_LOG(ERROR) << "File does not exist or permission denied. path: " << path; | |||||
| return FAILED; | |||||
| } | |||||
| if (fin.fail()) { | |||||
| MS_LOG(ERROR) << "Failed to open file. path: " << path; | |||||
| return FAILED; | |||||
| } | |||||
| CHECK_FAIL_RETURN_UNEXPECTED(fin, "Failed to open file. path: " + path); | |||||
| // fetch file size | // fetch file size | ||||
| auto &io_seekg = fin.seekg(0, std::ios::end); | auto &io_seekg = fin.seekg(0, std::ios::end); | ||||
| if (!io_seekg.good() || io_seekg.fail() || io_seekg.bad()) { | if (!io_seekg.good() || io_seekg.fail() || io_seekg.bad()) { | ||||
| fin.close(); | fin.close(); | ||||
| MS_LOG(ERROR) << "File seekg failed. path: " << path; | |||||
| return FAILED; | |||||
| RETURN_STATUS_UNEXPECTED("File seekg failed. path: " + path); | |||||
| } | } | ||||
| size_t file_size = fin.tellg(); | size_t file_size = fin.tellg(); | ||||
| if (file_size < kMinFileSize) { | if (file_size < kMinFileSize) { | ||||
| fin.close(); | fin.close(); | ||||
| MS_LOG(ERROR) << "Invalid file. path: " << path; | |||||
| return FAILED; | |||||
| RETURN_STATUS_UNEXPECTED("Invalid file. path: " + path); | |||||
| } | } | ||||
| fin.close(); | fin.close(); | ||||
| return SUCCESS; | |||||
| return Status::OK(); | |||||
| } | } | ||||
| std::pair<MSRStatus, json> ShardHeader::ValidateHeader(const std::string &path) { | |||||
| if (CheckFileStatus(path) != SUCCESS) { | |||||
| return {FAILED, {}}; | |||||
| } | |||||
| Status ShardHeader::ValidateHeader(const std::string &path, std::shared_ptr<json> *header_ptr) { | |||||
| RETURN_UNEXPECTED_IF_NULL(header_ptr); | |||||
| RETURN_IF_NOT_OK(CheckFileStatus(path)); | |||||
| // read header size | // read header size | ||||
| json json_header; | json json_header; | ||||
| std::ifstream fin(common::SafeCStr(path), std::ios::in | std::ios::binary); | std::ifstream fin(common::SafeCStr(path), std::ios::in | std::ios::binary); | ||||
| if (!fin.is_open()) { | |||||
| MS_LOG(ERROR) << "File seekg failed. path: " << path; | |||||
| return {FAILED, json_header}; | |||||
| } | |||||
| CHECK_FAIL_RETURN_UNEXPECTED(fin.is_open(), "Failed to open file. path: " + path); | |||||
| uint64_t header_size = 0; | uint64_t header_size = 0; | ||||
| auto &io_read = fin.read(reinterpret_cast<char *>(&header_size), kInt64Len); | auto &io_read = fin.read(reinterpret_cast<char *>(&header_size), kInt64Len); | ||||
| if (!io_read.good() || io_read.fail() || io_read.bad()) { | if (!io_read.good() || io_read.fail() || io_read.bad()) { | ||||
| MS_LOG(ERROR) << "File read failed"; | |||||
| fin.close(); | fin.close(); | ||||
| return {FAILED, json_header}; | |||||
| RETURN_STATUS_UNEXPECTED("File read failed"); | |||||
| } | } | ||||
| if (header_size > kMaxHeaderSize) { | if (header_size > kMaxHeaderSize) { | ||||
| fin.close(); | fin.close(); | ||||
| MS_LOG(ERROR) << "Invalid file content. path: " << path; | |||||
| return {FAILED, json_header}; | |||||
| RETURN_STATUS_UNEXPECTED("Invalid file content. path: " + path); | |||||
| } | } | ||||
| // read header content | // read header content | ||||
| std::vector<uint8_t> header_content(header_size); | std::vector<uint8_t> header_content(header_size); | ||||
| auto &io_read_content = fin.read(reinterpret_cast<char *>(&header_content[0]), header_size); | auto &io_read_content = fin.read(reinterpret_cast<char *>(&header_content[0]), header_size); | ||||
| if (!io_read_content.good() || io_read_content.fail() || io_read_content.bad()) { | if (!io_read_content.good() || io_read_content.fail() || io_read_content.bad()) { | ||||
| MS_LOG(ERROR) << "File read failed. path: " << path; | |||||
| fin.close(); | fin.close(); | ||||
| return {FAILED, json_header}; | |||||
| RETURN_STATUS_UNEXPECTED("File read failed. path: " + path); | |||||
| } | } | ||||
| fin.close(); | fin.close(); | ||||
| @@ -144,34 +114,35 @@ std::pair<MSRStatus, json> ShardHeader::ValidateHeader(const std::string &path) | |||||
| try { | try { | ||||
| json_header = json::parse(raw_header_content); | json_header = json::parse(raw_header_content); | ||||
| } catch (json::parse_error &e) { | } catch (json::parse_error &e) { | ||||
| MS_LOG(ERROR) << "Json parse error: " << e.what(); | |||||
| return {FAILED, json_header}; | |||||
| RETURN_STATUS_UNEXPECTED("Json parse error: " + std::string(e.what())); | |||||
| } | } | ||||
| return {SUCCESS, json_header}; | |||||
| *header_ptr = std::make_shared<json>(json_header); | |||||
| return Status::OK(); | |||||
| } | } | ||||
| std::pair<MSRStatus, json> ShardHeader::BuildSingleHeader(const std::string &file_path) { | |||||
| auto ret = ValidateHeader(file_path); | |||||
| if (SUCCESS != ret.first) { | |||||
| return {FAILED, json()}; | |||||
| } | |||||
| json raw_header = ret.second; | |||||
| Status ShardHeader::BuildSingleHeader(const std::string &file_path, std::shared_ptr<json> *header_ptr) { | |||||
| RETURN_UNEXPECTED_IF_NULL(header_ptr); | |||||
| std::shared_ptr<json> raw_header; | |||||
| RETURN_IF_NOT_OK(ValidateHeader(file_path, &raw_header)); | |||||
| uint64_t compression_size = | uint64_t compression_size = | ||||
| raw_header.contains("compression_size") ? raw_header["compression_size"].get<uint64_t>() : 0; | |||||
| json header = {{"shard_addresses", raw_header["shard_addresses"]}, | |||||
| {"header_size", raw_header["header_size"]}, | |||||
| {"page_size", raw_header["page_size"]}, | |||||
| raw_header->contains("compression_size") ? (*raw_header)["compression_size"].get<uint64_t>() : 0; | |||||
| json header = {{"shard_addresses", (*raw_header)["shard_addresses"]}, | |||||
| {"header_size", (*raw_header)["header_size"]}, | |||||
| {"page_size", (*raw_header)["page_size"]}, | |||||
| {"compression_size", compression_size}, | {"compression_size", compression_size}, | ||||
| {"index_fields", raw_header["index_fields"]}, | |||||
| {"blob_fields", raw_header["schema"][0]["blob_fields"]}, | |||||
| {"schema", raw_header["schema"][0]["schema"]}, | |||||
| {"version", raw_header["version"]}}; | |||||
| return {SUCCESS, header}; | |||||
| {"index_fields", (*raw_header)["index_fields"]}, | |||||
| {"blob_fields", (*raw_header)["schema"][0]["blob_fields"]}, | |||||
| {"schema", (*raw_header)["schema"][0]["schema"]}, | |||||
| {"version", (*raw_header)["version"]}}; | |||||
| *header_ptr = std::make_shared<json>(header); | |||||
| return Status::OK(); | |||||
| } | } | ||||
| MSRStatus ShardHeader::BuildDataset(const std::vector<std::string> &file_paths, bool load_dataset) { | |||||
| Status ShardHeader::BuildDataset(const std::vector<std::string> &file_paths, bool load_dataset) { | |||||
| uint32_t thread_num = std::thread::hardware_concurrency(); | uint32_t thread_num = std::thread::hardware_concurrency(); | ||||
| if (thread_num == 0) thread_num = kThreadNumber; | |||||
| if (thread_num == 0) { | |||||
| thread_num = kThreadNumber; | |||||
| } | |||||
| uint32_t work_thread_num = 0; | uint32_t work_thread_num = 0; | ||||
| uint32_t shard_count = file_paths.size(); | uint32_t shard_count = file_paths.size(); | ||||
| int group_num = ceil(shard_count * 1.0 / thread_num); | int group_num = ceil(shard_count * 1.0 / thread_num); | ||||
| @@ -194,12 +165,10 @@ MSRStatus ShardHeader::BuildDataset(const std::vector<std::string> &file_paths, | |||||
| } | } | ||||
| if (thread_status) { | if (thread_status) { | ||||
| thread_status = false; | thread_status = false; | ||||
| return FAILED; | |||||
| RETURN_STATUS_UNEXPECTED("Error occurred in GetHeadersOneTask thread."); | |||||
| } | } | ||||
| if (SUCCESS != InitializeHeader(headers, load_dataset)) { | |||||
| return FAILED; | |||||
| } | |||||
| return SUCCESS; | |||||
| RETURN_IF_NOT_OK(InitializeHeader(headers, load_dataset)); | |||||
| return Status::OK(); | |||||
| } | } | ||||
| void ShardHeader::GetHeadersOneTask(int start, int end, std::vector<json> &headers, | void ShardHeader::GetHeadersOneTask(int start, int end, std::vector<json> &headers, | ||||
| @@ -208,48 +177,39 @@ void ShardHeader::GetHeadersOneTask(int start, int end, std::vector<json> &heade | |||||
| return; | return; | ||||
| } | } | ||||
| for (int x = start; x < end; ++x) { | for (int x = start; x < end; ++x) { | ||||
| auto ret = ValidateHeader(realAddresses[x]); | |||||
| if (SUCCESS != ret.first) { | |||||
| std::shared_ptr<json> header; | |||||
| if (ValidateHeader(realAddresses[x], &header).IsError()) { | |||||
| thread_status = true; | thread_status = true; | ||||
| return; | return; | ||||
| } | } | ||||
| json header; | |||||
| header = ret.second; | |||||
| header["shard_addresses"] = realAddresses; | |||||
| if (std::find(kSupportedVersion.begin(), kSupportedVersion.end(), header["version"]) == kSupportedVersion.end()) { | |||||
| MS_LOG(ERROR) << "Version wrong, file version is: " << header["version"].dump() | |||||
| (*header)["shard_addresses"] = realAddresses; | |||||
| if (std::find(kSupportedVersion.begin(), kSupportedVersion.end(), (*header)["version"]) == | |||||
| kSupportedVersion.end()) { | |||||
| MS_LOG(ERROR) << "Version wrong, file version is: " << (*header)["version"].dump() | |||||
| << ", lib version is: " << kVersion; | << ", lib version is: " << kVersion; | ||||
| thread_status = true; | thread_status = true; | ||||
| return; | return; | ||||
| } | } | ||||
| headers[x] = header; | |||||
| headers[x] = *header; | |||||
| } | } | ||||
| } | } | ||||
| MSRStatus ShardHeader::InitByFiles(const std::vector<std::string> &file_paths) { | |||||
| Status ShardHeader::InitByFiles(const std::vector<std::string> &file_paths) { | |||||
| std::vector<std::string> file_names(file_paths.size()); | std::vector<std::string> file_names(file_paths.size()); | ||||
| std::transform(file_paths.begin(), file_paths.end(), file_names.begin(), [](std::string fp) -> std::string { | std::transform(file_paths.begin(), file_paths.end(), file_names.begin(), [](std::string fp) -> std::string { | ||||
| if (GetFileName(fp).first == SUCCESS) { | |||||
| return GetFileName(fp).second; | |||||
| } | |||||
| std::shared_ptr<std::string> fn; | |||||
| return GetFileName(fp, &fn).IsOk() ? *fn : ""; | |||||
| }); | }); | ||||
| shard_addresses_ = std::move(file_names); | shard_addresses_ = std::move(file_names); | ||||
| shard_count_ = file_paths.size(); | shard_count_ = file_paths.size(); | ||||
| if (shard_count_ == 0) { | |||||
| return FAILED; | |||||
| } | |||||
| if (shard_count_ <= kMaxShardCount) { | |||||
| pages_.resize(shard_count_); | |||||
| } else { | |||||
| return FAILED; | |||||
| } | |||||
| return SUCCESS; | |||||
| CHECK_FAIL_RETURN_UNEXPECTED(shard_count_ != 0 && (shard_count_ <= kMaxShardCount), | |||||
| "shard count is invalid. shard count: " + std::to_string(shard_count_)); | |||||
| pages_.resize(shard_count_); | |||||
| return Status::OK(); | |||||
| } | } | ||||
| void ShardHeader::ParseHeader(const json &header) {} | |||||
| MSRStatus ShardHeader::ParseIndexFields(const json &index_fields) { | |||||
| Status ShardHeader::ParseIndexFields(const json &index_fields) { | |||||
| std::vector<std::pair<uint64_t, std::string>> parsed_index_fields; | std::vector<std::pair<uint64_t, std::string>> parsed_index_fields; | ||||
| for (auto &index_field : index_fields) { | for (auto &index_field : index_fields) { | ||||
| auto schema_id = index_field["schema_id"].get<uint64_t>(); | auto schema_id = index_field["schema_id"].get<uint64_t>(); | ||||
| @@ -257,18 +217,15 @@ MSRStatus ShardHeader::ParseIndexFields(const json &index_fields) { | |||||
| std::pair<uint64_t, std::string> parsed_index_field(schema_id, field_name); | std::pair<uint64_t, std::string> parsed_index_field(schema_id, field_name); | ||||
| parsed_index_fields.push_back(parsed_index_field); | parsed_index_fields.push_back(parsed_index_field); | ||||
| } | } | ||||
| if (!parsed_index_fields.empty() && AddIndexFields(parsed_index_fields) != SUCCESS) { | |||||
| return FAILED; | |||||
| } | |||||
| return SUCCESS; | |||||
| RETURN_IF_NOT_OK(AddIndexFields(parsed_index_fields)); | |||||
| return Status::OK(); | |||||
| } | } | ||||
| MSRStatus ShardHeader::ParsePage(const json &pages, int shard_index, bool load_dataset) { | |||||
| Status ShardHeader::ParsePage(const json &pages, int shard_index, bool load_dataset) { | |||||
| // set shard_index when load_dataset is false | // set shard_index when load_dataset is false | ||||
| if (shard_count_ > kMaxFileCount) { | |||||
| MS_LOG(ERROR) << "The number of mindrecord files is greater than max value: " << kMaxFileCount; | |||||
| return FAILED; | |||||
| } | |||||
| CHECK_FAIL_RETURN_UNEXPECTED( | |||||
| shard_count_ <= kMaxFileCount, | |||||
| "The number of mindrecord files is greater than max value: " + std::to_string(kMaxFileCount)); | |||||
| if (pages_.empty() && shard_count_ <= kMaxFileCount) { | if (pages_.empty() && shard_count_ <= kMaxFileCount) { | ||||
| pages_.resize(shard_count_); | pages_.resize(shard_count_); | ||||
| } | } | ||||
| @@ -295,44 +252,37 @@ MSRStatus ShardHeader::ParsePage(const json &pages, int shard_index, bool load_d | |||||
| pages_[shard_index].push_back(std::move(parsed_page)); | pages_[shard_index].push_back(std::move(parsed_page)); | ||||
| } | } | ||||
| } | } | ||||
| return SUCCESS; | |||||
| return Status::OK(); | |||||
| } | } | ||||
| MSRStatus ShardHeader::ParseStatistics(const json &statistics) { | |||||
| Status ShardHeader::ParseStatistics(const json &statistics) { | |||||
| for (auto &statistic : statistics) { | for (auto &statistic : statistics) { | ||||
| if (statistic.find("desc") == statistic.end() || statistic.find("statistics") == statistic.end()) { | |||||
| MS_LOG(ERROR) << "Deserialize statistics failed, statistic: " << statistics.dump(); | |||||
| return FAILED; | |||||
| } | |||||
| CHECK_FAIL_RETURN_UNEXPECTED( | |||||
| statistic.find("desc") != statistic.end() && statistic.find("statistics") != statistic.end(), | |||||
| "Deserialize statistics failed, statistic: " + statistics.dump()); | |||||
| std::string statistic_description = statistic["desc"].get<std::string>(); | std::string statistic_description = statistic["desc"].get<std::string>(); | ||||
| json statistic_body = statistic["statistics"]; | json statistic_body = statistic["statistics"]; | ||||
| std::shared_ptr<Statistics> parsed_statistic = Statistics::Build(statistic_description, statistic_body); | std::shared_ptr<Statistics> parsed_statistic = Statistics::Build(statistic_description, statistic_body); | ||||
| if (!parsed_statistic) { | |||||
| return FAILED; | |||||
| } | |||||
| RETURN_UNEXPECTED_IF_NULL(parsed_statistic); | |||||
| AddStatistic(parsed_statistic); | AddStatistic(parsed_statistic); | ||||
| } | } | ||||
| return SUCCESS; | |||||
| return Status::OK(); | |||||
| } | } | ||||
| MSRStatus ShardHeader::ParseSchema(const json &schemas) { | |||||
| Status ShardHeader::ParseSchema(const json &schemas) { | |||||
| for (auto &schema : schemas) { | for (auto &schema : schemas) { | ||||
| // change how we get schemaBody once design is finalized | // change how we get schemaBody once design is finalized | ||||
| if (schema.find("desc") == schema.end() || schema.find("blob_fields") == schema.end() || | |||||
| schema.find("schema") == schema.end()) { | |||||
| MS_LOG(ERROR) << "Deserialize schema failed. schema: " << schema.dump(); | |||||
| return FAILED; | |||||
| } | |||||
| CHECK_FAIL_RETURN_UNEXPECTED(schema.find("desc") != schema.end() && schema.find("blob_fields") != schema.end() && | |||||
| schema.find("schema") != schema.end(), | |||||
| "Deserialize schema failed. schema: " + schema.dump()); | |||||
| std::string schema_description = schema["desc"].get<std::string>(); | std::string schema_description = schema["desc"].get<std::string>(); | ||||
| std::vector<std::string> blob_fields = schema["blob_fields"].get<std::vector<std::string>>(); | std::vector<std::string> blob_fields = schema["blob_fields"].get<std::vector<std::string>>(); | ||||
| json schema_body = schema["schema"]; | json schema_body = schema["schema"]; | ||||
| std::shared_ptr<Schema> parsed_schema = Schema::Build(schema_description, schema_body); | std::shared_ptr<Schema> parsed_schema = Schema::Build(schema_description, schema_body); | ||||
| if (!parsed_schema) { | |||||
| return FAILED; | |||||
| } | |||||
| RETURN_UNEXPECTED_IF_NULL(parsed_schema); | |||||
| AddSchema(parsed_schema); | AddSchema(parsed_schema); | ||||
| } | } | ||||
| return SUCCESS; | |||||
| return Status::OK(); | |||||
| } | } | ||||
| void ShardHeader::ParseShardAddress(const json &address) { | void ShardHeader::ParseShardAddress(const json &address) { | ||||
| @@ -340,7 +290,7 @@ void ShardHeader::ParseShardAddress(const json &address) { | |||||
| } | } | ||||
| std::vector<std::string> ShardHeader::SerializeHeader() { | std::vector<std::string> ShardHeader::SerializeHeader() { | ||||
| std::vector<string> header; | |||||
| std::vector<std::string> header; | |||||
| auto index = SerializeIndexFields(); | auto index = SerializeIndexFields(); | ||||
| auto stats = SerializeStatistics(); | auto stats = SerializeStatistics(); | ||||
| auto schema = SerializeSchema(); | auto schema = SerializeSchema(); | ||||
| @@ -406,45 +356,42 @@ std::string ShardHeader::SerializeSchema() { | |||||
| std::string ShardHeader::SerializeShardAddress() { | std::string ShardHeader::SerializeShardAddress() { | ||||
| json j; | json j; | ||||
| (void)std::transform(shard_addresses_.begin(), shard_addresses_.end(), std::back_inserter(j), | |||||
| [](const std::string &addr) { return GetFileName(addr).second; }); | |||||
| std::shared_ptr<std::string> fn_ptr; | |||||
| for (const auto &addr : shard_addresses_) { | |||||
| (void)GetFileName(addr, &fn_ptr); | |||||
| j.emplace_back(*fn_ptr); | |||||
| } | |||||
| return j.dump(); | return j.dump(); | ||||
| } | } | ||||
| std::pair<std::shared_ptr<Page>, MSRStatus> ShardHeader::GetPage(const int &shard_id, const int &page_id) { | |||||
| Status ShardHeader::GetPage(const int &shard_id, const int &page_id, std::shared_ptr<Page> *page_ptr) { | |||||
| RETURN_UNEXPECTED_IF_NULL(page_ptr); | |||||
| if (shard_id < static_cast<int>(pages_.size()) && page_id < static_cast<int>(pages_[shard_id].size())) { | if (shard_id < static_cast<int>(pages_.size()) && page_id < static_cast<int>(pages_[shard_id].size())) { | ||||
| return std::make_pair(pages_[shard_id][page_id], SUCCESS); | |||||
| } else { | |||||
| return std::make_pair(nullptr, FAILED); | |||||
| *page_ptr = pages_[shard_id][page_id]; | |||||
| return Status::OK(); | |||||
| } | } | ||||
| page_ptr = nullptr; | |||||
| RETURN_STATUS_UNEXPECTED("Failed to Get Page."); | |||||
| } | } | ||||
| MSRStatus ShardHeader::SetPage(const std::shared_ptr<Page> &new_page) { | |||||
| if (new_page == nullptr) { | |||||
| return FAILED; | |||||
| } | |||||
| Status ShardHeader::SetPage(const std::shared_ptr<Page> &new_page) { | |||||
| int shard_id = new_page->GetShardID(); | int shard_id = new_page->GetShardID(); | ||||
| int page_id = new_page->GetPageID(); | int page_id = new_page->GetPageID(); | ||||
| if (shard_id < static_cast<int>(pages_.size()) && page_id < static_cast<int>(pages_[shard_id].size())) { | if (shard_id < static_cast<int>(pages_.size()) && page_id < static_cast<int>(pages_[shard_id].size())) { | ||||
| pages_[shard_id][page_id] = new_page; | pages_[shard_id][page_id] = new_page; | ||||
| return SUCCESS; | |||||
| } else { | |||||
| return FAILED; | |||||
| return Status::OK(); | |||||
| } | } | ||||
| RETURN_STATUS_UNEXPECTED("Failed to Set Page."); | |||||
| } | } | ||||
| MSRStatus ShardHeader::AddPage(const std::shared_ptr<Page> &new_page) { | |||||
| if (new_page == nullptr) { | |||||
| return FAILED; | |||||
| } | |||||
| Status ShardHeader::AddPage(const std::shared_ptr<Page> &new_page) { | |||||
| int shard_id = new_page->GetShardID(); | int shard_id = new_page->GetShardID(); | ||||
| int page_id = new_page->GetPageID(); | int page_id = new_page->GetPageID(); | ||||
| if (shard_id < static_cast<int>(pages_.size()) && page_id == static_cast<int>(pages_[shard_id].size())) { | if (shard_id < static_cast<int>(pages_.size()) && page_id == static_cast<int>(pages_[shard_id].size())) { | ||||
| pages_[shard_id].push_back(new_page); | pages_[shard_id].push_back(new_page); | ||||
| return SUCCESS; | |||||
| } else { | |||||
| return FAILED; | |||||
| return Status::OK(); | |||||
| } | } | ||||
| RETURN_STATUS_UNEXPECTED("Failed to Add Page."); | |||||
| } | } | ||||
| int64_t ShardHeader::GetLastPageId(const int &shard_id) { | int64_t ShardHeader::GetLastPageId(const int &shard_id) { | ||||
| @@ -468,20 +415,18 @@ int ShardHeader::GetLastPageIdByType(const int &shard_id, const std::string &pag | |||||
| return last_page_id; | return last_page_id; | ||||
| } | } | ||||
| const std::pair<MSRStatus, std::shared_ptr<Page>> ShardHeader::GetPageByGroupId(const int &group_id, | |||||
| const int &shard_id) { | |||||
| if (shard_id >= static_cast<int>(pages_.size())) { | |||||
| MS_LOG(ERROR) << "Shard id is more than sum of shards."; | |||||
| return {FAILED, nullptr}; | |||||
| } | |||||
| Status ShardHeader::GetPageByGroupId(const int &group_id, const int &shard_id, std::shared_ptr<Page> *page_ptr) { | |||||
| RETURN_UNEXPECTED_IF_NULL(page_ptr); | |||||
| CHECK_FAIL_RETURN_UNEXPECTED(shard_id < static_cast<int>(pages_.size()), "Shard id is more than sum of shards."); | |||||
| for (uint64_t i = pages_[shard_id].size(); i >= 1; i--) { | for (uint64_t i = pages_[shard_id].size(); i >= 1; i--) { | ||||
| auto page = pages_[shard_id][i - 1]; | auto page = pages_[shard_id][i - 1]; | ||||
| if (page->GetPageType() == kPageTypeBlob && page->GetPageTypeID() == group_id) { | if (page->GetPageType() == kPageTypeBlob && page->GetPageTypeID() == group_id) { | ||||
| return {SUCCESS, page}; | |||||
| *page_ptr = std::make_shared<Page>(*page); | |||||
| return Status::OK(); | |||||
| } | } | ||||
| } | } | ||||
| MS_LOG(ERROR) << "Could not get page by group id " << group_id; | |||||
| return {FAILED, nullptr}; | |||||
| page_ptr = nullptr; | |||||
| RETURN_STATUS_UNEXPECTED("Failed to get page by group id: " + group_id); | |||||
| } | } | ||||
| int ShardHeader::AddSchema(std::shared_ptr<Schema> schema) { | int ShardHeader::AddSchema(std::shared_ptr<Schema> schema) { | ||||
| @@ -524,151 +469,88 @@ std::shared_ptr<Index> ShardHeader::InitIndexPtr() { | |||||
| return index; | return index; | ||||
| } | } | ||||
| MSRStatus ShardHeader::CheckIndexField(const std::string &field, const json &schema) { | |||||
| Status ShardHeader::CheckIndexField(const std::string &field, const json &schema) { | |||||
| // check field name is or is not valid | // check field name is or is not valid | ||||
| if (schema.find(field) == schema.end()) { | |||||
| MS_LOG(ERROR) << "Schema do not contain the field: " << field << "."; | |||||
| return FAILED; | |||||
| } | |||||
| if (schema[field]["type"] == "bytes") { | |||||
| MS_LOG(ERROR) << field << " is bytes type, can not be schema index field."; | |||||
| return FAILED; | |||||
| } | |||||
| if (schema.find(field) != schema.end() && schema[field].find("shape") != schema[field].end()) { | |||||
| MS_LOG(ERROR) << field << " array can not be schema index field."; | |||||
| return FAILED; | |||||
| } | |||||
| return SUCCESS; | |||||
| CHECK_FAIL_RETURN_UNEXPECTED(schema.find(field) != schema.end(), "Filed can not found in schema."); | |||||
| CHECK_FAIL_RETURN_UNEXPECTED(schema[field]["type"] != "Bytes", "bytes can not be as index field."); | |||||
| CHECK_FAIL_RETURN_UNEXPECTED(schema.find(field) == schema.end() || schema[field].find("shape") == schema[field].end(), | |||||
| "array can not be as index field."); | |||||
| return Status::OK(); | |||||
| } | } | ||||
| MSRStatus ShardHeader::AddIndexFields(const std::vector<std::string> &fields) { | |||||
| Status ShardHeader::AddIndexFields(const std::vector<std::string> &fields) { | |||||
| if (fields.empty()) { | |||||
| return Status::OK(); | |||||
| } | |||||
| CHECK_FAIL_RETURN_UNEXPECTED(!GetSchemas().empty(), "Schema is empty."); | |||||
| // create index Object | // create index Object | ||||
| std::shared_ptr<Index> index = InitIndexPtr(); | std::shared_ptr<Index> index = InitIndexPtr(); | ||||
| if (fields.size() == kInt0) { | |||||
| MS_LOG(ERROR) << "There are no index fields"; | |||||
| return FAILED; | |||||
| } | |||||
| if (GetSchemas().empty()) { | |||||
| MS_LOG(ERROR) << "No schema is set"; | |||||
| return FAILED; | |||||
| } | |||||
| for (const auto &schemaPtr : schema_) { | for (const auto &schemaPtr : schema_) { | ||||
| auto result = GetSchemaByID(schemaPtr->GetSchemaID()); | |||||
| if (result.second != SUCCESS) { | |||||
| MS_LOG(ERROR) << "Could not get schema by id."; | |||||
| return FAILED; | |||||
| } | |||||
| if (result.first == nullptr) { | |||||
| MS_LOG(ERROR) << "Could not get schema by id."; | |||||
| return FAILED; | |||||
| } | |||||
| json schema = result.first->GetSchema().at("schema"); | |||||
| std::shared_ptr<Schema> schema_ptr; | |||||
| RETURN_IF_NOT_OK(GetSchemaByID(schemaPtr->GetSchemaID(), &schema_ptr)); | |||||
| json schema = schema_ptr->GetSchema().at("schema"); | |||||
| // checkout and add fields for each schema | // checkout and add fields for each schema | ||||
| std::set<std::string> field_set; | std::set<std::string> field_set; | ||||
| for (const auto &item : index->GetFields()) { | for (const auto &item : index->GetFields()) { | ||||
| field_set.insert(item.second); | field_set.insert(item.second); | ||||
| } | } | ||||
| for (const auto &field : fields) { | for (const auto &field : fields) { | ||||
| if (field_set.find(field) != field_set.end()) { | |||||
| MS_LOG(ERROR) << "Add same index field twice"; | |||||
| return FAILED; | |||||
| } | |||||
| CHECK_FAIL_RETURN_UNEXPECTED(field_set.find(field) == field_set.end(), "Add same index field twice."); | |||||
| // check field name is or is not valid | // check field name is or is not valid | ||||
| if (CheckIndexField(field, schema) == FAILED) { | |||||
| return FAILED; | |||||
| } | |||||
| RETURN_IF_NOT_OK(CheckIndexField(field, schema)); | |||||
| field_set.insert(field); | field_set.insert(field); | ||||
| // add field into index | // add field into index | ||||
| index.get()->AddIndexField(schemaPtr->GetSchemaID(), field); | index.get()->AddIndexField(schemaPtr->GetSchemaID(), field); | ||||
| } | } | ||||
| } | } | ||||
| index_ = index; | index_ = index; | ||||
| return SUCCESS; | |||||
| return Status::OK(); | |||||
| } | } | ||||
| MSRStatus ShardHeader::GetAllSchemaID(std::set<uint64_t> &bucket_count) { | |||||
| Status ShardHeader::GetAllSchemaID(std::set<uint64_t> &bucket_count) { | |||||
| // get all schema id | // get all schema id | ||||
| for (const auto &schema : schema_) { | for (const auto &schema : schema_) { | ||||
| auto bucket_it = bucket_count.find(schema->GetSchemaID()); | auto bucket_it = bucket_count.find(schema->GetSchemaID()); | ||||
| if (bucket_it != bucket_count.end()) { | |||||
| MS_LOG(ERROR) << "Schema duplication"; | |||||
| return FAILED; | |||||
| } else { | |||||
| bucket_count.insert(schema->GetSchemaID()); | |||||
| } | |||||
| CHECK_FAIL_RETURN_UNEXPECTED(bucket_it == bucket_count.end(), "Schema duplication."); | |||||
| bucket_count.insert(schema->GetSchemaID()); | |||||
| } | } | ||||
| return SUCCESS; | |||||
| return Status::OK(); | |||||
| } | } | ||||
| MSRStatus ShardHeader::AddIndexFields(std::vector<std::pair<uint64_t, std::string>> fields) { | |||||
| Status ShardHeader::AddIndexFields(std::vector<std::pair<uint64_t, std::string>> fields) { | |||||
| if (fields.empty()) { | |||||
| return Status::OK(); | |||||
| } | |||||
| // create index Object | // create index Object | ||||
| std::shared_ptr<Index> index = InitIndexPtr(); | std::shared_ptr<Index> index = InitIndexPtr(); | ||||
| if (fields.size() == kInt0) { | |||||
| MS_LOG(ERROR) << "There are no index fields"; | |||||
| return FAILED; | |||||
| } | |||||
| // get all schema id | // get all schema id | ||||
| std::set<uint64_t> bucket_count; | std::set<uint64_t> bucket_count; | ||||
| if (GetAllSchemaID(bucket_count) != SUCCESS) { | |||||
| return FAILED; | |||||
| } | |||||
| RETURN_IF_NOT_OK(GetAllSchemaID(bucket_count)); | |||||
| // check and add fields for each schema | // check and add fields for each schema | ||||
| std::set<std::pair<uint64_t, std::string>> field_set; | std::set<std::pair<uint64_t, std::string>> field_set; | ||||
| for (const auto &item : index->GetFields()) { | for (const auto &item : index->GetFields()) { | ||||
| field_set.insert(item); | field_set.insert(item); | ||||
| } | } | ||||
| for (const auto &field : fields) { | for (const auto &field : fields) { | ||||
| if (field_set.find(field) != field_set.end()) { | |||||
| MS_LOG(ERROR) << "Add same index field twice"; | |||||
| return FAILED; | |||||
| } | |||||
| CHECK_FAIL_RETURN_UNEXPECTED(field_set.find(field) == field_set.end(), "Add same index field twice."); | |||||
| uint64_t schema_id = field.first; | uint64_t schema_id = field.first; | ||||
| std::string field_name = field.second; | std::string field_name = field.second; | ||||
| // check schemaId is or is not valid | // check schemaId is or is not valid | ||||
| if (bucket_count.find(schema_id) == bucket_count.end()) { | |||||
| MS_LOG(ERROR) << "Illegal schema id: " << schema_id; | |||||
| return FAILED; | |||||
| } | |||||
| CHECK_FAIL_RETURN_UNEXPECTED(bucket_count.find(schema_id) != bucket_count.end(), "Invalid schema id: " + schema_id); | |||||
| // check field name is or is not valid | // check field name is or is not valid | ||||
| auto result = GetSchemaByID(schema_id); | |||||
| if (result.second != SUCCESS) { | |||||
| MS_LOG(ERROR) << "Could not get schema by id."; | |||||
| return FAILED; | |||||
| } | |||||
| json schema = result.first->GetSchema().at("schema"); | |||||
| if (schema.find(field_name) == schema.end()) { | |||||
| MS_LOG(ERROR) << "Schema " << schema_id << " do not contain the field: " << field_name; | |||||
| return FAILED; | |||||
| } | |||||
| if (CheckIndexField(field_name, schema) == FAILED) { | |||||
| return FAILED; | |||||
| } | |||||
| std::shared_ptr<Schema> schema_ptr; | |||||
| RETURN_IF_NOT_OK(GetSchemaByID(schema_id, &schema_ptr)); | |||||
| json schema = schema_ptr->GetSchema().at("schema"); | |||||
| CHECK_FAIL_RETURN_UNEXPECTED(schema.find(field_name) != schema.end(), | |||||
| "Schema " + std::to_string(schema_id) + " do not contain the field: " + field_name); | |||||
| RETURN_IF_NOT_OK(CheckIndexField(field_name, schema)); | |||||
| field_set.insert(field); | field_set.insert(field); | ||||
| // add field into index | // add field into index | ||||
| index.get()->AddIndexField(schema_id, field_name); | |||||
| index->AddIndexField(schema_id, field_name); | |||||
| } | } | ||||
| index_ = index; | index_ = index; | ||||
| return SUCCESS; | |||||
| return Status::OK(); | |||||
| } | } | ||||
| std::string ShardHeader::GetShardAddressByID(int64_t shard_id) { | std::string ShardHeader::GetShardAddressByID(int64_t shard_id) { | ||||
| @@ -686,103 +568,71 @@ std::vector<std::pair<uint64_t, std::string>> ShardHeader::GetFields() { return | |||||
| std::shared_ptr<Index> ShardHeader::GetIndex() { return index_; } | std::shared_ptr<Index> ShardHeader::GetIndex() { return index_; } | ||||
| std::pair<std::shared_ptr<Schema>, MSRStatus> ShardHeader::GetSchemaByID(int64_t schema_id) { | |||||
| Status ShardHeader::GetSchemaByID(int64_t schema_id, std::shared_ptr<Schema> *schema_ptr) { | |||||
| RETURN_UNEXPECTED_IF_NULL(schema_ptr); | |||||
| int64_t schemaSize = schema_.size(); | int64_t schemaSize = schema_.size(); | ||||
| if (schema_id < 0 || schema_id >= schemaSize) { | |||||
| MS_LOG(ERROR) << "Illegal schema id"; | |||||
| return std::make_pair(nullptr, FAILED); | |||||
| } | |||||
| return std::make_pair(schema_.at(schema_id), SUCCESS); | |||||
| CHECK_FAIL_RETURN_UNEXPECTED(schema_id >= 0 && schema_id < schemaSize, "schema id is invalid."); | |||||
| *schema_ptr = schema_.at(schema_id); | |||||
| return Status::OK(); | |||||
| } | } | ||||
| std::pair<std::shared_ptr<Statistics>, MSRStatus> ShardHeader::GetStatisticByID(int64_t statistic_id) { | |||||
| Status ShardHeader::GetStatisticByID(int64_t statistic_id, std::shared_ptr<Statistics> *statistics_ptr) { | |||||
| RETURN_UNEXPECTED_IF_NULL(statistics_ptr); | |||||
| int64_t statistics_size = statistics_.size(); | int64_t statistics_size = statistics_.size(); | ||||
| if (statistic_id < 0 || statistic_id >= statistics_size) { | |||||
| return std::make_pair(nullptr, FAILED); | |||||
| } | |||||
| return std::make_pair(statistics_.at(statistic_id), SUCCESS); | |||||
| CHECK_FAIL_RETURN_UNEXPECTED(statistic_id >= 0 && statistic_id < statistics_size, "statistic id is invalid."); | |||||
| *statistics_ptr = statistics_.at(statistic_id); | |||||
| return Status::OK(); | |||||
| } | } | ||||
| MSRStatus ShardHeader::PagesToFile(const std::string dump_file_name) { | |||||
| Status ShardHeader::PagesToFile(const std::string dump_file_name) { | |||||
| auto realpath = Common::GetRealPath(dump_file_name); | auto realpath = Common::GetRealPath(dump_file_name); | ||||
| if (!realpath.has_value()) { | |||||
| MS_LOG(ERROR) << "Get real path failed, path=" << dump_file_name; | |||||
| return FAILED; | |||||
| } | |||||
| CHECK_FAIL_RETURN_UNEXPECTED(realpath.has_value(), "Get real path failed, path=" + dump_file_name); | |||||
| // write header content to file, dump whatever is in the file before | // write header content to file, dump whatever is in the file before | ||||
| std::ofstream page_out_handle(realpath.value(), std::ios_base::trunc | std::ios_base::out); | std::ofstream page_out_handle(realpath.value(), std::ios_base::trunc | std::ios_base::out); | ||||
| if (page_out_handle.fail()) { | |||||
| MS_LOG(ERROR) << "Failed in opening page file"; | |||||
| return FAILED; | |||||
| } | |||||
| CHECK_FAIL_RETURN_UNEXPECTED(page_out_handle.good(), "Failed to open page file."); | |||||
| auto pages = SerializePage(); | auto pages = SerializePage(); | ||||
| for (const auto &shard_pages : pages) { | for (const auto &shard_pages : pages) { | ||||
| page_out_handle << shard_pages << "\n"; | page_out_handle << shard_pages << "\n"; | ||||
| } | } | ||||
| page_out_handle.close(); | page_out_handle.close(); | ||||
| return SUCCESS; | |||||
| return Status::OK(); | |||||
| } | } | ||||
| MSRStatus ShardHeader::FileToPages(const std::string dump_file_name) { | |||||
| Status ShardHeader::FileToPages(const std::string dump_file_name) { | |||||
| for (auto &v : pages_) { // clean pages | for (auto &v : pages_) { // clean pages | ||||
| v.clear(); | v.clear(); | ||||
| } | } | ||||
| auto realpath = Common::GetRealPath(dump_file_name); | auto realpath = Common::GetRealPath(dump_file_name); | ||||
| if (!realpath.has_value()) { | |||||
| MS_LOG(ERROR) << "Get real path failed, path=" << dump_file_name; | |||||
| return FAILED; | |||||
| } | |||||
| CHECK_FAIL_RETURN_UNEXPECTED(realpath.has_value(), "Get real path failed, path=" + dump_file_name); | |||||
| // attempt to open the file contains the page in json | // attempt to open the file contains the page in json | ||||
| std::ifstream page_in_handle(realpath.value()); | std::ifstream page_in_handle(realpath.value()); | ||||
| if (!page_in_handle.good()) { | |||||
| MS_LOG(INFO) << "No page file exists."; | |||||
| return SUCCESS; | |||||
| } | |||||
| CHECK_FAIL_RETURN_UNEXPECTED(page_in_handle.good(), "No page file exists."); | |||||
| std::string line; | std::string line; | ||||
| while (std::getline(page_in_handle, line)) { | while (std::getline(page_in_handle, line)) { | ||||
| if (SUCCESS != ParsePage(json::parse(line), -1, true)) { | |||||
| return FAILED; | |||||
| } | |||||
| RETURN_IF_NOT_OK(ParsePage(json::parse(line), -1, true)); | |||||
| } | } | ||||
| page_in_handle.close(); | page_in_handle.close(); | ||||
| return SUCCESS; | |||||
| return Status::OK(); | |||||
| } | } | ||||
| MSRStatus ShardHeader::Initialize(const std::shared_ptr<ShardHeader> *header_ptr, const json &schema, | |||||
| const std::vector<std::string> &index_fields, std::vector<std::string> &blob_fields, | |||||
| uint64_t &schema_id) { | |||||
| if (header_ptr == nullptr) { | |||||
| MS_LOG(ERROR) << "ShardHeader pointer is NULL."; | |||||
| return FAILED; | |||||
| } | |||||
| Status ShardHeader::Initialize(const std::shared_ptr<ShardHeader> *header_ptr, const json &schema, | |||||
| const std::vector<std::string> &index_fields, std::vector<std::string> &blob_fields, | |||||
| uint64_t &schema_id) { | |||||
| RETURN_UNEXPECTED_IF_NULL(header_ptr); | |||||
| auto schema_ptr = Schema::Build("mindrecord", schema); | auto schema_ptr = Schema::Build("mindrecord", schema); | ||||
| if (schema_ptr == nullptr) { | |||||
| MS_LOG(ERROR) << "Got unexpected error when building mindrecord schema."; | |||||
| return FAILED; | |||||
| } | |||||
| RETURN_UNEXPECTED_IF_NULL(schema_ptr); | |||||
| schema_id = (*header_ptr)->AddSchema(schema_ptr); | schema_id = (*header_ptr)->AddSchema(schema_ptr); | ||||
| // create index | // create index | ||||
| std::vector<std::pair<uint64_t, std::string>> id_index_fields; | std::vector<std::pair<uint64_t, std::string>> id_index_fields; | ||||
| if (!index_fields.empty()) { | if (!index_fields.empty()) { | ||||
| (void)std::transform(index_fields.begin(), index_fields.end(), std::back_inserter(id_index_fields), | |||||
| [schema_id](const std::string &el) { return std::make_pair(schema_id, el); }); | |||||
| if (SUCCESS != (*header_ptr)->AddIndexFields(id_index_fields)) { | |||||
| MS_LOG(ERROR) << "Got unexpected error when adding mindrecord index."; | |||||
| return FAILED; | |||||
| } | |||||
| (void)transform(index_fields.begin(), index_fields.end(), std::back_inserter(id_index_fields), | |||||
| [schema_id](const std::string &el) { return std::make_pair(schema_id, el); }); | |||||
| RETURN_IF_NOT_OK((*header_ptr)->AddIndexFields(id_index_fields)); | |||||
| } | } | ||||
| auto build_schema_ptr = (*header_ptr)->GetSchemas()[0]; | auto build_schema_ptr = (*header_ptr)->GetSchemas()[0]; | ||||
| blob_fields = build_schema_ptr->GetBlobFields(); | blob_fields = build_schema_ptr->GetBlobFields(); | ||||
| return SUCCESS; | |||||
| return Status::OK(); | |||||
| } | } | ||||
| } // namespace mindrecord | } // namespace mindrecord | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -37,13 +37,11 @@ ShardPkSample::ShardPkSample(const std::string &category_field, int64_t num_elem | |||||
| shuffle_op_ = std::make_shared<ShardShuffle>(seed, kShuffleSample); // do shuffle and replacement | shuffle_op_ = std::make_shared<ShardShuffle>(seed, kShuffleSample); // do shuffle and replacement | ||||
| } | } | ||||
| MSRStatus ShardPkSample::SufExecute(ShardTaskList &tasks) { | |||||
| Status ShardPkSample::SufExecute(ShardTaskList &tasks) { | |||||
| if (shuffle_ == true) { | if (shuffle_ == true) { | ||||
| if (SUCCESS != (*shuffle_op_)(tasks)) { | |||||
| return FAILED; | |||||
| } | |||||
| RETURN_IF_NOT_OK((*shuffle_op_)(tasks)); | |||||
| } | } | ||||
| return SUCCESS; | |||||
| return Status::OK(); | |||||
| } | } | ||||
| } // namespace mindrecord | } // namespace mindrecord | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -80,7 +80,7 @@ int64_t ShardSample::GetNumSamples(int64_t dataset_size, int64_t num_classes) { | |||||
| return 0; | return 0; | ||||
| } | } | ||||
| MSRStatus ShardSample::UpdateTasks(ShardTaskList &tasks, int taking) { | |||||
| Status ShardSample::UpdateTasks(ShardTaskList &tasks, int taking) { | |||||
| if (tasks.permutation_.empty()) { | if (tasks.permutation_.empty()) { | ||||
| ShardTaskList new_tasks; | ShardTaskList new_tasks; | ||||
| int total_no = static_cast<int>(tasks.sample_ids_.size()); | int total_no = static_cast<int>(tasks.sample_ids_.size()); | ||||
| @@ -110,9 +110,7 @@ MSRStatus ShardSample::UpdateTasks(ShardTaskList &tasks, int taking) { | |||||
| ShardTaskList::TaskListSwap(tasks, new_tasks); | ShardTaskList::TaskListSwap(tasks, new_tasks); | ||||
| } else { | } else { | ||||
| ShardTaskList new_tasks; | ShardTaskList new_tasks; | ||||
| if (taking > static_cast<int>(tasks.sample_ids_.size())) { | |||||
| return FAILED; | |||||
| } | |||||
| CHECK_FAIL_RETURN_UNEXPECTED(taking <= static_cast<int>(tasks.sample_ids_.size()), "taking is out of range."); | |||||
| int total_no = static_cast<int>(tasks.permutation_.size()); | int total_no = static_cast<int>(tasks.permutation_.size()); | ||||
| int cnt = 0; | int cnt = 0; | ||||
| for (size_t i = partition_id_ * taking; i < (partition_id_ + 1) * taking; i++) { | for (size_t i = partition_id_ * taking; i < (partition_id_ + 1) * taking; i++) { | ||||
| @@ -122,10 +120,10 @@ MSRStatus ShardSample::UpdateTasks(ShardTaskList &tasks, int taking) { | |||||
| } | } | ||||
| ShardTaskList::TaskListSwap(tasks, new_tasks); | ShardTaskList::TaskListSwap(tasks, new_tasks); | ||||
| } | } | ||||
| return SUCCESS; | |||||
| return Status::OK(); | |||||
| } | } | ||||
| MSRStatus ShardSample::Execute(ShardTaskList &tasks) { | |||||
| Status ShardSample::Execute(ShardTaskList &tasks) { | |||||
| if (offset_ != -1) { | if (offset_ != -1) { | ||||
| int64_t old_v = 0; | int64_t old_v = 0; | ||||
| int num_rows_ = static_cast<int>(tasks.sample_ids_.size()); | int num_rows_ = static_cast<int>(tasks.sample_ids_.size()); | ||||
| @@ -146,10 +144,8 @@ MSRStatus ShardSample::Execute(ShardTaskList &tasks) { | |||||
| no_of_samples_ = std::min(no_of_samples_, total_no); | no_of_samples_ = std::min(no_of_samples_, total_no); | ||||
| taking = no_of_samples_ - no_of_samples_ % no_of_categories; | taking = no_of_samples_ - no_of_samples_ % no_of_categories; | ||||
| } else if (sampler_type_ == kSubsetRandomSampler || sampler_type_ == kSubsetSampler) { | } else if (sampler_type_ == kSubsetRandomSampler || sampler_type_ == kSubsetSampler) { | ||||
| if (indices_.size() > static_cast<size_t>(total_no)) { | |||||
| MS_LOG(ERROR) << "parameter indices's size is greater than dataset size."; | |||||
| return FAILED; | |||||
| } | |||||
| CHECK_FAIL_RETURN_UNEXPECTED(indices_.size() <= static_cast<size_t>(total_no), | |||||
| "Parameter indices's size is greater than dataset size."); | |||||
| } else { // constructor TopPercent | } else { // constructor TopPercent | ||||
| if (numerator_ > 0 && denominator_ > 0 && numerator_ <= denominator_) { | if (numerator_ > 0 && denominator_ > 0 && numerator_ <= denominator_) { | ||||
| if (numerator_ == 1 && denominator_ > 1) { // sharding | if (numerator_ == 1 && denominator_ > 1) { // sharding | ||||
| @@ -159,20 +155,17 @@ MSRStatus ShardSample::Execute(ShardTaskList &tasks) { | |||||
| taking -= (taking % no_of_categories); | taking -= (taking % no_of_categories); | ||||
| } | } | ||||
| } else { | } else { | ||||
| MS_LOG(ERROR) << "parameter numerator or denominator is illegal"; | |||||
| return FAILED; | |||||
| RETURN_STATUS_UNEXPECTED("Parameter numerator or denominator is invalid."); | |||||
| } | } | ||||
| } | } | ||||
| return UpdateTasks(tasks, taking); | return UpdateTasks(tasks, taking); | ||||
| } | } | ||||
| MSRStatus ShardSample::SufExecute(ShardTaskList &tasks) { | |||||
| Status ShardSample::SufExecute(ShardTaskList &tasks) { | |||||
| if (sampler_type_ == kSubsetRandomSampler) { | if (sampler_type_ == kSubsetRandomSampler) { | ||||
| if (SUCCESS != (*shuffle_op_)(tasks)) { | |||||
| return FAILED; | |||||
| } | |||||
| RETURN_IF_NOT_OK((*shuffle_op_)(tasks)); | |||||
| } | } | ||||
| return SUCCESS; | |||||
| return Status::OK(); | |||||
| } | } | ||||
| } // namespace mindrecord | } // namespace mindrecord | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -38,12 +38,6 @@ std::shared_ptr<Schema> Schema::Build(std::string desc, const json &schema) { | |||||
| return std::make_shared<Schema>(object_schema); | return std::make_shared<Schema>(object_schema); | ||||
| } | } | ||||
| std::shared_ptr<Schema> Schema::Build(std::string desc, pybind11::handle schema) { | |||||
| // validate check | |||||
| json schema_json = nlohmann::detail::ToJsonImpl(schema); | |||||
| return Build(std::move(desc), schema_json); | |||||
| } | |||||
| std::string Schema::GetDesc() const { return desc_; } | std::string Schema::GetDesc() const { return desc_; } | ||||
| json Schema::GetSchema() const { | json Schema::GetSchema() const { | ||||
| @@ -54,12 +48,6 @@ json Schema::GetSchema() const { | |||||
| return str_schema; | return str_schema; | ||||
| } | } | ||||
| pybind11::object Schema::GetSchemaForPython() const { | |||||
| json schema_json = GetSchema(); | |||||
| pybind11::object schema_py = nlohmann::detail::FromJsonImpl(schema_json); | |||||
| return schema_py; | |||||
| } | |||||
| void Schema::SetSchemaID(int64_t id) { schema_id_ = id; } | void Schema::SetSchemaID(int64_t id) { schema_id_ = id; } | ||||
| int64_t Schema::GetSchemaID() const { return schema_id_; } | int64_t Schema::GetSchemaID() const { return schema_id_; } | ||||
| @@ -38,7 +38,7 @@ int64_t ShardSequentialSample::GetNumSamples(int64_t dataset_size, int64_t num_c | |||||
| return std::min(static_cast<int64_t>(no_of_samples_), dataset_size); | return std::min(static_cast<int64_t>(no_of_samples_), dataset_size); | ||||
| } | } | ||||
| MSRStatus ShardSequentialSample::Execute(ShardTaskList &tasks) { | |||||
| Status ShardSequentialSample::Execute(ShardTaskList &tasks) { | |||||
| int64_t taking; | int64_t taking; | ||||
| int64_t total_no = static_cast<int64_t>(tasks.sample_ids_.size()); | int64_t total_no = static_cast<int64_t>(tasks.sample_ids_.size()); | ||||
| if (no_of_samples_ == 0 && (per_ >= -kEpsilon && per_ <= kEpsilon)) { | if (no_of_samples_ == 0 && (per_ >= -kEpsilon && per_ <= kEpsilon)) { | ||||
| @@ -58,16 +58,15 @@ MSRStatus ShardSequentialSample::Execute(ShardTaskList &tasks) { | |||||
| ShardTaskList::TaskListSwap(tasks, new_tasks); | ShardTaskList::TaskListSwap(tasks, new_tasks); | ||||
| } else { // shuffled | } else { // shuffled | ||||
| ShardTaskList new_tasks; | ShardTaskList new_tasks; | ||||
| if (taking > static_cast<int64_t>(tasks.permutation_.size())) { | |||||
| return FAILED; | |||||
| } | |||||
| CHECK_FAIL_RETURN_UNEXPECTED(taking <= static_cast<int64_t>(tasks.permutation_.size()), | |||||
| "Taking is out of task range."); | |||||
| total_no = static_cast<int64_t>(tasks.permutation_.size()); | total_no = static_cast<int64_t>(tasks.permutation_.size()); | ||||
| for (size_t i = offset_; i < taking + offset_; ++i) { | for (size_t i = offset_; i < taking + offset_; ++i) { | ||||
| new_tasks.AssignTask(tasks, tasks.permutation_[i % total_no]); | new_tasks.AssignTask(tasks, tasks.permutation_[i % total_no]); | ||||
| } | } | ||||
| ShardTaskList::TaskListSwap(tasks, new_tasks); | ShardTaskList::TaskListSwap(tasks, new_tasks); | ||||
| } | } | ||||
| return SUCCESS; | |||||
| return Status::OK(); | |||||
| } | } | ||||
| } // namespace mindrecord | } // namespace mindrecord | ||||
| @@ -42,7 +42,7 @@ int64_t ShardShuffle::GetNumSamples(int64_t dataset_size, int64_t num_classes) { | |||||
| return no_of_samples_ == 0 ? dataset_size : std::min(dataset_size, no_of_samples_); | return no_of_samples_ == 0 ? dataset_size : std::min(dataset_size, no_of_samples_); | ||||
| } | } | ||||
| MSRStatus ShardShuffle::CategoryShuffle(ShardTaskList &tasks) { | |||||
| Status ShardShuffle::CategoryShuffle(ShardTaskList &tasks) { | |||||
| uint32_t individual_size = tasks.sample_ids_.size() / tasks.categories; | uint32_t individual_size = tasks.sample_ids_.size() / tasks.categories; | ||||
| std::vector<std::vector<int>> new_permutations(tasks.categories, std::vector<int>(individual_size)); | std::vector<std::vector<int>> new_permutations(tasks.categories, std::vector<int>(individual_size)); | ||||
| for (uint32_t i = 0; i < tasks.categories; i++) { | for (uint32_t i = 0; i < tasks.categories; i++) { | ||||
| @@ -62,17 +62,14 @@ MSRStatus ShardShuffle::CategoryShuffle(ShardTaskList &tasks) { | |||||
| } | } | ||||
| ShardTaskList::TaskListSwap(tasks, new_tasks); | ShardTaskList::TaskListSwap(tasks, new_tasks); | ||||
| return SUCCESS; | |||||
| return Status::OK(); | |||||
| } | } | ||||
| MSRStatus ShardShuffle::ShuffleFiles(ShardTaskList &tasks) { | |||||
| Status ShardShuffle::ShuffleFiles(ShardTaskList &tasks) { | |||||
| if (no_of_samples_ == 0) { | if (no_of_samples_ == 0) { | ||||
| no_of_samples_ = static_cast<int>(tasks.Size()); | no_of_samples_ = static_cast<int>(tasks.Size()); | ||||
| } | } | ||||
| if (no_of_samples_ <= 0) { | |||||
| MS_LOG(ERROR) << "no_of_samples need to be positive."; | |||||
| return FAILED; | |||||
| } | |||||
| CHECK_FAIL_RETURN_UNEXPECTED(no_of_samples_ > 0, "Parameter no_of_samples need to be positive."); | |||||
| auto shard_sample_cout = GetShardSampleCount(); | auto shard_sample_cout = GetShardSampleCount(); | ||||
| // shuffle the files index | // shuffle the files index | ||||
| @@ -118,16 +115,14 @@ MSRStatus ShardShuffle::ShuffleFiles(ShardTaskList &tasks) { | |||||
| new_tasks.AssignTask(tasks, tasks.permutation_[i]); | new_tasks.AssignTask(tasks, tasks.permutation_[i]); | ||||
| } | } | ||||
| ShardTaskList::TaskListSwap(tasks, new_tasks); | ShardTaskList::TaskListSwap(tasks, new_tasks); | ||||
| return Status::OK(); | |||||
| } | } | ||||
| MSRStatus ShardShuffle::ShuffleInfile(ShardTaskList &tasks) { | |||||
| Status ShardShuffle::ShuffleInfile(ShardTaskList &tasks) { | |||||
| if (no_of_samples_ == 0) { | if (no_of_samples_ == 0) { | ||||
| no_of_samples_ = static_cast<int>(tasks.Size()); | no_of_samples_ = static_cast<int>(tasks.Size()); | ||||
| } | } | ||||
| if (no_of_samples_ <= 0) { | |||||
| MS_LOG(ERROR) << "no_of_samples need to be positive."; | |||||
| return FAILED; | |||||
| } | |||||
| CHECK_FAIL_RETURN_UNEXPECTED(no_of_samples_ > 0, "Parameter no_of_samples need to be positive."); | |||||
| // reconstruct the permutation in file | // reconstruct the permutation in file | ||||
| // -- before -- | // -- before -- | ||||
| // file1: [0, 1, 2] | // file1: [0, 1, 2] | ||||
| @@ -154,13 +149,12 @@ MSRStatus ShardShuffle::ShuffleInfile(ShardTaskList &tasks) { | |||||
| new_tasks.AssignTask(tasks, tasks.permutation_[i]); | new_tasks.AssignTask(tasks, tasks.permutation_[i]); | ||||
| } | } | ||||
| ShardTaskList::TaskListSwap(tasks, new_tasks); | ShardTaskList::TaskListSwap(tasks, new_tasks); | ||||
| return Status::OK(); | |||||
| } | } | ||||
| MSRStatus ShardShuffle::Execute(ShardTaskList &tasks) { | |||||
| Status ShardShuffle::Execute(ShardTaskList &tasks) { | |||||
| if (reshuffle_each_epoch_) shuffle_seed_++; | if (reshuffle_each_epoch_) shuffle_seed_++; | ||||
| if (tasks.categories < 1) { | |||||
| return FAILED; | |||||
| } | |||||
| CHECK_FAIL_RETURN_UNEXPECTED(tasks.categories >= 1, "Task category is invalid."); | |||||
| if (shuffle_type_ == kShuffleSample) { // shuffle each sample | if (shuffle_type_ == kShuffleSample) { // shuffle each sample | ||||
| if (tasks.permutation_.empty() == true) { | if (tasks.permutation_.empty() == true) { | ||||
| tasks.MakePerm(); | tasks.MakePerm(); | ||||
| @@ -169,10 +163,7 @@ MSRStatus ShardShuffle::Execute(ShardTaskList &tasks) { | |||||
| if (replacement_ == true) { | if (replacement_ == true) { | ||||
| ShardTaskList new_tasks; | ShardTaskList new_tasks; | ||||
| if (no_of_samples_ == 0) no_of_samples_ = static_cast<int>(tasks.sample_ids_.size()); | if (no_of_samples_ == 0) no_of_samples_ = static_cast<int>(tasks.sample_ids_.size()); | ||||
| if (no_of_samples_ <= 0) { | |||||
| MS_LOG(ERROR) << "no_of_samples need to be positive."; | |||||
| return FAILED; | |||||
| } | |||||
| CHECK_FAIL_RETURN_UNEXPECTED(no_of_samples_ > 0, "Parameter no_of_samples need to be positive."); | |||||
| for (uint32_t i = 0; i < no_of_samples_; ++i) { | for (uint32_t i = 0; i < no_of_samples_; ++i) { | ||||
| new_tasks.AssignTask(tasks, tasks.GetRandomTaskID()); | new_tasks.AssignTask(tasks, tasks.GetRandomTaskID()); | ||||
| } | } | ||||
| @@ -190,20 +181,14 @@ MSRStatus ShardShuffle::Execute(ShardTaskList &tasks) { | |||||
| ShardTaskList::TaskListSwap(tasks, new_tasks); | ShardTaskList::TaskListSwap(tasks, new_tasks); | ||||
| } | } | ||||
| } else if (GetShuffleMode() == dataset::ShuffleMode::kInfile) { | } else if (GetShuffleMode() == dataset::ShuffleMode::kInfile) { | ||||
| auto ret = ShuffleInfile(tasks); | |||||
| if (ret != SUCCESS) { | |||||
| return ret; | |||||
| } | |||||
| RETURN_IF_NOT_OK(ShuffleInfile(tasks)); | |||||
| } else if (GetShuffleMode() == dataset::ShuffleMode::kFiles) { | } else if (GetShuffleMode() == dataset::ShuffleMode::kFiles) { | ||||
| auto ret = ShuffleFiles(tasks); | |||||
| if (ret != SUCCESS) { | |||||
| return ret; | |||||
| } | |||||
| RETURN_IF_NOT_OK(ShuffleFiles(tasks)); | |||||
| } | } | ||||
| } else { // shuffle unit like: (a1, b1, c1),(a2, b2, c2),..., (an, bn, cn) | } else { // shuffle unit like: (a1, b1, c1),(a2, b2, c2),..., (an, bn, cn) | ||||
| return this->CategoryShuffle(tasks); | return this->CategoryShuffle(tasks); | ||||
| } | } | ||||
| return SUCCESS; | |||||
| return Status::OK(); | |||||
| } | } | ||||
| } // namespace mindrecord | } // namespace mindrecord | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -35,19 +35,6 @@ std::shared_ptr<Statistics> Statistics::Build(std::string desc, const json &stat | |||||
| return std::make_shared<Statistics>(object_statistics); | return std::make_shared<Statistics>(object_statistics); | ||||
| } | } | ||||
| std::shared_ptr<Statistics> Statistics::Build(std::string desc, pybind11::handle statistics) { | |||||
| // validate check | |||||
| json statistics_json = nlohmann::detail::ToJsonImpl(statistics); | |||||
| if (!Validate(statistics_json)) { | |||||
| return nullptr; | |||||
| } | |||||
| Statistics object_statistics; | |||||
| object_statistics.desc_ = std::move(desc); | |||||
| object_statistics.statistics_ = statistics_json; | |||||
| object_statistics.statistics_id_ = -1; | |||||
| return std::make_shared<Statistics>(object_statistics); | |||||
| } | |||||
| std::string Statistics::GetDesc() const { return desc_; } | std::string Statistics::GetDesc() const { return desc_; } | ||||
| json Statistics::GetStatistics() const { | json Statistics::GetStatistics() const { | ||||
| @@ -57,11 +44,6 @@ json Statistics::GetStatistics() const { | |||||
| return str_statistics; | return str_statistics; | ||||
| } | } | ||||
| pybind11::object Statistics::GetStatisticsForPython() const { | |||||
| json str_statistics = Statistics::GetStatistics(); | |||||
| return nlohmann::detail::FromJsonImpl(str_statistics); | |||||
| } | |||||
| void Statistics::SetStatisticsID(int64_t id) { statistics_id_ = id; } | void Statistics::SetStatisticsID(int64_t id) { statistics_id_ = id; } | ||||
| int64_t Statistics::GetStatisticsID() const { return statistics_id_; } | int64_t Statistics::GetStatisticsID() const { return statistics_id_; } | ||||
| @@ -85,15 +85,9 @@ uint32_t ShardTaskList::SizeOfRows() const { | |||||
| return nRows; | return nRows; | ||||
| } | } | ||||
| ShardTask &ShardTaskList::GetTaskByID(size_t id) { | |||||
| MS_ASSERT(id < task_list_.size()); | |||||
| return task_list_[id]; | |||||
| } | |||||
| ShardTask &ShardTaskList::GetTaskByID(size_t id) { return task_list_[id]; } | |||||
| int ShardTaskList::GetTaskSampleByID(size_t id) { | |||||
| MS_ASSERT(id < sample_ids_.size()); | |||||
| return sample_ids_[id]; | |||||
| } | |||||
| int ShardTaskList::GetTaskSampleByID(size_t id) { return sample_ids_[id]; } | |||||
| int ShardTaskList::GetRandomTaskID() { | int ShardTaskList::GetRandomTaskID() { | ||||
| std::mt19937 gen = mindspore::dataset::GetRandomDevice(); | std::mt19937 gen = mindspore::dataset::GetRandomDevice(); | ||||
| @@ -70,7 +70,7 @@ class ShardReader: | |||||
| Raises: | Raises: | ||||
| MRMLaunchError: If failed to launch worker threads. | MRMLaunchError: If failed to launch worker threads. | ||||
| """ | """ | ||||
| ret = self._reader.launch(False) | |||||
| ret = self._reader.launch() | |||||
| if ret != ms.MSRStatus.SUCCESS: | if ret != ms.MSRStatus.SUCCESS: | ||||
| logger.error("Failed to launch worker threads.") | logger.error("Failed to launch worker threads.") | ||||
| raise MRMLaunchError | raise MRMLaunchError | ||||
| @@ -19,7 +19,6 @@ import mindspore._c_mindrecord as ms | |||||
| from mindspore import log as logger | from mindspore import log as logger | ||||
| from .shardutils import populate_data, SUCCESS | from .shardutils import populate_data, SUCCESS | ||||
| from .shardheader import ShardHeader | from .shardheader import ShardHeader | ||||
| from .common.exceptions import MRMOpenError, MRMFetchCandidateFieldsError, MRMReadCategoryInfoError, MRMFetchDataError | |||||
| __all__ = ['ShardSegment'] | __all__ = ['ShardSegment'] | ||||
| @@ -73,15 +72,8 @@ class ShardSegment: | |||||
| Returns: | Returns: | ||||
| list[str], by which data could be grouped. | list[str], by which data could be grouped. | ||||
| Raises: | |||||
| MRMFetchCandidateFieldsError: If failed to get candidate category fields. | |||||
| """ | """ | ||||
| ret, fields = self._segment.get_category_fields() | |||||
| if ret != SUCCESS: | |||||
| logger.error("Failed to get candidate category fields.") | |||||
| raise MRMFetchCandidateFieldsError | |||||
| return fields | |||||
| return self._segment.get_category_fields() | |||||
| def set_category_field(self, category_field): | def set_category_field(self, category_field): | ||||
| """Select one category field to use.""" | """Select one category field to use.""" | ||||
| @@ -94,14 +86,8 @@ class ShardSegment: | |||||
| Returns: | Returns: | ||||
| str, description fo group information. | str, description fo group information. | ||||
| Raises: | |||||
| MRMReadCategoryInfoError: If failed to read category information. | |||||
| """ | """ | ||||
| ret, category_info = self._segment.read_category_info() | |||||
| if ret != SUCCESS: | |||||
| logger.error("Failed to read category information.") | |||||
| raise MRMReadCategoryInfoError | |||||
| return category_info | |||||
| return self._segment.read_category_info() | |||||
| def read_at_page_by_id(self, category_id, page, num_row): | def read_at_page_by_id(self, category_id, page, num_row): | ||||
| """ | """ | ||||
| @@ -116,13 +102,9 @@ class ShardSegment: | |||||
| list[dict] | list[dict] | ||||
| Raises: | Raises: | ||||
| MRMFetchDataError: If failed to read by category id. | |||||
| MRMUnsupportedSchemaError: If schema is invalid. | MRMUnsupportedSchemaError: If schema is invalid. | ||||
| """ | """ | ||||
| ret, data = self._segment.read_at_page_by_id(category_id, page, num_row) | |||||
| if ret != SUCCESS: | |||||
| logger.error("Failed to read by category id.") | |||||
| raise MRMFetchDataError | |||||
| data = self._segment.read_at_page_by_id(category_id, page, num_row) | |||||
| return [populate_data(raw, blob, self._columns, self._header.blob_fields, | return [populate_data(raw, blob, self._columns, self._header.blob_fields, | ||||
| self._header.schema) for blob, raw in data] | self._header.schema) for blob, raw in data] | ||||
| @@ -139,12 +121,8 @@ class ShardSegment: | |||||
| list[dict] | list[dict] | ||||
| Raises: | Raises: | ||||
| MRMFetchDataError: If failed to read by category name. | |||||
| MRMUnsupportedSchemaError: If schema is invalid. | MRMUnsupportedSchemaError: If schema is invalid. | ||||
| """ | """ | ||||
| ret, data = self._segment.read_at_page_by_name(category_name, page, num_row) | |||||
| if ret != SUCCESS: | |||||
| logger.error("Failed to read by category name.") | |||||
| raise MRMFetchDataError | |||||
| data = self._segment.read_at_page_by_name(category_name, page, num_row) | |||||
| return [populate_data(raw, blob, self._columns, self._header.blob_fields, | return [populate_data(raw, blob, self._columns, self._header.blob_fields, | ||||
| self._header.schema) for blob, raw in data] | self._header.schema) for blob, raw in data] | ||||
| @@ -384,8 +384,8 @@ void ShardWriterImageNetOpenForAppend(string filename) { | |||||
| { | { | ||||
| MS_LOG(INFO) << "=============== images " << bin_data.size() << " ============================"; | MS_LOG(INFO) << "=============== images " << bin_data.size() << " ============================"; | ||||
| mindrecord::ShardWriter fw; | mindrecord::ShardWriter fw; | ||||
| auto ret = fw.OpenForAppend(filename); | |||||
| if (ret == FAILED) { | |||||
| auto status = fw.OpenForAppend(filename); | |||||
| if (status.IsError()) { | |||||
| return; | return; | ||||
| } | } | ||||
| @@ -121,14 +121,19 @@ TEST_F(TestShard, TestShardHeaderPart) { | |||||
| re_statistics.push_back(*statistic); | re_statistics.push_back(*statistic); | ||||
| } | } | ||||
| ASSERT_EQ(re_statistics, validate_statistics); | ASSERT_EQ(re_statistics, validate_statistics); | ||||
| ASSERT_EQ(header_data.GetStatisticByID(-1).second, FAILED); | |||||
| ASSERT_EQ(header_data.GetStatisticByID(10).second, FAILED); | |||||
| std::shared_ptr<Statistics> statistics_ptr; | |||||
| auto status = header_data.GetStatisticByID(-1, &statistics_ptr); | |||||
| EXPECT_FALSE(status.IsOk()); | |||||
| status = header_data.GetStatisticByID(10, &statistics_ptr); | |||||
| EXPECT_FALSE(status.IsOk()); | |||||
| // test add index fields | // test add index fields | ||||
| std::vector<std::pair<uint64_t, std::string>> fields; | std::vector<std::pair<uint64_t, std::string>> fields; | ||||
| std::pair<uint64_t, std::string> pair1(0, "name"); | std::pair<uint64_t, std::string> pair1(0, "name"); | ||||
| fields.push_back(pair1); | fields.push_back(pair1); | ||||
| ASSERT_TRUE(header_data.AddIndexFields(fields) == SUCCESS); | |||||
| status = header_data.AddIndexFields(fields); | |||||
| EXPECT_TRUE(status.IsOk()); | |||||
| std::vector<std::pair<uint64_t, std::string>> resFields = header_data.GetFields(); | std::vector<std::pair<uint64_t, std::string>> resFields = header_data.GetFields(); | ||||
| ASSERT_EQ(resFields, fields); | ASSERT_EQ(resFields, fields); | ||||
| } | } | ||||
| @@ -79,36 +79,37 @@ TEST_F(TestShardHeader, AddIndexFields) { | |||||
| std::pair<uint64_t, std::string> index_field2(schema_id1, "box"); | std::pair<uint64_t, std::string> index_field2(schema_id1, "box"); | ||||
| fields.push_back(index_field1); | fields.push_back(index_field1); | ||||
| fields.push_back(index_field2); | fields.push_back(index_field2); | ||||
| MSRStatus res = header_data.AddIndexFields(fields); | |||||
| ASSERT_EQ(res, SUCCESS); | |||||
| Status status = header_data.AddIndexFields(fields); | |||||
| EXPECT_TRUE(status.IsOk()); | |||||
| ASSERT_EQ(header_data.GetFields().size(), 2); | ASSERT_EQ(header_data.GetFields().size(), 2); | ||||
| fields.clear(); | fields.clear(); | ||||
| std::pair<uint64_t, std::string> index_field3(schema_id1, "name"); | std::pair<uint64_t, std::string> index_field3(schema_id1, "name"); | ||||
| fields.push_back(index_field3); | fields.push_back(index_field3); | ||||
| res = header_data.AddIndexFields(fields); | |||||
| ASSERT_EQ(res, FAILED); | |||||
| status = header_data.AddIndexFields(fields); | |||||
| EXPECT_FALSE(status.IsOk()); | |||||
| ASSERT_EQ(header_data.GetFields().size(), 2); | ASSERT_EQ(header_data.GetFields().size(), 2); | ||||
| fields.clear(); | fields.clear(); | ||||
| std::pair<uint64_t, std::string> index_field4(schema_id1, "names"); | std::pair<uint64_t, std::string> index_field4(schema_id1, "names"); | ||||
| fields.push_back(index_field4); | fields.push_back(index_field4); | ||||
| res = header_data.AddIndexFields(fields); | |||||
| ASSERT_EQ(res, FAILED); | |||||
| status = header_data.AddIndexFields(fields); | |||||
| EXPECT_FALSE(status.IsOk()); | |||||
| ASSERT_EQ(header_data.GetFields().size(), 2); | ASSERT_EQ(header_data.GetFields().size(), 2); | ||||
| fields.clear(); | fields.clear(); | ||||
| std::pair<uint64_t, std::string> index_field5(schema_id1 + 1, "name"); | std::pair<uint64_t, std::string> index_field5(schema_id1 + 1, "name"); | ||||
| fields.push_back(index_field5); | fields.push_back(index_field5); | ||||
| res = header_data.AddIndexFields(fields); | |||||
| ASSERT_EQ(res, FAILED); | |||||
| status = header_data.AddIndexFields(fields); | |||||
| EXPECT_FALSE(status.IsOk()); | |||||
| ASSERT_EQ(header_data.GetFields().size(), 2); | ASSERT_EQ(header_data.GetFields().size(), 2); | ||||
| fields.clear(); | fields.clear(); | ||||
| std::pair<uint64_t, std::string> index_field6(schema_id1, "label"); | std::pair<uint64_t, std::string> index_field6(schema_id1, "label"); | ||||
| fields.push_back(index_field6); | fields.push_back(index_field6); | ||||
| res = header_data.AddIndexFields(fields); | |||||
| ASSERT_EQ(res, FAILED); | |||||
| status = header_data.AddIndexFields(fields); | |||||
| EXPECT_FALSE(status.IsOk()); | |||||
| ASSERT_EQ(header_data.GetFields().size(), 2); | ASSERT_EQ(header_data.GetFields().size(), 2); | ||||
| std::string desc_new = "this is a test1"; | std::string desc_new = "this is a test1"; | ||||
| @@ -129,26 +130,26 @@ TEST_F(TestShardHeader, AddIndexFields) { | |||||
| single_fields.push_back("name"); | single_fields.push_back("name"); | ||||
| single_fields.push_back("name"); | single_fields.push_back("name"); | ||||
| single_fields.push_back("box"); | single_fields.push_back("box"); | ||||
| res = header_data_new.AddIndexFields(single_fields); | |||||
| ASSERT_EQ(res, FAILED); | |||||
| status = header_data_new.AddIndexFields(single_fields); | |||||
| EXPECT_FALSE(status.IsOk()); | |||||
| ASSERT_EQ(header_data_new.GetFields().size(), 1); | ASSERT_EQ(header_data_new.GetFields().size(), 1); | ||||
| single_fields.push_back("name"); | single_fields.push_back("name"); | ||||
| single_fields.push_back("box"); | single_fields.push_back("box"); | ||||
| res = header_data_new.AddIndexFields(single_fields); | |||||
| ASSERT_EQ(res, FAILED); | |||||
| status = header_data_new.AddIndexFields(single_fields); | |||||
| EXPECT_FALSE(status.IsOk()); | |||||
| ASSERT_EQ(header_data_new.GetFields().size(), 1); | ASSERT_EQ(header_data_new.GetFields().size(), 1); | ||||
| single_fields.clear(); | single_fields.clear(); | ||||
| single_fields.push_back("names"); | single_fields.push_back("names"); | ||||
| res = header_data_new.AddIndexFields(single_fields); | |||||
| ASSERT_EQ(res, FAILED); | |||||
| status = header_data_new.AddIndexFields(single_fields); | |||||
| EXPECT_FALSE(status.IsOk()); | |||||
| ASSERT_EQ(header_data_new.GetFields().size(), 1); | ASSERT_EQ(header_data_new.GetFields().size(), 1); | ||||
| single_fields.clear(); | single_fields.clear(); | ||||
| single_fields.push_back("box"); | single_fields.push_back("box"); | ||||
| res = header_data_new.AddIndexFields(single_fields); | |||||
| ASSERT_EQ(res, SUCCESS); | |||||
| status = header_data_new.AddIndexFields(single_fields); | |||||
| EXPECT_TRUE(status.IsOk()); | |||||
| ASSERT_EQ(header_data_new.GetFields().size(), 2); | ASSERT_EQ(header_data_new.GetFields().size(), 2); | ||||
| } | } | ||||
| } // namespace mindrecord | } // namespace mindrecord | ||||
| @@ -167,8 +167,8 @@ TEST_F(TestShardReader, TestShardReaderColumnNotInIndex) { | |||||
| std::string file_name = "./imagenet.shard01"; | std::string file_name = "./imagenet.shard01"; | ||||
| auto column_list = std::vector<std::string>{"label"}; | auto column_list = std::vector<std::string>{"label"}; | ||||
| ShardReader dataset; | ShardReader dataset; | ||||
| MSRStatus ret = dataset.Open({file_name}, true, 4, column_list); | |||||
| ASSERT_EQ(ret, SUCCESS); | |||||
| auto status = dataset.Open({file_name}, true, 4, column_list); | |||||
| EXPECT_TRUE(status.IsOk()); | |||||
| dataset.Launch(); | dataset.Launch(); | ||||
| while (true) { | while (true) { | ||||
| @@ -188,16 +188,16 @@ TEST_F(TestShardReader, TestShardReaderColumnNotInSchema) { | |||||
| std::string file_name = "./imagenet.shard01"; | std::string file_name = "./imagenet.shard01"; | ||||
| auto column_list = std::vector<std::string>{"file_namex"}; | auto column_list = std::vector<std::string>{"file_namex"}; | ||||
| ShardReader dataset; | ShardReader dataset; | ||||
| MSRStatus ret = dataset.Open({file_name}, true, 4, column_list); | |||||
| ASSERT_EQ(ret, ILLEGAL_COLUMN_LIST); | |||||
| auto status= dataset.Open({file_name}, true, 4, column_list); | |||||
| EXPECT_FALSE(status.IsOk()); | |||||
| } | } | ||||
| TEST_F(TestShardReader, TestShardVersion) { | TEST_F(TestShardReader, TestShardVersion) { | ||||
| MS_LOG(INFO) << FormatInfo("Test shard version"); | MS_LOG(INFO) << FormatInfo("Test shard version"); | ||||
| std::string file_name = "./imagenet.shard01"; | std::string file_name = "./imagenet.shard01"; | ||||
| ShardReader dataset; | ShardReader dataset; | ||||
| MSRStatus ret = dataset.Open({file_name}, true, 4); | |||||
| ASSERT_EQ(ret, SUCCESS); | |||||
| auto status = dataset.Open({file_name}, true, 4); | |||||
| EXPECT_TRUE(status.IsOk()); | |||||
| dataset.Launch(); | dataset.Launch(); | ||||
| while (true) { | while (true) { | ||||
| @@ -219,8 +219,8 @@ TEST_F(TestShardReader, TestShardReaderDir) { | |||||
| auto column_list = std::vector<std::string>{"file_name"}; | auto column_list = std::vector<std::string>{"file_name"}; | ||||
| ShardReader dataset; | ShardReader dataset; | ||||
| MSRStatus ret = dataset.Open({file_name}, true, 4, column_list); | |||||
| ASSERT_EQ(ret, FAILED); | |||||
| auto status = dataset.Open({file_name}, true, 4, column_list); | |||||
| EXPECT_FALSE(status.IsOk()); | |||||
| } | } | ||||
| TEST_F(TestShardReader, TestShardReaderConsumer) { | TEST_F(TestShardReader, TestShardReaderConsumer) { | ||||
| @@ -61,35 +61,44 @@ TEST_F(TestShardSegment, TestShardSegment) { | |||||
| ShardSegment dataset; | ShardSegment dataset; | ||||
| dataset.Open({file_name}, true, 4); | dataset.Open({file_name}, true, 4); | ||||
| auto x = dataset.GetCategoryFields(); | |||||
| for (const auto &fields : x.second) { | |||||
| auto fields_ptr = std::make_shared<vector<std::string>>(); | |||||
| auto status = dataset.GetCategoryFields(&fields_ptr); | |||||
| for (const auto &fields : *fields_ptr) { | |||||
| MS_LOG(INFO) << "Get category field: " << fields; | MS_LOG(INFO) << "Get category field: " << fields; | ||||
| } | } | ||||
| ASSERT_TRUE(dataset.SetCategoryField("label") == SUCCESS); | |||||
| ASSERT_TRUE(dataset.SetCategoryField("laabel_0") == FAILED); | |||||
| status = dataset.SetCategoryField("label"); | |||||
| EXPECT_TRUE(status.IsOk()); | |||||
| status = dataset.SetCategoryField("laabel_0"); | |||||
| EXPECT_FALSE(status.IsOk()); | |||||
| MS_LOG(INFO) << "Read category info: " << dataset.ReadCategoryInfo().second; | |||||
| auto ret = dataset.ReadAtPageByName("822", 0, 10); | |||||
| auto images = ret.second; | |||||
| MS_LOG(INFO) << "category field: 822, images count: " << images.size() << ", image[0] size: " << images[0].size(); | |||||
| std::shared_ptr<std::string> category_ptr; | |||||
| status = dataset.ReadCategoryInfo(&category_ptr); | |||||
| EXPECT_TRUE(status.IsOk()); | |||||
| MS_LOG(INFO) << "Read category info: " << *category_ptr; | |||||
| auto ret1 = dataset.ReadAtPageByName("823", 0, 10); | |||||
| auto images2 = ret1.second; | |||||
| MS_LOG(INFO) << "category field: 823, images count: " << images2.size(); | |||||
| auto pages_ptr = std::make_shared<std::vector<std::vector<uint8_t>>>(); | |||||
| status = dataset.ReadAtPageByName("822", 0, 10, &pages_ptr); | |||||
| EXPECT_TRUE(status.IsOk()); | |||||
| MS_LOG(INFO) << "category field: 822, images count: " << pages_ptr->size() << ", image[0] size: " << ((*pages_ptr)[0]).size(); | |||||
| auto ret2 = dataset.ReadAtPageById(1, 0, 10); | |||||
| auto images3 = ret2.second; | |||||
| MS_LOG(INFO) << "category id: 1, images count: " << images3.size() << ", image[0] size: " << images3[0].size(); | |||||
| auto pages_ptr_1 = std::make_shared<std::vector<std::vector<uint8_t>>>(); | |||||
| status = dataset.ReadAtPageByName("823", 0, 10, &pages_ptr_1); | |||||
| MS_LOG(INFO) << "category field: 823, images count: " << pages_ptr_1->size(); | |||||
| auto ret3 = dataset.ReadAllAtPageByName("822", 0, 10); | |||||
| auto images4 = ret3.second; | |||||
| MS_LOG(INFO) << "category field: 822, images count: " << images4.size(); | |||||
| auto pages_ptr_2 = std::make_shared<std::vector<std::vector<uint8_t>>>(); | |||||
| status = dataset.ReadAtPageById(1, 0, 10, &pages_ptr_2); | |||||
| EXPECT_TRUE(status.IsOk()); | |||||
| MS_LOG(INFO) << "category id: 1, images count: " << pages_ptr_2->size() << ", image[0] size: " << ((*pages_ptr_2)[0]).size(); | |||||
| auto ret4 = dataset.ReadAllAtPageById(1, 0, 10); | |||||
| auto images5 = ret4.second; | |||||
| MS_LOG(INFO) << "category id: 1, images count: " << images5.size(); | |||||
| auto pages_ptr_3 = std::make_shared<PAGES>(); | |||||
| status = dataset.ReadAllAtPageByName("822", 0, 10, &pages_ptr_3); | |||||
| MS_LOG(INFO) << "category field: 822, images count: " << pages_ptr_3->size(); | |||||
| auto pages_ptr_4 = std::make_shared<PAGES>(); | |||||
| status = dataset.ReadAllAtPageById(1, 0, 10, &pages_ptr_4); | |||||
| MS_LOG(INFO) << "category id: 1, images count: " << pages_ptr_4->size(); | |||||
| } | } | ||||
| TEST_F(TestShardSegment, TestReadAtPageByNameOfCategoryName) { | TEST_F(TestShardSegment, TestReadAtPageByNameOfCategoryName) { | ||||
| @@ -99,21 +108,28 @@ TEST_F(TestShardSegment, TestReadAtPageByNameOfCategoryName) { | |||||
| ShardSegment dataset; | ShardSegment dataset; | ||||
| dataset.Open({file_name}, true, 4); | dataset.Open({file_name}, true, 4); | ||||
| auto x = dataset.GetCategoryFields(); | |||||
| for (const auto &fields : x.second) { | |||||
| auto fields_ptr = std::make_shared<vector<std::string>>(); | |||||
| auto status = dataset.GetCategoryFields(&fields_ptr); | |||||
| for (const auto &fields : *fields_ptr) { | |||||
| MS_LOG(INFO) << "Get category field: " << fields; | MS_LOG(INFO) << "Get category field: " << fields; | ||||
| } | } | ||||
| string category_name = "82Cus"; | string category_name = "82Cus"; | ||||
| string category_field = "laabel_0"; | string category_field = "laabel_0"; | ||||
| ASSERT_TRUE(dataset.SetCategoryField("label") == SUCCESS); | |||||
| ASSERT_TRUE(dataset.SetCategoryField(category_field) == FAILED); | |||||
| status = dataset.SetCategoryField("label"); | |||||
| EXPECT_TRUE(status.IsOk()); | |||||
| status = dataset.SetCategoryField(category_field); | |||||
| EXPECT_FALSE(status.IsOk()); | |||||
| MS_LOG(INFO) << "Read category info: " << dataset.ReadCategoryInfo().second; | |||||
| std::shared_ptr<std::string> category_ptr; | |||||
| status = dataset.ReadCategoryInfo(&category_ptr); | |||||
| EXPECT_TRUE(status.IsOk()); | |||||
| MS_LOG(INFO) << "Read category info: " << *category_ptr; | |||||
| auto ret = dataset.ReadAtPageByName(category_name, 0, 10); | |||||
| EXPECT_TRUE(ret.first == FAILED); | |||||
| auto pages_ptr = std::make_shared<std::vector<std::vector<uint8_t>>>(); | |||||
| status = dataset.ReadAtPageByName(category_name, 0, 10, &pages_ptr); | |||||
| EXPECT_FALSE(status.IsOk()); | |||||
| } | } | ||||
| TEST_F(TestShardSegment, TestReadAtPageByIdOfCategoryId) { | TEST_F(TestShardSegment, TestReadAtPageByIdOfCategoryId) { | ||||
| @@ -123,19 +139,25 @@ TEST_F(TestShardSegment, TestReadAtPageByIdOfCategoryId) { | |||||
| ShardSegment dataset; | ShardSegment dataset; | ||||
| dataset.Open({file_name}, true, 4); | dataset.Open({file_name}, true, 4); | ||||
| auto x = dataset.GetCategoryFields(); | |||||
| for (const auto &fields : x.second) { | |||||
| auto fields_ptr = std::make_shared<vector<std::string>>(); | |||||
| auto status = dataset.GetCategoryFields(&fields_ptr); | |||||
| for (const auto &fields : *fields_ptr) { | |||||
| MS_LOG(INFO) << "Get category field: " << fields; | MS_LOG(INFO) << "Get category field: " << fields; | ||||
| } | } | ||||
| int64_t categoryId = 2251799813685247; | int64_t categoryId = 2251799813685247; | ||||
| MS_LOG(INFO) << "Input category id: " << categoryId; | MS_LOG(INFO) << "Input category id: " << categoryId; | ||||
| ASSERT_TRUE(dataset.SetCategoryField("label") == SUCCESS); | |||||
| MS_LOG(INFO) << "Read category info: " << dataset.ReadCategoryInfo().second; | |||||
| status = dataset.SetCategoryField("label"); | |||||
| EXPECT_TRUE(status.IsOk()); | |||||
| std::shared_ptr<std::string> category_ptr; | |||||
| status = dataset.ReadCategoryInfo(&category_ptr); | |||||
| EXPECT_TRUE(status.IsOk()); | |||||
| MS_LOG(INFO) << "Read category info: " << *category_ptr; | |||||
| auto ret2 = dataset.ReadAtPageById(categoryId, 0, 10); | |||||
| EXPECT_TRUE(ret2.first == FAILED); | |||||
| auto pages_ptr = std::make_shared<std::vector<std::vector<uint8_t>>>(); | |||||
| status = dataset.ReadAtPageById(categoryId, 0, 10, &pages_ptr); | |||||
| EXPECT_FALSE(status.IsOk()); | |||||
| } | } | ||||
| TEST_F(TestShardSegment, TestReadAtPageByIdOfPageNo) { | TEST_F(TestShardSegment, TestReadAtPageByIdOfPageNo) { | ||||
| @@ -145,19 +167,27 @@ TEST_F(TestShardSegment, TestReadAtPageByIdOfPageNo) { | |||||
| ShardSegment dataset; | ShardSegment dataset; | ||||
| dataset.Open({file_name}, true, 4); | dataset.Open({file_name}, true, 4); | ||||
| auto x = dataset.GetCategoryFields(); | |||||
| for (const auto &fields : x.second) { | |||||
| auto fields_ptr = std::make_shared<vector<std::string>>(); | |||||
| auto status = dataset.GetCategoryFields(&fields_ptr); | |||||
| for (const auto &fields : *fields_ptr) { | |||||
| MS_LOG(INFO) << "Get category field: " << fields; | MS_LOG(INFO) << "Get category field: " << fields; | ||||
| } | } | ||||
| int64_t page_no = 2251799813685247; | int64_t page_no = 2251799813685247; | ||||
| MS_LOG(INFO) << "Input page no: " << page_no; | MS_LOG(INFO) << "Input page no: " << page_no; | ||||
| ASSERT_TRUE(dataset.SetCategoryField("label") == SUCCESS); | |||||
| MS_LOG(INFO) << "Read category info: " << dataset.ReadCategoryInfo().second; | |||||
| status = dataset.SetCategoryField("label"); | |||||
| EXPECT_TRUE(status.IsOk()); | |||||
| auto ret2 = dataset.ReadAtPageById(1, page_no, 10); | |||||
| EXPECT_TRUE(ret2.first == FAILED); | |||||
| std::shared_ptr<std::string> category_ptr; | |||||
| status = dataset.ReadCategoryInfo(&category_ptr); | |||||
| EXPECT_TRUE(status.IsOk()); | |||||
| MS_LOG(INFO) << "Read category info: " << *category_ptr; | |||||
| auto pages_ptr = std::make_shared<std::vector<std::vector<uint8_t>>>(); | |||||
| status = dataset.ReadAtPageById(1, page_no, 10, &pages_ptr); | |||||
| EXPECT_FALSE(status.IsOk()); | |||||
| } | } | ||||
| TEST_F(TestShardSegment, TestReadAtPageByIdOfPageRows) { | TEST_F(TestShardSegment, TestReadAtPageByIdOfPageRows) { | ||||
| @@ -167,19 +197,26 @@ TEST_F(TestShardSegment, TestReadAtPageByIdOfPageRows) { | |||||
| ShardSegment dataset; | ShardSegment dataset; | ||||
| dataset.Open({file_name}, true, 4); | dataset.Open({file_name}, true, 4); | ||||
| auto x = dataset.GetCategoryFields(); | |||||
| for (const auto &fields : x.second) { | |||||
| auto fields_ptr = std::make_shared<vector<std::string>>(); | |||||
| auto status = dataset.GetCategoryFields(&fields_ptr); | |||||
| for (const auto &fields : *fields_ptr) { | |||||
| MS_LOG(INFO) << "Get category field: " << fields; | MS_LOG(INFO) << "Get category field: " << fields; | ||||
| } | } | ||||
| int64_t pageRows = 0; | int64_t pageRows = 0; | ||||
| MS_LOG(INFO) << "Input page rows: " << pageRows; | MS_LOG(INFO) << "Input page rows: " << pageRows; | ||||
| ASSERT_TRUE(dataset.SetCategoryField("label") == SUCCESS); | |||||
| MS_LOG(INFO) << "Read category info: " << dataset.ReadCategoryInfo().second; | |||||
| status = dataset.SetCategoryField("label"); | |||||
| EXPECT_TRUE(status.IsOk()); | |||||
| std::shared_ptr<std::string> category_ptr; | |||||
| status = dataset.ReadCategoryInfo(&category_ptr); | |||||
| EXPECT_TRUE(status.IsOk()); | |||||
| MS_LOG(INFO) << "Read category info: " << *category_ptr; | |||||
| auto ret2 = dataset.ReadAtPageById(1, 0, pageRows); | |||||
| EXPECT_TRUE(ret2.first == FAILED); | |||||
| auto pages_ptr = std::make_shared<std::vector<std::vector<uint8_t>>>(); | |||||
| status = dataset.ReadAtPageById(1, 0, pageRows, &pages_ptr); | |||||
| EXPECT_FALSE(status.IsOk()); | |||||
| } | } | ||||
| } // namespace mindrecord | } // namespace mindrecord | ||||
| @@ -60,8 +60,8 @@ TEST_F(TestShardWriter, TestShardWriterOneSample) { | |||||
| std::string filename = "./OneSample.shard01"; | std::string filename = "./OneSample.shard01"; | ||||
| ShardReader dataset; | ShardReader dataset; | ||||
| MSRStatus ret = dataset.Open({filename}, true, 4); | |||||
| ASSERT_EQ(ret, SUCCESS); | |||||
| auto status = dataset.Open({filename}, true, 4); | |||||
| EXPECT_TRUE(status.IsOk()); | |||||
| dataset.Launch(); | dataset.Launch(); | ||||
| while (true) { | while (true) { | ||||
| @@ -675,8 +675,8 @@ TEST_F(TestShardWriter, AllRawDataWrong) { | |||||
| fw.SetShardHeader(std::make_shared<mindrecord::ShardHeader>(header_data)); | fw.SetShardHeader(std::make_shared<mindrecord::ShardHeader>(header_data)); | ||||
| // write rawdata | // write rawdata | ||||
| MSRStatus res = fw.WriteRawData(rawdatas, bin_data); | |||||
| ASSERT_EQ(res, SUCCESS); | |||||
| auto status = fw.WriteRawData(rawdatas, bin_data); | |||||
| EXPECT_TRUE(status.IsOk()); | |||||
| for (const auto &filename : file_names) { | for (const auto &filename : file_names) { | ||||
| auto filename_db = filename + ".db"; | auto filename_db = filename + ".db"; | ||||
| remove(common::SafeCStr(filename_db)); | remove(common::SafeCStr(filename_db)); | ||||
| @@ -716,7 +716,8 @@ TEST_F(TestShardWriter, TestShardReaderStringAndNumberColumnInIndex) { | |||||
| fields.push_back(index_field2); | fields.push_back(index_field2); | ||||
| // add index to shardHeader | // add index to shardHeader | ||||
| ASSERT_EQ(header_data.AddIndexFields(fields), SUCCESS); | |||||
| auto status = header_data.AddIndexFields(fields); | |||||
| EXPECT_TRUE(status.IsOk()); | |||||
| MS_LOG(INFO) << "Init Index Fields Already."; | MS_LOG(INFO) << "Init Index Fields Already."; | ||||
| // load meta data | // load meta data | ||||
| @@ -736,28 +737,34 @@ TEST_F(TestShardWriter, TestShardReaderStringAndNumberColumnInIndex) { | |||||
| } | } | ||||
| mindrecord::ShardWriter fw_init; | mindrecord::ShardWriter fw_init; | ||||
| ASSERT_TRUE(fw_init.Open(file_names) == SUCCESS); | |||||
| status = fw_init.Open(file_names); | |||||
| EXPECT_TRUE(status.IsOk()); | |||||
| // set shardHeader | // set shardHeader | ||||
| ASSERT_TRUE(fw_init.SetShardHeader(std::make_shared<mindrecord::ShardHeader>(header_data)) == SUCCESS); | |||||
| status = fw_init.SetShardHeader(std::make_shared<mindrecord::ShardHeader>(header_data)); | |||||
| EXPECT_TRUE(status.IsOk()); | |||||
| // write raw data | // write raw data | ||||
| ASSERT_TRUE(fw_init.WriteRawData(rawdatas, bin_data) == SUCCESS); | |||||
| ASSERT_TRUE(fw_init.Commit() == SUCCESS); | |||||
| status = fw_init.WriteRawData(rawdatas, bin_data); | |||||
| EXPECT_TRUE(status.IsOk()); | |||||
| status = fw_init.Commit(); | |||||
| EXPECT_TRUE(status.IsOk()); | |||||
| // create the index file | // create the index file | ||||
| std::string filename = "./imagenet.shard01"; | std::string filename = "./imagenet.shard01"; | ||||
| mindrecord::ShardIndexGenerator sg{filename}; | mindrecord::ShardIndexGenerator sg{filename}; | ||||
| sg.Build(); | sg.Build(); | ||||
| ASSERT_TRUE(sg.WriteToDatabase() == SUCCESS); | |||||
| status = sg.WriteToDatabase(); | |||||
| EXPECT_TRUE(status.IsOk()); | |||||
| MS_LOG(INFO) << "Done create index"; | MS_LOG(INFO) << "Done create index"; | ||||
| // read the mindrecord file | // read the mindrecord file | ||||
| filename = "./imagenet.shard01"; | filename = "./imagenet.shard01"; | ||||
| auto column_list = std::vector<std::string>{"label", "file_name", "data"}; | auto column_list = std::vector<std::string>{"label", "file_name", "data"}; | ||||
| ShardReader dataset; | ShardReader dataset; | ||||
| MSRStatus ret = dataset.Open({filename}, true, 4, column_list); | |||||
| ASSERT_EQ(ret, SUCCESS); | |||||
| status = dataset.Open({filename}, true, 4, column_list); | |||||
| EXPECT_TRUE(status.IsOk()); | |||||
| dataset.Launch(); | dataset.Launch(); | ||||
| int count = 0; | int count = 0; | ||||
| @@ -822,28 +829,34 @@ TEST_F(TestShardWriter, TestShardNoBlob) { | |||||
| } | } | ||||
| mindrecord::ShardWriter fw_init; | mindrecord::ShardWriter fw_init; | ||||
| ASSERT_TRUE(fw_init.Open(file_names) == SUCCESS); | |||||
| auto status = fw_init.Open(file_names); | |||||
| EXPECT_TRUE(status.IsOk()); | |||||
| // set shardHeader | // set shardHeader | ||||
| ASSERT_TRUE(fw_init.SetShardHeader(std::make_shared<mindrecord::ShardHeader>(header_data)) == SUCCESS); | |||||
| status = fw_init.SetShardHeader(std::make_shared<mindrecord::ShardHeader>(header_data)); | |||||
| EXPECT_TRUE(status.IsOk()); | |||||
| // write raw data | // write raw data | ||||
| ASSERT_TRUE(fw_init.WriteRawData(rawdatas, bin_data) == SUCCESS); | |||||
| ASSERT_TRUE(fw_init.Commit() == SUCCESS); | |||||
| status = fw_init.WriteRawData(rawdatas, bin_data); | |||||
| EXPECT_TRUE(status.IsOk()); | |||||
| status = fw_init.Commit(); | |||||
| EXPECT_TRUE(status.IsOk()); | |||||
| // create the index file | // create the index file | ||||
| std::string filename = "./imagenet.shard01"; | std::string filename = "./imagenet.shard01"; | ||||
| mindrecord::ShardIndexGenerator sg{filename}; | mindrecord::ShardIndexGenerator sg{filename}; | ||||
| sg.Build(); | sg.Build(); | ||||
| ASSERT_TRUE(sg.WriteToDatabase() == SUCCESS); | |||||
| status = sg.WriteToDatabase(); | |||||
| EXPECT_TRUE(status.IsOk()); | |||||
| MS_LOG(INFO) << "Done create index"; | MS_LOG(INFO) << "Done create index"; | ||||
| // read the mindrecord file | // read the mindrecord file | ||||
| filename = "./imagenet.shard01"; | filename = "./imagenet.shard01"; | ||||
| auto column_list = std::vector<std::string>{"label", "file_name"}; | auto column_list = std::vector<std::string>{"label", "file_name"}; | ||||
| ShardReader dataset; | ShardReader dataset; | ||||
| MSRStatus ret = dataset.Open({filename}, true, 4, column_list); | |||||
| ASSERT_EQ(ret, SUCCESS); | |||||
| status = dataset.Open({filename}, true, 4, column_list); | |||||
| EXPECT_TRUE(status.IsOk()); | |||||
| dataset.Launch(); | dataset.Launch(); | ||||
| int count = 0; | int count = 0; | ||||
| @@ -896,7 +909,8 @@ TEST_F(TestShardWriter, TestShardReaderStringAndNumberNotColumnInIndex) { | |||||
| fields.push_back(index_field1); | fields.push_back(index_field1); | ||||
| // add index to shardHeader | // add index to shardHeader | ||||
| ASSERT_EQ(header_data.AddIndexFields(fields), SUCCESS); | |||||
| auto status = header_data.AddIndexFields(fields); | |||||
| EXPECT_TRUE(status.IsOk()); | |||||
| MS_LOG(INFO) << "Init Index Fields Already."; | MS_LOG(INFO) << "Init Index Fields Already."; | ||||
| // load meta data | // load meta data | ||||
| @@ -916,28 +930,34 @@ TEST_F(TestShardWriter, TestShardReaderStringAndNumberNotColumnInIndex) { | |||||
| } | } | ||||
| mindrecord::ShardWriter fw_init; | mindrecord::ShardWriter fw_init; | ||||
| ASSERT_TRUE(fw_init.Open(file_names) == SUCCESS); | |||||
| status = fw_init.Open(file_names); | |||||
| EXPECT_TRUE(status.IsOk()); | |||||
| // set shardHeader | // set shardHeader | ||||
| ASSERT_TRUE(fw_init.SetShardHeader(std::make_shared<mindrecord::ShardHeader>(header_data)) == SUCCESS); | |||||
| status = fw_init.SetShardHeader(std::make_shared<mindrecord::ShardHeader>(header_data)); | |||||
| EXPECT_TRUE(status.IsOk()); | |||||
| // write raw data | // write raw data | ||||
| ASSERT_TRUE(fw_init.WriteRawData(rawdatas, bin_data) == SUCCESS); | |||||
| ASSERT_TRUE(fw_init.Commit() == SUCCESS); | |||||
| status = fw_init.WriteRawData(rawdatas, bin_data); | |||||
| EXPECT_TRUE(status.IsOk()); | |||||
| status = fw_init.Commit(); | |||||
| EXPECT_TRUE(status.IsOk()); | |||||
| // create the index file | // create the index file | ||||
| std::string filename = "./imagenet.shard01"; | std::string filename = "./imagenet.shard01"; | ||||
| mindrecord::ShardIndexGenerator sg{filename}; | mindrecord::ShardIndexGenerator sg{filename}; | ||||
| sg.Build(); | sg.Build(); | ||||
| ASSERT_TRUE(sg.WriteToDatabase() == SUCCESS); | |||||
| status = sg.WriteToDatabase(); | |||||
| EXPECT_TRUE(status.IsOk()); | |||||
| MS_LOG(INFO) << "Done create index"; | MS_LOG(INFO) << "Done create index"; | ||||
| // read the mindrecord file | // read the mindrecord file | ||||
| filename = "./imagenet.shard01"; | filename = "./imagenet.shard01"; | ||||
| auto column_list = std::vector<std::string>{"label", "data"}; | auto column_list = std::vector<std::string>{"label", "data"}; | ||||
| ShardReader dataset; | ShardReader dataset; | ||||
| MSRStatus ret = dataset.Open({filename}, true, 4, column_list); | |||||
| ASSERT_EQ(ret, SUCCESS); | |||||
| status = dataset.Open({filename}, true, 4, column_list); | |||||
| EXPECT_TRUE(status.IsOk()); | |||||
| dataset.Launch(); | dataset.Launch(); | ||||
| int count = 0; | int count = 0; | ||||
| @@ -1043,8 +1063,8 @@ TEST_F(TestShardWriter, TestShardWriter10Sample40Shard) { | |||||
| filename = "./TenSampleFortyShard.shard01"; | filename = "./TenSampleFortyShard.shard01"; | ||||
| ShardReader dataset; | ShardReader dataset; | ||||
| MSRStatus ret = dataset.Open({filename}, true, 4); | |||||
| ASSERT_EQ(ret, SUCCESS); | |||||
| auto status = dataset.Open({filename}, true, 4); | |||||
| EXPECT_TRUE(status.IsOk()); | |||||
| dataset.Launch(); | dataset.Launch(); | ||||
| int count = 0; | int count = 0; | ||||
| @@ -95,7 +95,7 @@ def test_invalid_mindrecord(): | |||||
| f.write('just for test') | f.write('just for test') | ||||
| columns_list = ["data", "file_name", "label"] | columns_list = ["data", "file_name", "label"] | ||||
| num_readers = 4 | num_readers = 4 | ||||
| with pytest.raises(Exception, match="MindRecordOp init failed"): | |||||
| with pytest.raises(RuntimeError, match="Unexpected error. Invalid file content. path:"): | |||||
| data_set = ds.MindDataset('dummy.mindrecord', columns_list, num_readers) | data_set = ds.MindDataset('dummy.mindrecord', columns_list, num_readers) | ||||
| num_iter = 0 | num_iter = 0 | ||||
| for _ in data_set.create_dict_iterator(num_epochs=1, output_numpy=True): | for _ in data_set.create_dict_iterator(num_epochs=1, output_numpy=True): | ||||
| @@ -114,7 +114,7 @@ def test_minddataset_lack_db(): | |||||
| os.remove("{}.db".format(CV_FILE_NAME)) | os.remove("{}.db".format(CV_FILE_NAME)) | ||||
| columns_list = ["data", "file_name", "label"] | columns_list = ["data", "file_name", "label"] | ||||
| num_readers = 4 | num_readers = 4 | ||||
| with pytest.raises(Exception, match="MindRecordOp init failed"): | |||||
| with pytest.raises(RuntimeError, match="Unexpected error. Invalid database file:"): | |||||
| data_set = ds.MindDataset(CV_FILE_NAME, columns_list, num_readers) | data_set = ds.MindDataset(CV_FILE_NAME, columns_list, num_readers) | ||||
| num_iter = 0 | num_iter = 0 | ||||
| for _ in data_set.create_dict_iterator(num_epochs=1, output_numpy=True): | for _ in data_set.create_dict_iterator(num_epochs=1, output_numpy=True): | ||||
| @@ -133,7 +133,7 @@ def test_cv_minddataset_pk_sample_error_class_column(): | |||||
| columns_list = ["data", "file_name", "label"] | columns_list = ["data", "file_name", "label"] | ||||
| num_readers = 4 | num_readers = 4 | ||||
| sampler = ds.PKSampler(5, None, True, 'no_exist_column') | sampler = ds.PKSampler(5, None, True, 'no_exist_column') | ||||
| with pytest.raises(Exception, match="MindRecordOp launch failed"): | |||||
| with pytest.raises(RuntimeError, match="Unexpected error. Failed to launch read threads."): | |||||
| data_set = ds.MindDataset(CV_FILE_NAME, columns_list, num_readers, sampler=sampler) | data_set = ds.MindDataset(CV_FILE_NAME, columns_list, num_readers, sampler=sampler) | ||||
| num_iter = 0 | num_iter = 0 | ||||
| for _ in data_set.create_dict_iterator(num_epochs=1, output_numpy=True): | for _ in data_set.create_dict_iterator(num_epochs=1, output_numpy=True): | ||||
| @@ -162,7 +162,7 @@ def test_cv_minddataset_reader_different_schema(): | |||||
| create_diff_schema_cv_mindrecord(1) | create_diff_schema_cv_mindrecord(1) | ||||
| columns_list = ["data", "label"] | columns_list = ["data", "label"] | ||||
| num_readers = 4 | num_readers = 4 | ||||
| with pytest.raises(Exception, match="MindRecordOp init failed"): | |||||
| with pytest.raises(RuntimeError, match="Mindrecord files meta information is different"): | |||||
| data_set = ds.MindDataset([CV_FILE_NAME, CV1_FILE_NAME], columns_list, | data_set = ds.MindDataset([CV_FILE_NAME, CV1_FILE_NAME], columns_list, | ||||
| num_readers) | num_readers) | ||||
| num_iter = 0 | num_iter = 0 | ||||
| @@ -179,7 +179,7 @@ def test_cv_minddataset_reader_different_page_size(): | |||||
| create_diff_page_size_cv_mindrecord(1) | create_diff_page_size_cv_mindrecord(1) | ||||
| columns_list = ["data", "label"] | columns_list = ["data", "label"] | ||||
| num_readers = 4 | num_readers = 4 | ||||
| with pytest.raises(Exception, match="MindRecordOp init failed"): | |||||
| with pytest.raises(RuntimeError, match="Mindrecord files meta information is different"): | |||||
| data_set = ds.MindDataset([CV_FILE_NAME, CV1_FILE_NAME], columns_list, | data_set = ds.MindDataset([CV_FILE_NAME, CV1_FILE_NAME], columns_list, | ||||
| num_readers) | num_readers) | ||||
| num_iter = 0 | num_iter = 0 | ||||
| @@ -19,7 +19,6 @@ import pytest | |||||
| from mindspore import log as logger | from mindspore import log as logger | ||||
| from mindspore.mindrecord import Cifar100ToMR | from mindspore.mindrecord import Cifar100ToMR | ||||
| from mindspore.mindrecord import FileReader | from mindspore.mindrecord import FileReader | ||||
| from mindspore.mindrecord import MRMOpenError | |||||
| from mindspore.mindrecord import SUCCESS | from mindspore.mindrecord import SUCCESS | ||||
| CIFAR100_DIR = "../data/mindrecord/testCifar100Data" | CIFAR100_DIR = "../data/mindrecord/testCifar100Data" | ||||
| @@ -119,8 +118,8 @@ def test_cifar100_to_mindrecord_directory(fixture_file): | |||||
| test transform cifar10 dataset to mindrecord | test transform cifar10 dataset to mindrecord | ||||
| when destination path is directory. | when destination path is directory. | ||||
| """ | """ | ||||
| with pytest.raises(MRMOpenError, | |||||
| match="MindRecord File could not open successfully"): | |||||
| with pytest.raises(RuntimeError, | |||||
| match="MindRecord file already existed, please delete file:"): | |||||
| cifar100_transformer = Cifar100ToMR(CIFAR100_DIR, | cifar100_transformer = Cifar100ToMR(CIFAR100_DIR, | ||||
| CIFAR100_DIR) | CIFAR100_DIR) | ||||
| cifar100_transformer.transform() | cifar100_transformer.transform() | ||||
| @@ -130,8 +129,8 @@ def test_cifar100_to_mindrecord_filename_equals_cifar100(fixture_file): | |||||
| test transform cifar10 dataset to mindrecord | test transform cifar10 dataset to mindrecord | ||||
| when destination path equals source path. | when destination path equals source path. | ||||
| """ | """ | ||||
| with pytest.raises(MRMOpenError, | |||||
| match="MindRecord File could not open successfully"): | |||||
| with pytest.raises(RuntimeError, | |||||
| match="indRecord file already existed, please delete file:"): | |||||
| cifar100_transformer = Cifar100ToMR(CIFAR100_DIR, | cifar100_transformer = Cifar100ToMR(CIFAR100_DIR, | ||||
| CIFAR100_DIR + "/train") | CIFAR100_DIR + "/train") | ||||
| cifar100_transformer.transform() | cifar100_transformer.transform() | ||||
| @@ -19,7 +19,7 @@ import pytest | |||||
| from mindspore import log as logger | from mindspore import log as logger | ||||
| from mindspore.mindrecord import Cifar10ToMR | from mindspore.mindrecord import Cifar10ToMR | ||||
| from mindspore.mindrecord import FileReader | from mindspore.mindrecord import FileReader | ||||
| from mindspore.mindrecord import MRMOpenError, SUCCESS | |||||
| from mindspore.mindrecord import SUCCESS | |||||
| CIFAR10_DIR = "../data/mindrecord/testCifar10Data" | CIFAR10_DIR = "../data/mindrecord/testCifar10Data" | ||||
| MINDRECORD_FILE = "./cifar10.mindrecord" | MINDRECORD_FILE = "./cifar10.mindrecord" | ||||
| @@ -146,8 +146,8 @@ def test_cifar10_to_mindrecord_directory(fixture_file): | |||||
| test transform cifar10 dataset to mindrecord | test transform cifar10 dataset to mindrecord | ||||
| when destination path is directory. | when destination path is directory. | ||||
| """ | """ | ||||
| with pytest.raises(MRMOpenError, | |||||
| match="MindRecord File could not open successfully"): | |||||
| with pytest.raises(RuntimeError, | |||||
| match="MindRecord file already existed, please delete file:"): | |||||
| cifar10_transformer = Cifar10ToMR(CIFAR10_DIR, CIFAR10_DIR) | cifar10_transformer = Cifar10ToMR(CIFAR10_DIR, CIFAR10_DIR) | ||||
| cifar10_transformer.transform() | cifar10_transformer.transform() | ||||
| @@ -157,8 +157,8 @@ def test_cifar10_to_mindrecord_filename_equals_cifar10(): | |||||
| test transform cifar10 dataset to mindrecord | test transform cifar10 dataset to mindrecord | ||||
| when destination path equals source path. | when destination path equals source path. | ||||
| """ | """ | ||||
| with pytest.raises(MRMOpenError, | |||||
| match="MindRecord File could not open successfully"): | |||||
| with pytest.raises(RuntimeError, | |||||
| match="MindRecord file already existed, please delete file:"): | |||||
| cifar10_transformer = Cifar10ToMR(CIFAR10_DIR, | cifar10_transformer = Cifar10ToMR(CIFAR10_DIR, | ||||
| CIFAR10_DIR + "/data_batch_0") | CIFAR10_DIR + "/data_batch_0") | ||||
| cifar10_transformer.transform() | cifar10_transformer.transform() | ||||
| @@ -21,8 +21,7 @@ from utils import get_data | |||||
| from mindspore import log as logger | from mindspore import log as logger | ||||
| from mindspore.mindrecord import FileWriter, FileReader, MindPage, SUCCESS | from mindspore.mindrecord import FileWriter, FileReader, MindPage, SUCCESS | ||||
| from mindspore.mindrecord import MRMOpenError, MRMGenerateIndexError, ParamValueError, MRMGetMetaError, \ | |||||
| MRMFetchDataError | |||||
| from mindspore.mindrecord import ParamValueError, MRMGetMetaError | |||||
| CV_FILE_NAME = "./imagenet.mindrecord" | CV_FILE_NAME = "./imagenet.mindrecord" | ||||
| NLP_FILE_NAME = "./aclImdb.mindrecord" | NLP_FILE_NAME = "./aclImdb.mindrecord" | ||||
| @@ -106,21 +105,19 @@ def create_cv_mindrecord(files_num): | |||||
| def test_lack_partition_and_db(): | def test_lack_partition_and_db(): | ||||
| """test file reader when mindrecord file does not exist.""" | """test file reader when mindrecord file does not exist.""" | ||||
| with pytest.raises(MRMOpenError) as err: | |||||
| with pytest.raises(RuntimeError) as err: | |||||
| reader = FileReader('dummy.mindrecord') | reader = FileReader('dummy.mindrecord') | ||||
| reader.close() | reader.close() | ||||
| assert '[MRMOpenError]: MindRecord File could not open successfully.' \ | |||||
| in str(err.value) | |||||
| assert 'Unexpected error. Invalid file path:' in str(err.value) | |||||
| def test_lack_db(fixture_cv_file): | def test_lack_db(fixture_cv_file): | ||||
| """test file reader when db file does not exist.""" | """test file reader when db file does not exist.""" | ||||
| create_cv_mindrecord(1) | create_cv_mindrecord(1) | ||||
| os.remove("{}.db".format(CV_FILE_NAME)) | os.remove("{}.db".format(CV_FILE_NAME)) | ||||
| with pytest.raises(MRMOpenError) as err: | |||||
| with pytest.raises(RuntimeError) as err: | |||||
| reader = FileReader(CV_FILE_NAME) | reader = FileReader(CV_FILE_NAME) | ||||
| reader.close() | reader.close() | ||||
| assert '[MRMOpenError]: MindRecord File could not open successfully.' \ | |||||
| in str(err.value) | |||||
| assert 'Unexpected error. Invalid database file:' in str(err.value) | |||||
| def test_lack_some_partition_and_db(fixture_cv_file): | def test_lack_some_partition_and_db(fixture_cv_file): | ||||
| """test file reader when some partition and db do not exist.""" | """test file reader when some partition and db do not exist.""" | ||||
| @@ -129,11 +126,10 @@ def test_lack_some_partition_and_db(fixture_cv_file): | |||||
| for x in range(FILES_NUM)] | for x in range(FILES_NUM)] | ||||
| os.remove("{}".format(paths[3])) | os.remove("{}".format(paths[3])) | ||||
| os.remove("{}.db".format(paths[3])) | os.remove("{}.db".format(paths[3])) | ||||
| with pytest.raises(MRMOpenError) as err: | |||||
| with pytest.raises(RuntimeError) as err: | |||||
| reader = FileReader(CV_FILE_NAME + "0") | reader = FileReader(CV_FILE_NAME + "0") | ||||
| reader.close() | reader.close() | ||||
| assert '[MRMOpenError]: MindRecord File could not open successfully.' \ | |||||
| in str(err.value) | |||||
| assert 'Unexpected error. Invalid file path:' in str(err.value) | |||||
| def test_lack_some_partition_first(fixture_cv_file): | def test_lack_some_partition_first(fixture_cv_file): | ||||
| """test file reader when first partition does not exist.""" | """test file reader when first partition does not exist.""" | ||||
| @@ -141,11 +137,10 @@ def test_lack_some_partition_first(fixture_cv_file): | |||||
| paths = ["{}{}".format(CV_FILE_NAME, str(x).rjust(1, '0')) | paths = ["{}{}".format(CV_FILE_NAME, str(x).rjust(1, '0')) | ||||
| for x in range(FILES_NUM)] | for x in range(FILES_NUM)] | ||||
| os.remove("{}".format(paths[0])) | os.remove("{}".format(paths[0])) | ||||
| with pytest.raises(MRMOpenError) as err: | |||||
| with pytest.raises(RuntimeError) as err: | |||||
| reader = FileReader(CV_FILE_NAME + "0") | reader = FileReader(CV_FILE_NAME + "0") | ||||
| reader.close() | reader.close() | ||||
| assert '[MRMOpenError]: MindRecord File could not open successfully.' \ | |||||
| in str(err.value) | |||||
| assert 'Unexpected error. Invalid file path:' in str(err.value) | |||||
| def test_lack_some_partition_middle(fixture_cv_file): | def test_lack_some_partition_middle(fixture_cv_file): | ||||
| """test file reader when some partition does not exist.""" | """test file reader when some partition does not exist.""" | ||||
| @@ -153,11 +148,10 @@ def test_lack_some_partition_middle(fixture_cv_file): | |||||
| paths = ["{}{}".format(CV_FILE_NAME, str(x).rjust(1, '0')) | paths = ["{}{}".format(CV_FILE_NAME, str(x).rjust(1, '0')) | ||||
| for x in range(FILES_NUM)] | for x in range(FILES_NUM)] | ||||
| os.remove("{}".format(paths[1])) | os.remove("{}".format(paths[1])) | ||||
| with pytest.raises(MRMOpenError) as err: | |||||
| with pytest.raises(RuntimeError) as err: | |||||
| reader = FileReader(CV_FILE_NAME + "0") | reader = FileReader(CV_FILE_NAME + "0") | ||||
| reader.close() | reader.close() | ||||
| assert '[MRMOpenError]: MindRecord File could not open successfully.' \ | |||||
| in str(err.value) | |||||
| assert 'Unexpected error. Invalid file path:' in str(err.value) | |||||
| def test_lack_some_partition_last(fixture_cv_file): | def test_lack_some_partition_last(fixture_cv_file): | ||||
| """test file reader when last partition does not exist.""" | """test file reader when last partition does not exist.""" | ||||
| @@ -165,11 +159,10 @@ def test_lack_some_partition_last(fixture_cv_file): | |||||
| paths = ["{}{}".format(CV_FILE_NAME, str(x).rjust(1, '0')) | paths = ["{}{}".format(CV_FILE_NAME, str(x).rjust(1, '0')) | ||||
| for x in range(FILES_NUM)] | for x in range(FILES_NUM)] | ||||
| os.remove("{}".format(paths[3])) | os.remove("{}".format(paths[3])) | ||||
| with pytest.raises(MRMOpenError) as err: | |||||
| with pytest.raises(RuntimeError) as err: | |||||
| reader = FileReader(CV_FILE_NAME + "0") | reader = FileReader(CV_FILE_NAME + "0") | ||||
| reader.close() | reader.close() | ||||
| assert '[MRMOpenError]: MindRecord File could not open successfully.' \ | |||||
| in str(err.value) | |||||
| assert 'Unexpected error. Invalid file path:' in str(err.value) | |||||
| def test_mindpage_lack_some_partition(fixture_cv_file): | def test_mindpage_lack_some_partition(fixture_cv_file): | ||||
| """test page reader when some partition does not exist.""" | """test page reader when some partition does not exist.""" | ||||
| @@ -177,10 +170,9 @@ def test_mindpage_lack_some_partition(fixture_cv_file): | |||||
| paths = ["{}{}".format(CV_FILE_NAME, str(x).rjust(1, '0')) | paths = ["{}{}".format(CV_FILE_NAME, str(x).rjust(1, '0')) | ||||
| for x in range(FILES_NUM)] | for x in range(FILES_NUM)] | ||||
| os.remove("{}".format(paths[0])) | os.remove("{}".format(paths[0])) | ||||
| with pytest.raises(MRMOpenError) as err: | |||||
| with pytest.raises(RuntimeError) as err: | |||||
| MindPage(CV_FILE_NAME + "0") | MindPage(CV_FILE_NAME + "0") | ||||
| assert '[MRMOpenError]: MindRecord File could not open successfully.' \ | |||||
| in str(err.value) | |||||
| assert 'Unexpected error. Invalid file path:' in str(err.value) | |||||
| def test_lack_some_db(fixture_cv_file): | def test_lack_some_db(fixture_cv_file): | ||||
| """test file reader when some db does not exist.""" | """test file reader when some db does not exist.""" | ||||
| @@ -188,11 +180,10 @@ def test_lack_some_db(fixture_cv_file): | |||||
| paths = ["{}{}".format(CV_FILE_NAME, str(x).rjust(1, '0')) | paths = ["{}{}".format(CV_FILE_NAME, str(x).rjust(1, '0')) | ||||
| for x in range(FILES_NUM)] | for x in range(FILES_NUM)] | ||||
| os.remove("{}.db".format(paths[3])) | os.remove("{}.db".format(paths[3])) | ||||
| with pytest.raises(MRMOpenError) as err: | |||||
| with pytest.raises(RuntimeError) as err: | |||||
| reader = FileReader(CV_FILE_NAME + "0") | reader = FileReader(CV_FILE_NAME + "0") | ||||
| reader.close() | reader.close() | ||||
| assert '[MRMOpenError]: MindRecord File could not open successfully.' \ | |||||
| in str(err.value) | |||||
| assert 'Unexpected error. Invalid database file:' in str(err.value) | |||||
| def test_invalid_mindrecord(): | def test_invalid_mindrecord(): | ||||
| @@ -200,10 +191,9 @@ def test_invalid_mindrecord(): | |||||
| with open(CV_FILE_NAME, 'w') as f: | with open(CV_FILE_NAME, 'w') as f: | ||||
| dummy = 's' * 100 | dummy = 's' * 100 | ||||
| f.write(dummy) | f.write(dummy) | ||||
| with pytest.raises(MRMOpenError) as err: | |||||
| with pytest.raises(RuntimeError) as err: | |||||
| FileReader(CV_FILE_NAME) | FileReader(CV_FILE_NAME) | ||||
| assert '[MRMOpenError]: MindRecord File could not open successfully.' \ | |||||
| in str(err.value) | |||||
| assert 'Unexpected error. Invalid file content. path:' in str(err.value) | |||||
| os.remove(CV_FILE_NAME) | os.remove(CV_FILE_NAME) | ||||
| def test_invalid_db(fixture_cv_file): | def test_invalid_db(fixture_cv_file): | ||||
| @@ -212,27 +202,26 @@ def test_invalid_db(fixture_cv_file): | |||||
| os.remove("imagenet.mindrecord.db") | os.remove("imagenet.mindrecord.db") | ||||
| with open('imagenet.mindrecord.db', 'w') as f: | with open('imagenet.mindrecord.db', 'w') as f: | ||||
| f.write('just for test') | f.write('just for test') | ||||
| with pytest.raises(MRMOpenError) as err: | |||||
| with pytest.raises(RuntimeError) as err: | |||||
| FileReader('imagenet.mindrecord') | FileReader('imagenet.mindrecord') | ||||
| assert '[MRMOpenError]: MindRecord File could not open successfully.' \ | |||||
| in str(err.value) | |||||
| assert 'Unexpected error. Error in execute sql:' in str(err.value) | |||||
| def test_overwrite_invalid_mindrecord(fixture_cv_file): | def test_overwrite_invalid_mindrecord(fixture_cv_file): | ||||
| """test file writer when overwrite invalid mindreocrd file.""" | """test file writer when overwrite invalid mindreocrd file.""" | ||||
| with open(CV_FILE_NAME, 'w') as f: | with open(CV_FILE_NAME, 'w') as f: | ||||
| f.write('just for test') | f.write('just for test') | ||||
| with pytest.raises(MRMOpenError) as err: | |||||
| with pytest.raises(RuntimeError) as err: | |||||
| create_cv_mindrecord(1) | create_cv_mindrecord(1) | ||||
| assert '[MRMOpenError]: MindRecord File could not open successfully.' \ | |||||
| assert 'Unexpected error. MindRecord file already existed, please delete file:' \ | |||||
| in str(err.value) | in str(err.value) | ||||
| def test_overwrite_invalid_db(fixture_cv_file): | def test_overwrite_invalid_db(fixture_cv_file): | ||||
| """test file writer when overwrite invalid db file.""" | """test file writer when overwrite invalid db file.""" | ||||
| with open('imagenet.mindrecord.db', 'w') as f: | with open('imagenet.mindrecord.db', 'w') as f: | ||||
| f.write('just for test') | f.write('just for test') | ||||
| with pytest.raises(MRMGenerateIndexError) as err: | |||||
| with pytest.raises(RuntimeError) as err: | |||||
| create_cv_mindrecord(1) | create_cv_mindrecord(1) | ||||
| assert '[MRMGenerateIndexError]: Failed to generate index.' in str(err.value) | |||||
| assert 'Unexpected error. Failed to write data to db.' in str(err.value) | |||||
| def test_read_after_close(fixture_cv_file): | def test_read_after_close(fixture_cv_file): | ||||
| """test file reader when close read.""" | """test file reader when close read.""" | ||||
| @@ -302,7 +291,7 @@ def test_mindpage_pageno_pagesize_not_int(fixture_cv_file): | |||||
| with pytest.raises(ParamValueError): | with pytest.raises(ParamValueError): | ||||
| reader.read_at_page_by_name("822", 0, "qwer") | reader.read_at_page_by_name("822", 0, "qwer") | ||||
| with pytest.raises(MRMFetchDataError, match="Failed to fetch data by category."): | |||||
| with pytest.raises(RuntimeError, match="Unexpected error. Invalid category id:"): | |||||
| reader.read_at_page_by_id(99999, 0, 1) | reader.read_at_page_by_id(99999, 0, 1) | ||||
| @@ -320,10 +309,10 @@ def test_mindpage_filename_not_exist(fixture_cv_file): | |||||
| info = reader.read_category_info() | info = reader.read_category_info() | ||||
| logger.info("category info: {}".format(info)) | logger.info("category info: {}".format(info)) | ||||
| with pytest.raises(MRMFetchDataError): | |||||
| with pytest.raises(RuntimeError, match="Unexpected error. Invalid category id:"): | |||||
| reader.read_at_page_by_id(9999, 0, 1) | reader.read_at_page_by_id(9999, 0, 1) | ||||
| with pytest.raises(MRMFetchDataError): | |||||
| with pytest.raises(RuntimeError, match="Unexpected error. Invalid category name."): | |||||
| reader.read_at_page_by_name("abc.jpg", 0, 1) | reader.read_at_page_by_name("abc.jpg", 0, 1) | ||||
| with pytest.raises(ParamValueError): | with pytest.raises(ParamValueError): | ||||
| @@ -475,7 +464,7 @@ def test_write_with_invalid_data(): | |||||
| mindrecord_file_name = "test.mindrecord" | mindrecord_file_name = "test.mindrecord" | ||||
| # field: file_name => filename | # field: file_name => filename | ||||
| with pytest.raises(Exception, match="Failed to write dataset"): | |||||
| with pytest.raises(RuntimeError, match="Unexpected error. Data size is not positive."): | |||||
| remove_one_file(mindrecord_file_name) | remove_one_file(mindrecord_file_name) | ||||
| remove_one_file(mindrecord_file_name + ".db") | remove_one_file(mindrecord_file_name + ".db") | ||||
| @@ -510,7 +499,7 @@ def test_write_with_invalid_data(): | |||||
| writer.commit() | writer.commit() | ||||
| # field: mask => masks | # field: mask => masks | ||||
| with pytest.raises(Exception, match="Failed to write dataset"): | |||||
| with pytest.raises(RuntimeError, match="Unexpected error. Data size is not positive."): | |||||
| remove_one_file(mindrecord_file_name) | remove_one_file(mindrecord_file_name) | ||||
| remove_one_file(mindrecord_file_name + ".db") | remove_one_file(mindrecord_file_name + ".db") | ||||
| @@ -545,7 +534,7 @@ def test_write_with_invalid_data(): | |||||
| writer.commit() | writer.commit() | ||||
| # field: data => image | # field: data => image | ||||
| with pytest.raises(Exception, match="Failed to write dataset"): | |||||
| with pytest.raises(RuntimeError, match="Unexpected error. Data size is not positive."): | |||||
| remove_one_file(mindrecord_file_name) | remove_one_file(mindrecord_file_name) | ||||
| remove_one_file(mindrecord_file_name + ".db") | remove_one_file(mindrecord_file_name + ".db") | ||||
| @@ -580,7 +569,7 @@ def test_write_with_invalid_data(): | |||||
| writer.commit() | writer.commit() | ||||
| # field: label => labels | # field: label => labels | ||||
| with pytest.raises(Exception, match="Failed to write dataset"): | |||||
| with pytest.raises(RuntimeError, match="Unexpected error. Data size is not positive."): | |||||
| remove_one_file(mindrecord_file_name) | remove_one_file(mindrecord_file_name) | ||||
| remove_one_file(mindrecord_file_name + ".db") | remove_one_file(mindrecord_file_name + ".db") | ||||
| @@ -615,7 +604,7 @@ def test_write_with_invalid_data(): | |||||
| writer.commit() | writer.commit() | ||||
| # field: score => scores | # field: score => scores | ||||
| with pytest.raises(Exception, match="Failed to write dataset"): | |||||
| with pytest.raises(RuntimeError, match="Unexpected error. Data size is not positive."): | |||||
| remove_one_file(mindrecord_file_name) | remove_one_file(mindrecord_file_name) | ||||
| remove_one_file(mindrecord_file_name + ".db") | remove_one_file(mindrecord_file_name + ".db") | ||||
| @@ -650,7 +639,7 @@ def test_write_with_invalid_data(): | |||||
| writer.commit() | writer.commit() | ||||
| # string type with int value | # string type with int value | ||||
| with pytest.raises(Exception, match="Failed to write dataset"): | |||||
| with pytest.raises(RuntimeError, match="Unexpected error. Data size is not positive."): | |||||
| remove_one_file(mindrecord_file_name) | remove_one_file(mindrecord_file_name) | ||||
| remove_one_file(mindrecord_file_name + ".db") | remove_one_file(mindrecord_file_name + ".db") | ||||
| @@ -685,7 +674,7 @@ def test_write_with_invalid_data(): | |||||
| writer.commit() | writer.commit() | ||||
| # field with int64 type, but the real data is string | # field with int64 type, but the real data is string | ||||
| with pytest.raises(Exception, match="Failed to write dataset"): | |||||
| with pytest.raises(RuntimeError, match="Unexpected error. Data size is not positive."): | |||||
| remove_one_file(mindrecord_file_name) | remove_one_file(mindrecord_file_name) | ||||
| remove_one_file(mindrecord_file_name + ".db") | remove_one_file(mindrecord_file_name + ".db") | ||||
| @@ -720,7 +709,7 @@ def test_write_with_invalid_data(): | |||||
| writer.commit() | writer.commit() | ||||
| # bytes field is string | # bytes field is string | ||||
| with pytest.raises(Exception, match="Failed to write dataset"): | |||||
| with pytest.raises(RuntimeError, match="Unexpected error. Data size is not positive."): | |||||
| remove_one_file(mindrecord_file_name) | remove_one_file(mindrecord_file_name) | ||||
| remove_one_file(mindrecord_file_name + ".db") | remove_one_file(mindrecord_file_name + ".db") | ||||
| @@ -755,7 +744,7 @@ def test_write_with_invalid_data(): | |||||
| writer.commit() | writer.commit() | ||||
| # field is not numpy type | # field is not numpy type | ||||
| with pytest.raises(Exception, match="Failed to write dataset"): | |||||
| with pytest.raises(RuntimeError, match="Unexpected error. Data size is not positive."): | |||||
| remove_one_file(mindrecord_file_name) | remove_one_file(mindrecord_file_name) | ||||
| remove_one_file(mindrecord_file_name + ".db") | remove_one_file(mindrecord_file_name + ".db") | ||||
| @@ -790,7 +779,7 @@ def test_write_with_invalid_data(): | |||||
| writer.commit() | writer.commit() | ||||
| # not enough field | # not enough field | ||||
| with pytest.raises(Exception, match="Failed to write dataset"): | |||||
| with pytest.raises(RuntimeError, match="Unexpected error. Data size is not positive."): | |||||
| remove_one_file(mindrecord_file_name) | remove_one_file(mindrecord_file_name) | ||||
| remove_one_file(mindrecord_file_name + ".db") | remove_one_file(mindrecord_file_name + ".db") | ||||