|
|
|
@@ -51,7 +51,8 @@ ShardReader::ShardReader() |
|
|
|
lazy_load_(false), |
|
|
|
shard_sample_count_() {} |
|
|
|
|
|
|
|
std::pair<MSRStatus, std::vector<std::string>> ShardReader::GetMeta(const std::string &file_path, json &meta_data) { |
|
|
|
std::pair<MSRStatus, std::vector<std::string>> ShardReader::GetMeta(const std::string &file_path, |
|
|
|
std::shared_ptr<json> meta_data_ptr) { |
|
|
|
if (!IsLegalFile(file_path)) { |
|
|
|
return {FAILED, {}}; |
|
|
|
} |
|
|
|
@@ -60,16 +61,16 @@ std::pair<MSRStatus, std::vector<std::string>> ShardReader::GetMeta(const std::s |
|
|
|
return {FAILED, {}}; |
|
|
|
} |
|
|
|
auto header = ret.second; |
|
|
|
meta_data = {{"header_size", header["header_size"]}, {"page_size", header["page_size"]}, |
|
|
|
{"version", header["version"]}, {"index_fields", header["index_fields"]}, |
|
|
|
{"schema", header["schema"]}, {"blob_fields", header["blob_fields"]}}; |
|
|
|
*meta_data_ptr = {{"header_size", header["header_size"]}, {"page_size", header["page_size"]}, |
|
|
|
{"version", header["version"]}, {"index_fields", header["index_fields"]}, |
|
|
|
{"schema", header["schema"]}, {"blob_fields", header["blob_fields"]}}; |
|
|
|
return {SUCCESS, header["shard_addresses"]}; |
|
|
|
} |
|
|
|
|
|
|
|
MSRStatus ShardReader::Init(const std::vector<std::string> &file_paths, bool load_dataset) { |
|
|
|
std::string file_path = file_paths[0]; |
|
|
|
json first_meta_data = json(); |
|
|
|
auto ret = GetMeta(file_path, first_meta_data); |
|
|
|
auto first_meta_data_ptr = std::make_shared<json>(); |
|
|
|
auto ret = GetMeta(file_path, first_meta_data_ptr); |
|
|
|
if (ret.first != SUCCESS) { |
|
|
|
return FAILED; |
|
|
|
} |
|
|
|
@@ -91,12 +92,12 @@ MSRStatus ShardReader::Init(const std::vector<std::string> &file_paths, bool loa |
|
|
|
return FAILED; |
|
|
|
} |
|
|
|
for (const auto &file : file_paths_) { |
|
|
|
json meta_data = json(); |
|
|
|
auto ret1 = GetMeta(file, meta_data); |
|
|
|
auto meta_data_ptr = std::make_shared<json>(); |
|
|
|
auto ret1 = GetMeta(file, meta_data_ptr); |
|
|
|
if (ret1.first != SUCCESS) { |
|
|
|
return FAILED; |
|
|
|
} |
|
|
|
if (meta_data != first_meta_data) { |
|
|
|
if (*meta_data_ptr != *first_meta_data_ptr) { |
|
|
|
MS_LOG(ERROR) << "Mindrecord files meta information is different."; |
|
|
|
return FAILED; |
|
|
|
} |
|
|
|
@@ -140,7 +141,7 @@ MSRStatus ShardReader::Init(const std::vector<std::string> &file_paths, bool loa |
|
|
|
header_size_ = shard_header_->GetHeaderSize(); |
|
|
|
page_size_ = shard_header_->GetPageSize(); |
|
|
|
// version < 3.0 |
|
|
|
if (first_meta_data["version"] < kVersion) { |
|
|
|
if ((*first_meta_data_ptr)["version"] < kVersion) { |
|
|
|
shard_column_ = std::make_shared<ShardColumn>(shard_header_, false); |
|
|
|
} else { |
|
|
|
shard_column_ = std::make_shared<ShardColumn>(shard_header_, true); |
|
|
|
@@ -314,14 +315,14 @@ MSRStatus ShardReader::GetTotalBlobSize(int64_t *total_blob_size) { |
|
|
|
|
|
|
|
MSRStatus ShardReader::ConvertLabelToJson(const std::vector<std::vector<std::string>> &labels, |
|
|
|
std::shared_ptr<std::fstream> fs, |
|
|
|
std::vector<std::vector<std::vector<uint64_t>>> &offsets, int shard_id, |
|
|
|
const std::vector<std::string> &columns, |
|
|
|
std::vector<std::vector<json>> &column_values) { |
|
|
|
std::shared_ptr<std::vector<std::vector<std::vector<uint64_t>>>> offset_ptr, |
|
|
|
int shard_id, const std::vector<std::string> &columns, |
|
|
|
std::shared_ptr<std::vector<std::vector<json>>> col_val_ptr) { |
|
|
|
for (int i = 0; i < static_cast<int>(labels.size()); ++i) { |
|
|
|
uint64_t group_id = std::stoull(labels[i][0]); |
|
|
|
uint64_t offset_start = std::stoull(labels[i][1]) + kInt64Len; |
|
|
|
uint64_t offset_end = std::stoull(labels[i][2]); |
|
|
|
offsets[shard_id].emplace_back( |
|
|
|
(*offset_ptr)[shard_id].emplace_back( |
|
|
|
std::vector<uint64_t>{static_cast<uint64_t>(shard_id), group_id, offset_start, offset_end}); |
|
|
|
if (!all_in_index_) { |
|
|
|
int raw_page_id = std::stoi(labels[i][3]); |
|
|
|
@@ -353,7 +354,7 @@ MSRStatus ShardReader::ConvertLabelToJson(const std::vector<std::vector<std::str |
|
|
|
} else { |
|
|
|
tmp = label_json; |
|
|
|
} |
|
|
|
column_values[shard_id].emplace_back(tmp); |
|
|
|
(*col_val_ptr)[shard_id].emplace_back(tmp); |
|
|
|
} else { |
|
|
|
json construct_json; |
|
|
|
for (unsigned int j = 0; j < columns.size(); ++j) { |
|
|
|
@@ -373,7 +374,7 @@ MSRStatus ShardReader::ConvertLabelToJson(const std::vector<std::vector<std::str |
|
|
|
construct_json[columns[j]] = std::string(labels[i][j + 3]); |
|
|
|
} |
|
|
|
} |
|
|
|
column_values[shard_id].emplace_back(construct_json); |
|
|
|
(*col_val_ptr)[shard_id].emplace_back(construct_json); |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
@@ -381,8 +382,8 @@ MSRStatus ShardReader::ConvertLabelToJson(const std::vector<std::vector<std::str |
|
|
|
} |
|
|
|
|
|
|
|
MSRStatus ShardReader::ReadAllRowsInShard(int shard_id, const std::string &sql, const std::vector<std::string> &columns, |
|
|
|
std::vector<std::vector<std::vector<uint64_t>>> &offsets, |
|
|
|
std::vector<std::vector<json>> &column_values) { |
|
|
|
std::shared_ptr<std::vector<std::vector<std::vector<uint64_t>>>> offset_ptr, |
|
|
|
std::shared_ptr<std::vector<std::vector<json>>> col_val_ptr) { |
|
|
|
auto db = database_paths_[shard_id]; |
|
|
|
std::vector<std::vector<std::string>> labels; |
|
|
|
char *errmsg = nullptr; |
|
|
|
@@ -406,10 +407,11 @@ MSRStatus ShardReader::ReadAllRowsInShard(int shard_id, const std::string &sql, |
|
|
|
} |
|
|
|
} |
|
|
|
sqlite3_free(errmsg); |
|
|
|
return ConvertLabelToJson(labels, fs, offsets, shard_id, columns, column_values); |
|
|
|
return ConvertLabelToJson(labels, fs, offset_ptr, shard_id, columns, col_val_ptr); |
|
|
|
} |
|
|
|
|
|
|
|
MSRStatus ShardReader::GetAllClasses(const std::string &category_field, std::set<std::string> &categories) { |
|
|
|
MSRStatus ShardReader::GetAllClasses(const std::string &category_field, |
|
|
|
std::shared_ptr<std::set<std::string>> category_ptr) { |
|
|
|
std::map<std::string, uint64_t> index_columns; |
|
|
|
for (auto &field : GetShardHeader()->GetFields()) { |
|
|
|
index_columns[field.second] = field.first; |
|
|
|
@@ -425,7 +427,7 @@ MSRStatus ShardReader::GetAllClasses(const std::string &category_field, std::set |
|
|
|
std::string sql = "SELECT DISTINCT " + ret.second + " FROM INDEXES"; |
|
|
|
std::vector<std::thread> threads = std::vector<std::thread>(shard_count_); |
|
|
|
for (int x = 0; x < shard_count_; x++) { |
|
|
|
threads[x] = std::thread(&ShardReader::GetClassesInShard, this, database_paths_[x], x, sql, std::ref(categories)); |
|
|
|
threads[x] = std::thread(&ShardReader::GetClassesInShard, this, database_paths_[x], x, sql, category_ptr); |
|
|
|
} |
|
|
|
|
|
|
|
for (int x = 0; x < shard_count_; x++) { |
|
|
|
@@ -434,8 +436,8 @@ MSRStatus ShardReader::GetAllClasses(const std::string &category_field, std::set |
|
|
|
return SUCCESS; |
|
|
|
} |
|
|
|
|
|
|
|
void ShardReader::GetClassesInShard(sqlite3 *db, int shard_id, const std::string sql, |
|
|
|
std::set<std::string> &categories) { |
|
|
|
void ShardReader::GetClassesInShard(sqlite3 *db, int shard_id, const std::string &sql, |
|
|
|
std::shared_ptr<std::set<std::string>> category_ptr) { |
|
|
|
if (nullptr == db) { |
|
|
|
return; |
|
|
|
} |
|
|
|
@@ -452,20 +454,22 @@ void ShardReader::GetClassesInShard(sqlite3 *db, int shard_id, const std::string |
|
|
|
MS_LOG(INFO) << "Get " << static_cast<int>(columns.size()) << " records from shard " << shard_id << " index."; |
|
|
|
std::lock_guard<std::mutex> lck(shard_locker_); |
|
|
|
for (int i = 0; i < static_cast<int>(columns.size()); ++i) { |
|
|
|
categories.emplace(columns[i][0]); |
|
|
|
category_ptr->emplace(columns[i][0]); |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
ROW_GROUPS ShardReader::ReadAllRowGroup(std::vector<std::string> &columns) { |
|
|
|
ROW_GROUPS ShardReader::ReadAllRowGroup(const std::vector<std::string> &columns) { |
|
|
|
std::string fields = "ROW_GROUP_ID, PAGE_OFFSET_BLOB, PAGE_OFFSET_BLOB_END"; |
|
|
|
std::vector<std::vector<std::vector<uint64_t>>> offsets(shard_count_, std::vector<std::vector<uint64_t>>{}); |
|
|
|
std::vector<std::vector<json>> column_values(shard_count_, std::vector<json>{}); |
|
|
|
auto offset_ptr = std::make_shared<std::vector<std::vector<std::vector<uint64_t>>>>( |
|
|
|
shard_count_, std::vector<std::vector<uint64_t>>{}); |
|
|
|
auto col_val_ptr = std::make_shared<std::vector<std::vector<json>>>(shard_count_, std::vector<json>{}); |
|
|
|
|
|
|
|
if (all_in_index_) { |
|
|
|
for (unsigned int i = 0; i < columns.size(); ++i) { |
|
|
|
fields += ','; |
|
|
|
auto ret = ShardIndexGenerator::GenerateFieldName(std::make_pair(column_schema_id_[columns[i]], columns[i])); |
|
|
|
if (ret.first != SUCCESS) { |
|
|
|
return std::make_tuple(FAILED, std::move(offsets), std::move(column_values)); |
|
|
|
return std::make_tuple(FAILED, std::move(*offset_ptr), std::move(*col_val_ptr)); |
|
|
|
} |
|
|
|
fields += ret.second; |
|
|
|
} |
|
|
|
@@ -477,27 +481,27 @@ ROW_GROUPS ShardReader::ReadAllRowGroup(std::vector<std::string> &columns) { |
|
|
|
|
|
|
|
std::vector<std::thread> thread_read_db = std::vector<std::thread>(shard_count_); |
|
|
|
for (int x = 0; x < shard_count_; x++) { |
|
|
|
thread_read_db[x] = |
|
|
|
std::thread(&ShardReader::ReadAllRowsInShard, this, x, sql, columns, std::ref(offsets), std::ref(column_values)); |
|
|
|
thread_read_db[x] = std::thread(&ShardReader::ReadAllRowsInShard, this, x, sql, columns, offset_ptr, col_val_ptr); |
|
|
|
} |
|
|
|
|
|
|
|
for (int x = 0; x < shard_count_; x++) { |
|
|
|
thread_read_db[x].join(); |
|
|
|
} |
|
|
|
return std::make_tuple(SUCCESS, std::move(offsets), std::move(column_values)); |
|
|
|
return std::make_tuple(SUCCESS, std::move(*offset_ptr), std::move(*col_val_ptr)); |
|
|
|
} |
|
|
|
|
|
|
|
ROW_GROUPS ShardReader::ReadRowGroupByShardIDAndSampleID(const std::vector<std::string> &columns, |
|
|
|
const uint32_t &shard_id, const uint32_t &sample_id) { |
|
|
|
std::string fields = "ROW_GROUP_ID, PAGE_OFFSET_BLOB, PAGE_OFFSET_BLOB_END"; |
|
|
|
std::vector<std::vector<std::vector<uint64_t>>> offsets(shard_count_, std::vector<std::vector<uint64_t>>{}); |
|
|
|
std::vector<std::vector<json>> column_values(shard_count_, std::vector<json>{}); |
|
|
|
auto offset_ptr = std::make_shared<std::vector<std::vector<std::vector<uint64_t>>>>( |
|
|
|
shard_count_, std::vector<std::vector<uint64_t>>{}); |
|
|
|
auto col_val_ptr = std::make_shared<std::vector<std::vector<json>>>(shard_count_, std::vector<json>{}); |
|
|
|
if (all_in_index_) { |
|
|
|
for (unsigned int i = 0; i < columns.size(); ++i) { |
|
|
|
fields += ','; |
|
|
|
auto ret = ShardIndexGenerator::GenerateFieldName(std::make_pair(column_schema_id_[columns[i]], columns[i])); |
|
|
|
if (ret.first != SUCCESS) { |
|
|
|
return std::make_tuple(FAILED, std::move(offsets), std::move(column_values)); |
|
|
|
return std::make_tuple(FAILED, std::move(*offset_ptr), std::move(*col_val_ptr)); |
|
|
|
} |
|
|
|
fields += ret.second; |
|
|
|
} |
|
|
|
@@ -507,12 +511,12 @@ ROW_GROUPS ShardReader::ReadRowGroupByShardIDAndSampleID(const std::vector<std:: |
|
|
|
|
|
|
|
std::string sql = "SELECT " + fields + " FROM INDEXES WHERE ROW_ID = " + std::to_string(sample_id); |
|
|
|
|
|
|
|
if (ReadAllRowsInShard(shard_id, sql, columns, offsets, column_values) != SUCCESS) { |
|
|
|
if (ReadAllRowsInShard(shard_id, sql, columns, offset_ptr, col_val_ptr) != SUCCESS) { |
|
|
|
MS_LOG(ERROR) << "Read shard id: " << shard_id << ", sample id: " << sample_id << " from index failed."; |
|
|
|
return std::make_tuple(FAILED, std::move(offsets), std::move(column_values)); |
|
|
|
return std::make_tuple(FAILED, std::move(*offset_ptr), std::move(*col_val_ptr)); |
|
|
|
} |
|
|
|
|
|
|
|
return std::make_tuple(SUCCESS, std::move(offsets), std::move(column_values)); |
|
|
|
return std::make_tuple(SUCCESS, std::move(*offset_ptr), std::move(*col_val_ptr)); |
|
|
|
} |
|
|
|
|
|
|
|
ROW_GROUP_BRIEF ShardReader::ReadRowGroupBrief(int group_id, int shard_id, const std::vector<std::string> &columns) { |
|
|
|
@@ -550,7 +554,10 @@ ROW_GROUP_BRIEF ShardReader::ReadRowGroupCriteria(int group_id, int shard_id, |
|
|
|
uint64_t page_length = page->GetPageSize(); |
|
|
|
uint64_t page_offset = page_size_ * page->GetPageID() + header_size_; |
|
|
|
std::vector<std::vector<uint64_t>> image_offset = GetImageOffset(page->GetPageID(), shard_id, criteria); |
|
|
|
|
|
|
|
if (image_offset.empty()) { |
|
|
|
return std::make_tuple(SUCCESS, file_name, page_length, page_offset, std::vector<std::vector<uint64_t>>(), |
|
|
|
std::vector<json>()); |
|
|
|
} |
|
|
|
auto status_labels = GetLabels(page->GetPageID(), shard_id, columns, criteria); |
|
|
|
if (status_labels.first != SUCCESS) { |
|
|
|
return std::make_tuple(FAILED, "", 0, 0, std::vector<std::vector<uint64_t>>(), std::vector<json>()); |
|
|
|
@@ -601,7 +608,7 @@ std::vector<std::vector<uint64_t>> ShardReader::GetImageOffset(int page_id, int |
|
|
|
db = nullptr; |
|
|
|
return std::vector<std::vector<uint64_t>>(); |
|
|
|
} else { |
|
|
|
MS_LOG(DEBUG) << "Get " << static_cast<int>(image_offsets.size()) << "records from index."; |
|
|
|
MS_LOG(DEBUG) << "Get " << static_cast<int>(image_offsets.size()) << " records from index."; |
|
|
|
} |
|
|
|
std::vector<std::vector<uint64_t>> res; |
|
|
|
for (int i = static_cast<int>(image_offsets.size()) - 1; i >= 0; i--) res.emplace_back(std::vector<uint64_t>{0, 0}); |
|
|
|
@@ -614,6 +621,44 @@ std::vector<std::vector<uint64_t>> ShardReader::GetImageOffset(int page_id, int |
|
|
|
return res; |
|
|
|
} |
|
|
|
|
|
|
|
std::pair<MSRStatus, std::vector<uint64_t>> ShardReader::GetPagesByCategory( |
|
|
|
int shard_id, const std::pair<std::string, std::string> &criteria) { |
|
|
|
auto db = database_paths_[shard_id]; |
|
|
|
|
|
|
|
std::string sql = "SELECT DISTINCT PAGE_ID_BLOB FROM INDEXES WHERE 1 = 1 "; |
|
|
|
|
|
|
|
if (!criteria.first.empty()) { |
|
|
|
auto schema = shard_header_->GetSchemas()[0]->GetSchema(); |
|
|
|
|
|
|
|
if (kNumberFieldTypeSet.find(schema["schema"][criteria.first]["type"]) != kNumberFieldTypeSet.end()) { |
|
|
|
sql += |
|
|
|
" AND " + criteria.first + "_" + std::to_string(column_schema_id_[criteria.first]) + " = " + criteria.second; |
|
|
|
} else { |
|
|
|
sql += " AND " + criteria.first + "_" + std::to_string(column_schema_id_[criteria.first]) + " = '" + |
|
|
|
criteria.second + "'"; |
|
|
|
} |
|
|
|
} |
|
|
|
sql += ";"; |
|
|
|
std::vector<std::vector<std::string>> page_ids; |
|
|
|
char *errmsg = nullptr; |
|
|
|
int rc = sqlite3_exec(db, common::SafeCStr(sql), SelectCallback, &page_ids, &errmsg); |
|
|
|
if (rc != SQLITE_OK) { |
|
|
|
MS_LOG(ERROR) << "Error in select statement, sql: " << sql << ", error: " << errmsg; |
|
|
|
sqlite3_free(errmsg); |
|
|
|
sqlite3_close(db); |
|
|
|
db = nullptr; |
|
|
|
return std::make_pair(FAILED, std::vector<uint64_t>()); |
|
|
|
} else { |
|
|
|
MS_LOG(DEBUG) << "Get " << page_ids.size() << "pages from index."; |
|
|
|
} |
|
|
|
std::vector<uint64_t> res; |
|
|
|
for (int i = 0; i < static_cast<int>(page_ids.size()); ++i) { |
|
|
|
res.emplace_back(std::stoull(page_ids[i][0])); |
|
|
|
} |
|
|
|
sqlite3_free(errmsg); |
|
|
|
return std::make_pair(SUCCESS, res); |
|
|
|
} |
|
|
|
|
|
|
|
std::pair<ShardType, std::vector<std::string>> ShardReader::GetBlobFields() { |
|
|
|
std::vector<std::string> blob_fields; |
|
|
|
for (auto &p : GetShardHeader()->GetSchemas()) { |
|
|
|
@@ -642,8 +687,8 @@ void ShardReader::CheckIfColumnInIndex(const std::vector<std::string> &columns) |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
MSRStatus ShardReader::QueryWithCriteria(sqlite3 *db, string &sql, string criteria, |
|
|
|
std::vector<std::vector<std::string>> &labels) { |
|
|
|
MSRStatus ShardReader::QueryWithCriteria(sqlite3 *db, const string &sql, const string &criteria, |
|
|
|
std::shared_ptr<std::vector<std::vector<std::string>>> labels_ptr) { |
|
|
|
sqlite3_stmt *stmt = nullptr; |
|
|
|
if (sqlite3_prepare_v2(db, common::SafeCStr(sql), -1, &stmt, 0) != SQLITE_OK) { |
|
|
|
MS_LOG(ERROR) << "SQL error: could not prepare statement, sql: " << sql; |
|
|
|
@@ -661,7 +706,7 @@ MSRStatus ShardReader::QueryWithCriteria(sqlite3 *db, string &sql, string criter |
|
|
|
for (int i = 0; i < ncols; i++) { |
|
|
|
tmp.emplace_back(reinterpret_cast<const char *>(sqlite3_column_text(stmt, i))); |
|
|
|
} |
|
|
|
labels.push_back(tmp); |
|
|
|
labels_ptr->push_back(tmp); |
|
|
|
rc = sqlite3_step(stmt); |
|
|
|
} |
|
|
|
(void)sqlite3_finalize(stmt); |
|
|
|
@@ -724,16 +769,16 @@ std::pair<MSRStatus, std::vector<json>> ShardReader::GetLabelsFromPage( |
|
|
|
auto db = database_paths_[shard_id]; |
|
|
|
std::string sql = "SELECT PAGE_ID_RAW, PAGE_OFFSET_RAW,PAGE_OFFSET_RAW_END FROM INDEXES WHERE PAGE_ID_BLOB = " + |
|
|
|
std::to_string(page_id); |
|
|
|
std::vector<std::vector<std::string>> label_offsets; |
|
|
|
auto label_offset_ptr = std::make_shared<std::vector<std::vector<std::string>>>(); |
|
|
|
if (!criteria.first.empty()) { |
|
|
|
sql += " AND " + criteria.first + "_" + std::to_string(column_schema_id_[criteria.first]) + " = :criteria"; |
|
|
|
if (QueryWithCriteria(db, sql, criteria.second, label_offsets) == FAILED) { |
|
|
|
if (QueryWithCriteria(db, sql, criteria.second, label_offset_ptr) == FAILED) { |
|
|
|
return {FAILED, {}}; |
|
|
|
} |
|
|
|
} else { |
|
|
|
sql += ";"; |
|
|
|
char *errmsg = nullptr; |
|
|
|
int rc = sqlite3_exec(db, common::SafeCStr(sql), SelectCallback, &label_offsets, &errmsg); |
|
|
|
int rc = sqlite3_exec(db, common::SafeCStr(sql), SelectCallback, label_offset_ptr.get(), &errmsg); |
|
|
|
if (rc != SQLITE_OK) { |
|
|
|
MS_LOG(ERROR) << "Error in select statement, sql: " << sql << ", error: " << errmsg; |
|
|
|
sqlite3_free(errmsg); |
|
|
|
@@ -741,11 +786,11 @@ std::pair<MSRStatus, std::vector<json>> ShardReader::GetLabelsFromPage( |
|
|
|
db = nullptr; |
|
|
|
return {FAILED, {}}; |
|
|
|
} |
|
|
|
MS_LOG(DEBUG) << "Get " << label_offsets.size() << "records from index."; |
|
|
|
MS_LOG(DEBUG) << "Get " << label_offset_ptr->size() << " records from index."; |
|
|
|
sqlite3_free(errmsg); |
|
|
|
} |
|
|
|
// get labels from binary file |
|
|
|
return GetLabelsFromBinaryFile(shard_id, columns, label_offsets); |
|
|
|
return GetLabelsFromBinaryFile(shard_id, columns, *label_offset_ptr); |
|
|
|
} |
|
|
|
|
|
|
|
std::pair<MSRStatus, std::vector<json>> ShardReader::GetLabels(int page_id, int shard_id, |
|
|
|
@@ -760,17 +805,17 @@ std::pair<MSRStatus, std::vector<json>> ShardReader::GetLabels(int page_id, int |
|
|
|
fields += columns[i] + "_" + std::to_string(schema_id); |
|
|
|
} |
|
|
|
if (fields.empty()) fields = "*"; |
|
|
|
std::vector<std::vector<std::string>> labels; |
|
|
|
auto labels_ptr = std::make_shared<std::vector<std::vector<std::string>>>(); |
|
|
|
std::string sql = "SELECT " + fields + " FROM INDEXES WHERE PAGE_ID_BLOB = " + std::to_string(page_id); |
|
|
|
if (!criteria.first.empty()) { |
|
|
|
sql += " AND " + criteria.first + "_" + std::to_string(column_schema_id_[criteria.first]) + " = " + ":criteria"; |
|
|
|
if (QueryWithCriteria(db, sql, criteria.second, labels) == FAILED) { |
|
|
|
if (QueryWithCriteria(db, sql, criteria.second, labels_ptr) == FAILED) { |
|
|
|
return {FAILED, {}}; |
|
|
|
} |
|
|
|
} else { |
|
|
|
sql += ";"; |
|
|
|
char *errmsg = nullptr; |
|
|
|
int rc = sqlite3_exec(db, common::SafeCStr(sql), SelectCallback, &labels, &errmsg); |
|
|
|
int rc = sqlite3_exec(db, common::SafeCStr(sql), SelectCallback, labels_ptr.get(), &errmsg); |
|
|
|
if (rc != SQLITE_OK) { |
|
|
|
MS_LOG(ERROR) << "Error in select statement, sql: " << sql << ", error: " << errmsg; |
|
|
|
sqlite3_free(errmsg); |
|
|
|
@@ -778,13 +823,13 @@ std::pair<MSRStatus, std::vector<json>> ShardReader::GetLabels(int page_id, int |
|
|
|
db = nullptr; |
|
|
|
return {FAILED, {}}; |
|
|
|
} else { |
|
|
|
MS_LOG(DEBUG) << "Get " << static_cast<int>(labels.size()) << "records from index."; |
|
|
|
MS_LOG(DEBUG) << "Get " << static_cast<int>(labels_ptr->size()) << " records from index."; |
|
|
|
} |
|
|
|
sqlite3_free(errmsg); |
|
|
|
} |
|
|
|
std::vector<json> ret; |
|
|
|
for (unsigned int i = 0; i < labels.size(); ++i) ret.emplace_back(json{}); |
|
|
|
for (unsigned int i = 0; i < labels.size(); ++i) { |
|
|
|
for (unsigned int i = 0; i < labels_ptr->size(); ++i) ret.emplace_back(json{}); |
|
|
|
for (unsigned int i = 0; i < labels_ptr->size(); ++i) { |
|
|
|
json construct_json; |
|
|
|
for (unsigned int j = 0; j < columns.size(); ++j) { |
|
|
|
// construct json "f1": value |
|
|
|
@@ -792,15 +837,15 @@ std::pair<MSRStatus, std::vector<json>> ShardReader::GetLabels(int page_id, int |
|
|
|
|
|
|
|
// convert the string to base type by schema |
|
|
|
if (schema[columns[j]]["type"] == "int32") { |
|
|
|
construct_json[columns[j]] = StringToNum<int32_t>(labels[i][j]); |
|
|
|
construct_json[columns[j]] = StringToNum<int32_t>((*labels_ptr)[i][j]); |
|
|
|
} else if (schema[columns[j]]["type"] == "int64") { |
|
|
|
construct_json[columns[j]] = StringToNum<int64_t>(labels[i][j]); |
|
|
|
construct_json[columns[j]] = StringToNum<int64_t>((*labels_ptr)[i][j]); |
|
|
|
} else if (schema[columns[j]]["type"] == "float32") { |
|
|
|
construct_json[columns[j]] = StringToNum<float>(labels[i][j]); |
|
|
|
construct_json[columns[j]] = StringToNum<float>((*labels_ptr)[i][j]); |
|
|
|
} else if (schema[columns[j]]["type"] == "float64") { |
|
|
|
construct_json[columns[j]] = StringToNum<double>(labels[i][j]); |
|
|
|
construct_json[columns[j]] = StringToNum<double>((*labels_ptr)[i][j]); |
|
|
|
} else { |
|
|
|
construct_json[columns[j]] = std::string(labels[i][j]); |
|
|
|
construct_json[columns[j]] = std::string((*labels_ptr)[i][j]); |
|
|
|
} |
|
|
|
} |
|
|
|
ret[i] = construct_json; |
|
|
|
@@ -834,7 +879,7 @@ int64_t ShardReader::GetNumClasses(const std::string &category_field) { |
|
|
|
} |
|
|
|
std::string sql = "SELECT DISTINCT " + ret.second + " FROM INDEXES"; |
|
|
|
std::vector<std::thread> threads = std::vector<std::thread>(shard_count); |
|
|
|
std::set<std::string> categories; |
|
|
|
auto category_ptr = std::make_shared<std::set<std::string>>(); |
|
|
|
for (int x = 0; x < shard_count; x++) { |
|
|
|
sqlite3 *db = nullptr; |
|
|
|
int rc = sqlite3_open_v2(common::SafeCStr(file_paths_[x] + ".db"), &db, SQLITE_OPEN_READONLY, nullptr); |
|
|
|
@@ -843,13 +888,13 @@ int64_t ShardReader::GetNumClasses(const std::string &category_field) { |
|
|
|
<< sqlite3_errmsg(db); |
|
|
|
return -1; |
|
|
|
} |
|
|
|
threads[x] = std::thread(&ShardReader::GetClassesInShard, this, db, x, sql, std::ref(categories)); |
|
|
|
threads[x] = std::thread(&ShardReader::GetClassesInShard, this, db, x, sql, category_ptr); |
|
|
|
} |
|
|
|
|
|
|
|
for (int x = 0; x < shard_count; x++) { |
|
|
|
threads[x].join(); |
|
|
|
} |
|
|
|
return categories.size(); |
|
|
|
return category_ptr->size(); |
|
|
|
} |
|
|
|
|
|
|
|
MSRStatus ShardReader::CountTotalRows(const std::vector<std::string> &file_paths, bool load_dataset, |
|
|
|
@@ -1008,8 +1053,7 @@ MSRStatus ShardReader::Launch(bool isSimpleReader) { |
|
|
|
return SUCCESS; |
|
|
|
} |
|
|
|
|
|
|
|
MSRStatus ShardReader::CreateTasksByCategory(const std::vector<std::tuple<int, int, int, uint64_t>> &row_group_summary, |
|
|
|
const std::shared_ptr<ShardOperator> &op) { |
|
|
|
MSRStatus ShardReader::CreateTasksByCategory(const std::shared_ptr<ShardOperator> &op) { |
|
|
|
CheckIfColumnInIndex(selected_columns_); |
|
|
|
auto category_op = std::dynamic_pointer_cast<ShardCategory>(op); |
|
|
|
auto categories = category_op->GetCategories(); |
|
|
|
@@ -1033,42 +1077,50 @@ MSRStatus ShardReader::CreateTasksByCategory(const std::vector<std::tuple<int, i |
|
|
|
MS_LOG(ERROR) << "Invalid parameter, num_categories must be greater than 0, but got " << num_elements; |
|
|
|
return FAILED; |
|
|
|
} |
|
|
|
std::set<std::string> categories_set; |
|
|
|
auto ret = GetAllClasses(category_field, categories_set); |
|
|
|
auto category_ptr = std::make_shared<std::set<std::string>>(); |
|
|
|
auto ret = GetAllClasses(category_field, category_ptr); |
|
|
|
if (SUCCESS != ret) { |
|
|
|
return FAILED; |
|
|
|
} |
|
|
|
int i = 0; |
|
|
|
for (auto it = categories_set.begin(); it != categories_set.end() && i < num_categories; ++it) { |
|
|
|
for (auto it = category_ptr->begin(); it != category_ptr->end() && i < num_categories; ++it) { |
|
|
|
categories.emplace_back(category_field, *it); |
|
|
|
i++; |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
// Generate task list, a task will create a batch |
|
|
|
std::vector<ShardTask> categoryTasks(categories.size()); |
|
|
|
for (uint32_t categoryNo = 0; categoryNo < categories.size(); ++categoryNo) { |
|
|
|
int category_index = 0; |
|
|
|
for (const auto &rg : row_group_summary) { |
|
|
|
if (category_index >= num_elements) break; |
|
|
|
auto shard_id = std::get<0>(rg); |
|
|
|
auto group_id = std::get<1>(rg); |
|
|
|
|
|
|
|
auto details = ReadRowGroupCriteria(group_id, shard_id, categories[categoryNo], selected_columns_); |
|
|
|
if (SUCCESS != std::get<0>(details)) { |
|
|
|
for (int shard_id = 0; shard_id < shard_count_ && category_index < num_elements; ++shard_id) { |
|
|
|
auto res = GetPagesByCategory(shard_id, categories[categoryNo]); |
|
|
|
if (SUCCESS != res.first) { |
|
|
|
return FAILED; |
|
|
|
} |
|
|
|
auto offsets = std::get<4>(details); |
|
|
|
|
|
|
|
auto number_of_rows = offsets.size(); |
|
|
|
for (uint32_t iStart = 0; iStart < number_of_rows; iStart += 1) { |
|
|
|
if (category_index < num_elements) { |
|
|
|
categoryTasks[categoryNo].InsertTask(TaskType::kCommonTask, shard_id, group_id, std::get<4>(details)[iStart], |
|
|
|
std::get<5>(details)[iStart]); |
|
|
|
category_index++; |
|
|
|
auto page_ids = res.second; |
|
|
|
for (const auto &page_id : page_ids) { |
|
|
|
if (category_index >= num_elements) break; |
|
|
|
const auto &page_t = shard_header_->GetPage(shard_id, page_id); |
|
|
|
const auto &page = page_t.first; |
|
|
|
auto group_id = page->GetPageTypeID(); |
|
|
|
auto details = ReadRowGroupCriteria(group_id, shard_id, categories[categoryNo], selected_columns_); |
|
|
|
if (SUCCESS != std::get<0>(details)) { |
|
|
|
return FAILED; |
|
|
|
} |
|
|
|
auto offsets = std::get<4>(details); |
|
|
|
|
|
|
|
auto number_of_rows = offsets.size(); |
|
|
|
for (uint32_t iStart = 0; iStart < number_of_rows; iStart += 1) { |
|
|
|
if (category_index < num_elements) { |
|
|
|
categoryTasks[categoryNo].InsertTask(TaskType::kCommonTask, shard_id, group_id, |
|
|
|
std::get<4>(details)[iStart], std::get<5>(details)[iStart]); |
|
|
|
category_index++; |
|
|
|
} |
|
|
|
} |
|
|
|
MS_LOG(INFO) << "Category #" << categoryNo << " has " << categoryTasks[categoryNo].Size() << " tasks"; |
|
|
|
} |
|
|
|
} |
|
|
|
MS_LOG(INFO) << "Category #" << categoryNo << " has " << categoryTasks[categoryNo].Size() << " tasks"; |
|
|
|
} |
|
|
|
tasks_ = ShardTask::Combine(categoryTasks, category_op->GetReplacement(), num_elements, num_samples); |
|
|
|
if (SUCCESS != (*category_op)(tasks_)) { |
|
|
|
@@ -1189,7 +1241,7 @@ MSRStatus ShardReader::CreateTasks(const std::vector<std::tuple<int, int, int, u |
|
|
|
} |
|
|
|
} |
|
|
|
} else { |
|
|
|
if (SUCCESS != CreateTasksByCategory(row_group_summary, operators[category_operator])) { |
|
|
|
if (SUCCESS != CreateTasksByCategory(operators[category_operator])) { |
|
|
|
return FAILED; |
|
|
|
} |
|
|
|
} |
|
|
|
|