diff --git a/mindspore/ccsrc/minddata/mindrecord/include/shard_reader.h b/mindspore/ccsrc/minddata/mindrecord/include/shard_reader.h index eb891c228f..769b8555c5 100644 --- a/mindspore/ccsrc/minddata/mindrecord/include/shard_reader.h +++ b/mindspore/ccsrc/minddata/mindrecord/include/shard_reader.h @@ -55,6 +55,8 @@ #include "minddata/mindrecord/include/shard_shuffle.h" #include "utils/log_adapter.h" +#define API_PUBLIC __attribute__((visibility("default"))) + namespace mindspore { namespace mindrecord { using ROW_GROUPS = @@ -65,7 +67,7 @@ using TASK_RETURN_CONTENT = std::pair, json>>>>; const int kNumBatchInMap = 1000; // iterator buffer size in row-reader mode -class __attribute__((visibility("default"))) ShardReader { +class API_PUBLIC ShardReader { public: ShardReader(); @@ -203,7 +205,7 @@ class __attribute__((visibility("default"))) ShardReader { void SetAllInIndex(bool all_in_index) { all_in_index_ = all_in_index; } /// \brief get all classes - MSRStatus GetAllClasses(const std::string &category_field, std::set &categories); + MSRStatus GetAllClasses(const std::string &category_field, std::shared_ptr> category_ptr); /// \brief get the size of blob data MSRStatus GetTotalBlobSize(int64_t *total_blob_size); @@ -215,11 +217,12 @@ class __attribute__((visibility("default"))) ShardReader { private: /// \brief wrap up labels to json format MSRStatus ConvertLabelToJson(const std::vector> &labels, std::shared_ptr fs, - std::vector>> &offsets, int shard_id, - const std::vector &columns, std::vector> &column_values); + std::shared_ptr>>> offset_ptr, + int shard_id, const std::vector &columns, + std::shared_ptr>> col_val_ptr); /// \brief read all rows for specified columns - ROW_GROUPS ReadAllRowGroup(std::vector &columns); + ROW_GROUPS ReadAllRowGroup(const std::vector &columns); /// \brief read row meta by shard_id and sample_id ROW_GROUPS ReadRowGroupByShardIDAndSampleID(const std::vector &columns, const uint32_t &shard_id, @@ -227,8 +230,8 @@ class __attribute__((visibility("default"))) ShardReader { /// \brief read all rows in one shard MSRStatus ReadAllRowsInShard(int shard_id, const std::string &sql, const std::vector &columns, - std::vector>> &offsets, - std::vector> &column_values); + std::shared_ptr>>> offset_ptr, + std::shared_ptr>> col_val_ptr); /// \brief initialize reader MSRStatus Init(const std::vector &file_paths, bool load_dataset); @@ -243,8 +246,12 @@ class __attribute__((visibility("default"))) ShardReader { std::vector> GetImageOffset(int group_id, int shard_id, const std::pair &criteria = {"", ""}); + /// \brief get page id by category + std::pair> GetPagesByCategory(int shard_id, + const std::pair &criteria); /// \brief execute sqlite query with prepare statement - MSRStatus QueryWithCriteria(sqlite3 *db, string &sql, string criteria, std::vector> &labels); + MSRStatus QueryWithCriteria(sqlite3 *db, const string &sql, const string &criteria, + std::shared_ptr>> labels_ptr); /// \brief get column values std::pair> GetLabels(int group_id, int shard_id, const std::vector &columns, @@ -257,8 +264,7 @@ class __attribute__((visibility("default"))) ShardReader { ""}); /// \brief create category-applied task list - MSRStatus CreateTasksByCategory(const std::vector> &row_group_summary, - const std::shared_ptr &op); + MSRStatus CreateTasksByCategory(const std::shared_ptr &op); /// \brief create task list in row-reader mode MSRStatus CreateTasksByRow(const std::vector> &row_group_summary, @@ -286,13 +292,15 @@ class __attribute__((visibility("default"))) ShardReader { int shard_id, const std::vector &columns, const std::vector> &label_offsets); /// \brief get classes in one shard - void GetClassesInShard(sqlite3 *db, int shard_id, const std::string sql, std::set &categories); + void GetClassesInShard(sqlite3 *db, int shard_id, const std::string &sql, + std::shared_ptr> category_ptr); /// \brief get number of classes int64_t GetNumClasses(const std::string &category_field); /// \brief get meta of header - std::pair> GetMeta(const std::string &file_path, json &meta_data); + std::pair> GetMeta(const std::string &file_path, + std::shared_ptr meta_data_ptr); /// \brief extract uncompressed data based on column list std::pair>> UnCompressBlob(const std::vector &raw_blob_data); diff --git a/mindspore/ccsrc/minddata/mindrecord/io/shard_reader.cc b/mindspore/ccsrc/minddata/mindrecord/io/shard_reader.cc index 98c2bbc5e0..8d771d219d 100644 --- a/mindspore/ccsrc/minddata/mindrecord/io/shard_reader.cc +++ b/mindspore/ccsrc/minddata/mindrecord/io/shard_reader.cc @@ -51,7 +51,8 @@ ShardReader::ShardReader() lazy_load_(false), shard_sample_count_() {} -std::pair> ShardReader::GetMeta(const std::string &file_path, json &meta_data) { +std::pair> ShardReader::GetMeta(const std::string &file_path, + std::shared_ptr meta_data_ptr) { if (!IsLegalFile(file_path)) { return {FAILED, {}}; } @@ -60,16 +61,16 @@ std::pair> ShardReader::GetMeta(const std::s return {FAILED, {}}; } auto header = ret.second; - meta_data = {{"header_size", header["header_size"]}, {"page_size", header["page_size"]}, - {"version", header["version"]}, {"index_fields", header["index_fields"]}, - {"schema", header["schema"]}, {"blob_fields", header["blob_fields"]}}; + *meta_data_ptr = {{"header_size", header["header_size"]}, {"page_size", header["page_size"]}, + {"version", header["version"]}, {"index_fields", header["index_fields"]}, + {"schema", header["schema"]}, {"blob_fields", header["blob_fields"]}}; return {SUCCESS, header["shard_addresses"]}; } MSRStatus ShardReader::Init(const std::vector &file_paths, bool load_dataset) { std::string file_path = file_paths[0]; - json first_meta_data = json(); - auto ret = GetMeta(file_path, first_meta_data); + auto first_meta_data_ptr = std::make_shared(); + auto ret = GetMeta(file_path, first_meta_data_ptr); if (ret.first != SUCCESS) { return FAILED; } @@ -91,12 +92,12 @@ MSRStatus ShardReader::Init(const std::vector &file_paths, bool loa return FAILED; } for (const auto &file : file_paths_) { - json meta_data = json(); - auto ret1 = GetMeta(file, meta_data); + auto meta_data_ptr = std::make_shared(); + auto ret1 = GetMeta(file, meta_data_ptr); if (ret1.first != SUCCESS) { return FAILED; } - if (meta_data != first_meta_data) { + if (*meta_data_ptr != *first_meta_data_ptr) { MS_LOG(ERROR) << "Mindrecord files meta information is different."; return FAILED; } @@ -140,7 +141,7 @@ MSRStatus ShardReader::Init(const std::vector &file_paths, bool loa header_size_ = shard_header_->GetHeaderSize(); page_size_ = shard_header_->GetPageSize(); // version < 3.0 - if (first_meta_data["version"] < kVersion) { + if ((*first_meta_data_ptr)["version"] < kVersion) { shard_column_ = std::make_shared(shard_header_, false); } else { shard_column_ = std::make_shared(shard_header_, true); @@ -314,14 +315,14 @@ MSRStatus ShardReader::GetTotalBlobSize(int64_t *total_blob_size) { MSRStatus ShardReader::ConvertLabelToJson(const std::vector> &labels, std::shared_ptr fs, - std::vector>> &offsets, int shard_id, - const std::vector &columns, - std::vector> &column_values) { + std::shared_ptr>>> offset_ptr, + int shard_id, const std::vector &columns, + std::shared_ptr>> col_val_ptr) { for (int i = 0; i < static_cast(labels.size()); ++i) { uint64_t group_id = std::stoull(labels[i][0]); uint64_t offset_start = std::stoull(labels[i][1]) + kInt64Len; uint64_t offset_end = std::stoull(labels[i][2]); - offsets[shard_id].emplace_back( + (*offset_ptr)[shard_id].emplace_back( std::vector{static_cast(shard_id), group_id, offset_start, offset_end}); if (!all_in_index_) { int raw_page_id = std::stoi(labels[i][3]); @@ -353,7 +354,7 @@ MSRStatus ShardReader::ConvertLabelToJson(const std::vector &columns, - std::vector>> &offsets, - std::vector> &column_values) { + std::shared_ptr>>> offset_ptr, + std::shared_ptr>> col_val_ptr) { auto db = database_paths_[shard_id]; std::vector> labels; char *errmsg = nullptr; @@ -406,10 +407,11 @@ MSRStatus ShardReader::ReadAllRowsInShard(int shard_id, const std::string &sql, } } sqlite3_free(errmsg); - return ConvertLabelToJson(labels, fs, offsets, shard_id, columns, column_values); + return ConvertLabelToJson(labels, fs, offset_ptr, shard_id, columns, col_val_ptr); } -MSRStatus ShardReader::GetAllClasses(const std::string &category_field, std::set &categories) { +MSRStatus ShardReader::GetAllClasses(const std::string &category_field, + std::shared_ptr> category_ptr) { std::map index_columns; for (auto &field : GetShardHeader()->GetFields()) { index_columns[field.second] = field.first; @@ -425,7 +427,7 @@ MSRStatus ShardReader::GetAllClasses(const std::string &category_field, std::set std::string sql = "SELECT DISTINCT " + ret.second + " FROM INDEXES"; std::vector threads = std::vector(shard_count_); for (int x = 0; x < shard_count_; x++) { - threads[x] = std::thread(&ShardReader::GetClassesInShard, this, database_paths_[x], x, sql, std::ref(categories)); + threads[x] = std::thread(&ShardReader::GetClassesInShard, this, database_paths_[x], x, sql, category_ptr); } for (int x = 0; x < shard_count_; x++) { @@ -434,8 +436,8 @@ MSRStatus ShardReader::GetAllClasses(const std::string &category_field, std::set return SUCCESS; } -void ShardReader::GetClassesInShard(sqlite3 *db, int shard_id, const std::string sql, - std::set &categories) { +void ShardReader::GetClassesInShard(sqlite3 *db, int shard_id, const std::string &sql, + std::shared_ptr> category_ptr) { if (nullptr == db) { return; } @@ -452,20 +454,22 @@ void ShardReader::GetClassesInShard(sqlite3 *db, int shard_id, const std::string MS_LOG(INFO) << "Get " << static_cast(columns.size()) << " records from shard " << shard_id << " index."; std::lock_guard lck(shard_locker_); for (int i = 0; i < static_cast(columns.size()); ++i) { - categories.emplace(columns[i][0]); + category_ptr->emplace(columns[i][0]); } } -ROW_GROUPS ShardReader::ReadAllRowGroup(std::vector &columns) { +ROW_GROUPS ShardReader::ReadAllRowGroup(const std::vector &columns) { std::string fields = "ROW_GROUP_ID, PAGE_OFFSET_BLOB, PAGE_OFFSET_BLOB_END"; - std::vector>> offsets(shard_count_, std::vector>{}); - std::vector> column_values(shard_count_, std::vector{}); + auto offset_ptr = std::make_shared>>>( + shard_count_, std::vector>{}); + auto col_val_ptr = std::make_shared>>(shard_count_, std::vector{}); + if (all_in_index_) { for (unsigned int i = 0; i < columns.size(); ++i) { fields += ','; auto ret = ShardIndexGenerator::GenerateFieldName(std::make_pair(column_schema_id_[columns[i]], columns[i])); if (ret.first != SUCCESS) { - return std::make_tuple(FAILED, std::move(offsets), std::move(column_values)); + return std::make_tuple(FAILED, std::move(*offset_ptr), std::move(*col_val_ptr)); } fields += ret.second; } @@ -477,27 +481,27 @@ ROW_GROUPS ShardReader::ReadAllRowGroup(std::vector &columns) { std::vector thread_read_db = std::vector(shard_count_); for (int x = 0; x < shard_count_; x++) { - thread_read_db[x] = - std::thread(&ShardReader::ReadAllRowsInShard, this, x, sql, columns, std::ref(offsets), std::ref(column_values)); + thread_read_db[x] = std::thread(&ShardReader::ReadAllRowsInShard, this, x, sql, columns, offset_ptr, col_val_ptr); } for (int x = 0; x < shard_count_; x++) { thread_read_db[x].join(); } - return std::make_tuple(SUCCESS, std::move(offsets), std::move(column_values)); + return std::make_tuple(SUCCESS, std::move(*offset_ptr), std::move(*col_val_ptr)); } ROW_GROUPS ShardReader::ReadRowGroupByShardIDAndSampleID(const std::vector &columns, const uint32_t &shard_id, const uint32_t &sample_id) { std::string fields = "ROW_GROUP_ID, PAGE_OFFSET_BLOB, PAGE_OFFSET_BLOB_END"; - std::vector>> offsets(shard_count_, std::vector>{}); - std::vector> column_values(shard_count_, std::vector{}); + auto offset_ptr = std::make_shared>>>( + shard_count_, std::vector>{}); + auto col_val_ptr = std::make_shared>>(shard_count_, std::vector{}); if (all_in_index_) { for (unsigned int i = 0; i < columns.size(); ++i) { fields += ','; auto ret = ShardIndexGenerator::GenerateFieldName(std::make_pair(column_schema_id_[columns[i]], columns[i])); if (ret.first != SUCCESS) { - return std::make_tuple(FAILED, std::move(offsets), std::move(column_values)); + return std::make_tuple(FAILED, std::move(*offset_ptr), std::move(*col_val_ptr)); } fields += ret.second; } @@ -507,12 +511,12 @@ ROW_GROUPS ShardReader::ReadRowGroupByShardIDAndSampleID(const std::vector &columns) { @@ -550,7 +554,10 @@ ROW_GROUP_BRIEF ShardReader::ReadRowGroupCriteria(int group_id, int shard_id, uint64_t page_length = page->GetPageSize(); uint64_t page_offset = page_size_ * page->GetPageID() + header_size_; std::vector> image_offset = GetImageOffset(page->GetPageID(), shard_id, criteria); - + if (image_offset.empty()) { + return std::make_tuple(SUCCESS, file_name, page_length, page_offset, std::vector>(), + std::vector()); + } auto status_labels = GetLabels(page->GetPageID(), shard_id, columns, criteria); if (status_labels.first != SUCCESS) { return std::make_tuple(FAILED, "", 0, 0, std::vector>(), std::vector()); @@ -601,7 +608,7 @@ std::vector> ShardReader::GetImageOffset(int page_id, int db = nullptr; return std::vector>(); } else { - MS_LOG(DEBUG) << "Get " << static_cast(image_offsets.size()) << "records from index."; + MS_LOG(DEBUG) << "Get " << static_cast(image_offsets.size()) << " records from index."; } std::vector> res; for (int i = static_cast(image_offsets.size()) - 1; i >= 0; i--) res.emplace_back(std::vector{0, 0}); @@ -614,6 +621,44 @@ std::vector> ShardReader::GetImageOffset(int page_id, int return res; } +std::pair> ShardReader::GetPagesByCategory( + int shard_id, const std::pair &criteria) { + auto db = database_paths_[shard_id]; + + std::string sql = "SELECT DISTINCT PAGE_ID_BLOB FROM INDEXES WHERE 1 = 1 "; + + if (!criteria.first.empty()) { + auto schema = shard_header_->GetSchemas()[0]->GetSchema(); + + if (kNumberFieldTypeSet.find(schema["schema"][criteria.first]["type"]) != kNumberFieldTypeSet.end()) { + sql += + " AND " + criteria.first + "_" + std::to_string(column_schema_id_[criteria.first]) + " = " + criteria.second; + } else { + sql += " AND " + criteria.first + "_" + std::to_string(column_schema_id_[criteria.first]) + " = '" + + criteria.second + "'"; + } + } + sql += ";"; + std::vector> page_ids; + char *errmsg = nullptr; + int rc = sqlite3_exec(db, common::SafeCStr(sql), SelectCallback, &page_ids, &errmsg); + if (rc != SQLITE_OK) { + MS_LOG(ERROR) << "Error in select statement, sql: " << sql << ", error: " << errmsg; + sqlite3_free(errmsg); + sqlite3_close(db); + db = nullptr; + return std::make_pair(FAILED, std::vector()); + } else { + MS_LOG(DEBUG) << "Get " << page_ids.size() << "pages from index."; + } + std::vector res; + for (int i = 0; i < static_cast(page_ids.size()); ++i) { + res.emplace_back(std::stoull(page_ids[i][0])); + } + sqlite3_free(errmsg); + return std::make_pair(SUCCESS, res); +} + std::pair> ShardReader::GetBlobFields() { std::vector blob_fields; for (auto &p : GetShardHeader()->GetSchemas()) { @@ -642,8 +687,8 @@ void ShardReader::CheckIfColumnInIndex(const std::vector &columns) } } -MSRStatus ShardReader::QueryWithCriteria(sqlite3 *db, string &sql, string criteria, - std::vector> &labels) { +MSRStatus ShardReader::QueryWithCriteria(sqlite3 *db, const string &sql, const string &criteria, + std::shared_ptr>> labels_ptr) { sqlite3_stmt *stmt = nullptr; if (sqlite3_prepare_v2(db, common::SafeCStr(sql), -1, &stmt, 0) != SQLITE_OK) { MS_LOG(ERROR) << "SQL error: could not prepare statement, sql: " << sql; @@ -661,7 +706,7 @@ MSRStatus ShardReader::QueryWithCriteria(sqlite3 *db, string &sql, string criter for (int i = 0; i < ncols; i++) { tmp.emplace_back(reinterpret_cast(sqlite3_column_text(stmt, i))); } - labels.push_back(tmp); + labels_ptr->push_back(tmp); rc = sqlite3_step(stmt); } (void)sqlite3_finalize(stmt); @@ -724,16 +769,16 @@ std::pair> ShardReader::GetLabelsFromPage( auto db = database_paths_[shard_id]; std::string sql = "SELECT PAGE_ID_RAW, PAGE_OFFSET_RAW,PAGE_OFFSET_RAW_END FROM INDEXES WHERE PAGE_ID_BLOB = " + std::to_string(page_id); - std::vector> label_offsets; + auto label_offset_ptr = std::make_shared>>(); if (!criteria.first.empty()) { sql += " AND " + criteria.first + "_" + std::to_string(column_schema_id_[criteria.first]) + " = :criteria"; - if (QueryWithCriteria(db, sql, criteria.second, label_offsets) == FAILED) { + if (QueryWithCriteria(db, sql, criteria.second, label_offset_ptr) == FAILED) { return {FAILED, {}}; } } else { sql += ";"; char *errmsg = nullptr; - int rc = sqlite3_exec(db, common::SafeCStr(sql), SelectCallback, &label_offsets, &errmsg); + int rc = sqlite3_exec(db, common::SafeCStr(sql), SelectCallback, label_offset_ptr.get(), &errmsg); if (rc != SQLITE_OK) { MS_LOG(ERROR) << "Error in select statement, sql: " << sql << ", error: " << errmsg; sqlite3_free(errmsg); @@ -741,11 +786,11 @@ std::pair> ShardReader::GetLabelsFromPage( db = nullptr; return {FAILED, {}}; } - MS_LOG(DEBUG) << "Get " << label_offsets.size() << "records from index."; + MS_LOG(DEBUG) << "Get " << label_offset_ptr->size() << " records from index."; sqlite3_free(errmsg); } // get labels from binary file - return GetLabelsFromBinaryFile(shard_id, columns, label_offsets); + return GetLabelsFromBinaryFile(shard_id, columns, *label_offset_ptr); } std::pair> ShardReader::GetLabels(int page_id, int shard_id, @@ -760,17 +805,17 @@ std::pair> ShardReader::GetLabels(int page_id, int fields += columns[i] + "_" + std::to_string(schema_id); } if (fields.empty()) fields = "*"; - std::vector> labels; + auto labels_ptr = std::make_shared>>(); std::string sql = "SELECT " + fields + " FROM INDEXES WHERE PAGE_ID_BLOB = " + std::to_string(page_id); if (!criteria.first.empty()) { sql += " AND " + criteria.first + "_" + std::to_string(column_schema_id_[criteria.first]) + " = " + ":criteria"; - if (QueryWithCriteria(db, sql, criteria.second, labels) == FAILED) { + if (QueryWithCriteria(db, sql, criteria.second, labels_ptr) == FAILED) { return {FAILED, {}}; } } else { sql += ";"; char *errmsg = nullptr; - int rc = sqlite3_exec(db, common::SafeCStr(sql), SelectCallback, &labels, &errmsg); + int rc = sqlite3_exec(db, common::SafeCStr(sql), SelectCallback, labels_ptr.get(), &errmsg); if (rc != SQLITE_OK) { MS_LOG(ERROR) << "Error in select statement, sql: " << sql << ", error: " << errmsg; sqlite3_free(errmsg); @@ -778,13 +823,13 @@ std::pair> ShardReader::GetLabels(int page_id, int db = nullptr; return {FAILED, {}}; } else { - MS_LOG(DEBUG) << "Get " << static_cast(labels.size()) << "records from index."; + MS_LOG(DEBUG) << "Get " << static_cast(labels_ptr->size()) << " records from index."; } sqlite3_free(errmsg); } std::vector ret; - for (unsigned int i = 0; i < labels.size(); ++i) ret.emplace_back(json{}); - for (unsigned int i = 0; i < labels.size(); ++i) { + for (unsigned int i = 0; i < labels_ptr->size(); ++i) ret.emplace_back(json{}); + for (unsigned int i = 0; i < labels_ptr->size(); ++i) { json construct_json; for (unsigned int j = 0; j < columns.size(); ++j) { // construct json "f1": value @@ -792,15 +837,15 @@ std::pair> ShardReader::GetLabels(int page_id, int // convert the string to base type by schema if (schema[columns[j]]["type"] == "int32") { - construct_json[columns[j]] = StringToNum(labels[i][j]); + construct_json[columns[j]] = StringToNum((*labels_ptr)[i][j]); } else if (schema[columns[j]]["type"] == "int64") { - construct_json[columns[j]] = StringToNum(labels[i][j]); + construct_json[columns[j]] = StringToNum((*labels_ptr)[i][j]); } else if (schema[columns[j]]["type"] == "float32") { - construct_json[columns[j]] = StringToNum(labels[i][j]); + construct_json[columns[j]] = StringToNum((*labels_ptr)[i][j]); } else if (schema[columns[j]]["type"] == "float64") { - construct_json[columns[j]] = StringToNum(labels[i][j]); + construct_json[columns[j]] = StringToNum((*labels_ptr)[i][j]); } else { - construct_json[columns[j]] = std::string(labels[i][j]); + construct_json[columns[j]] = std::string((*labels_ptr)[i][j]); } } ret[i] = construct_json; @@ -834,7 +879,7 @@ int64_t ShardReader::GetNumClasses(const std::string &category_field) { } std::string sql = "SELECT DISTINCT " + ret.second + " FROM INDEXES"; std::vector threads = std::vector(shard_count); - std::set categories; + auto category_ptr = std::make_shared>(); for (int x = 0; x < shard_count; x++) { sqlite3 *db = nullptr; int rc = sqlite3_open_v2(common::SafeCStr(file_paths_[x] + ".db"), &db, SQLITE_OPEN_READONLY, nullptr); @@ -843,13 +888,13 @@ int64_t ShardReader::GetNumClasses(const std::string &category_field) { << sqlite3_errmsg(db); return -1; } - threads[x] = std::thread(&ShardReader::GetClassesInShard, this, db, x, sql, std::ref(categories)); + threads[x] = std::thread(&ShardReader::GetClassesInShard, this, db, x, sql, category_ptr); } for (int x = 0; x < shard_count; x++) { threads[x].join(); } - return categories.size(); + return category_ptr->size(); } MSRStatus ShardReader::CountTotalRows(const std::vector &file_paths, bool load_dataset, @@ -1008,8 +1053,7 @@ MSRStatus ShardReader::Launch(bool isSimpleReader) { return SUCCESS; } -MSRStatus ShardReader::CreateTasksByCategory(const std::vector> &row_group_summary, - const std::shared_ptr &op) { +MSRStatus ShardReader::CreateTasksByCategory(const std::shared_ptr &op) { CheckIfColumnInIndex(selected_columns_); auto category_op = std::dynamic_pointer_cast(op); auto categories = category_op->GetCategories(); @@ -1033,42 +1077,50 @@ MSRStatus ShardReader::CreateTasksByCategory(const std::vector categories_set; - auto ret = GetAllClasses(category_field, categories_set); + auto category_ptr = std::make_shared>(); + auto ret = GetAllClasses(category_field, category_ptr); if (SUCCESS != ret) { return FAILED; } int i = 0; - for (auto it = categories_set.begin(); it != categories_set.end() && i < num_categories; ++it) { + for (auto it = category_ptr->begin(); it != category_ptr->end() && i < num_categories; ++it) { categories.emplace_back(category_field, *it); i++; } } + // Generate task list, a task will create a batch std::vector categoryTasks(categories.size()); for (uint32_t categoryNo = 0; categoryNo < categories.size(); ++categoryNo) { int category_index = 0; - for (const auto &rg : row_group_summary) { - if (category_index >= num_elements) break; - auto shard_id = std::get<0>(rg); - auto group_id = std::get<1>(rg); - - auto details = ReadRowGroupCriteria(group_id, shard_id, categories[categoryNo], selected_columns_); - if (SUCCESS != std::get<0>(details)) { + for (int shard_id = 0; shard_id < shard_count_ && category_index < num_elements; ++shard_id) { + auto res = GetPagesByCategory(shard_id, categories[categoryNo]); + if (SUCCESS != res.first) { return FAILED; } - auto offsets = std::get<4>(details); - - auto number_of_rows = offsets.size(); - for (uint32_t iStart = 0; iStart < number_of_rows; iStart += 1) { - if (category_index < num_elements) { - categoryTasks[categoryNo].InsertTask(TaskType::kCommonTask, shard_id, group_id, std::get<4>(details)[iStart], - std::get<5>(details)[iStart]); - category_index++; + auto page_ids = res.second; + for (const auto &page_id : page_ids) { + if (category_index >= num_elements) break; + const auto &page_t = shard_header_->GetPage(shard_id, page_id); + const auto &page = page_t.first; + auto group_id = page->GetPageTypeID(); + auto details = ReadRowGroupCriteria(group_id, shard_id, categories[categoryNo], selected_columns_); + if (SUCCESS != std::get<0>(details)) { + return FAILED; + } + auto offsets = std::get<4>(details); + + auto number_of_rows = offsets.size(); + for (uint32_t iStart = 0; iStart < number_of_rows; iStart += 1) { + if (category_index < num_elements) { + categoryTasks[categoryNo].InsertTask(TaskType::kCommonTask, shard_id, group_id, + std::get<4>(details)[iStart], std::get<5>(details)[iStart]); + category_index++; + } } + MS_LOG(INFO) << "Category #" << categoryNo << " has " << categoryTasks[categoryNo].Size() << " tasks"; } } - MS_LOG(INFO) << "Category #" << categoryNo << " has " << categoryTasks[categoryNo].Size() << " tasks"; } tasks_ = ShardTask::Combine(categoryTasks, category_op->GetReplacement(), num_elements, num_samples); if (SUCCESS != (*category_op)(tasks_)) { @@ -1189,7 +1241,7 @@ MSRStatus ShardReader::CreateTasks(const std::vector