Merge pull request !998 from guozhijian/enhance_format_func_nametags/v0.3.0-alpha
| @@ -108,7 +108,7 @@ Status MindRecordOp::Init() { | |||
| data_schema_ = std::make_unique<DataSchema>(); | |||
| std::vector<std::shared_ptr<Schema>> schema_vec = shard_reader_->get_shard_header()->get_schemas(); | |||
| std::vector<std::shared_ptr<Schema>> schema_vec = shard_reader_->GetShardHeader()->GetSchemas(); | |||
| // check whether schema exists, if so use the first one | |||
| CHECK_FAIL_RETURN_UNEXPECTED(!schema_vec.empty(), "No schema found"); | |||
| mindrecord::json mr_schema = schema_vec[0]->GetSchema()["schema"]; | |||
| @@ -155,7 +155,7 @@ Status MindRecordOp::Init() { | |||
| column_name_mapping_[columns_to_load_[i]] = i; | |||
| } | |||
| num_rows_ = shard_reader_->get_num_rows(); | |||
| num_rows_ = shard_reader_->GetNumRows(); | |||
| // Compute how many buffers we would need to accomplish rowsPerBuffer | |||
| buffers_needed_ = (num_rows_ + rows_per_buffer_ - 1) / rows_per_buffer_; | |||
| RETURN_IF_NOT_OK(SetColumnsBlob()); | |||
| @@ -164,7 +164,7 @@ Status MindRecordOp::Init() { | |||
| } | |||
| Status MindRecordOp::SetColumnsBlob() { | |||
| columns_blob_ = shard_reader_->get_blob_fields().second; | |||
| columns_blob_ = shard_reader_->GetBlobFields().second; | |||
| // get the exactly blob fields by columns_to_load_ | |||
| std::vector<std::string> columns_blob_exact; | |||
| @@ -600,7 +600,7 @@ Status MindRecordOp::FetchBlockBuffer(const int32_t &buffer_id) { | |||
| // Main logic, Register Queue with TaskGroup, launch all threads and do the functor's work | |||
| Status MindRecordOp::operator()() { | |||
| RETURN_IF_NOT_OK(LaunchThreadAndInitOp()); | |||
| num_rows_ = shard_reader_->get_num_rows(); | |||
| num_rows_ = shard_reader_->GetNumRows(); | |||
| buffers_needed_ = num_rows_ / rows_per_buffer_; | |||
| if (num_rows_ % rows_per_buffer_ != 0) { | |||
| @@ -39,18 +39,18 @@ namespace mindrecord { | |||
| void BindSchema(py::module *m) { | |||
| (void)py::class_<Schema, std::shared_ptr<Schema>>(*m, "Schema", py::module_local()) | |||
| .def_static("build", (std::shared_ptr<Schema>(*)(std::string, py::handle)) & Schema::Build) | |||
| .def("get_desc", &Schema::get_desc) | |||
| .def("get_desc", &Schema::GetDesc) | |||
| .def("get_schema_content", (py::object(Schema::*)()) & Schema::GetSchemaForPython) | |||
| .def("get_blob_fields", &Schema::get_blob_fields) | |||
| .def("get_schema_id", &Schema::get_schema_id); | |||
| .def("get_blob_fields", &Schema::GetBlobFields) | |||
| .def("get_schema_id", &Schema::GetSchemaID); | |||
| } | |||
| void BindStatistics(const py::module *m) { | |||
| (void)py::class_<Statistics, std::shared_ptr<Statistics>>(*m, "Statistics", py::module_local()) | |||
| .def_static("build", (std::shared_ptr<Statistics>(*)(std::string, py::handle)) & Statistics::Build) | |||
| .def("get_desc", &Statistics::get_desc) | |||
| .def("get_desc", &Statistics::GetDesc) | |||
| .def("get_statistics", (py::object(Statistics::*)()) & Statistics::GetStatisticsForPython) | |||
| .def("get_statistics_id", &Statistics::get_statistics_id); | |||
| .def("get_statistics_id", &Statistics::GetStatisticsID); | |||
| } | |||
| void BindShardHeader(const py::module *m) { | |||
| @@ -60,9 +60,9 @@ void BindShardHeader(const py::module *m) { | |||
| .def("add_statistics", &ShardHeader::AddStatistic) | |||
| .def("add_index_fields", | |||
| (MSRStatus(ShardHeader::*)(const std::vector<std::string> &)) & ShardHeader::AddIndexFields) | |||
| .def("get_meta", &ShardHeader::get_schemas) | |||
| .def("get_statistics", &ShardHeader::get_statistics) | |||
| .def("get_fields", &ShardHeader::get_fields) | |||
| .def("get_meta", &ShardHeader::GetSchemas) | |||
| .def("get_statistics", &ShardHeader::GetStatistics) | |||
| .def("get_fields", &ShardHeader::GetFields) | |||
| .def("get_schema_by_id", &ShardHeader::GetSchemaByID) | |||
| .def("get_statistic_by_id", &ShardHeader::GetStatisticByID); | |||
| } | |||
| @@ -72,8 +72,8 @@ void BindShardWriter(py::module *m) { | |||
| .def(py::init<>()) | |||
| .def("open", &ShardWriter::Open) | |||
| .def("open_for_append", &ShardWriter::OpenForAppend) | |||
| .def("set_header_size", &ShardWriter::set_header_size) | |||
| .def("set_page_size", &ShardWriter::set_page_size) | |||
| .def("set_header_size", &ShardWriter::SetHeaderSize) | |||
| .def("set_page_size", &ShardWriter::SetPageSize) | |||
| .def("set_shard_header", &ShardWriter::SetShardHeader) | |||
| .def("write_raw_data", (MSRStatus(ShardWriter::*)(std::map<uint64_t, std::vector<py::handle>> &, | |||
| vector<vector<uint8_t>> &, bool, bool)) & | |||
| @@ -88,8 +88,8 @@ void BindShardReader(const py::module *m) { | |||
| const std::vector<std::shared_ptr<ShardOperator>> &)) & | |||
| ShardReader::OpenPy) | |||
| .def("launch", &ShardReader::Launch) | |||
| .def("get_header", &ShardReader::get_shard_header) | |||
| .def("get_blob_fields", &ShardReader::get_blob_fields) | |||
| .def("get_header", &ShardReader::GetShardHeader) | |||
| .def("get_blob_fields", &ShardReader::GetBlobFields) | |||
| .def("get_next", | |||
| (std::vector<std::tuple<std::vector<uint8_t>, pybind11::object>>(ShardReader::*)()) & ShardReader::GetNextPy) | |||
| .def("finish", &ShardReader::Finish) | |||
| @@ -119,9 +119,9 @@ void BindShardSegment(py::module *m) { | |||
| .def("read_at_page_by_name", (std::pair<MSRStatus, std::vector<std::tuple<std::vector<uint8_t>, pybind11::object>>>( | |||
| ShardSegment::*)(std::string, int64_t, int64_t)) & | |||
| ShardSegment::ReadAtPageByNamePy) | |||
| .def("get_header", &ShardSegment::get_shard_header) | |||
| .def("get_header", &ShardSegment::GetShardHeader) | |||
| .def("get_blob_fields", | |||
| (std::pair<ShardType, std::vector<std::string>>(ShardSegment::*)()) & ShardSegment::get_blob_fields); | |||
| (std::pair<ShardType, std::vector<std::string>>(ShardSegment::*)()) & ShardSegment::GetBlobFields); | |||
| } | |||
| void BindGlobalParams(py::module *m) { | |||
| @@ -36,7 +36,7 @@ class ShardCategory : public ShardOperator { | |||
| ~ShardCategory() override{}; | |||
| const std::vector<std::pair<std::string, std::string>> &get_categories() const { return categories_; } | |||
| const std::vector<std::pair<std::string, std::string>> &GetCategories() const { return categories_; } | |||
| const std::string GetCategoryField() const { return category_field_; } | |||
| @@ -46,7 +46,7 @@ class ShardCategory : public ShardOperator { | |||
| bool GetReplacement() const { return replacement_; } | |||
| MSRStatus execute(ShardTask &tasks) override; | |||
| MSRStatus Execute(ShardTask &tasks) override; | |||
| int64_t GetNumSamples(int64_t dataset_size, int64_t num_classes) override; | |||
| @@ -58,19 +58,19 @@ class ShardHeader { | |||
| /// \brief get the schema | |||
| /// \return the schema | |||
| std::vector<std::shared_ptr<Schema>> get_schemas(); | |||
| std::vector<std::shared_ptr<Schema>> GetSchemas(); | |||
| /// \brief get Statistics | |||
| /// \return the Statistic | |||
| std::vector<std::shared_ptr<Statistics>> get_statistics(); | |||
| std::vector<std::shared_ptr<Statistics>> GetStatistics(); | |||
| /// \brief get the fields of the index | |||
| /// \return the fields of the index | |||
| std::vector<std::pair<uint64_t, std::string>> get_fields(); | |||
| std::vector<std::pair<uint64_t, std::string>> GetFields(); | |||
| /// \brief get the index | |||
| /// \return the index | |||
| std::shared_ptr<Index> get_index(); | |||
| std::shared_ptr<Index> GetIndex(); | |||
| /// \brief get the schema by schemaid | |||
| /// \param[in] schemaId the id of schema needs to be got | |||
| @@ -80,7 +80,7 @@ class ShardHeader { | |||
| /// \brief get the filepath to shard by shardID | |||
| /// \param[in] shardID the id of shard which filepath needs to be obtained | |||
| /// \return the filepath obtained by shardID | |||
| std::string get_shard_address_by_id(int64_t shard_id); | |||
| std::string GetShardAddressByID(int64_t shard_id); | |||
| /// \brief get the statistic by statistic id | |||
| /// \param[in] statisticId the id of statistic needs to be get | |||
| @@ -89,7 +89,7 @@ class ShardHeader { | |||
| MSRStatus InitByFiles(const std::vector<std::string> &file_paths); | |||
| void set_index(Index index) { index_ = std::make_shared<Index>(index); } | |||
| void SetIndex(Index index) { index_ = std::make_shared<Index>(index); } | |||
| std::pair<std::shared_ptr<Page>, MSRStatus> GetPage(const int &shard_id, const int &page_id); | |||
| @@ -103,21 +103,21 @@ class ShardHeader { | |||
| const std::pair<MSRStatus, std::shared_ptr<Page>> GetPageByGroupId(const int &group_id, const int &shard_id); | |||
| std::vector<std::string> get_shard_addresses() const { return shard_addresses_; } | |||
| std::vector<std::string> GetShardAddresses() const { return shard_addresses_; } | |||
| int get_shard_count() const { return shard_count_; } | |||
| int GetShardCount() const { return shard_count_; } | |||
| int get_schema_count() const { return schema_.size(); } | |||
| int GetSchemaCount() const { return schema_.size(); } | |||
| uint64_t get_header_size() const { return header_size_; } | |||
| uint64_t GetHeaderSize() const { return header_size_; } | |||
| uint64_t get_page_size() const { return page_size_; } | |||
| uint64_t GetPageSize() const { return page_size_; } | |||
| void set_header_size(const uint64_t &header_size) { header_size_ = header_size; } | |||
| void SetHeaderSize(const uint64_t &header_size) { header_size_ = header_size; } | |||
| void set_page_size(const uint64_t &page_size) { page_size_ = page_size; } | |||
| void SetPageSize(const uint64_t &page_size) { page_size_ = page_size; } | |||
| const string get_version() { return version_; } | |||
| const string GetVersion() { return version_; } | |||
| std::vector<std::string> SerializeHeader(); | |||
| @@ -132,7 +132,7 @@ class ShardHeader { | |||
| /// \param[in] the shard data real path | |||
| /// \param[in] the headers which readed from the shard data | |||
| /// \return SUCCESS/FAILED | |||
| MSRStatus get_headers(const vector<string> &real_addresses, std::vector<json> &headers); | |||
| MSRStatus GetHeaders(const vector<string> &real_addresses, std::vector<json> &headers); | |||
| MSRStatus ValidateField(const std::vector<std::string> &field_name, json schema, const uint64_t &schema_id); | |||
| @@ -52,7 +52,7 @@ class Index { | |||
| /// \brief get stored fields | |||
| /// \return fields stored | |||
| std::vector<std::pair<uint64_t, std::string> > get_fields(); | |||
| std::vector<std::pair<uint64_t, std::string> > GetFields(); | |||
| private: | |||
| std::vector<std::pair<uint64_t, std::string> > fields_; | |||
| @@ -26,23 +26,23 @@ class ShardOperator { | |||
| virtual ~ShardOperator() = default; | |||
| MSRStatus operator()(ShardTask &tasks) { | |||
| if (SUCCESS != this->pre_execute(tasks)) { | |||
| if (SUCCESS != this->PreExecute(tasks)) { | |||
| return FAILED; | |||
| } | |||
| if (SUCCESS != this->execute(tasks)) { | |||
| if (SUCCESS != this->Execute(tasks)) { | |||
| return FAILED; | |||
| } | |||
| if (SUCCESS != this->suf_execute(tasks)) { | |||
| if (SUCCESS != this->SufExecute(tasks)) { | |||
| return FAILED; | |||
| } | |||
| return SUCCESS; | |||
| } | |||
| virtual MSRStatus pre_execute(ShardTask &tasks) { return SUCCESS; } | |||
| virtual MSRStatus PreExecute(ShardTask &tasks) { return SUCCESS; } | |||
| virtual MSRStatus execute(ShardTask &tasks) = 0; | |||
| virtual MSRStatus Execute(ShardTask &tasks) = 0; | |||
| virtual MSRStatus suf_execute(ShardTask &tasks) { return SUCCESS; } | |||
| virtual MSRStatus SufExecute(ShardTask &tasks) { return SUCCESS; } | |||
| virtual int64_t GetNumSamples(int64_t dataset_size, int64_t num_classes) { return -1; } | |||
| }; | |||
| @@ -53,29 +53,29 @@ class Page { | |||
| /// \return the json format of the page and its description | |||
| json GetPage() const; | |||
| int get_page_id() const { return page_id_; } | |||
| int GetPageID() const { return page_id_; } | |||
| int get_shard_id() const { return shard_id_; } | |||
| int GetShardID() const { return shard_id_; } | |||
| int get_page_type_id() const { return page_type_id_; } | |||
| int GetPageTypeID() const { return page_type_id_; } | |||
| std::string get_page_type() const { return page_type_; } | |||
| std::string GetPageType() const { return page_type_; } | |||
| uint64_t get_page_size() const { return page_size_; } | |||
| uint64_t GetPageSize() const { return page_size_; } | |||
| uint64_t get_start_row_id() const { return start_row_id_; } | |||
| uint64_t GetStartRowID() const { return start_row_id_; } | |||
| uint64_t get_end_row_id() const { return end_row_id_; } | |||
| uint64_t GetEndRowID() const { return end_row_id_; } | |||
| void set_end_row_id(const uint64_t &end_row_id) { end_row_id_ = end_row_id; } | |||
| void SetEndRowID(const uint64_t &end_row_id) { end_row_id_ = end_row_id; } | |||
| void set_page_size(const uint64_t &page_size) { page_size_ = page_size; } | |||
| void SetPageSize(const uint64_t &page_size) { page_size_ = page_size; } | |||
| std::pair<int, uint64_t> get_last_row_group_id() const { return row_group_ids_.back(); } | |||
| std::pair<int, uint64_t> GetLastRowGroupID() const { return row_group_ids_.back(); } | |||
| std::vector<std::pair<int, uint64_t>> get_row_group_ids() const { return row_group_ids_; } | |||
| std::vector<std::pair<int, uint64_t>> GetRowGroupIds() const { return row_group_ids_; } | |||
| void set_row_group_ids(const std::vector<std::pair<int, uint64_t>> &last_row_group_ids) { | |||
| void SetRowGroupIds(const std::vector<std::pair<int, uint64_t>> &last_row_group_ids) { | |||
| row_group_ids_ = last_row_group_ids; | |||
| } | |||
| @@ -37,7 +37,7 @@ class ShardPkSample : public ShardCategory { | |||
| ~ShardPkSample() override{}; | |||
| MSRStatus suf_execute(ShardTask &tasks) override; | |||
| MSRStatus SufExecute(ShardTask &tasks) override; | |||
| private: | |||
| bool shuffle_; | |||
| @@ -107,11 +107,11 @@ class ShardReader { | |||
| /// \brief aim to get the meta data | |||
| /// \return the metadata | |||
| std::shared_ptr<ShardHeader> get_shard_header() const; | |||
| std::shared_ptr<ShardHeader> GetShardHeader() const; | |||
| /// \brief get the number of shards | |||
| /// \return # of shards | |||
| int get_shard_count() const; | |||
| int GetShardCount() const; | |||
| /// \brief get the number of rows in database | |||
| /// \param[in] file_path the path of ONE file, any file in dataset is fine | |||
| @@ -126,7 +126,7 @@ class ShardReader { | |||
| /// \brief get the number of rows in database | |||
| /// \return # of rows | |||
| int get_num_rows() const; | |||
| int GetNumRows() const; | |||
| /// \brief Read the summary of row groups | |||
| /// \return the tuple of 4 elements | |||
| @@ -185,7 +185,7 @@ class ShardReader { | |||
| /// \brief get blob filed list | |||
| /// \return blob field list | |||
| std::pair<ShardType, std::vector<std::string>> get_blob_fields(); | |||
| std::pair<ShardType, std::vector<std::string>> GetBlobFields(); | |||
| /// \brief reset reader | |||
| /// \return null | |||
| @@ -193,10 +193,10 @@ class ShardReader { | |||
| /// \brief set flag of all-in-index | |||
| /// \return null | |||
| void set_all_in_index(bool all_in_index) { all_in_index_ = all_in_index; } | |||
| void SetAllInIndex(bool all_in_index) { all_in_index_ = all_in_index; } | |||
| /// \brief get NLP flag | |||
| bool get_nlp_flag(); | |||
| bool GetNlpFlag(); | |||
| /// \brief get all classes | |||
| MSRStatus GetAllClasses(const std::string &category_field, std::set<std::string> &categories); | |||
| @@ -38,11 +38,11 @@ class ShardSample : public ShardOperator { | |||
| ~ShardSample() override{}; | |||
| const std::pair<int, int> get_partitions() const; | |||
| const std::pair<int, int> GetPartitions() const; | |||
| MSRStatus execute(ShardTask &tasks) override; | |||
| MSRStatus Execute(ShardTask &tasks) override; | |||
| MSRStatus suf_execute(ShardTask &tasks) override; | |||
| MSRStatus SufExecute(ShardTask &tasks) override; | |||
| int64_t GetNumSamples(int64_t dataset_size, int64_t num_classes) override; | |||
| @@ -51,7 +51,7 @@ class Schema { | |||
| /// \brief get the schema and its description | |||
| /// \return the json format of the schema and its description | |||
| std::string get_desc() const; | |||
| std::string GetDesc() const; | |||
| /// \brief get the schema and its description | |||
| /// \return the json format of the schema and its description | |||
| @@ -63,15 +63,15 @@ class Schema { | |||
| /// set the schema id | |||
| /// \param[in] id the id need to be set | |||
| void set_schema_id(int64_t id); | |||
| void SetSchemaID(int64_t id); | |||
| /// get the schema id | |||
| /// \return the int64 schema id | |||
| int64_t get_schema_id() const; | |||
| int64_t GetSchemaID() const; | |||
| /// get the blob fields | |||
| /// \return the vector<string> blob fields | |||
| std::vector<std::string> get_blob_fields() const; | |||
| std::vector<std::string> GetBlobFields() const; | |||
| private: | |||
| Schema() = default; | |||
| @@ -81,7 +81,7 @@ class ShardSegment : public ShardReader { | |||
| std::pair<MSRStatus, std::vector<std::tuple<std::vector<uint8_t>, pybind11::object>>> ReadAtPageByNamePy( | |||
| std::string category_name, int64_t page_no, int64_t n_rows_of_page); | |||
| std::pair<ShardType, std::vector<std::string>> get_blob_fields(); | |||
| std::pair<ShardType, std::vector<std::string>> GetBlobFields(); | |||
| private: | |||
| std::pair<MSRStatus, std::vector<std::tuple<int, std::string, int>>> WrapCategoryInfo(); | |||
| @@ -28,7 +28,7 @@ class ShardShuffle : public ShardOperator { | |||
| ~ShardShuffle() override{}; | |||
| MSRStatus execute(ShardTask &tasks) override; | |||
| MSRStatus Execute(ShardTask &tasks) override; | |||
| private: | |||
| uint32_t shuffle_seed_; | |||
| @@ -53,11 +53,11 @@ class Statistics { | |||
| /// \brief get the description | |||
| /// \return the description | |||
| std::string get_desc() const; | |||
| std::string GetDesc() const; | |||
| /// \brief get the statistic | |||
| /// \return json format of the statistic | |||
| json get_statistics() const; | |||
| json GetStatistics() const; | |||
| /// \brief get the statistic for python | |||
| /// \return the python object of statistics | |||
| @@ -66,11 +66,11 @@ class Statistics { | |||
| /// \brief decode the bson statistics to json | |||
| /// \param[in] encodedStatistics the bson type of statistics | |||
| /// \return json type of statistic | |||
| void set_statistics_id(int64_t id); | |||
| void SetStatisticsID(int64_t id); | |||
| /// \brief get the statistics id | |||
| /// \return the int64 statistics id | |||
| int64_t get_statistics_id() const; | |||
| int64_t GetStatisticsID() const; | |||
| private: | |||
| /// \brief validate the statistic | |||
| @@ -39,9 +39,9 @@ class ShardTask { | |||
| uint32_t SizeOfRows() const; | |||
| std::tuple<std::tuple<int, int>, std::vector<uint64_t>, json> &get_task_by_id(size_t id); | |||
| std::tuple<std::tuple<int, int>, std::vector<uint64_t>, json> &GetTaskByID(size_t id); | |||
| std::tuple<std::tuple<int, int>, std::vector<uint64_t>, json> &get_random_task(); | |||
| std::tuple<std::tuple<int, int>, std::vector<uint64_t>, json> &GetRandomTask(); | |||
| static ShardTask Combine(std::vector<ShardTask> &category_tasks, bool replacement, int64_t num_elements); | |||
| @@ -69,12 +69,12 @@ class ShardWriter { | |||
| /// \brief Set file size | |||
| /// \param[in] header_size the size of header, only (1<<N) is accepted | |||
| /// \return MSRStatus the status of MSRStatus | |||
| MSRStatus set_header_size(const uint64_t &header_size); | |||
| MSRStatus SetHeaderSize(const uint64_t &header_size); | |||
| /// \brief Set page size | |||
| /// \param[in] page_size the size of page, only (1<<N) is accepted | |||
| /// \return MSRStatus the status of MSRStatus | |||
| MSRStatus set_page_size(const uint64_t &page_size); | |||
| MSRStatus SetPageSize(const uint64_t &page_size); | |||
| /// \brief Set shard header | |||
| /// \param[in] header_data the info of header | |||
| @@ -64,7 +64,7 @@ std::pair<MSRStatus, std::string> ShardIndexGenerator::GetValueByField(const str | |||
| } | |||
| // schema does not contain the field | |||
| auto schema = shard_header_.get_schemas()[0]->GetSchema()["schema"]; | |||
| auto schema = shard_header_.GetSchemas()[0]->GetSchema()["schema"]; | |||
| if (schema.find(field) == schema.end()) { | |||
| MS_LOG(ERROR) << "The field " << field << " is not found in schema " << schema; | |||
| return {FAILED, ""}; | |||
| @@ -203,7 +203,7 @@ MSRStatus ShardIndexGenerator::CreateShardNameTable(sqlite3 *db, const std::stri | |||
| } | |||
| std::pair<MSRStatus, sqlite3 *> ShardIndexGenerator::CreateDatabase(int shard_no) { | |||
| std::string shard_address = shard_header_.get_shard_address_by_id(shard_no); | |||
| std::string shard_address = shard_header_.GetShardAddressByID(shard_no); | |||
| if (shard_address.empty()) { | |||
| MS_LOG(ERROR) << "Shard address is null, shard no: " << shard_no; | |||
| return {FAILED, nullptr}; | |||
| @@ -357,12 +357,12 @@ MSRStatus ShardIndexGenerator::BindParameterExecuteSQL( | |||
| MSRStatus ShardIndexGenerator::AddBlobPageInfo(std::vector<std::tuple<std::string, std::string, std::string>> &row_data, | |||
| const std::shared_ptr<Page> cur_blob_page, | |||
| uint64_t &cur_blob_page_offset, std::fstream &in) { | |||
| row_data.emplace_back(":PAGE_ID_BLOB", "INTEGER", std::to_string(cur_blob_page->get_page_id())); | |||
| row_data.emplace_back(":PAGE_ID_BLOB", "INTEGER", std::to_string(cur_blob_page->GetPageID())); | |||
| // blob data start | |||
| row_data.emplace_back(":PAGE_OFFSET_BLOB", "INTEGER", std::to_string(cur_blob_page_offset)); | |||
| auto &io_seekg_blob = | |||
| in.seekg(page_size_ * cur_blob_page->get_page_id() + header_size_ + cur_blob_page_offset, std::ios::beg); | |||
| in.seekg(page_size_ * cur_blob_page->GetPageID() + header_size_ + cur_blob_page_offset, std::ios::beg); | |||
| if (!io_seekg_blob.good() || io_seekg_blob.fail() || io_seekg_blob.bad()) { | |||
| MS_LOG(ERROR) << "File seekg failed"; | |||
| in.close(); | |||
| @@ -405,7 +405,7 @@ ROW_DATA ShardIndexGenerator::GenerateRowData(int shard_no, const std::map<int, | |||
| std::shared_ptr<Page> cur_raw_page = shard_header_.GetPage(shard_no, raw_page_id).first; | |||
| // related blob page | |||
| vector<pair<int, uint64_t>> row_group_list = cur_raw_page->get_row_group_ids(); | |||
| vector<pair<int, uint64_t>> row_group_list = cur_raw_page->GetRowGroupIds(); | |||
| // pair: row_group id, offset in raw data page | |||
| for (pair<int, int> blob_ids : row_group_list) { | |||
| @@ -415,18 +415,18 @@ ROW_DATA ShardIndexGenerator::GenerateRowData(int shard_no, const std::map<int, | |||
| // offset in current raw data page | |||
| auto cur_raw_page_offset = static_cast<uint64_t>(blob_ids.second); | |||
| uint64_t cur_blob_page_offset = 0; | |||
| for (unsigned int i = cur_blob_page->get_start_row_id(); i < cur_blob_page->get_end_row_id(); ++i) { | |||
| for (unsigned int i = cur_blob_page->GetStartRowID(); i < cur_blob_page->GetEndRowID(); ++i) { | |||
| std::vector<std::tuple<std::string, std::string, std::string>> row_data; | |||
| row_data.emplace_back(":ROW_ID", "INTEGER", std::to_string(i)); | |||
| row_data.emplace_back(":ROW_GROUP_ID", "INTEGER", std::to_string(cur_blob_page->get_page_type_id())); | |||
| row_data.emplace_back(":PAGE_ID_RAW", "INTEGER", std::to_string(cur_raw_page->get_page_id())); | |||
| row_data.emplace_back(":ROW_GROUP_ID", "INTEGER", std::to_string(cur_blob_page->GetPageTypeID())); | |||
| row_data.emplace_back(":PAGE_ID_RAW", "INTEGER", std::to_string(cur_raw_page->GetPageID())); | |||
| // raw data start | |||
| row_data.emplace_back(":PAGE_OFFSET_RAW", "INTEGER", std::to_string(cur_raw_page_offset)); | |||
| // calculate raw data end | |||
| auto &io_seekg = | |||
| in.seekg(page_size_ * (cur_raw_page->get_page_id()) + header_size_ + cur_raw_page_offset, std::ios::beg); | |||
| in.seekg(page_size_ * (cur_raw_page->GetPageID()) + header_size_ + cur_raw_page_offset, std::ios::beg); | |||
| if (!io_seekg.good() || io_seekg.fail() || io_seekg.bad()) { | |||
| MS_LOG(ERROR) << "File seekg failed"; | |||
| in.close(); | |||
| @@ -473,7 +473,7 @@ ROW_DATA ShardIndexGenerator::GenerateRowData(int shard_no, const std::map<int, | |||
| INDEX_FIELDS ShardIndexGenerator::GenerateIndexFields(const std::vector<json> &schema_detail) { | |||
| std::vector<std::tuple<std::string, std::string, std::string>> fields; | |||
| // index fields | |||
| std::vector<std::pair<uint64_t, std::string>> index_fields = shard_header_.get_fields(); | |||
| std::vector<std::pair<uint64_t, std::string>> index_fields = shard_header_.GetFields(); | |||
| for (const auto &field : index_fields) { | |||
| if (field.first >= schema_detail.size()) { | |||
| return {FAILED, {}}; | |||
| @@ -504,7 +504,7 @@ MSRStatus ShardIndexGenerator::ExecuteTransaction(const int &shard_no, const std | |||
| const std::vector<int> &raw_page_ids, | |||
| const std::map<int, int> &blob_id_to_page_id) { | |||
| // Add index data to database | |||
| std::string shard_address = shard_header_.get_shard_address_by_id(shard_no); | |||
| std::string shard_address = shard_header_.GetShardAddressByID(shard_no); | |||
| if (shard_address.empty()) { | |||
| MS_LOG(ERROR) << "Shard address is null"; | |||
| return FAILED; | |||
| @@ -546,12 +546,12 @@ MSRStatus ShardIndexGenerator::ExecuteTransaction(const int &shard_no, const std | |||
| } | |||
| MSRStatus ShardIndexGenerator::WriteToDatabase() { | |||
| fields_ = shard_header_.get_fields(); | |||
| page_size_ = shard_header_.get_page_size(); | |||
| header_size_ = shard_header_.get_header_size(); | |||
| schema_count_ = shard_header_.get_schema_count(); | |||
| if (shard_header_.get_shard_count() > kMaxShardCount) { | |||
| MS_LOG(ERROR) << "num shards: " << shard_header_.get_shard_count() << " exceeds max count:" << kMaxSchemaCount; | |||
| fields_ = shard_header_.GetFields(); | |||
| page_size_ = shard_header_.GetPageSize(); | |||
| header_size_ = shard_header_.GetHeaderSize(); | |||
| schema_count_ = shard_header_.GetSchemaCount(); | |||
| if (shard_header_.GetShardCount() > kMaxShardCount) { | |||
| MS_LOG(ERROR) << "num shards: " << shard_header_.GetShardCount() << " exceeds max count:" << kMaxSchemaCount; | |||
| return FAILED; | |||
| } | |||
| task_ = 0; // set two atomic vars to initial value | |||
| @@ -559,7 +559,7 @@ MSRStatus ShardIndexGenerator::WriteToDatabase() { | |||
| // spawn half the physical threads or total number of shards whichever is smaller | |||
| const unsigned int num_workers = | |||
| std::min(std::thread::hardware_concurrency() / 2 + 1, static_cast<unsigned int>(shard_header_.get_shard_count())); | |||
| std::min(std::thread::hardware_concurrency() / 2 + 1, static_cast<unsigned int>(shard_header_.GetShardCount())); | |||
| std::vector<std::thread> threads; | |||
| threads.reserve(num_workers); | |||
| @@ -576,7 +576,7 @@ MSRStatus ShardIndexGenerator::WriteToDatabase() { | |||
| void ShardIndexGenerator::DatabaseWriter() { | |||
| int shard_no = task_++; | |||
| while (shard_no < shard_header_.get_shard_count()) { | |||
| while (shard_no < shard_header_.GetShardCount()) { | |||
| auto db = CreateDatabase(shard_no); | |||
| if (db.first != SUCCESS || db.second == nullptr || write_success_ == false) { | |||
| write_success_ = false; | |||
| @@ -592,10 +592,10 @@ void ShardIndexGenerator::DatabaseWriter() { | |||
| std::vector<int> raw_page_ids; | |||
| for (uint64_t i = 0; i < total_pages; ++i) { | |||
| std::shared_ptr<Page> cur_page = shard_header_.GetPage(shard_no, i).first; | |||
| if (cur_page->get_page_type() == "RAW_DATA") { | |||
| if (cur_page->GetPageType() == "RAW_DATA") { | |||
| raw_page_ids.push_back(i); | |||
| } else if (cur_page->get_page_type() == "BLOB_DATA") { | |||
| blob_id_to_page_id[cur_page->get_page_type_id()] = i; | |||
| } else if (cur_page->GetPageType() == "BLOB_DATA") { | |||
| blob_id_to_page_id[cur_page->GetPageTypeID()] = i; | |||
| } | |||
| } | |||
| @@ -56,9 +56,9 @@ MSRStatus ShardReader::Init(const std::string &file_path) { | |||
| return FAILED; | |||
| } | |||
| shard_header_ = std::make_shared<ShardHeader>(sh); | |||
| header_size_ = shard_header_->get_header_size(); | |||
| page_size_ = shard_header_->get_page_size(); | |||
| file_paths_ = shard_header_->get_shard_addresses(); | |||
| header_size_ = shard_header_->GetHeaderSize(); | |||
| page_size_ = shard_header_->GetPageSize(); | |||
| file_paths_ = shard_header_->GetShardAddresses(); | |||
| for (const auto &file : file_paths_) { | |||
| sqlite3 *db = nullptr; | |||
| @@ -105,7 +105,7 @@ MSRStatus ShardReader::Init(const std::string &file_path) { | |||
| MSRStatus ShardReader::CheckColumnList(const std::vector<std::string> &selected_columns) { | |||
| vector<int> inSchema(selected_columns.size(), 0); | |||
| for (auto &p : get_shard_header()->get_schemas()) { | |||
| for (auto &p : GetShardHeader()->GetSchemas()) { | |||
| auto schema = p->GetSchema()["schema"]; | |||
| for (unsigned int i = 0; i < selected_columns.size(); ++i) { | |||
| if (schema.find(selected_columns[i]) != schema.end()) { | |||
| @@ -183,15 +183,15 @@ void ShardReader::Close() { | |||
| FileStreamsOperator(); | |||
| } | |||
| std::shared_ptr<ShardHeader> ShardReader::get_shard_header() const { return shard_header_; } | |||
| std::shared_ptr<ShardHeader> ShardReader::GetShardHeader() const { return shard_header_; } | |||
| int ShardReader::get_shard_count() const { return shard_header_->get_shard_count(); } | |||
| int ShardReader::GetShardCount() const { return shard_header_->GetShardCount(); } | |||
| int ShardReader::get_num_rows() const { return num_rows_; } | |||
| int ShardReader::GetNumRows() const { return num_rows_; } | |||
| std::vector<std::tuple<int, int, int, uint64_t>> ShardReader::ReadRowGroupSummary() { | |||
| std::vector<std::tuple<int, int, int, uint64_t>> row_group_summary; | |||
| int shard_count = shard_header_->get_shard_count(); | |||
| int shard_count = shard_header_->GetShardCount(); | |||
| if (shard_count <= 0) { | |||
| return row_group_summary; | |||
| } | |||
| @@ -205,13 +205,13 @@ std::vector<std::tuple<int, int, int, uint64_t>> ShardReader::ReadRowGroupSummar | |||
| for (uint64_t page_id = 0; page_id <= last_page_id; ++page_id) { | |||
| const auto &page_t = shard_header_->GetPage(shard_id, page_id); | |||
| const auto &page = page_t.first; | |||
| if (page->get_page_type() != kPageTypeBlob) continue; | |||
| uint64_t start_row_id = page->get_start_row_id(); | |||
| if (start_row_id > page->get_end_row_id()) { | |||
| if (page->GetPageType() != kPageTypeBlob) continue; | |||
| uint64_t start_row_id = page->GetStartRowID(); | |||
| if (start_row_id > page->GetEndRowID()) { | |||
| return std::vector<std::tuple<int, int, int, uint64_t>>(); | |||
| } | |||
| uint64_t number_of_rows = page->get_end_row_id() - start_row_id; | |||
| row_group_summary.emplace_back(shard_id, page->get_page_type_id(), start_row_id, number_of_rows); | |||
| uint64_t number_of_rows = page->GetEndRowID() - start_row_id; | |||
| row_group_summary.emplace_back(shard_id, page->GetPageTypeID(), start_row_id, number_of_rows); | |||
| } | |||
| } | |||
| } | |||
| @@ -265,7 +265,7 @@ MSRStatus ShardReader::ConvertLabelToJson(const std::vector<std::vector<std::str | |||
| json construct_json; | |||
| for (unsigned int j = 0; j < columns.size(); ++j) { | |||
| // construct json "f1": value | |||
| auto schema = shard_header_->get_schemas()[0]->GetSchema()["schema"]; | |||
| auto schema = shard_header_->GetSchemas()[0]->GetSchema()["schema"]; | |||
| // convert the string to base type by schema | |||
| if (schema[columns[j]]["type"] == "int32") { | |||
| @@ -317,7 +317,7 @@ MSRStatus ShardReader::ReadAllRowsInShard(int shard_id, const std::string &sql, | |||
| MSRStatus ShardReader::GetAllClasses(const std::string &category_field, std::set<std::string> &categories) { | |||
| std::map<std::string, uint64_t> index_columns; | |||
| for (auto &field : get_shard_header()->get_fields()) { | |||
| for (auto &field : GetShardHeader()->GetFields()) { | |||
| index_columns[field.second] = field.first; | |||
| } | |||
| if (index_columns.find(category_field) == index_columns.end()) { | |||
| @@ -400,11 +400,11 @@ ROW_GROUP_BRIEF ShardReader::ReadRowGroupBrief(int group_id, int shard_id, const | |||
| } | |||
| const std::shared_ptr<Page> &page = ret.second; | |||
| std::string file_name = file_paths_[shard_id]; | |||
| uint64_t page_length = page->get_page_size(); | |||
| uint64_t page_offset = page_size_ * page->get_page_id() + header_size_; | |||
| std::vector<std::vector<uint64_t>> image_offset = GetImageOffset(page->get_page_id(), shard_id); | |||
| uint64_t page_length = page->GetPageSize(); | |||
| uint64_t page_offset = page_size_ * page->GetPageID() + header_size_; | |||
| std::vector<std::vector<uint64_t>> image_offset = GetImageOffset(page->GetPageID(), shard_id); | |||
| auto status_labels = GetLabels(page->get_page_id(), shard_id, columns); | |||
| auto status_labels = GetLabels(page->GetPageID(), shard_id, columns); | |||
| if (status_labels.first != SUCCESS) { | |||
| return std::make_tuple(FAILED, "", 0, 0, std::vector<std::vector<uint64_t>>(), std::vector<json>()); | |||
| } | |||
| @@ -426,11 +426,11 @@ ROW_GROUP_BRIEF ShardReader::ReadRowGroupCriteria(int group_id, int shard_id, | |||
| } | |||
| const std::shared_ptr<Page> &page = ret.second; | |||
| std::string file_name = file_paths_[shard_id]; | |||
| uint64_t page_length = page->get_page_size(); | |||
| uint64_t page_offset = page_size_ * page->get_page_id() + header_size_; | |||
| std::vector<std::vector<uint64_t>> image_offset = GetImageOffset(page->get_page_id(), shard_id, criteria); | |||
| uint64_t page_length = page->GetPageSize(); | |||
| uint64_t page_offset = page_size_ * page->GetPageID() + header_size_; | |||
| std::vector<std::vector<uint64_t>> image_offset = GetImageOffset(page->GetPageID(), shard_id, criteria); | |||
| auto status_labels = GetLabels(page->get_page_id(), shard_id, columns, criteria); | |||
| auto status_labels = GetLabels(page->GetPageID(), shard_id, columns, criteria); | |||
| if (status_labels.first != SUCCESS) { | |||
| return std::make_tuple(FAILED, "", 0, 0, std::vector<std::vector<uint64_t>>(), std::vector<json>()); | |||
| } | |||
| @@ -458,7 +458,7 @@ std::vector<std::vector<uint64_t>> ShardReader::GetImageOffset(int page_id, int | |||
| // whether use index search | |||
| if (!criteria.first.empty()) { | |||
| auto schema = shard_header_->get_schemas()[0]->GetSchema(); | |||
| auto schema = shard_header_->GetSchemas()[0]->GetSchema(); | |||
| // not number field should add '' in sql | |||
| if (kNumberFieldTypeSet.find(schema["schema"][criteria.first]["type"]) != kNumberFieldTypeSet.end()) { | |||
| @@ -497,13 +497,13 @@ void ShardReader::CheckNlp() { | |||
| return; | |||
| } | |||
| bool ShardReader::get_nlp_flag() { return nlp_; } | |||
| bool ShardReader::GetNlpFlag() { return nlp_; } | |||
| std::pair<ShardType, std::vector<std::string>> ShardReader::get_blob_fields() { | |||
| std::pair<ShardType, std::vector<std::string>> ShardReader::GetBlobFields() { | |||
| std::vector<std::string> blob_fields; | |||
| for (auto &p : get_shard_header()->get_schemas()) { | |||
| for (auto &p : GetShardHeader()->GetSchemas()) { | |||
| // assume one schema | |||
| const auto &fields = p->get_blob_fields(); | |||
| const auto &fields = p->GetBlobFields(); | |||
| blob_fields.assign(fields.begin(), fields.end()); | |||
| break; | |||
| } | |||
| @@ -516,7 +516,7 @@ void ShardReader::CheckIfColumnInIndex(const std::vector<std::string> &columns) | |||
| all_in_index_ = false; | |||
| return; | |||
| } | |||
| for (auto &field : get_shard_header()->get_fields()) { | |||
| for (auto &field : GetShardHeader()->GetFields()) { | |||
| column_schema_id_[field.second] = field.first; | |||
| } | |||
| for (auto &col : columns) { | |||
| @@ -671,7 +671,7 @@ std::pair<MSRStatus, std::vector<json>> ShardReader::GetLabels(int page_id, int | |||
| json construct_json; | |||
| for (unsigned int j = 0; j < columns.size(); ++j) { | |||
| // construct json "f1": value | |||
| auto schema = shard_header_->get_schemas()[0]->GetSchema()["schema"]; | |||
| auto schema = shard_header_->GetSchemas()[0]->GetSchema()["schema"]; | |||
| // convert the string to base type by schema | |||
| if (schema[columns[j]]["type"] == "int32") { | |||
| @@ -719,9 +719,9 @@ int64_t ShardReader::GetNumClasses(const std::string &file_path, const std::stri | |||
| return -1; | |||
| } | |||
| auto header = std::make_shared<ShardHeader>(sh); | |||
| auto file_paths = header->get_shard_addresses(); | |||
| auto file_paths = header->GetShardAddresses(); | |||
| auto shard_count = file_paths.size(); | |||
| auto index_fields = header->get_fields(); | |||
| auto index_fields = header->GetFields(); | |||
| std::map<std::string, int64_t> map_schema_id_fields; | |||
| for (auto &field : index_fields) { | |||
| @@ -799,7 +799,7 @@ MSRStatus ShardReader::Open(const std::string &file_path, int n_consumer, | |||
| if (nlp_) { | |||
| selected_columns_ = selected_columns; | |||
| } else { | |||
| vector<std::string> blob_fields = get_blob_fields().second; | |||
| vector<std::string> blob_fields = GetBlobFields().second; | |||
| for (unsigned int i = 0; i < selected_columns.size(); ++i) { | |||
| if (!std::any_of(blob_fields.begin(), blob_fields.end(), | |||
| [&selected_columns, i](std::string item) { return selected_columns[i] == item; })) { | |||
| @@ -846,7 +846,7 @@ MSRStatus ShardReader::OpenPy(const std::string &file_path, const int &n_consume | |||
| } | |||
| // should remove blob field from selected_columns when call from python | |||
| std::vector<std::string> columns(selected_columns); | |||
| auto blob_fields = get_blob_fields().second; | |||
| auto blob_fields = GetBlobFields().second; | |||
| for (auto &blob_field : blob_fields) { | |||
| auto it = std::find(selected_columns.begin(), selected_columns.end(), blob_field); | |||
| if (it != selected_columns.end()) { | |||
| @@ -909,7 +909,7 @@ vector<std::string> ShardReader::GetAllColumns() { | |||
| vector<std::string> columns; | |||
| if (nlp_) { | |||
| for (auto &c : selected_columns_) { | |||
| for (auto &p : get_shard_header()->get_schemas()) { | |||
| for (auto &p : GetShardHeader()->GetSchemas()) { | |||
| auto schema = p->GetSchema()["schema"]; // make sure schema is not reference since error occurred in arm. | |||
| for (auto it = schema.begin(); it != schema.end(); ++it) { | |||
| if (it.key() == c) { | |||
| @@ -943,7 +943,7 @@ MSRStatus ShardReader::CreateTasksByCategory(const std::vector<std::tuple<int, i | |||
| CheckIfColumnInIndex(columns); | |||
| auto category_op = std::dynamic_pointer_cast<ShardCategory>(op); | |||
| auto categories = category_op->get_categories(); | |||
| auto categories = category_op->GetCategories(); | |||
| int64_t num_elements = category_op->GetNumElements(); | |||
| if (num_elements <= 0) { | |||
| MS_LOG(ERROR) << "Parameter num_element is not positive"; | |||
| @@ -1104,7 +1104,7 @@ TASK_RETURN_CONTENT ShardReader::ConsumerOneTask(int task_id, uint32_t consumer_ | |||
| } | |||
| // Pick up task from task list | |||
| auto task = tasks_.get_task_by_id(tasks_.permutation_[task_id]); | |||
| auto task = tasks_.GetTaskByID(tasks_.permutation_[task_id]); | |||
| auto shard_id = std::get<0>(std::get<0>(task)); | |||
| auto group_id = std::get<1>(std::get<0>(task)); | |||
| @@ -1117,7 +1117,7 @@ TASK_RETURN_CONTENT ShardReader::ConsumerOneTask(int task_id, uint32_t consumer_ | |||
| // Pack image list | |||
| std::vector<uint8_t> images(addr[1] - addr[0]); | |||
| auto file_offset = header_size_ + page_size_ * (page->get_page_id()) + addr[0]; | |||
| auto file_offset = header_size_ + page_size_ * (page->GetPageID()) + addr[0]; | |||
| auto &io_seekg = file_streams_random_[consumer_id][shard_id]->seekg(file_offset, std::ios::beg); | |||
| if (!io_seekg.good() || io_seekg.fail() || io_seekg.bad()) { | |||
| @@ -1139,7 +1139,7 @@ TASK_RETURN_CONTENT ShardReader::ConsumerOneTask(int task_id, uint32_t consumer_ | |||
| if (selected_columns_.size() == 0) { | |||
| images_with_exact_columns = images; | |||
| } else { | |||
| auto blob_fields = get_blob_fields(); | |||
| auto blob_fields = GetBlobFields(); | |||
| std::vector<uint32_t> ordered_selected_columns_index; | |||
| uint32_t index = 0; | |||
| @@ -1272,7 +1272,7 @@ MSRStatus ShardReader::ConsumerByBlock(int consumer_id) { | |||
| } | |||
| // Pick up task from task list | |||
| auto task = tasks_.get_task_by_id(tasks_.permutation_[task_id]); | |||
| auto task = tasks_.GetTaskByID(tasks_.permutation_[task_id]); | |||
| auto shard_id = std::get<0>(std::get<0>(task)); | |||
| auto group_id = std::get<1>(std::get<0>(task)); | |||
| @@ -28,7 +28,7 @@ using mindspore::MsLogLevel::INFO; | |||
| namespace mindspore { | |||
| namespace mindrecord { | |||
| ShardSegment::ShardSegment() { set_all_in_index(false); } | |||
| ShardSegment::ShardSegment() { SetAllInIndex(false); } | |||
| std::pair<MSRStatus, vector<std::string>> ShardSegment::GetCategoryFields() { | |||
| // Skip if already populated | |||
| @@ -211,7 +211,7 @@ std::pair<MSRStatus, std::vector<uint8_t>> ShardSegment::PackImages(int group_id | |||
| // Pack image list | |||
| std::vector<uint8_t> images(offset[1] - offset[0]); | |||
| auto file_offset = header_size_ + page_size_ * (blob_page->get_page_id()) + offset[0]; | |||
| auto file_offset = header_size_ + page_size_ * (blob_page->GetPageID()) + offset[0]; | |||
| auto &io_seekg = file_streams_random_[0][shard_id]->seekg(file_offset, std::ios::beg); | |||
| if (!io_seekg.good() || io_seekg.fail() || io_seekg.bad()) { | |||
| MS_LOG(ERROR) << "File seekg failed"; | |||
| @@ -363,21 +363,21 @@ std::pair<MSRStatus, std::vector<std::tuple<std::vector<uint8_t>, pybind11::obje | |||
| return {SUCCESS, std::move(json_data)}; | |||
| } | |||
| std::pair<ShardType, std::vector<std::string>> ShardSegment::get_blob_fields() { | |||
| std::pair<ShardType, std::vector<std::string>> ShardSegment::GetBlobFields() { | |||
| std::vector<std::string> blob_fields; | |||
| for (auto &p : get_shard_header()->get_schemas()) { | |||
| for (auto &p : GetShardHeader()->GetSchemas()) { | |||
| // assume one schema | |||
| const auto &fields = p->get_blob_fields(); | |||
| const auto &fields = p->GetBlobFields(); | |||
| blob_fields.assign(fields.begin(), fields.end()); | |||
| break; | |||
| } | |||
| return std::make_pair(get_nlp_flag() ? kNLP : kCV, blob_fields); | |||
| return std::make_pair(GetNlpFlag() ? kNLP : kCV, blob_fields); | |||
| } | |||
| std::tuple<std::vector<uint8_t>, json> ShardSegment::GetImageLabel(std::vector<uint8_t> images, json label) { | |||
| if (get_nlp_flag()) { | |||
| if (GetNlpFlag()) { | |||
| vector<std::string> columns; | |||
| for (auto &p : get_shard_header()->get_schemas()) { | |||
| for (auto &p : GetShardHeader()->GetSchemas()) { | |||
| auto schema = p->GetSchema()["schema"]; // make sure schema is not reference since error occurred in arm. | |||
| auto schema_items = schema.items(); | |||
| using it_type = decltype(schema_items.begin()); | |||
| @@ -179,12 +179,12 @@ MSRStatus ShardWriter::OpenForAppend(const std::string &path) { | |||
| return FAILED; | |||
| } | |||
| shard_header_ = std::make_shared<ShardHeader>(sh); | |||
| auto paths = shard_header_->get_shard_addresses(); | |||
| MSRStatus ret = set_header_size(shard_header_->get_header_size()); | |||
| auto paths = shard_header_->GetShardAddresses(); | |||
| MSRStatus ret = SetHeaderSize(shard_header_->GetHeaderSize()); | |||
| if (ret == FAILED) { | |||
| return FAILED; | |||
| } | |||
| ret = set_page_size(shard_header_->get_page_size()); | |||
| ret = SetPageSize(shard_header_->GetPageSize()); | |||
| if (ret == FAILED) { | |||
| return FAILED; | |||
| } | |||
| @@ -229,10 +229,10 @@ MSRStatus ShardWriter::SetShardHeader(std::shared_ptr<ShardHeader> header_data) | |||
| } | |||
| // set fields in mindrecord when empty | |||
| std::vector<std::pair<uint64_t, std::string>> fields = header_data->get_fields(); | |||
| std::vector<std::pair<uint64_t, std::string>> fields = header_data->GetFields(); | |||
| if (fields.empty()) { | |||
| MS_LOG(DEBUG) << "Missing index fields by user, auto generate index fields."; | |||
| std::vector<std::shared_ptr<Schema>> schemas = header_data->get_schemas(); | |||
| std::vector<std::shared_ptr<Schema>> schemas = header_data->GetSchemas(); | |||
| for (const auto &schema : schemas) { | |||
| json jsonSchema = schema->GetSchema()["schema"]; | |||
| for (const auto &el : jsonSchema.items()) { | |||
| @@ -241,7 +241,7 @@ MSRStatus ShardWriter::SetShardHeader(std::shared_ptr<ShardHeader> header_data) | |||
| (el.value()["type"] == "int64" && el.value().find("shape") == el.value().end()) || | |||
| (el.value()["type"] == "float32" && el.value().find("shape") == el.value().end()) || | |||
| (el.value()["type"] == "float64" && el.value().find("shape") == el.value().end())) { | |||
| fields.emplace_back(std::make_pair(schema->get_schema_id(), el.key())); | |||
| fields.emplace_back(std::make_pair(schema->GetSchemaID(), el.key())); | |||
| } | |||
| } | |||
| } | |||
| @@ -256,12 +256,12 @@ MSRStatus ShardWriter::SetShardHeader(std::shared_ptr<ShardHeader> header_data) | |||
| } | |||
| shard_header_ = header_data; | |||
| shard_header_->set_header_size(header_size_); | |||
| shard_header_->set_page_size(page_size_); | |||
| shard_header_->SetHeaderSize(header_size_); | |||
| shard_header_->SetPageSize(page_size_); | |||
| return SUCCESS; | |||
| } | |||
| MSRStatus ShardWriter::set_header_size(const uint64_t &header_size) { | |||
| MSRStatus ShardWriter::SetHeaderSize(const uint64_t &header_size) { | |||
| // header_size [16KB, 128MB] | |||
| if (header_size < kMinHeaderSize || header_size > kMaxHeaderSize) { | |||
| MS_LOG(ERROR) << "Header size should between 16KB and 128MB."; | |||
| @@ -276,7 +276,7 @@ MSRStatus ShardWriter::set_header_size(const uint64_t &header_size) { | |||
| return SUCCESS; | |||
| } | |||
| MSRStatus ShardWriter::set_page_size(const uint64_t &page_size) { | |||
| MSRStatus ShardWriter::SetPageSize(const uint64_t &page_size) { | |||
| // PageSize [32KB, 256MB] | |||
| if (page_size < kMinPageSize || page_size > kMaxPageSize) { | |||
| MS_LOG(ERROR) << "Page size should between 16KB and 256MB."; | |||
| @@ -398,7 +398,7 @@ MSRStatus ShardWriter::CheckData(const std::map<uint64_t, std::vector<json>> &ra | |||
| return FAILED; | |||
| } | |||
| json schema = result.first->GetSchema()["schema"]; | |||
| for (const auto &field : result.first->get_blob_fields()) { | |||
| for (const auto &field : result.first->GetBlobFields()) { | |||
| (void)schema.erase(field); | |||
| } | |||
| std::vector<json> sub_raw_data = rawdata_iter->second; | |||
| @@ -456,7 +456,7 @@ std::tuple<MSRStatus, int, int> ShardWriter::ValidateRawData(std::map<uint64_t, | |||
| MS_LOG(DEBUG) << "Schema count is " << schema_count_; | |||
| // Determine if the number of schemas is the same | |||
| if (shard_header_->get_schemas().size() != schema_count_) { | |||
| if (shard_header_->GetSchemas().size() != schema_count_) { | |||
| MS_LOG(ERROR) << "Data size is not equal with the schema size"; | |||
| return failed; | |||
| } | |||
| @@ -475,9 +475,9 @@ std::tuple<MSRStatus, int, int> ShardWriter::ValidateRawData(std::map<uint64_t, | |||
| } | |||
| (void)schema_ids.insert(rawdata_iter->first); | |||
| } | |||
| const std::vector<std::shared_ptr<Schema>> &schemas = shard_header_->get_schemas(); | |||
| const std::vector<std::shared_ptr<Schema>> &schemas = shard_header_->GetSchemas(); | |||
| if (std::any_of(schemas.begin(), schemas.end(), [schema_ids](const std::shared_ptr<Schema> &schema) { | |||
| return schema_ids.find(schema->get_schema_id()) == schema_ids.end(); | |||
| return schema_ids.find(schema->GetSchemaID()) == schema_ids.end(); | |||
| })) { | |||
| // There is not enough data which is not matching the number of schema | |||
| MS_LOG(ERROR) << "Input rawdata schema id do not match real schema id."; | |||
| @@ -810,10 +810,10 @@ MSRStatus ShardWriter::CutRowGroup(int start_row, int end_row, const std::vector | |||
| std::vector<std::pair<int, int>> &rows_in_group, | |||
| const std::shared_ptr<Page> &last_raw_page, | |||
| const std::shared_ptr<Page> &last_blob_page) { | |||
| auto n_byte_blob = last_blob_page ? last_blob_page->get_page_size() : 0; | |||
| auto n_byte_blob = last_blob_page ? last_blob_page->GetPageSize() : 0; | |||
| auto last_raw_page_size = last_raw_page ? last_raw_page->get_page_size() : 0; | |||
| auto last_raw_offset = last_raw_page ? last_raw_page->get_last_row_group_id().second : 0; | |||
| auto last_raw_page_size = last_raw_page ? last_raw_page->GetPageSize() : 0; | |||
| auto last_raw_offset = last_raw_page ? last_raw_page->GetLastRowGroupID().second : 0; | |||
| auto n_byte_raw = last_raw_page_size - last_raw_offset; | |||
| int page_start_row = start_row; | |||
| @@ -849,8 +849,8 @@ MSRStatus ShardWriter::AppendBlobPage(const int &shard_id, const std::vector<std | |||
| if (blob_row.first == blob_row.second) return SUCCESS; | |||
| // Write disk | |||
| auto page_id = last_blob_page->get_page_id(); | |||
| auto bytes_page = last_blob_page->get_page_size(); | |||
| auto page_id = last_blob_page->GetPageID(); | |||
| auto bytes_page = last_blob_page->GetPageSize(); | |||
| auto &io_seekp = file_streams_[shard_id]->seekp(page_size_ * page_id + header_size_ + bytes_page, std::ios::beg); | |||
| if (!io_seekp.good() || io_seekp.fail() || io_seekp.bad()) { | |||
| MS_LOG(ERROR) << "File seekp failed"; | |||
| @@ -862,9 +862,9 @@ MSRStatus ShardWriter::AppendBlobPage(const int &shard_id, const std::vector<std | |||
| // Update last blob page | |||
| bytes_page += std::accumulate(blob_data_size_.begin() + blob_row.first, blob_data_size_.begin() + blob_row.second, 0); | |||
| last_blob_page->set_page_size(bytes_page); | |||
| uint64_t end_row = last_blob_page->get_end_row_id() + blob_row.second - blob_row.first; | |||
| last_blob_page->set_end_row_id(end_row); | |||
| last_blob_page->SetPageSize(bytes_page); | |||
| uint64_t end_row = last_blob_page->GetEndRowID() + blob_row.second - blob_row.first; | |||
| last_blob_page->SetEndRowID(end_row); | |||
| (void)shard_header_->SetPage(last_blob_page); | |||
| return SUCCESS; | |||
| } | |||
| @@ -873,8 +873,8 @@ MSRStatus ShardWriter::NewBlobPage(const int &shard_id, const std::vector<std::v | |||
| const std::vector<std::pair<int, int>> &rows_in_group, | |||
| const std::shared_ptr<Page> &last_blob_page) { | |||
| auto page_id = shard_header_->GetLastPageId(shard_id); | |||
| auto page_type_id = last_blob_page ? last_blob_page->get_page_type_id() : -1; | |||
| auto current_row = last_blob_page ? last_blob_page->get_end_row_id() : 0; | |||
| auto page_type_id = last_blob_page ? last_blob_page->GetPageTypeID() : -1; | |||
| auto current_row = last_blob_page ? last_blob_page->GetEndRowID() : 0; | |||
| // index(0) indicate appendBlobPage | |||
| for (uint32_t i = 1; i < rows_in_group.size(); ++i) { | |||
| auto blob_row = rows_in_group[i]; | |||
| @@ -905,15 +905,15 @@ MSRStatus ShardWriter::ShiftRawPage(const int &shard_id, const std::vector<std:: | |||
| std::shared_ptr<Page> &last_raw_page) { | |||
| auto blob_row = rows_in_group[0]; | |||
| if (blob_row.first == blob_row.second) return SUCCESS; | |||
| auto last_raw_page_size = last_raw_page ? last_raw_page->get_page_size() : 0; | |||
| auto last_raw_page_size = last_raw_page ? last_raw_page->GetPageSize() : 0; | |||
| if (std::accumulate(raw_data_size_.begin() + blob_row.first, raw_data_size_.begin() + blob_row.second, 0) + | |||
| last_raw_page_size <= | |||
| page_size_) { | |||
| return SUCCESS; | |||
| } | |||
| auto page_id = shard_header_->GetLastPageId(shard_id); | |||
| auto last_row_group_id_offset = last_raw_page->get_last_row_group_id().second; | |||
| auto last_raw_page_id = last_raw_page->get_page_id(); | |||
| auto last_row_group_id_offset = last_raw_page->GetLastRowGroupID().second; | |||
| auto last_raw_page_id = last_raw_page->GetPageID(); | |||
| auto shift_size = last_raw_page_size - last_row_group_id_offset; | |||
| std::vector<uint8_t> buf(shift_size); | |||
| @@ -956,10 +956,10 @@ MSRStatus ShardWriter::ShiftRawPage(const int &shard_id, const std::vector<std:: | |||
| (void)shard_header_->SetPage(last_raw_page); | |||
| // Refresh page info in header | |||
| int row_group_id = last_raw_page->get_last_row_group_id().first + 1; | |||
| int row_group_id = last_raw_page->GetLastRowGroupID().first + 1; | |||
| std::vector<std::pair<int, uint64_t>> row_group_ids; | |||
| row_group_ids.emplace_back(row_group_id, 0); | |||
| int page_type_id = last_raw_page->get_page_id(); | |||
| int page_type_id = last_raw_page->GetPageID(); | |||
| auto page = Page(++page_id, shard_id, kPageTypeRaw, ++page_type_id, 0, 0, row_group_ids, shift_size); | |||
| (void)shard_header_->AddPage(std::make_shared<Page>(page)); | |||
| @@ -971,7 +971,7 @@ MSRStatus ShardWriter::ShiftRawPage(const int &shard_id, const std::vector<std:: | |||
| MSRStatus ShardWriter::WriteRawPage(const int &shard_id, const std::vector<std::pair<int, int>> &rows_in_group, | |||
| std::shared_ptr<Page> &last_raw_page, | |||
| const std::vector<std::vector<uint8_t>> &bin_raw_data) { | |||
| int last_row_group_id = last_raw_page ? last_raw_page->get_last_row_group_id().first : -1; | |||
| int last_row_group_id = last_raw_page ? last_raw_page->GetLastRowGroupID().first : -1; | |||
| for (uint32_t i = 0; i < rows_in_group.size(); ++i) { | |||
| const auto &blob_row = rows_in_group[i]; | |||
| if (blob_row.first == blob_row.second) continue; | |||
| @@ -979,7 +979,7 @@ MSRStatus ShardWriter::WriteRawPage(const int &shard_id, const std::vector<std:: | |||
| std::accumulate(raw_data_size_.begin() + blob_row.first, raw_data_size_.begin() + blob_row.second, 0); | |||
| if (!last_raw_page) { | |||
| EmptyRawPage(shard_id, last_raw_page); | |||
| } else if (last_raw_page->get_page_size() + raw_size > page_size_) { | |||
| } else if (last_raw_page->GetPageSize() + raw_size > page_size_) { | |||
| (void)shard_header_->SetPage(last_raw_page); | |||
| EmptyRawPage(shard_id, last_raw_page); | |||
| } | |||
| @@ -994,7 +994,7 @@ MSRStatus ShardWriter::WriteRawPage(const int &shard_id, const std::vector<std:: | |||
| void ShardWriter::EmptyRawPage(const int &shard_id, std::shared_ptr<Page> &last_raw_page) { | |||
| auto row_group_ids = std::vector<std::pair<int, uint64_t>>(); | |||
| auto page_id = shard_header_->GetLastPageId(shard_id); | |||
| auto page_type_id = last_raw_page ? last_raw_page->get_page_id() : -1; | |||
| auto page_type_id = last_raw_page ? last_raw_page->GetPageID() : -1; | |||
| auto page = Page(++page_id, shard_id, kPageTypeRaw, ++page_type_id, 0, 0, row_group_ids, 0); | |||
| (void)shard_header_->AddPage(std::make_shared<Page>(page)); | |||
| SetLastRawPage(shard_id, last_raw_page); | |||
| @@ -1003,9 +1003,9 @@ void ShardWriter::EmptyRawPage(const int &shard_id, std::shared_ptr<Page> &last_ | |||
| MSRStatus ShardWriter::AppendRawPage(const int &shard_id, const std::vector<std::pair<int, int>> &rows_in_group, | |||
| const int &chunk_id, int &last_row_group_id, std::shared_ptr<Page> last_raw_page, | |||
| const std::vector<std::vector<uint8_t>> &bin_raw_data) { | |||
| std::vector<std::pair<int, uint64_t>> row_group_ids = last_raw_page->get_row_group_ids(); | |||
| auto last_raw_page_id = last_raw_page->get_page_id(); | |||
| auto n_bytes = last_raw_page->get_page_size(); | |||
| std::vector<std::pair<int, uint64_t>> row_group_ids = last_raw_page->GetRowGroupIds(); | |||
| auto last_raw_page_id = last_raw_page->GetPageID(); | |||
| auto n_bytes = last_raw_page->GetPageSize(); | |||
| // previous raw data page | |||
| auto &io_seekp = | |||
| @@ -1022,8 +1022,8 @@ MSRStatus ShardWriter::AppendRawPage(const int &shard_id, const std::vector<std: | |||
| (void)FlushRawChunk(file_streams_[shard_id], rows_in_group, chunk_id, bin_raw_data); | |||
| // Update previous raw data page | |||
| last_raw_page->set_page_size(n_bytes); | |||
| last_raw_page->set_row_group_ids(row_group_ids); | |||
| last_raw_page->SetPageSize(n_bytes); | |||
| last_raw_page->SetRowGroupIds(row_group_ids); | |||
| (void)shard_header_->SetPage(last_raw_page); | |||
| return SUCCESS; | |||
| @@ -34,7 +34,7 @@ ShardCategory::ShardCategory(const std::string &category_field, int64_t num_elem | |||
| num_categories_(num_categories), | |||
| replacement_(replacement) {} | |||
| MSRStatus ShardCategory::execute(ShardTask &tasks) { return SUCCESS; } | |||
| MSRStatus ShardCategory::Execute(ShardTask &tasks) { return SUCCESS; } | |||
| int64_t ShardCategory::GetNumSamples(int64_t dataset_size, int64_t num_classes) { | |||
| if (dataset_size == 0) return dataset_size; | |||
| @@ -343,7 +343,7 @@ std::vector<std::string> ShardHeader::SerializeHeader() { | |||
| std::string ShardHeader::SerializeIndexFields() { | |||
| json j; | |||
| auto fields = index_->get_fields(); | |||
| auto fields = index_->GetFields(); | |||
| for (const auto &field : fields) { | |||
| j.push_back({{"schema_id", field.first}, {"index_field", field.second}}); | |||
| } | |||
| @@ -365,7 +365,7 @@ std::vector<std::string> ShardHeader::SerializePage() { | |||
| std::string ShardHeader::SerializeStatistics() { | |||
| json j; | |||
| for (const auto &stats : statistics_) { | |||
| j.emplace_back(stats->get_statistics()); | |||
| j.emplace_back(stats->GetStatistics()); | |||
| } | |||
| return j.dump(); | |||
| } | |||
| @@ -398,8 +398,8 @@ MSRStatus ShardHeader::SetPage(const std::shared_ptr<Page> &new_page) { | |||
| if (new_page == nullptr) { | |||
| return FAILED; | |||
| } | |||
| int shard_id = new_page->get_shard_id(); | |||
| int page_id = new_page->get_page_id(); | |||
| int shard_id = new_page->GetShardID(); | |||
| int page_id = new_page->GetPageID(); | |||
| if (shard_id < static_cast<int>(pages_.size()) && page_id < static_cast<int>(pages_[shard_id].size())) { | |||
| pages_[shard_id][page_id] = new_page; | |||
| return SUCCESS; | |||
| @@ -412,8 +412,8 @@ MSRStatus ShardHeader::AddPage(const std::shared_ptr<Page> &new_page) { | |||
| if (new_page == nullptr) { | |||
| return FAILED; | |||
| } | |||
| int shard_id = new_page->get_shard_id(); | |||
| int page_id = new_page->get_page_id(); | |||
| int shard_id = new_page->GetShardID(); | |||
| int page_id = new_page->GetPageID(); | |||
| if (shard_id < static_cast<int>(pages_.size()) && page_id == static_cast<int>(pages_[shard_id].size())) { | |||
| pages_[shard_id].push_back(new_page); | |||
| return SUCCESS; | |||
| @@ -435,8 +435,8 @@ int ShardHeader::GetLastPageIdByType(const int &shard_id, const std::string &pag | |||
| } | |||
| int last_page_id = -1; | |||
| for (uint64_t i = pages_[shard_id].size(); i >= 1; i--) { | |||
| if (pages_[shard_id][i - 1]->get_page_type() == page_type) { | |||
| last_page_id = pages_[shard_id][i - 1]->get_page_id(); | |||
| if (pages_[shard_id][i - 1]->GetPageType() == page_type) { | |||
| last_page_id = pages_[shard_id][i - 1]->GetPageID(); | |||
| return last_page_id; | |||
| } | |||
| } | |||
| @@ -451,7 +451,7 @@ const std::pair<MSRStatus, std::shared_ptr<Page>> ShardHeader::GetPageByGroupId( | |||
| } | |||
| for (uint64_t i = pages_[shard_id].size(); i >= 1; i--) { | |||
| auto page = pages_[shard_id][i - 1]; | |||
| if (page->get_page_type() == kPageTypeBlob && page->get_page_type_id() == group_id) { | |||
| if (page->GetPageType() == kPageTypeBlob && page->GetPageTypeID() == group_id) { | |||
| return {SUCCESS, page}; | |||
| } | |||
| } | |||
| @@ -470,10 +470,10 @@ int ShardHeader::AddSchema(std::shared_ptr<Schema> schema) { | |||
| return -1; | |||
| } | |||
| int64_t schema_id = schema->get_schema_id(); | |||
| int64_t schema_id = schema->GetSchemaID(); | |||
| if (schema_id == -1) { | |||
| schema_id = schema_.size(); | |||
| schema->set_schema_id(schema_id); | |||
| schema->SetSchemaID(schema_id); | |||
| } | |||
| schema_.push_back(schema); | |||
| return schema_id; | |||
| @@ -481,10 +481,10 @@ int ShardHeader::AddSchema(std::shared_ptr<Schema> schema) { | |||
| void ShardHeader::AddStatistic(std::shared_ptr<Statistics> statistic) { | |||
| if (statistic) { | |||
| int64_t statistics_id = statistic->get_statistics_id(); | |||
| int64_t statistics_id = statistic->GetStatisticsID(); | |||
| if (statistics_id == -1) { | |||
| statistics_id = statistics_.size(); | |||
| statistic->set_statistics_id(statistics_id); | |||
| statistic->SetStatisticsID(statistics_id); | |||
| } | |||
| statistics_.push_back(statistic); | |||
| } | |||
| @@ -527,13 +527,13 @@ MSRStatus ShardHeader::AddIndexFields(const std::vector<std::string> &fields) { | |||
| return FAILED; | |||
| } | |||
| if (get_schemas().empty()) { | |||
| if (GetSchemas().empty()) { | |||
| MS_LOG(ERROR) << "No schema is set"; | |||
| return FAILED; | |||
| } | |||
| for (const auto &schemaPtr : schema_) { | |||
| auto result = GetSchemaByID(schemaPtr->get_schema_id()); | |||
| auto result = GetSchemaByID(schemaPtr->GetSchemaID()); | |||
| if (result.second != SUCCESS) { | |||
| MS_LOG(ERROR) << "Could not get schema by id."; | |||
| return FAILED; | |||
| @@ -548,7 +548,7 @@ MSRStatus ShardHeader::AddIndexFields(const std::vector<std::string> &fields) { | |||
| // checkout and add fields for each schema | |||
| std::set<std::string> field_set; | |||
| for (const auto &item : index->get_fields()) { | |||
| for (const auto &item : index->GetFields()) { | |||
| field_set.insert(item.second); | |||
| } | |||
| for (const auto &field : fields) { | |||
| @@ -564,7 +564,7 @@ MSRStatus ShardHeader::AddIndexFields(const std::vector<std::string> &fields) { | |||
| field_set.insert(field); | |||
| // add field into index | |||
| index.get()->AddIndexField(schemaPtr->get_schema_id(), field); | |||
| index.get()->AddIndexField(schemaPtr->GetSchemaID(), field); | |||
| } | |||
| } | |||
| @@ -575,12 +575,12 @@ MSRStatus ShardHeader::AddIndexFields(const std::vector<std::string> &fields) { | |||
| MSRStatus ShardHeader::GetAllSchemaID(std::set<uint64_t> &bucket_count) { | |||
| // get all schema id | |||
| for (const auto &schema : schema_) { | |||
| auto bucket_it = bucket_count.find(schema->get_schema_id()); | |||
| auto bucket_it = bucket_count.find(schema->GetSchemaID()); | |||
| if (bucket_it != bucket_count.end()) { | |||
| MS_LOG(ERROR) << "Schema duplication"; | |||
| return FAILED; | |||
| } else { | |||
| bucket_count.insert(schema->get_schema_id()); | |||
| bucket_count.insert(schema->GetSchemaID()); | |||
| } | |||
| } | |||
| return SUCCESS; | |||
| @@ -603,7 +603,7 @@ MSRStatus ShardHeader::AddIndexFields(std::vector<std::pair<uint64_t, std::strin | |||
| // check and add fields for each schema | |||
| std::set<std::pair<uint64_t, std::string>> field_set; | |||
| for (const auto &item : index->get_fields()) { | |||
| for (const auto &item : index->GetFields()) { | |||
| field_set.insert(item); | |||
| } | |||
| for (const auto &field : fields) { | |||
| @@ -646,20 +646,20 @@ MSRStatus ShardHeader::AddIndexFields(std::vector<std::pair<uint64_t, std::strin | |||
| return SUCCESS; | |||
| } | |||
| std::string ShardHeader::get_shard_address_by_id(int64_t shard_id) { | |||
| std::string ShardHeader::GetShardAddressByID(int64_t shard_id) { | |||
| if (shard_id >= shard_addresses_.size()) { | |||
| return ""; | |||
| } | |||
| return shard_addresses_.at(shard_id); | |||
| } | |||
| std::vector<std::shared_ptr<Schema>> ShardHeader::get_schemas() { return schema_; } | |||
| std::vector<std::shared_ptr<Schema>> ShardHeader::GetSchemas() { return schema_; } | |||
| std::vector<std::shared_ptr<Statistics>> ShardHeader::get_statistics() { return statistics_; } | |||
| std::vector<std::shared_ptr<Statistics>> ShardHeader::GetStatistics() { return statistics_; } | |||
| std::vector<std::pair<uint64_t, std::string>> ShardHeader::get_fields() { return index_->get_fields(); } | |||
| std::vector<std::pair<uint64_t, std::string>> ShardHeader::GetFields() { return index_->GetFields(); } | |||
| std::shared_ptr<Index> ShardHeader::get_index() { return index_; } | |||
| std::shared_ptr<Index> ShardHeader::GetIndex() { return index_; } | |||
| std::pair<std::shared_ptr<Schema>, MSRStatus> ShardHeader::GetSchemaByID(int64_t schema_id) { | |||
| int64_t schemaSize = schema_.size(); | |||
| @@ -28,6 +28,6 @@ void Index::AddIndexField(const int64_t &schemaId, const std::string &field) { | |||
| } | |||
| // Get attribute list | |||
| std::vector<std::pair<uint64_t, std::string>> Index::get_fields() { return fields_; } | |||
| std::vector<std::pair<uint64_t, std::string>> Index::GetFields() { return fields_; } | |||
| } // namespace mindrecord | |||
| } // namespace mindspore | |||
| @@ -34,7 +34,7 @@ ShardPkSample::ShardPkSample(const std::string &category_field, int64_t num_elem | |||
| shuffle_op_ = std::make_shared<ShardShuffle>(seed, kShuffleSample); // do shuffle and replacement | |||
| } | |||
| MSRStatus ShardPkSample::suf_execute(ShardTask &tasks) { | |||
| MSRStatus ShardPkSample::SufExecute(ShardTask &tasks) { | |||
| if (shuffle_ == true) { | |||
| if (SUCCESS != (*shuffle_op_)(tasks)) { | |||
| return FAILED; | |||
| @@ -74,14 +74,14 @@ int64_t ShardSample::GetNumSamples(int64_t dataset_size, int64_t num_classes) { | |||
| return -1; | |||
| } | |||
| const std::pair<int, int> ShardSample::get_partitions() const { | |||
| const std::pair<int, int> ShardSample::GetPartitions() const { | |||
| if (numerator_ == 1 && denominator_ > 1) { | |||
| return std::pair<int, int>(denominator_, partition_id_); | |||
| } | |||
| return std::pair<int, int>(-1, -1); | |||
| } | |||
| MSRStatus ShardSample::execute(ShardTask &tasks) { | |||
| MSRStatus ShardSample::Execute(ShardTask &tasks) { | |||
| int no_of_categories = static_cast<int>(tasks.categories); | |||
| int total_no = static_cast<int>(tasks.Size()); | |||
| @@ -114,11 +114,11 @@ MSRStatus ShardSample::execute(ShardTask &tasks) { | |||
| if (sampler_type_ == kSubsetRandomSampler) { | |||
| for (int i = 0; i < indices_.size(); ++i) { | |||
| int index = ((indices_[i] % total_no) + total_no) % total_no; | |||
| new_tasks.InsertTask(tasks.get_task_by_id(index)); // different mod result between c and python | |||
| new_tasks.InsertTask(tasks.GetTaskByID(index)); // different mod result between c and python | |||
| } | |||
| } else { | |||
| for (int i = partition_id_ * taking; i < (partition_id_ + 1) * taking; i++) { | |||
| new_tasks.InsertTask(tasks.get_task_by_id(i % total_no)); // rounding up. if overflow, go back to start | |||
| new_tasks.InsertTask(tasks.GetTaskByID(i % total_no)); // rounding up. if overflow, go back to start | |||
| } | |||
| } | |||
| std::swap(tasks, new_tasks); | |||
| @@ -129,14 +129,14 @@ MSRStatus ShardSample::execute(ShardTask &tasks) { | |||
| } | |||
| total_no = static_cast<int>(tasks.permutation_.size()); | |||
| for (size_t i = partition_id_ * taking; i < (partition_id_ + 1) * taking; i++) { | |||
| new_tasks.InsertTask(tasks.get_task_by_id(tasks.permutation_[i % total_no])); | |||
| new_tasks.InsertTask(tasks.GetTaskByID(tasks.permutation_[i % total_no])); | |||
| } | |||
| std::swap(tasks, new_tasks); | |||
| } | |||
| return SUCCESS; | |||
| } | |||
| MSRStatus ShardSample::suf_execute(ShardTask &tasks) { | |||
| MSRStatus ShardSample::SufExecute(ShardTask &tasks) { | |||
| if (sampler_type_ == kSubsetRandomSampler) { | |||
| if (SUCCESS != (*shuffle_op_)(tasks)) { | |||
| return FAILED; | |||
| @@ -44,7 +44,7 @@ std::shared_ptr<Schema> Schema::Build(std::string desc, pybind11::handle schema) | |||
| return Build(std::move(desc), schema_json); | |||
| } | |||
| std::string Schema::get_desc() const { return desc_; } | |||
| std::string Schema::GetDesc() const { return desc_; } | |||
| json Schema::GetSchema() const { | |||
| json str_schema; | |||
| @@ -60,11 +60,11 @@ pybind11::object Schema::GetSchemaForPython() const { | |||
| return schema_py; | |||
| } | |||
| void Schema::set_schema_id(int64_t id) { schema_id_ = id; } | |||
| void Schema::SetSchemaID(int64_t id) { schema_id_ = id; } | |||
| int64_t Schema::get_schema_id() const { return schema_id_; } | |||
| int64_t Schema::GetSchemaID() const { return schema_id_; } | |||
| std::vector<std::string> Schema::get_blob_fields() const { return blob_fields_; } | |||
| std::vector<std::string> Schema::GetBlobFields() const { return blob_fields_; } | |||
| std::vector<std::string> Schema::PopulateBlobFields(json schema) { | |||
| std::vector<std::string> blob_fields; | |||
| @@ -155,7 +155,7 @@ bool Schema::Validate(json schema) { | |||
| } | |||
| bool Schema::operator==(const mindrecord::Schema &b) const { | |||
| if (this->get_desc() != b.get_desc() || this->GetSchema() != b.GetSchema()) { | |||
| if (this->GetDesc() != b.GetDesc() || this->GetSchema() != b.GetSchema()) { | |||
| return false; | |||
| } | |||
| return true; | |||
| @@ -23,7 +23,7 @@ namespace mindrecord { | |||
| ShardShuffle::ShardShuffle(uint32_t seed, ShuffleType shuffle_type) | |||
| : shuffle_seed_(seed), shuffle_type_(shuffle_type) {} | |||
| MSRStatus ShardShuffle::execute(ShardTask &tasks) { | |||
| MSRStatus ShardShuffle::Execute(ShardTask &tasks) { | |||
| if (tasks.categories < 1) { | |||
| return FAILED; | |||
| } | |||
| @@ -48,9 +48,9 @@ std::shared_ptr<Statistics> Statistics::Build(std::string desc, pybind11::handle | |||
| return std::make_shared<Statistics>(object_statistics); | |||
| } | |||
| std::string Statistics::get_desc() const { return desc_; } | |||
| std::string Statistics::GetDesc() const { return desc_; } | |||
| json Statistics::get_statistics() const { | |||
| json Statistics::GetStatistics() const { | |||
| json str_statistics; | |||
| str_statistics["desc"] = desc_; | |||
| str_statistics["statistics"] = statistics_; | |||
| @@ -58,13 +58,13 @@ json Statistics::get_statistics() const { | |||
| } | |||
| pybind11::object Statistics::GetStatisticsForPython() const { | |||
| json str_statistics = Statistics::get_statistics(); | |||
| json str_statistics = Statistics::GetStatistics(); | |||
| return nlohmann::detail::FromJsonImpl(str_statistics); | |||
| } | |||
| void Statistics::set_statistics_id(int64_t id) { statistics_id_ = id; } | |||
| void Statistics::SetStatisticsID(int64_t id) { statistics_id_ = id; } | |||
| int64_t Statistics::get_statistics_id() const { return statistics_id_; } | |||
| int64_t Statistics::GetStatisticsID() const { return statistics_id_; } | |||
| bool Statistics::Validate(const json &statistics) { | |||
| if (statistics.size() != kInt1) { | |||
| @@ -103,7 +103,7 @@ bool Statistics::LevelRecursive(json level) { | |||
| } | |||
| bool Statistics::operator==(const Statistics &b) const { | |||
| if (this->get_statistics() != b.get_statistics()) { | |||
| if (this->GetStatistics() != b.GetStatistics()) { | |||
| return false; | |||
| } | |||
| return true; | |||
| @@ -59,12 +59,12 @@ uint32_t ShardTask::SizeOfRows() const { | |||
| return nRows; | |||
| } | |||
| std::tuple<std::tuple<int, int>, std::vector<uint64_t>, json> &ShardTask::get_task_by_id(size_t id) { | |||
| std::tuple<std::tuple<int, int>, std::vector<uint64_t>, json> &ShardTask::GetTaskByID(size_t id) { | |||
| MS_ASSERT(id < task_list_.size()); | |||
| return task_list_[id]; | |||
| } | |||
| std::tuple<std::tuple<int, int>, std::vector<uint64_t>, json> &ShardTask::get_random_task() { | |||
| std::tuple<std::tuple<int, int>, std::vector<uint64_t>, json> &ShardTask::GetRandomTask() { | |||
| std::random_device rd; | |||
| std::mt19937 gen(rd()); | |||
| std::uniform_int_distribution<> dis(0, task_list_.size() - 1); | |||
| @@ -82,7 +82,7 @@ ShardTask ShardTask::Combine(std::vector<ShardTask> &category_tasks, bool replac | |||
| } | |||
| for (uint32_t task_no = 0; task_no < minTasks; task_no++) { | |||
| for (uint32_t i = 0; i < total_categories; i++) { | |||
| res.InsertTask(std::move(category_tasks[i].get_task_by_id(static_cast<int>(task_no)))); | |||
| res.InsertTask(std::move(category_tasks[i].GetTaskByID(static_cast<int>(task_no)))); | |||
| } | |||
| } | |||
| } else { | |||
| @@ -95,7 +95,7 @@ ShardTask ShardTask::Combine(std::vector<ShardTask> &category_tasks, bool replac | |||
| } | |||
| for (uint32_t i = 0; i < total_categories; i++) { | |||
| for (uint32_t j = 0; j < maxTasks; j++) { | |||
| res.InsertTask(category_tasks[i].get_random_task()); | |||
| res.InsertTask(category_tasks[i].GetRandomTask()); | |||
| } | |||
| } | |||
| } | |||
| @@ -52,7 +52,7 @@ TEST_F(TestShard, TestShardSchemaPart) { | |||
| std::shared_ptr<Schema> schema = Schema::Build(desc, j); | |||
| ASSERT_TRUE(schema != nullptr); | |||
| MS_LOG(INFO) << "schema description: " << schema->get_desc() << ", schema: " << | |||
| MS_LOG(INFO) << "schema description: " << schema->GetDesc() << ", schema: " << | |||
| common::SafeCStr(schema->GetSchema().dump()); | |||
| for (int i = 1; i <= 4; i++) { | |||
| string filename = std::string("./imagenet.shard0") + std::to_string(i); | |||
| @@ -71,8 +71,8 @@ TEST_F(TestShard, TestStatisticPart) { | |||
| nlohmann::json statistic_json = json::parse(kStatistics[2]); | |||
| std::shared_ptr<Statistics> statistics = Statistics::Build(desc, statistic_json); | |||
| ASSERT_TRUE(statistics != nullptr); | |||
| MS_LOG(INFO) << "test get_desc(), result: " << statistics->get_desc(); | |||
| MS_LOG(INFO) << "test get_statistics, result: " << statistics->get_statistics().dump(); | |||
| MS_LOG(INFO) << "test get_desc(), result: " << statistics->GetDesc(); | |||
| MS_LOG(INFO) << "test get_statistics, result: " << statistics->GetStatistics().dump(); | |||
| std::string desc2 = "axis"; | |||
| nlohmann::json statistic_json2 = R"({})"; | |||
| @@ -111,13 +111,13 @@ TEST_F(TestShard, TestShardHeaderPart) { | |||
| ASSERT_EQ(res, 0); | |||
| header_data.AddStatistic(statistics1); | |||
| std::vector<Schema> re_schemas; | |||
| for (auto &schema_ptr : header_data.get_schemas()) { | |||
| for (auto &schema_ptr : header_data.GetSchemas()) { | |||
| re_schemas.push_back(*schema_ptr); | |||
| } | |||
| ASSERT_EQ(re_schemas, validate_schema); | |||
| std::vector<Statistics> re_statistics; | |||
| for (auto &statistic : header_data.get_statistics()) { | |||
| for (auto &statistic : header_data.GetStatistics()) { | |||
| re_statistics.push_back(*statistic); | |||
| } | |||
| ASSERT_EQ(re_statistics, validate_statistics); | |||
| @@ -129,7 +129,7 @@ TEST_F(TestShard, TestShardHeaderPart) { | |||
| std::pair<uint64_t, std::string> pair1(0, "name"); | |||
| fields.push_back(pair1); | |||
| ASSERT_TRUE(header_data.AddIndexFields(fields) == SUCCESS); | |||
| std::vector<std::pair<uint64_t, std::string>> resFields = header_data.get_fields(); | |||
| std::vector<std::pair<uint64_t, std::string>> resFields = header_data.GetFields(); | |||
| ASSERT_EQ(resFields, fields); | |||
| } | |||
| @@ -70,7 +70,7 @@ TEST_F(TestShardHeader, AddIndexFields) { | |||
| int schema_id1 = header_data.AddSchema(schema1); | |||
| int schema_id2 = header_data.AddSchema(schema2); | |||
| ASSERT_EQ(schema_id2, -1); | |||
| ASSERT_EQ(header_data.get_schemas().size(), 1); | |||
| ASSERT_EQ(header_data.GetSchemas().size(), 1); | |||
| // check out fields | |||
| std::vector<std::pair<uint64_t, std::string>> fields; | |||
| @@ -81,35 +81,35 @@ TEST_F(TestShardHeader, AddIndexFields) { | |||
| fields.push_back(index_field2); | |||
| MSRStatus res = header_data.AddIndexFields(fields); | |||
| ASSERT_EQ(res, SUCCESS); | |||
| ASSERT_EQ(header_data.get_fields().size(), 2); | |||
| ASSERT_EQ(header_data.GetFields().size(), 2); | |||
| fields.clear(); | |||
| std::pair<uint64_t, std::string> index_field3(schema_id1, "name"); | |||
| fields.push_back(index_field3); | |||
| res = header_data.AddIndexFields(fields); | |||
| ASSERT_EQ(res, FAILED); | |||
| ASSERT_EQ(header_data.get_fields().size(), 2); | |||
| ASSERT_EQ(header_data.GetFields().size(), 2); | |||
| fields.clear(); | |||
| std::pair<uint64_t, std::string> index_field4(schema_id1, "names"); | |||
| fields.push_back(index_field4); | |||
| res = header_data.AddIndexFields(fields); | |||
| ASSERT_EQ(res, FAILED); | |||
| ASSERT_EQ(header_data.get_fields().size(), 2); | |||
| ASSERT_EQ(header_data.GetFields().size(), 2); | |||
| fields.clear(); | |||
| std::pair<uint64_t, std::string> index_field5(schema_id1 + 1, "name"); | |||
| fields.push_back(index_field5); | |||
| res = header_data.AddIndexFields(fields); | |||
| ASSERT_EQ(res, FAILED); | |||
| ASSERT_EQ(header_data.get_fields().size(), 2); | |||
| ASSERT_EQ(header_data.GetFields().size(), 2); | |||
| fields.clear(); | |||
| std::pair<uint64_t, std::string> index_field6(schema_id1, "label"); | |||
| fields.push_back(index_field6); | |||
| res = header_data.AddIndexFields(fields); | |||
| ASSERT_EQ(res, FAILED); | |||
| ASSERT_EQ(header_data.get_fields().size(), 2); | |||
| ASSERT_EQ(header_data.GetFields().size(), 2); | |||
| std::string desc_new = "this is a test1"; | |||
| json schemaContent_new = R"({"name": {"type": "string"}, | |||
| @@ -121,7 +121,7 @@ TEST_F(TestShardHeader, AddIndexFields) { | |||
| mindrecord::ShardHeader header_data_new; | |||
| header_data_new.AddSchema(schema_new); | |||
| ASSERT_EQ(header_data_new.get_schemas().size(), 1); | |||
| ASSERT_EQ(header_data_new.GetSchemas().size(), 1); | |||
| // test add fields | |||
| std::vector<std::string> single_fields; | |||
| @@ -131,25 +131,25 @@ TEST_F(TestShardHeader, AddIndexFields) { | |||
| single_fields.push_back("box"); | |||
| res = header_data_new.AddIndexFields(single_fields); | |||
| ASSERT_EQ(res, FAILED); | |||
| ASSERT_EQ(header_data_new.get_fields().size(), 1); | |||
| ASSERT_EQ(header_data_new.GetFields().size(), 1); | |||
| single_fields.push_back("name"); | |||
| single_fields.push_back("box"); | |||
| res = header_data_new.AddIndexFields(single_fields); | |||
| ASSERT_EQ(res, FAILED); | |||
| ASSERT_EQ(header_data_new.get_fields().size(), 1); | |||
| ASSERT_EQ(header_data_new.GetFields().size(), 1); | |||
| single_fields.clear(); | |||
| single_fields.push_back("names"); | |||
| res = header_data_new.AddIndexFields(single_fields); | |||
| ASSERT_EQ(res, FAILED); | |||
| ASSERT_EQ(header_data_new.get_fields().size(), 1); | |||
| ASSERT_EQ(header_data_new.GetFields().size(), 1); | |||
| single_fields.clear(); | |||
| single_fields.push_back("box"); | |||
| res = header_data_new.AddIndexFields(single_fields); | |||
| ASSERT_EQ(res, SUCCESS); | |||
| ASSERT_EQ(header_data_new.get_fields().size(), 2); | |||
| ASSERT_EQ(header_data_new.GetFields().size(), 2); | |||
| } | |||
| } // namespace mindrecord | |||
| } // namespace mindspore | |||
| @@ -139,7 +139,7 @@ TEST_F(TestShardOperator, TestShardSamplePartition) { | |||
| const int kPar = 2; | |||
| std::vector<std::shared_ptr<ShardOperator>> ops; | |||
| ops.push_back(std::make_shared<ShardSample>(kNum, kDen, kPar)); | |||
| auto partitions = std::dynamic_pointer_cast<ShardSample>(ops[0])->get_partitions(); | |||
| auto partitions = std::dynamic_pointer_cast<ShardSample>(ops[0])->GetPartitions(); | |||
| ASSERT_TRUE(partitions.first == 4); | |||
| ASSERT_TRUE(partitions.second == 2); | |||
| @@ -57,15 +57,15 @@ TEST_F(TestShardPage, TestBasic) { | |||
| Page page = | |||
| Page(kGoldenPageId, kGoldenShardId, kGoldenType, kGoldenTypeId, kGoldenStart, kGoldenEnd, golden_row_group, kGoldenSize); | |||
| EXPECT_EQ(kGoldenPageId, page.get_page_id()); | |||
| EXPECT_EQ(kGoldenShardId, page.get_shard_id()); | |||
| EXPECT_EQ(kGoldenTypeId, page.get_page_type_id()); | |||
| ASSERT_TRUE(kGoldenType == page.get_page_type()); | |||
| EXPECT_EQ(kGoldenSize, page.get_page_size()); | |||
| EXPECT_EQ(kGoldenStart, page.get_start_row_id()); | |||
| EXPECT_EQ(kGoldenEnd, page.get_end_row_id()); | |||
| ASSERT_TRUE(std::make_pair(4, kOffset) == page.get_last_row_group_id()); | |||
| ASSERT_TRUE(golden_row_group == page.get_row_group_ids()); | |||
| EXPECT_EQ(kGoldenPageId, page.GetPageID()); | |||
| EXPECT_EQ(kGoldenShardId, page.GetShardID()); | |||
| EXPECT_EQ(kGoldenTypeId, page.GetPageTypeID()); | |||
| ASSERT_TRUE(kGoldenType == page.GetPageType()); | |||
| EXPECT_EQ(kGoldenSize, page.GetPageSize()); | |||
| EXPECT_EQ(kGoldenStart, page.GetStartRowID()); | |||
| EXPECT_EQ(kGoldenEnd, page.GetEndRowID()); | |||
| ASSERT_TRUE(std::make_pair(4, kOffset) == page.GetLastRowGroupID()); | |||
| ASSERT_TRUE(golden_row_group == page.GetRowGroupIds()); | |||
| } | |||
| TEST_F(TestShardPage, TestSetter) { | |||
| @@ -86,43 +86,43 @@ TEST_F(TestShardPage, TestSetter) { | |||
| Page page = | |||
| Page(kGoldenPageId, kGoldenShardId, kGoldenType, kGoldenTypeId, kGoldenStart, kGoldenEnd, golden_row_group, kGoldenSize); | |||
| EXPECT_EQ(kGoldenPageId, page.get_page_id()); | |||
| EXPECT_EQ(kGoldenShardId, page.get_shard_id()); | |||
| EXPECT_EQ(kGoldenTypeId, page.get_page_type_id()); | |||
| ASSERT_TRUE(kGoldenType == page.get_page_type()); | |||
| EXPECT_EQ(kGoldenSize, page.get_page_size()); | |||
| EXPECT_EQ(kGoldenStart, page.get_start_row_id()); | |||
| EXPECT_EQ(kGoldenEnd, page.get_end_row_id()); | |||
| ASSERT_TRUE(std::make_pair(4, kOffset1) == page.get_last_row_group_id()); | |||
| ASSERT_TRUE(golden_row_group == page.get_row_group_ids()); | |||
| EXPECT_EQ(kGoldenPageId, page.GetPageID()); | |||
| EXPECT_EQ(kGoldenShardId, page.GetShardID()); | |||
| EXPECT_EQ(kGoldenTypeId, page.GetPageTypeID()); | |||
| ASSERT_TRUE(kGoldenType == page.GetPageType()); | |||
| EXPECT_EQ(kGoldenSize, page.GetPageSize()); | |||
| EXPECT_EQ(kGoldenStart, page.GetStartRowID()); | |||
| EXPECT_EQ(kGoldenEnd, page.GetEndRowID()); | |||
| ASSERT_TRUE(std::make_pair(4, kOffset1) == page.GetLastRowGroupID()); | |||
| ASSERT_TRUE(golden_row_group == page.GetRowGroupIds()); | |||
| const int kNewEnd = 33; | |||
| const int kNewSize = 300; | |||
| std::vector<std::pair<int, uint64_t>> new_row_group = {{0, 100}, {100, 200}, {200, 3000}}; | |||
| page.set_end_row_id(kNewEnd); | |||
| page.set_page_size(kNewSize); | |||
| page.set_row_group_ids(new_row_group); | |||
| EXPECT_EQ(kGoldenPageId, page.get_page_id()); | |||
| EXPECT_EQ(kGoldenShardId, page.get_shard_id()); | |||
| EXPECT_EQ(kGoldenTypeId, page.get_page_type_id()); | |||
| ASSERT_TRUE(kGoldenType == page.get_page_type()); | |||
| EXPECT_EQ(kNewSize, page.get_page_size()); | |||
| EXPECT_EQ(kGoldenStart, page.get_start_row_id()); | |||
| EXPECT_EQ(kNewEnd, page.get_end_row_id()); | |||
| ASSERT_TRUE(std::make_pair(200, kOffset2) == page.get_last_row_group_id()); | |||
| ASSERT_TRUE(new_row_group == page.get_row_group_ids()); | |||
| page.SetEndRowID(kNewEnd); | |||
| page.SetPageSize(kNewSize); | |||
| page.SetRowGroupIds(new_row_group); | |||
| EXPECT_EQ(kGoldenPageId, page.GetPageID()); | |||
| EXPECT_EQ(kGoldenShardId, page.GetShardID()); | |||
| EXPECT_EQ(kGoldenTypeId, page.GetPageTypeID()); | |||
| ASSERT_TRUE(kGoldenType == page.GetPageType()); | |||
| EXPECT_EQ(kNewSize, page.GetPageSize()); | |||
| EXPECT_EQ(kGoldenStart, page.GetStartRowID()); | |||
| EXPECT_EQ(kNewEnd, page.GetEndRowID()); | |||
| ASSERT_TRUE(std::make_pair(200, kOffset2) == page.GetLastRowGroupID()); | |||
| ASSERT_TRUE(new_row_group == page.GetRowGroupIds()); | |||
| page.DeleteLastGroupId(); | |||
| EXPECT_EQ(kGoldenPageId, page.get_page_id()); | |||
| EXPECT_EQ(kGoldenShardId, page.get_shard_id()); | |||
| EXPECT_EQ(kGoldenTypeId, page.get_page_type_id()); | |||
| ASSERT_TRUE(kGoldenType == page.get_page_type()); | |||
| EXPECT_EQ(3000, page.get_page_size()); | |||
| EXPECT_EQ(kGoldenStart, page.get_start_row_id()); | |||
| EXPECT_EQ(kNewEnd, page.get_end_row_id()); | |||
| ASSERT_TRUE(std::make_pair(100, kOffset3) == page.get_last_row_group_id()); | |||
| EXPECT_EQ(kGoldenPageId, page.GetPageID()); | |||
| EXPECT_EQ(kGoldenShardId, page.GetShardID()); | |||
| EXPECT_EQ(kGoldenTypeId, page.GetPageTypeID()); | |||
| ASSERT_TRUE(kGoldenType == page.GetPageType()); | |||
| EXPECT_EQ(3000, page.GetPageSize()); | |||
| EXPECT_EQ(kGoldenStart, page.GetStartRowID()); | |||
| EXPECT_EQ(kNewEnd, page.GetEndRowID()); | |||
| ASSERT_TRUE(std::make_pair(100, kOffset3) == page.GetLastRowGroupID()); | |||
| new_row_group.pop_back(); | |||
| ASSERT_TRUE(new_row_group == page.get_row_group_ids()); | |||
| ASSERT_TRUE(new_row_group == page.GetRowGroupIds()); | |||
| } | |||
| TEST_F(TestShardPage, TestJson) { | |||
| @@ -107,15 +107,15 @@ TEST_F(TestShardSchema, TestFunction) { | |||
| std::shared_ptr<Schema> schema = Schema::Build(desc, schema_content); | |||
| ASSERT_NE(schema, nullptr); | |||
| ASSERT_EQ(schema->get_desc(), desc); | |||
| ASSERT_EQ(schema->GetDesc(), desc); | |||
| json schema_json = schema->GetSchema(); | |||
| ASSERT_EQ(schema_json["desc"], desc); | |||
| ASSERT_EQ(schema_json["schema"], schema_content); | |||
| ASSERT_EQ(schema->get_schema_id(), -1); | |||
| schema->set_schema_id(2); | |||
| ASSERT_EQ(schema->get_schema_id(), 2); | |||
| ASSERT_EQ(schema->GetSchemaID(), -1); | |||
| schema->SetSchemaID(2); | |||
| ASSERT_EQ(schema->GetSchemaID(), 2); | |||
| } | |||
| TEST_F(TestStatistics, StatisticPart) { | |||
| @@ -137,8 +137,8 @@ TEST_F(TestStatistics, StatisticPart) { | |||
| ASSERT_NE(statistics, nullptr); | |||
| MS_LOG(INFO) << "test get_desc(), result: " << statistics->get_desc(); | |||
| MS_LOG(INFO) << "test get_statistics, result: " << statistics->get_statistics().dump(); | |||
| MS_LOG(INFO) << "test GetDesc(), result: " << statistics->GetDesc(); | |||
| MS_LOG(INFO) << "test GetStatistics, result: " << statistics->GetStatistics().dump(); | |||
| statistic_json["test"] = "test"; | |||
| statistics = Statistics::Build(desc, statistic_json); | |||
| @@ -194,8 +194,8 @@ TEST_F(TestShardWriter, TestShardWriterShiftRawPage) { | |||
| fw.Open(file_names); | |||
| uint64_t header_size = 1 << 14; | |||
| uint64_t page_size = 1 << 15; | |||
| fw.set_header_size(header_size); | |||
| fw.set_page_size(page_size); | |||
| fw.SetHeaderSize(header_size); | |||
| fw.SetPageSize(page_size); | |||
| // set shardHeader | |||
| fw.SetShardHeader(std::make_shared<mindrecord::ShardHeader>(header_data)); | |||
| @@ -331,8 +331,8 @@ TEST_F(TestShardWriter, TestShardWriterTrial) { | |||
| fw.Open(file_names); | |||
| uint64_t header_size = 1 << 14; | |||
| uint64_t page_size = 1 << 17; | |||
| fw.set_header_size(header_size); | |||
| fw.set_page_size(page_size); | |||
| fw.SetHeaderSize(header_size); | |||
| fw.SetPageSize(page_size); | |||
| // set shardHeader | |||
| fw.SetShardHeader(std::make_shared<mindrecord::ShardHeader>(header_data)); | |||
| @@ -466,8 +466,8 @@ TEST_F(TestShardWriter, TestShardWriterTrialNoFields) { | |||
| fw.Open(file_names); | |||
| uint64_t header_size = 1 << 14; | |||
| uint64_t page_size = 1 << 17; | |||
| fw.set_header_size(header_size); | |||
| fw.set_page_size(page_size); | |||
| fw.SetHeaderSize(header_size); | |||
| fw.SetPageSize(page_size); | |||
| // set shardHeader | |||
| fw.SetShardHeader(std::make_shared<mindrecord::ShardHeader>(header_data)); | |||
| @@ -567,8 +567,8 @@ TEST_F(TestShardWriter, DataCheck) { | |||
| fw.Open(file_names); | |||
| uint64_t header_size = 1 << 14; | |||
| uint64_t page_size = 1 << 17; | |||
| fw.set_header_size(header_size); | |||
| fw.set_page_size(page_size); | |||
| fw.SetHeaderSize(header_size); | |||
| fw.SetPageSize(page_size); | |||
| // set shardHeader | |||
| fw.SetShardHeader(std::make_shared<mindrecord::ShardHeader>(header_data)); | |||
| @@ -668,8 +668,8 @@ TEST_F(TestShardWriter, AllRawDataWrong) { | |||
| fw.Open(file_names); | |||
| uint64_t header_size = 1 << 14; | |||
| uint64_t page_size = 1 << 17; | |||
| fw.set_header_size(header_size); | |||
| fw.set_page_size(page_size); | |||
| fw.SetHeaderSize(header_size); | |||
| fw.SetPageSize(page_size); | |||
| // set shardHeader | |||
| fw.SetShardHeader(std::make_shared<mindrecord::ShardHeader>(header_data)); | |||