Merge pull request !998 from guozhijian/enhance_format_func_nametags/v0.3.0-alpha
| @@ -108,7 +108,7 @@ Status MindRecordOp::Init() { | |||||
| data_schema_ = std::make_unique<DataSchema>(); | data_schema_ = std::make_unique<DataSchema>(); | ||||
| std::vector<std::shared_ptr<Schema>> schema_vec = shard_reader_->get_shard_header()->get_schemas(); | |||||
| std::vector<std::shared_ptr<Schema>> schema_vec = shard_reader_->GetShardHeader()->GetSchemas(); | |||||
| // check whether schema exists, if so use the first one | // check whether schema exists, if so use the first one | ||||
| CHECK_FAIL_RETURN_UNEXPECTED(!schema_vec.empty(), "No schema found"); | CHECK_FAIL_RETURN_UNEXPECTED(!schema_vec.empty(), "No schema found"); | ||||
| mindrecord::json mr_schema = schema_vec[0]->GetSchema()["schema"]; | mindrecord::json mr_schema = schema_vec[0]->GetSchema()["schema"]; | ||||
| @@ -155,7 +155,7 @@ Status MindRecordOp::Init() { | |||||
| column_name_mapping_[columns_to_load_[i]] = i; | column_name_mapping_[columns_to_load_[i]] = i; | ||||
| } | } | ||||
| num_rows_ = shard_reader_->get_num_rows(); | |||||
| num_rows_ = shard_reader_->GetNumRows(); | |||||
| // Compute how many buffers we would need to accomplish rowsPerBuffer | // Compute how many buffers we would need to accomplish rowsPerBuffer | ||||
| buffers_needed_ = (num_rows_ + rows_per_buffer_ - 1) / rows_per_buffer_; | buffers_needed_ = (num_rows_ + rows_per_buffer_ - 1) / rows_per_buffer_; | ||||
| RETURN_IF_NOT_OK(SetColumnsBlob()); | RETURN_IF_NOT_OK(SetColumnsBlob()); | ||||
| @@ -164,7 +164,7 @@ Status MindRecordOp::Init() { | |||||
| } | } | ||||
| Status MindRecordOp::SetColumnsBlob() { | Status MindRecordOp::SetColumnsBlob() { | ||||
| columns_blob_ = shard_reader_->get_blob_fields().second; | |||||
| columns_blob_ = shard_reader_->GetBlobFields().second; | |||||
| // get the exactly blob fields by columns_to_load_ | // get the exactly blob fields by columns_to_load_ | ||||
| std::vector<std::string> columns_blob_exact; | std::vector<std::string> columns_blob_exact; | ||||
| @@ -600,7 +600,7 @@ Status MindRecordOp::FetchBlockBuffer(const int32_t &buffer_id) { | |||||
| // Main logic, Register Queue with TaskGroup, launch all threads and do the functor's work | // Main logic, Register Queue with TaskGroup, launch all threads and do the functor's work | ||||
| Status MindRecordOp::operator()() { | Status MindRecordOp::operator()() { | ||||
| RETURN_IF_NOT_OK(LaunchThreadAndInitOp()); | RETURN_IF_NOT_OK(LaunchThreadAndInitOp()); | ||||
| num_rows_ = shard_reader_->get_num_rows(); | |||||
| num_rows_ = shard_reader_->GetNumRows(); | |||||
| buffers_needed_ = num_rows_ / rows_per_buffer_; | buffers_needed_ = num_rows_ / rows_per_buffer_; | ||||
| if (num_rows_ % rows_per_buffer_ != 0) { | if (num_rows_ % rows_per_buffer_ != 0) { | ||||
| @@ -39,18 +39,18 @@ namespace mindrecord { | |||||
| void BindSchema(py::module *m) { | void BindSchema(py::module *m) { | ||||
| (void)py::class_<Schema, std::shared_ptr<Schema>>(*m, "Schema", py::module_local()) | (void)py::class_<Schema, std::shared_ptr<Schema>>(*m, "Schema", py::module_local()) | ||||
| .def_static("build", (std::shared_ptr<Schema>(*)(std::string, py::handle)) & Schema::Build) | .def_static("build", (std::shared_ptr<Schema>(*)(std::string, py::handle)) & Schema::Build) | ||||
| .def("get_desc", &Schema::get_desc) | |||||
| .def("get_desc", &Schema::GetDesc) | |||||
| .def("get_schema_content", (py::object(Schema::*)()) & Schema::GetSchemaForPython) | .def("get_schema_content", (py::object(Schema::*)()) & Schema::GetSchemaForPython) | ||||
| .def("get_blob_fields", &Schema::get_blob_fields) | |||||
| .def("get_schema_id", &Schema::get_schema_id); | |||||
| .def("get_blob_fields", &Schema::GetBlobFields) | |||||
| .def("get_schema_id", &Schema::GetSchemaID); | |||||
| } | } | ||||
| void BindStatistics(const py::module *m) { | void BindStatistics(const py::module *m) { | ||||
| (void)py::class_<Statistics, std::shared_ptr<Statistics>>(*m, "Statistics", py::module_local()) | (void)py::class_<Statistics, std::shared_ptr<Statistics>>(*m, "Statistics", py::module_local()) | ||||
| .def_static("build", (std::shared_ptr<Statistics>(*)(std::string, py::handle)) & Statistics::Build) | .def_static("build", (std::shared_ptr<Statistics>(*)(std::string, py::handle)) & Statistics::Build) | ||||
| .def("get_desc", &Statistics::get_desc) | |||||
| .def("get_desc", &Statistics::GetDesc) | |||||
| .def("get_statistics", (py::object(Statistics::*)()) & Statistics::GetStatisticsForPython) | .def("get_statistics", (py::object(Statistics::*)()) & Statistics::GetStatisticsForPython) | ||||
| .def("get_statistics_id", &Statistics::get_statistics_id); | |||||
| .def("get_statistics_id", &Statistics::GetStatisticsID); | |||||
| } | } | ||||
| void BindShardHeader(const py::module *m) { | void BindShardHeader(const py::module *m) { | ||||
| @@ -60,9 +60,9 @@ void BindShardHeader(const py::module *m) { | |||||
| .def("add_statistics", &ShardHeader::AddStatistic) | .def("add_statistics", &ShardHeader::AddStatistic) | ||||
| .def("add_index_fields", | .def("add_index_fields", | ||||
| (MSRStatus(ShardHeader::*)(const std::vector<std::string> &)) & ShardHeader::AddIndexFields) | (MSRStatus(ShardHeader::*)(const std::vector<std::string> &)) & ShardHeader::AddIndexFields) | ||||
| .def("get_meta", &ShardHeader::get_schemas) | |||||
| .def("get_statistics", &ShardHeader::get_statistics) | |||||
| .def("get_fields", &ShardHeader::get_fields) | |||||
| .def("get_meta", &ShardHeader::GetSchemas) | |||||
| .def("get_statistics", &ShardHeader::GetStatistics) | |||||
| .def("get_fields", &ShardHeader::GetFields) | |||||
| .def("get_schema_by_id", &ShardHeader::GetSchemaByID) | .def("get_schema_by_id", &ShardHeader::GetSchemaByID) | ||||
| .def("get_statistic_by_id", &ShardHeader::GetStatisticByID); | .def("get_statistic_by_id", &ShardHeader::GetStatisticByID); | ||||
| } | } | ||||
| @@ -72,8 +72,8 @@ void BindShardWriter(py::module *m) { | |||||
| .def(py::init<>()) | .def(py::init<>()) | ||||
| .def("open", &ShardWriter::Open) | .def("open", &ShardWriter::Open) | ||||
| .def("open_for_append", &ShardWriter::OpenForAppend) | .def("open_for_append", &ShardWriter::OpenForAppend) | ||||
| .def("set_header_size", &ShardWriter::set_header_size) | |||||
| .def("set_page_size", &ShardWriter::set_page_size) | |||||
| .def("set_header_size", &ShardWriter::SetHeaderSize) | |||||
| .def("set_page_size", &ShardWriter::SetPageSize) | |||||
| .def("set_shard_header", &ShardWriter::SetShardHeader) | .def("set_shard_header", &ShardWriter::SetShardHeader) | ||||
| .def("write_raw_data", (MSRStatus(ShardWriter::*)(std::map<uint64_t, std::vector<py::handle>> &, | .def("write_raw_data", (MSRStatus(ShardWriter::*)(std::map<uint64_t, std::vector<py::handle>> &, | ||||
| vector<vector<uint8_t>> &, bool, bool)) & | vector<vector<uint8_t>> &, bool, bool)) & | ||||
| @@ -88,8 +88,8 @@ void BindShardReader(const py::module *m) { | |||||
| const std::vector<std::shared_ptr<ShardOperator>> &)) & | const std::vector<std::shared_ptr<ShardOperator>> &)) & | ||||
| ShardReader::OpenPy) | ShardReader::OpenPy) | ||||
| .def("launch", &ShardReader::Launch) | .def("launch", &ShardReader::Launch) | ||||
| .def("get_header", &ShardReader::get_shard_header) | |||||
| .def("get_blob_fields", &ShardReader::get_blob_fields) | |||||
| .def("get_header", &ShardReader::GetShardHeader) | |||||
| .def("get_blob_fields", &ShardReader::GetBlobFields) | |||||
| .def("get_next", | .def("get_next", | ||||
| (std::vector<std::tuple<std::vector<uint8_t>, pybind11::object>>(ShardReader::*)()) & ShardReader::GetNextPy) | (std::vector<std::tuple<std::vector<uint8_t>, pybind11::object>>(ShardReader::*)()) & ShardReader::GetNextPy) | ||||
| .def("finish", &ShardReader::Finish) | .def("finish", &ShardReader::Finish) | ||||
| @@ -119,9 +119,9 @@ void BindShardSegment(py::module *m) { | |||||
| .def("read_at_page_by_name", (std::pair<MSRStatus, std::vector<std::tuple<std::vector<uint8_t>, pybind11::object>>>( | .def("read_at_page_by_name", (std::pair<MSRStatus, std::vector<std::tuple<std::vector<uint8_t>, pybind11::object>>>( | ||||
| ShardSegment::*)(std::string, int64_t, int64_t)) & | ShardSegment::*)(std::string, int64_t, int64_t)) & | ||||
| ShardSegment::ReadAtPageByNamePy) | ShardSegment::ReadAtPageByNamePy) | ||||
| .def("get_header", &ShardSegment::get_shard_header) | |||||
| .def("get_header", &ShardSegment::GetShardHeader) | |||||
| .def("get_blob_fields", | .def("get_blob_fields", | ||||
| (std::pair<ShardType, std::vector<std::string>>(ShardSegment::*)()) & ShardSegment::get_blob_fields); | |||||
| (std::pair<ShardType, std::vector<std::string>>(ShardSegment::*)()) & ShardSegment::GetBlobFields); | |||||
| } | } | ||||
| void BindGlobalParams(py::module *m) { | void BindGlobalParams(py::module *m) { | ||||
| @@ -36,7 +36,7 @@ class ShardCategory : public ShardOperator { | |||||
| ~ShardCategory() override{}; | ~ShardCategory() override{}; | ||||
| const std::vector<std::pair<std::string, std::string>> &get_categories() const { return categories_; } | |||||
| const std::vector<std::pair<std::string, std::string>> &GetCategories() const { return categories_; } | |||||
| const std::string GetCategoryField() const { return category_field_; } | const std::string GetCategoryField() const { return category_field_; } | ||||
| @@ -46,7 +46,7 @@ class ShardCategory : public ShardOperator { | |||||
| bool GetReplacement() const { return replacement_; } | bool GetReplacement() const { return replacement_; } | ||||
| MSRStatus execute(ShardTask &tasks) override; | |||||
| MSRStatus Execute(ShardTask &tasks) override; | |||||
| int64_t GetNumSamples(int64_t dataset_size, int64_t num_classes) override; | int64_t GetNumSamples(int64_t dataset_size, int64_t num_classes) override; | ||||
| @@ -58,19 +58,19 @@ class ShardHeader { | |||||
| /// \brief get the schema | /// \brief get the schema | ||||
| /// \return the schema | /// \return the schema | ||||
| std::vector<std::shared_ptr<Schema>> get_schemas(); | |||||
| std::vector<std::shared_ptr<Schema>> GetSchemas(); | |||||
| /// \brief get Statistics | /// \brief get Statistics | ||||
| /// \return the Statistic | /// \return the Statistic | ||||
| std::vector<std::shared_ptr<Statistics>> get_statistics(); | |||||
| std::vector<std::shared_ptr<Statistics>> GetStatistics(); | |||||
| /// \brief get the fields of the index | /// \brief get the fields of the index | ||||
| /// \return the fields of the index | /// \return the fields of the index | ||||
| std::vector<std::pair<uint64_t, std::string>> get_fields(); | |||||
| std::vector<std::pair<uint64_t, std::string>> GetFields(); | |||||
| /// \brief get the index | /// \brief get the index | ||||
| /// \return the index | /// \return the index | ||||
| std::shared_ptr<Index> get_index(); | |||||
| std::shared_ptr<Index> GetIndex(); | |||||
| /// \brief get the schema by schemaid | /// \brief get the schema by schemaid | ||||
| /// \param[in] schemaId the id of schema needs to be got | /// \param[in] schemaId the id of schema needs to be got | ||||
| @@ -80,7 +80,7 @@ class ShardHeader { | |||||
| /// \brief get the filepath to shard by shardID | /// \brief get the filepath to shard by shardID | ||||
| /// \param[in] shardID the id of shard which filepath needs to be obtained | /// \param[in] shardID the id of shard which filepath needs to be obtained | ||||
| /// \return the filepath obtained by shardID | /// \return the filepath obtained by shardID | ||||
| std::string get_shard_address_by_id(int64_t shard_id); | |||||
| std::string GetShardAddressByID(int64_t shard_id); | |||||
| /// \brief get the statistic by statistic id | /// \brief get the statistic by statistic id | ||||
| /// \param[in] statisticId the id of statistic needs to be get | /// \param[in] statisticId the id of statistic needs to be get | ||||
| @@ -89,7 +89,7 @@ class ShardHeader { | |||||
| MSRStatus InitByFiles(const std::vector<std::string> &file_paths); | MSRStatus InitByFiles(const std::vector<std::string> &file_paths); | ||||
| void set_index(Index index) { index_ = std::make_shared<Index>(index); } | |||||
| void SetIndex(Index index) { index_ = std::make_shared<Index>(index); } | |||||
| std::pair<std::shared_ptr<Page>, MSRStatus> GetPage(const int &shard_id, const int &page_id); | std::pair<std::shared_ptr<Page>, MSRStatus> GetPage(const int &shard_id, const int &page_id); | ||||
| @@ -103,21 +103,21 @@ class ShardHeader { | |||||
| const std::pair<MSRStatus, std::shared_ptr<Page>> GetPageByGroupId(const int &group_id, const int &shard_id); | const std::pair<MSRStatus, std::shared_ptr<Page>> GetPageByGroupId(const int &group_id, const int &shard_id); | ||||
| std::vector<std::string> get_shard_addresses() const { return shard_addresses_; } | |||||
| std::vector<std::string> GetShardAddresses() const { return shard_addresses_; } | |||||
| int get_shard_count() const { return shard_count_; } | |||||
| int GetShardCount() const { return shard_count_; } | |||||
| int get_schema_count() const { return schema_.size(); } | |||||
| int GetSchemaCount() const { return schema_.size(); } | |||||
| uint64_t get_header_size() const { return header_size_; } | |||||
| uint64_t GetHeaderSize() const { return header_size_; } | |||||
| uint64_t get_page_size() const { return page_size_; } | |||||
| uint64_t GetPageSize() const { return page_size_; } | |||||
| void set_header_size(const uint64_t &header_size) { header_size_ = header_size; } | |||||
| void SetHeaderSize(const uint64_t &header_size) { header_size_ = header_size; } | |||||
| void set_page_size(const uint64_t &page_size) { page_size_ = page_size; } | |||||
| void SetPageSize(const uint64_t &page_size) { page_size_ = page_size; } | |||||
| const string get_version() { return version_; } | |||||
| const string GetVersion() { return version_; } | |||||
| std::vector<std::string> SerializeHeader(); | std::vector<std::string> SerializeHeader(); | ||||
| @@ -132,7 +132,7 @@ class ShardHeader { | |||||
| /// \param[in] the shard data real path | /// \param[in] the shard data real path | ||||
| /// \param[in] the headers which readed from the shard data | /// \param[in] the headers which readed from the shard data | ||||
| /// \return SUCCESS/FAILED | /// \return SUCCESS/FAILED | ||||
| MSRStatus get_headers(const vector<string> &real_addresses, std::vector<json> &headers); | |||||
| MSRStatus GetHeaders(const vector<string> &real_addresses, std::vector<json> &headers); | |||||
| MSRStatus ValidateField(const std::vector<std::string> &field_name, json schema, const uint64_t &schema_id); | MSRStatus ValidateField(const std::vector<std::string> &field_name, json schema, const uint64_t &schema_id); | ||||
| @@ -52,7 +52,7 @@ class Index { | |||||
| /// \brief get stored fields | /// \brief get stored fields | ||||
| /// \return fields stored | /// \return fields stored | ||||
| std::vector<std::pair<uint64_t, std::string> > get_fields(); | |||||
| std::vector<std::pair<uint64_t, std::string> > GetFields(); | |||||
| private: | private: | ||||
| std::vector<std::pair<uint64_t, std::string> > fields_; | std::vector<std::pair<uint64_t, std::string> > fields_; | ||||
| @@ -26,23 +26,23 @@ class ShardOperator { | |||||
| virtual ~ShardOperator() = default; | virtual ~ShardOperator() = default; | ||||
| MSRStatus operator()(ShardTask &tasks) { | MSRStatus operator()(ShardTask &tasks) { | ||||
| if (SUCCESS != this->pre_execute(tasks)) { | |||||
| if (SUCCESS != this->PreExecute(tasks)) { | |||||
| return FAILED; | return FAILED; | ||||
| } | } | ||||
| if (SUCCESS != this->execute(tasks)) { | |||||
| if (SUCCESS != this->Execute(tasks)) { | |||||
| return FAILED; | return FAILED; | ||||
| } | } | ||||
| if (SUCCESS != this->suf_execute(tasks)) { | |||||
| if (SUCCESS != this->SufExecute(tasks)) { | |||||
| return FAILED; | return FAILED; | ||||
| } | } | ||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| virtual MSRStatus pre_execute(ShardTask &tasks) { return SUCCESS; } | |||||
| virtual MSRStatus PreExecute(ShardTask &tasks) { return SUCCESS; } | |||||
| virtual MSRStatus execute(ShardTask &tasks) = 0; | |||||
| virtual MSRStatus Execute(ShardTask &tasks) = 0; | |||||
| virtual MSRStatus suf_execute(ShardTask &tasks) { return SUCCESS; } | |||||
| virtual MSRStatus SufExecute(ShardTask &tasks) { return SUCCESS; } | |||||
| virtual int64_t GetNumSamples(int64_t dataset_size, int64_t num_classes) { return -1; } | virtual int64_t GetNumSamples(int64_t dataset_size, int64_t num_classes) { return -1; } | ||||
| }; | }; | ||||
| @@ -53,29 +53,29 @@ class Page { | |||||
| /// \return the json format of the page and its description | /// \return the json format of the page and its description | ||||
| json GetPage() const; | json GetPage() const; | ||||
| int get_page_id() const { return page_id_; } | |||||
| int GetPageID() const { return page_id_; } | |||||
| int get_shard_id() const { return shard_id_; } | |||||
| int GetShardID() const { return shard_id_; } | |||||
| int get_page_type_id() const { return page_type_id_; } | |||||
| int GetPageTypeID() const { return page_type_id_; } | |||||
| std::string get_page_type() const { return page_type_; } | |||||
| std::string GetPageType() const { return page_type_; } | |||||
| uint64_t get_page_size() const { return page_size_; } | |||||
| uint64_t GetPageSize() const { return page_size_; } | |||||
| uint64_t get_start_row_id() const { return start_row_id_; } | |||||
| uint64_t GetStartRowID() const { return start_row_id_; } | |||||
| uint64_t get_end_row_id() const { return end_row_id_; } | |||||
| uint64_t GetEndRowID() const { return end_row_id_; } | |||||
| void set_end_row_id(const uint64_t &end_row_id) { end_row_id_ = end_row_id; } | |||||
| void SetEndRowID(const uint64_t &end_row_id) { end_row_id_ = end_row_id; } | |||||
| void set_page_size(const uint64_t &page_size) { page_size_ = page_size; } | |||||
| void SetPageSize(const uint64_t &page_size) { page_size_ = page_size; } | |||||
| std::pair<int, uint64_t> get_last_row_group_id() const { return row_group_ids_.back(); } | |||||
| std::pair<int, uint64_t> GetLastRowGroupID() const { return row_group_ids_.back(); } | |||||
| std::vector<std::pair<int, uint64_t>> get_row_group_ids() const { return row_group_ids_; } | |||||
| std::vector<std::pair<int, uint64_t>> GetRowGroupIds() const { return row_group_ids_; } | |||||
| void set_row_group_ids(const std::vector<std::pair<int, uint64_t>> &last_row_group_ids) { | |||||
| void SetRowGroupIds(const std::vector<std::pair<int, uint64_t>> &last_row_group_ids) { | |||||
| row_group_ids_ = last_row_group_ids; | row_group_ids_ = last_row_group_ids; | ||||
| } | } | ||||
| @@ -37,7 +37,7 @@ class ShardPkSample : public ShardCategory { | |||||
| ~ShardPkSample() override{}; | ~ShardPkSample() override{}; | ||||
| MSRStatus suf_execute(ShardTask &tasks) override; | |||||
| MSRStatus SufExecute(ShardTask &tasks) override; | |||||
| private: | private: | ||||
| bool shuffle_; | bool shuffle_; | ||||
| @@ -107,11 +107,11 @@ class ShardReader { | |||||
| /// \brief aim to get the meta data | /// \brief aim to get the meta data | ||||
| /// \return the metadata | /// \return the metadata | ||||
| std::shared_ptr<ShardHeader> get_shard_header() const; | |||||
| std::shared_ptr<ShardHeader> GetShardHeader() const; | |||||
| /// \brief get the number of shards | /// \brief get the number of shards | ||||
| /// \return # of shards | /// \return # of shards | ||||
| int get_shard_count() const; | |||||
| int GetShardCount() const; | |||||
| /// \brief get the number of rows in database | /// \brief get the number of rows in database | ||||
| /// \param[in] file_path the path of ONE file, any file in dataset is fine | /// \param[in] file_path the path of ONE file, any file in dataset is fine | ||||
| @@ -126,7 +126,7 @@ class ShardReader { | |||||
| /// \brief get the number of rows in database | /// \brief get the number of rows in database | ||||
| /// \return # of rows | /// \return # of rows | ||||
| int get_num_rows() const; | |||||
| int GetNumRows() const; | |||||
| /// \brief Read the summary of row groups | /// \brief Read the summary of row groups | ||||
| /// \return the tuple of 4 elements | /// \return the tuple of 4 elements | ||||
| @@ -185,7 +185,7 @@ class ShardReader { | |||||
| /// \brief get blob filed list | /// \brief get blob filed list | ||||
| /// \return blob field list | /// \return blob field list | ||||
| std::pair<ShardType, std::vector<std::string>> get_blob_fields(); | |||||
| std::pair<ShardType, std::vector<std::string>> GetBlobFields(); | |||||
| /// \brief reset reader | /// \brief reset reader | ||||
| /// \return null | /// \return null | ||||
| @@ -193,10 +193,10 @@ class ShardReader { | |||||
| /// \brief set flag of all-in-index | /// \brief set flag of all-in-index | ||||
| /// \return null | /// \return null | ||||
| void set_all_in_index(bool all_in_index) { all_in_index_ = all_in_index; } | |||||
| void SetAllInIndex(bool all_in_index) { all_in_index_ = all_in_index; } | |||||
| /// \brief get NLP flag | /// \brief get NLP flag | ||||
| bool get_nlp_flag(); | |||||
| bool GetNlpFlag(); | |||||
| /// \brief get all classes | /// \brief get all classes | ||||
| MSRStatus GetAllClasses(const std::string &category_field, std::set<std::string> &categories); | MSRStatus GetAllClasses(const std::string &category_field, std::set<std::string> &categories); | ||||
| @@ -38,11 +38,11 @@ class ShardSample : public ShardOperator { | |||||
| ~ShardSample() override{}; | ~ShardSample() override{}; | ||||
| const std::pair<int, int> get_partitions() const; | |||||
| const std::pair<int, int> GetPartitions() const; | |||||
| MSRStatus execute(ShardTask &tasks) override; | |||||
| MSRStatus Execute(ShardTask &tasks) override; | |||||
| MSRStatus suf_execute(ShardTask &tasks) override; | |||||
| MSRStatus SufExecute(ShardTask &tasks) override; | |||||
| int64_t GetNumSamples(int64_t dataset_size, int64_t num_classes) override; | int64_t GetNumSamples(int64_t dataset_size, int64_t num_classes) override; | ||||
| @@ -51,7 +51,7 @@ class Schema { | |||||
| /// \brief get the schema and its description | /// \brief get the schema and its description | ||||
| /// \return the json format of the schema and its description | /// \return the json format of the schema and its description | ||||
| std::string get_desc() const; | |||||
| std::string GetDesc() const; | |||||
| /// \brief get the schema and its description | /// \brief get the schema and its description | ||||
| /// \return the json format of the schema and its description | /// \return the json format of the schema and its description | ||||
| @@ -63,15 +63,15 @@ class Schema { | |||||
| /// set the schema id | /// set the schema id | ||||
| /// \param[in] id the id need to be set | /// \param[in] id the id need to be set | ||||
| void set_schema_id(int64_t id); | |||||
| void SetSchemaID(int64_t id); | |||||
| /// get the schema id | /// get the schema id | ||||
| /// \return the int64 schema id | /// \return the int64 schema id | ||||
| int64_t get_schema_id() const; | |||||
| int64_t GetSchemaID() const; | |||||
| /// get the blob fields | /// get the blob fields | ||||
| /// \return the vector<string> blob fields | /// \return the vector<string> blob fields | ||||
| std::vector<std::string> get_blob_fields() const; | |||||
| std::vector<std::string> GetBlobFields() const; | |||||
| private: | private: | ||||
| Schema() = default; | Schema() = default; | ||||
| @@ -81,7 +81,7 @@ class ShardSegment : public ShardReader { | |||||
| std::pair<MSRStatus, std::vector<std::tuple<std::vector<uint8_t>, pybind11::object>>> ReadAtPageByNamePy( | std::pair<MSRStatus, std::vector<std::tuple<std::vector<uint8_t>, pybind11::object>>> ReadAtPageByNamePy( | ||||
| std::string category_name, int64_t page_no, int64_t n_rows_of_page); | std::string category_name, int64_t page_no, int64_t n_rows_of_page); | ||||
| std::pair<ShardType, std::vector<std::string>> get_blob_fields(); | |||||
| std::pair<ShardType, std::vector<std::string>> GetBlobFields(); | |||||
| private: | private: | ||||
| std::pair<MSRStatus, std::vector<std::tuple<int, std::string, int>>> WrapCategoryInfo(); | std::pair<MSRStatus, std::vector<std::tuple<int, std::string, int>>> WrapCategoryInfo(); | ||||
| @@ -28,7 +28,7 @@ class ShardShuffle : public ShardOperator { | |||||
| ~ShardShuffle() override{}; | ~ShardShuffle() override{}; | ||||
| MSRStatus execute(ShardTask &tasks) override; | |||||
| MSRStatus Execute(ShardTask &tasks) override; | |||||
| private: | private: | ||||
| uint32_t shuffle_seed_; | uint32_t shuffle_seed_; | ||||
| @@ -53,11 +53,11 @@ class Statistics { | |||||
| /// \brief get the description | /// \brief get the description | ||||
| /// \return the description | /// \return the description | ||||
| std::string get_desc() const; | |||||
| std::string GetDesc() const; | |||||
| /// \brief get the statistic | /// \brief get the statistic | ||||
| /// \return json format of the statistic | /// \return json format of the statistic | ||||
| json get_statistics() const; | |||||
| json GetStatistics() const; | |||||
| /// \brief get the statistic for python | /// \brief get the statistic for python | ||||
| /// \return the python object of statistics | /// \return the python object of statistics | ||||
| @@ -66,11 +66,11 @@ class Statistics { | |||||
| /// \brief decode the bson statistics to json | /// \brief decode the bson statistics to json | ||||
| /// \param[in] encodedStatistics the bson type of statistics | /// \param[in] encodedStatistics the bson type of statistics | ||||
| /// \return json type of statistic | /// \return json type of statistic | ||||
| void set_statistics_id(int64_t id); | |||||
| void SetStatisticsID(int64_t id); | |||||
| /// \brief get the statistics id | /// \brief get the statistics id | ||||
| /// \return the int64 statistics id | /// \return the int64 statistics id | ||||
| int64_t get_statistics_id() const; | |||||
| int64_t GetStatisticsID() const; | |||||
| private: | private: | ||||
| /// \brief validate the statistic | /// \brief validate the statistic | ||||
| @@ -39,9 +39,9 @@ class ShardTask { | |||||
| uint32_t SizeOfRows() const; | uint32_t SizeOfRows() const; | ||||
| std::tuple<std::tuple<int, int>, std::vector<uint64_t>, json> &get_task_by_id(size_t id); | |||||
| std::tuple<std::tuple<int, int>, std::vector<uint64_t>, json> &GetTaskByID(size_t id); | |||||
| std::tuple<std::tuple<int, int>, std::vector<uint64_t>, json> &get_random_task(); | |||||
| std::tuple<std::tuple<int, int>, std::vector<uint64_t>, json> &GetRandomTask(); | |||||
| static ShardTask Combine(std::vector<ShardTask> &category_tasks, bool replacement, int64_t num_elements); | static ShardTask Combine(std::vector<ShardTask> &category_tasks, bool replacement, int64_t num_elements); | ||||
| @@ -69,12 +69,12 @@ class ShardWriter { | |||||
| /// \brief Set file size | /// \brief Set file size | ||||
| /// \param[in] header_size the size of header, only (1<<N) is accepted | /// \param[in] header_size the size of header, only (1<<N) is accepted | ||||
| /// \return MSRStatus the status of MSRStatus | /// \return MSRStatus the status of MSRStatus | ||||
| MSRStatus set_header_size(const uint64_t &header_size); | |||||
| MSRStatus SetHeaderSize(const uint64_t &header_size); | |||||
| /// \brief Set page size | /// \brief Set page size | ||||
| /// \param[in] page_size the size of page, only (1<<N) is accepted | /// \param[in] page_size the size of page, only (1<<N) is accepted | ||||
| /// \return MSRStatus the status of MSRStatus | /// \return MSRStatus the status of MSRStatus | ||||
| MSRStatus set_page_size(const uint64_t &page_size); | |||||
| MSRStatus SetPageSize(const uint64_t &page_size); | |||||
| /// \brief Set shard header | /// \brief Set shard header | ||||
| /// \param[in] header_data the info of header | /// \param[in] header_data the info of header | ||||
| @@ -64,7 +64,7 @@ std::pair<MSRStatus, std::string> ShardIndexGenerator::GetValueByField(const str | |||||
| } | } | ||||
| // schema does not contain the field | // schema does not contain the field | ||||
| auto schema = shard_header_.get_schemas()[0]->GetSchema()["schema"]; | |||||
| auto schema = shard_header_.GetSchemas()[0]->GetSchema()["schema"]; | |||||
| if (schema.find(field) == schema.end()) { | if (schema.find(field) == schema.end()) { | ||||
| MS_LOG(ERROR) << "The field " << field << " is not found in schema " << schema; | MS_LOG(ERROR) << "The field " << field << " is not found in schema " << schema; | ||||
| return {FAILED, ""}; | return {FAILED, ""}; | ||||
| @@ -203,7 +203,7 @@ MSRStatus ShardIndexGenerator::CreateShardNameTable(sqlite3 *db, const std::stri | |||||
| } | } | ||||
| std::pair<MSRStatus, sqlite3 *> ShardIndexGenerator::CreateDatabase(int shard_no) { | std::pair<MSRStatus, sqlite3 *> ShardIndexGenerator::CreateDatabase(int shard_no) { | ||||
| std::string shard_address = shard_header_.get_shard_address_by_id(shard_no); | |||||
| std::string shard_address = shard_header_.GetShardAddressByID(shard_no); | |||||
| if (shard_address.empty()) { | if (shard_address.empty()) { | ||||
| MS_LOG(ERROR) << "Shard address is null, shard no: " << shard_no; | MS_LOG(ERROR) << "Shard address is null, shard no: " << shard_no; | ||||
| return {FAILED, nullptr}; | return {FAILED, nullptr}; | ||||
| @@ -357,12 +357,12 @@ MSRStatus ShardIndexGenerator::BindParameterExecuteSQL( | |||||
| MSRStatus ShardIndexGenerator::AddBlobPageInfo(std::vector<std::tuple<std::string, std::string, std::string>> &row_data, | MSRStatus ShardIndexGenerator::AddBlobPageInfo(std::vector<std::tuple<std::string, std::string, std::string>> &row_data, | ||||
| const std::shared_ptr<Page> cur_blob_page, | const std::shared_ptr<Page> cur_blob_page, | ||||
| uint64_t &cur_blob_page_offset, std::fstream &in) { | uint64_t &cur_blob_page_offset, std::fstream &in) { | ||||
| row_data.emplace_back(":PAGE_ID_BLOB", "INTEGER", std::to_string(cur_blob_page->get_page_id())); | |||||
| row_data.emplace_back(":PAGE_ID_BLOB", "INTEGER", std::to_string(cur_blob_page->GetPageID())); | |||||
| // blob data start | // blob data start | ||||
| row_data.emplace_back(":PAGE_OFFSET_BLOB", "INTEGER", std::to_string(cur_blob_page_offset)); | row_data.emplace_back(":PAGE_OFFSET_BLOB", "INTEGER", std::to_string(cur_blob_page_offset)); | ||||
| auto &io_seekg_blob = | auto &io_seekg_blob = | ||||
| in.seekg(page_size_ * cur_blob_page->get_page_id() + header_size_ + cur_blob_page_offset, std::ios::beg); | |||||
| in.seekg(page_size_ * cur_blob_page->GetPageID() + header_size_ + cur_blob_page_offset, std::ios::beg); | |||||
| if (!io_seekg_blob.good() || io_seekg_blob.fail() || io_seekg_blob.bad()) { | if (!io_seekg_blob.good() || io_seekg_blob.fail() || io_seekg_blob.bad()) { | ||||
| MS_LOG(ERROR) << "File seekg failed"; | MS_LOG(ERROR) << "File seekg failed"; | ||||
| in.close(); | in.close(); | ||||
| @@ -405,7 +405,7 @@ ROW_DATA ShardIndexGenerator::GenerateRowData(int shard_no, const std::map<int, | |||||
| std::shared_ptr<Page> cur_raw_page = shard_header_.GetPage(shard_no, raw_page_id).first; | std::shared_ptr<Page> cur_raw_page = shard_header_.GetPage(shard_no, raw_page_id).first; | ||||
| // related blob page | // related blob page | ||||
| vector<pair<int, uint64_t>> row_group_list = cur_raw_page->get_row_group_ids(); | |||||
| vector<pair<int, uint64_t>> row_group_list = cur_raw_page->GetRowGroupIds(); | |||||
| // pair: row_group id, offset in raw data page | // pair: row_group id, offset in raw data page | ||||
| for (pair<int, int> blob_ids : row_group_list) { | for (pair<int, int> blob_ids : row_group_list) { | ||||
| @@ -415,18 +415,18 @@ ROW_DATA ShardIndexGenerator::GenerateRowData(int shard_no, const std::map<int, | |||||
| // offset in current raw data page | // offset in current raw data page | ||||
| auto cur_raw_page_offset = static_cast<uint64_t>(blob_ids.second); | auto cur_raw_page_offset = static_cast<uint64_t>(blob_ids.second); | ||||
| uint64_t cur_blob_page_offset = 0; | uint64_t cur_blob_page_offset = 0; | ||||
| for (unsigned int i = cur_blob_page->get_start_row_id(); i < cur_blob_page->get_end_row_id(); ++i) { | |||||
| for (unsigned int i = cur_blob_page->GetStartRowID(); i < cur_blob_page->GetEndRowID(); ++i) { | |||||
| std::vector<std::tuple<std::string, std::string, std::string>> row_data; | std::vector<std::tuple<std::string, std::string, std::string>> row_data; | ||||
| row_data.emplace_back(":ROW_ID", "INTEGER", std::to_string(i)); | row_data.emplace_back(":ROW_ID", "INTEGER", std::to_string(i)); | ||||
| row_data.emplace_back(":ROW_GROUP_ID", "INTEGER", std::to_string(cur_blob_page->get_page_type_id())); | |||||
| row_data.emplace_back(":PAGE_ID_RAW", "INTEGER", std::to_string(cur_raw_page->get_page_id())); | |||||
| row_data.emplace_back(":ROW_GROUP_ID", "INTEGER", std::to_string(cur_blob_page->GetPageTypeID())); | |||||
| row_data.emplace_back(":PAGE_ID_RAW", "INTEGER", std::to_string(cur_raw_page->GetPageID())); | |||||
| // raw data start | // raw data start | ||||
| row_data.emplace_back(":PAGE_OFFSET_RAW", "INTEGER", std::to_string(cur_raw_page_offset)); | row_data.emplace_back(":PAGE_OFFSET_RAW", "INTEGER", std::to_string(cur_raw_page_offset)); | ||||
| // calculate raw data end | // calculate raw data end | ||||
| auto &io_seekg = | auto &io_seekg = | ||||
| in.seekg(page_size_ * (cur_raw_page->get_page_id()) + header_size_ + cur_raw_page_offset, std::ios::beg); | |||||
| in.seekg(page_size_ * (cur_raw_page->GetPageID()) + header_size_ + cur_raw_page_offset, std::ios::beg); | |||||
| if (!io_seekg.good() || io_seekg.fail() || io_seekg.bad()) { | if (!io_seekg.good() || io_seekg.fail() || io_seekg.bad()) { | ||||
| MS_LOG(ERROR) << "File seekg failed"; | MS_LOG(ERROR) << "File seekg failed"; | ||||
| in.close(); | in.close(); | ||||
| @@ -473,7 +473,7 @@ ROW_DATA ShardIndexGenerator::GenerateRowData(int shard_no, const std::map<int, | |||||
| INDEX_FIELDS ShardIndexGenerator::GenerateIndexFields(const std::vector<json> &schema_detail) { | INDEX_FIELDS ShardIndexGenerator::GenerateIndexFields(const std::vector<json> &schema_detail) { | ||||
| std::vector<std::tuple<std::string, std::string, std::string>> fields; | std::vector<std::tuple<std::string, std::string, std::string>> fields; | ||||
| // index fields | // index fields | ||||
| std::vector<std::pair<uint64_t, std::string>> index_fields = shard_header_.get_fields(); | |||||
| std::vector<std::pair<uint64_t, std::string>> index_fields = shard_header_.GetFields(); | |||||
| for (const auto &field : index_fields) { | for (const auto &field : index_fields) { | ||||
| if (field.first >= schema_detail.size()) { | if (field.first >= schema_detail.size()) { | ||||
| return {FAILED, {}}; | return {FAILED, {}}; | ||||
| @@ -504,7 +504,7 @@ MSRStatus ShardIndexGenerator::ExecuteTransaction(const int &shard_no, const std | |||||
| const std::vector<int> &raw_page_ids, | const std::vector<int> &raw_page_ids, | ||||
| const std::map<int, int> &blob_id_to_page_id) { | const std::map<int, int> &blob_id_to_page_id) { | ||||
| // Add index data to database | // Add index data to database | ||||
| std::string shard_address = shard_header_.get_shard_address_by_id(shard_no); | |||||
| std::string shard_address = shard_header_.GetShardAddressByID(shard_no); | |||||
| if (shard_address.empty()) { | if (shard_address.empty()) { | ||||
| MS_LOG(ERROR) << "Shard address is null"; | MS_LOG(ERROR) << "Shard address is null"; | ||||
| return FAILED; | return FAILED; | ||||
| @@ -546,12 +546,12 @@ MSRStatus ShardIndexGenerator::ExecuteTransaction(const int &shard_no, const std | |||||
| } | } | ||||
| MSRStatus ShardIndexGenerator::WriteToDatabase() { | MSRStatus ShardIndexGenerator::WriteToDatabase() { | ||||
| fields_ = shard_header_.get_fields(); | |||||
| page_size_ = shard_header_.get_page_size(); | |||||
| header_size_ = shard_header_.get_header_size(); | |||||
| schema_count_ = shard_header_.get_schema_count(); | |||||
| if (shard_header_.get_shard_count() > kMaxShardCount) { | |||||
| MS_LOG(ERROR) << "num shards: " << shard_header_.get_shard_count() << " exceeds max count:" << kMaxSchemaCount; | |||||
| fields_ = shard_header_.GetFields(); | |||||
| page_size_ = shard_header_.GetPageSize(); | |||||
| header_size_ = shard_header_.GetHeaderSize(); | |||||
| schema_count_ = shard_header_.GetSchemaCount(); | |||||
| if (shard_header_.GetShardCount() > kMaxShardCount) { | |||||
| MS_LOG(ERROR) << "num shards: " << shard_header_.GetShardCount() << " exceeds max count:" << kMaxSchemaCount; | |||||
| return FAILED; | return FAILED; | ||||
| } | } | ||||
| task_ = 0; // set two atomic vars to initial value | task_ = 0; // set two atomic vars to initial value | ||||
| @@ -559,7 +559,7 @@ MSRStatus ShardIndexGenerator::WriteToDatabase() { | |||||
| // spawn half the physical threads or total number of shards whichever is smaller | // spawn half the physical threads or total number of shards whichever is smaller | ||||
| const unsigned int num_workers = | const unsigned int num_workers = | ||||
| std::min(std::thread::hardware_concurrency() / 2 + 1, static_cast<unsigned int>(shard_header_.get_shard_count())); | |||||
| std::min(std::thread::hardware_concurrency() / 2 + 1, static_cast<unsigned int>(shard_header_.GetShardCount())); | |||||
| std::vector<std::thread> threads; | std::vector<std::thread> threads; | ||||
| threads.reserve(num_workers); | threads.reserve(num_workers); | ||||
| @@ -576,7 +576,7 @@ MSRStatus ShardIndexGenerator::WriteToDatabase() { | |||||
| void ShardIndexGenerator::DatabaseWriter() { | void ShardIndexGenerator::DatabaseWriter() { | ||||
| int shard_no = task_++; | int shard_no = task_++; | ||||
| while (shard_no < shard_header_.get_shard_count()) { | |||||
| while (shard_no < shard_header_.GetShardCount()) { | |||||
| auto db = CreateDatabase(shard_no); | auto db = CreateDatabase(shard_no); | ||||
| if (db.first != SUCCESS || db.second == nullptr || write_success_ == false) { | if (db.first != SUCCESS || db.second == nullptr || write_success_ == false) { | ||||
| write_success_ = false; | write_success_ = false; | ||||
| @@ -592,10 +592,10 @@ void ShardIndexGenerator::DatabaseWriter() { | |||||
| std::vector<int> raw_page_ids; | std::vector<int> raw_page_ids; | ||||
| for (uint64_t i = 0; i < total_pages; ++i) { | for (uint64_t i = 0; i < total_pages; ++i) { | ||||
| std::shared_ptr<Page> cur_page = shard_header_.GetPage(shard_no, i).first; | std::shared_ptr<Page> cur_page = shard_header_.GetPage(shard_no, i).first; | ||||
| if (cur_page->get_page_type() == "RAW_DATA") { | |||||
| if (cur_page->GetPageType() == "RAW_DATA") { | |||||
| raw_page_ids.push_back(i); | raw_page_ids.push_back(i); | ||||
| } else if (cur_page->get_page_type() == "BLOB_DATA") { | |||||
| blob_id_to_page_id[cur_page->get_page_type_id()] = i; | |||||
| } else if (cur_page->GetPageType() == "BLOB_DATA") { | |||||
| blob_id_to_page_id[cur_page->GetPageTypeID()] = i; | |||||
| } | } | ||||
| } | } | ||||
| @@ -56,9 +56,9 @@ MSRStatus ShardReader::Init(const std::string &file_path) { | |||||
| return FAILED; | return FAILED; | ||||
| } | } | ||||
| shard_header_ = std::make_shared<ShardHeader>(sh); | shard_header_ = std::make_shared<ShardHeader>(sh); | ||||
| header_size_ = shard_header_->get_header_size(); | |||||
| page_size_ = shard_header_->get_page_size(); | |||||
| file_paths_ = shard_header_->get_shard_addresses(); | |||||
| header_size_ = shard_header_->GetHeaderSize(); | |||||
| page_size_ = shard_header_->GetPageSize(); | |||||
| file_paths_ = shard_header_->GetShardAddresses(); | |||||
| for (const auto &file : file_paths_) { | for (const auto &file : file_paths_) { | ||||
| sqlite3 *db = nullptr; | sqlite3 *db = nullptr; | ||||
| @@ -105,7 +105,7 @@ MSRStatus ShardReader::Init(const std::string &file_path) { | |||||
| MSRStatus ShardReader::CheckColumnList(const std::vector<std::string> &selected_columns) { | MSRStatus ShardReader::CheckColumnList(const std::vector<std::string> &selected_columns) { | ||||
| vector<int> inSchema(selected_columns.size(), 0); | vector<int> inSchema(selected_columns.size(), 0); | ||||
| for (auto &p : get_shard_header()->get_schemas()) { | |||||
| for (auto &p : GetShardHeader()->GetSchemas()) { | |||||
| auto schema = p->GetSchema()["schema"]; | auto schema = p->GetSchema()["schema"]; | ||||
| for (unsigned int i = 0; i < selected_columns.size(); ++i) { | for (unsigned int i = 0; i < selected_columns.size(); ++i) { | ||||
| if (schema.find(selected_columns[i]) != schema.end()) { | if (schema.find(selected_columns[i]) != schema.end()) { | ||||
| @@ -183,15 +183,15 @@ void ShardReader::Close() { | |||||
| FileStreamsOperator(); | FileStreamsOperator(); | ||||
| } | } | ||||
| std::shared_ptr<ShardHeader> ShardReader::get_shard_header() const { return shard_header_; } | |||||
| std::shared_ptr<ShardHeader> ShardReader::GetShardHeader() const { return shard_header_; } | |||||
| int ShardReader::get_shard_count() const { return shard_header_->get_shard_count(); } | |||||
| int ShardReader::GetShardCount() const { return shard_header_->GetShardCount(); } | |||||
| int ShardReader::get_num_rows() const { return num_rows_; } | |||||
| int ShardReader::GetNumRows() const { return num_rows_; } | |||||
| std::vector<std::tuple<int, int, int, uint64_t>> ShardReader::ReadRowGroupSummary() { | std::vector<std::tuple<int, int, int, uint64_t>> ShardReader::ReadRowGroupSummary() { | ||||
| std::vector<std::tuple<int, int, int, uint64_t>> row_group_summary; | std::vector<std::tuple<int, int, int, uint64_t>> row_group_summary; | ||||
| int shard_count = shard_header_->get_shard_count(); | |||||
| int shard_count = shard_header_->GetShardCount(); | |||||
| if (shard_count <= 0) { | if (shard_count <= 0) { | ||||
| return row_group_summary; | return row_group_summary; | ||||
| } | } | ||||
| @@ -205,13 +205,13 @@ std::vector<std::tuple<int, int, int, uint64_t>> ShardReader::ReadRowGroupSummar | |||||
| for (uint64_t page_id = 0; page_id <= last_page_id; ++page_id) { | for (uint64_t page_id = 0; page_id <= last_page_id; ++page_id) { | ||||
| const auto &page_t = shard_header_->GetPage(shard_id, page_id); | const auto &page_t = shard_header_->GetPage(shard_id, page_id); | ||||
| const auto &page = page_t.first; | const auto &page = page_t.first; | ||||
| if (page->get_page_type() != kPageTypeBlob) continue; | |||||
| uint64_t start_row_id = page->get_start_row_id(); | |||||
| if (start_row_id > page->get_end_row_id()) { | |||||
| if (page->GetPageType() != kPageTypeBlob) continue; | |||||
| uint64_t start_row_id = page->GetStartRowID(); | |||||
| if (start_row_id > page->GetEndRowID()) { | |||||
| return std::vector<std::tuple<int, int, int, uint64_t>>(); | return std::vector<std::tuple<int, int, int, uint64_t>>(); | ||||
| } | } | ||||
| uint64_t number_of_rows = page->get_end_row_id() - start_row_id; | |||||
| row_group_summary.emplace_back(shard_id, page->get_page_type_id(), start_row_id, number_of_rows); | |||||
| uint64_t number_of_rows = page->GetEndRowID() - start_row_id; | |||||
| row_group_summary.emplace_back(shard_id, page->GetPageTypeID(), start_row_id, number_of_rows); | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| @@ -265,7 +265,7 @@ MSRStatus ShardReader::ConvertLabelToJson(const std::vector<std::vector<std::str | |||||
| json construct_json; | json construct_json; | ||||
| for (unsigned int j = 0; j < columns.size(); ++j) { | for (unsigned int j = 0; j < columns.size(); ++j) { | ||||
| // construct json "f1": value | // construct json "f1": value | ||||
| auto schema = shard_header_->get_schemas()[0]->GetSchema()["schema"]; | |||||
| auto schema = shard_header_->GetSchemas()[0]->GetSchema()["schema"]; | |||||
| // convert the string to base type by schema | // convert the string to base type by schema | ||||
| if (schema[columns[j]]["type"] == "int32") { | if (schema[columns[j]]["type"] == "int32") { | ||||
| @@ -317,7 +317,7 @@ MSRStatus ShardReader::ReadAllRowsInShard(int shard_id, const std::string &sql, | |||||
| MSRStatus ShardReader::GetAllClasses(const std::string &category_field, std::set<std::string> &categories) { | MSRStatus ShardReader::GetAllClasses(const std::string &category_field, std::set<std::string> &categories) { | ||||
| std::map<std::string, uint64_t> index_columns; | std::map<std::string, uint64_t> index_columns; | ||||
| for (auto &field : get_shard_header()->get_fields()) { | |||||
| for (auto &field : GetShardHeader()->GetFields()) { | |||||
| index_columns[field.second] = field.first; | index_columns[field.second] = field.first; | ||||
| } | } | ||||
| if (index_columns.find(category_field) == index_columns.end()) { | if (index_columns.find(category_field) == index_columns.end()) { | ||||
| @@ -400,11 +400,11 @@ ROW_GROUP_BRIEF ShardReader::ReadRowGroupBrief(int group_id, int shard_id, const | |||||
| } | } | ||||
| const std::shared_ptr<Page> &page = ret.second; | const std::shared_ptr<Page> &page = ret.second; | ||||
| std::string file_name = file_paths_[shard_id]; | std::string file_name = file_paths_[shard_id]; | ||||
| uint64_t page_length = page->get_page_size(); | |||||
| uint64_t page_offset = page_size_ * page->get_page_id() + header_size_; | |||||
| std::vector<std::vector<uint64_t>> image_offset = GetImageOffset(page->get_page_id(), shard_id); | |||||
| uint64_t page_length = page->GetPageSize(); | |||||
| uint64_t page_offset = page_size_ * page->GetPageID() + header_size_; | |||||
| std::vector<std::vector<uint64_t>> image_offset = GetImageOffset(page->GetPageID(), shard_id); | |||||
| auto status_labels = GetLabels(page->get_page_id(), shard_id, columns); | |||||
| auto status_labels = GetLabels(page->GetPageID(), shard_id, columns); | |||||
| if (status_labels.first != SUCCESS) { | if (status_labels.first != SUCCESS) { | ||||
| return std::make_tuple(FAILED, "", 0, 0, std::vector<std::vector<uint64_t>>(), std::vector<json>()); | return std::make_tuple(FAILED, "", 0, 0, std::vector<std::vector<uint64_t>>(), std::vector<json>()); | ||||
| } | } | ||||
| @@ -426,11 +426,11 @@ ROW_GROUP_BRIEF ShardReader::ReadRowGroupCriteria(int group_id, int shard_id, | |||||
| } | } | ||||
| const std::shared_ptr<Page> &page = ret.second; | const std::shared_ptr<Page> &page = ret.second; | ||||
| std::string file_name = file_paths_[shard_id]; | std::string file_name = file_paths_[shard_id]; | ||||
| uint64_t page_length = page->get_page_size(); | |||||
| uint64_t page_offset = page_size_ * page->get_page_id() + header_size_; | |||||
| std::vector<std::vector<uint64_t>> image_offset = GetImageOffset(page->get_page_id(), shard_id, criteria); | |||||
| uint64_t page_length = page->GetPageSize(); | |||||
| uint64_t page_offset = page_size_ * page->GetPageID() + header_size_; | |||||
| std::vector<std::vector<uint64_t>> image_offset = GetImageOffset(page->GetPageID(), shard_id, criteria); | |||||
| auto status_labels = GetLabels(page->get_page_id(), shard_id, columns, criteria); | |||||
| auto status_labels = GetLabels(page->GetPageID(), shard_id, columns, criteria); | |||||
| if (status_labels.first != SUCCESS) { | if (status_labels.first != SUCCESS) { | ||||
| return std::make_tuple(FAILED, "", 0, 0, std::vector<std::vector<uint64_t>>(), std::vector<json>()); | return std::make_tuple(FAILED, "", 0, 0, std::vector<std::vector<uint64_t>>(), std::vector<json>()); | ||||
| } | } | ||||
| @@ -458,7 +458,7 @@ std::vector<std::vector<uint64_t>> ShardReader::GetImageOffset(int page_id, int | |||||
| // whether use index search | // whether use index search | ||||
| if (!criteria.first.empty()) { | if (!criteria.first.empty()) { | ||||
| auto schema = shard_header_->get_schemas()[0]->GetSchema(); | |||||
| auto schema = shard_header_->GetSchemas()[0]->GetSchema(); | |||||
| // not number field should add '' in sql | // not number field should add '' in sql | ||||
| if (kNumberFieldTypeSet.find(schema["schema"][criteria.first]["type"]) != kNumberFieldTypeSet.end()) { | if (kNumberFieldTypeSet.find(schema["schema"][criteria.first]["type"]) != kNumberFieldTypeSet.end()) { | ||||
| @@ -497,13 +497,13 @@ void ShardReader::CheckNlp() { | |||||
| return; | return; | ||||
| } | } | ||||
| bool ShardReader::get_nlp_flag() { return nlp_; } | |||||
| bool ShardReader::GetNlpFlag() { return nlp_; } | |||||
| std::pair<ShardType, std::vector<std::string>> ShardReader::get_blob_fields() { | |||||
| std::pair<ShardType, std::vector<std::string>> ShardReader::GetBlobFields() { | |||||
| std::vector<std::string> blob_fields; | std::vector<std::string> blob_fields; | ||||
| for (auto &p : get_shard_header()->get_schemas()) { | |||||
| for (auto &p : GetShardHeader()->GetSchemas()) { | |||||
| // assume one schema | // assume one schema | ||||
| const auto &fields = p->get_blob_fields(); | |||||
| const auto &fields = p->GetBlobFields(); | |||||
| blob_fields.assign(fields.begin(), fields.end()); | blob_fields.assign(fields.begin(), fields.end()); | ||||
| break; | break; | ||||
| } | } | ||||
| @@ -516,7 +516,7 @@ void ShardReader::CheckIfColumnInIndex(const std::vector<std::string> &columns) | |||||
| all_in_index_ = false; | all_in_index_ = false; | ||||
| return; | return; | ||||
| } | } | ||||
| for (auto &field : get_shard_header()->get_fields()) { | |||||
| for (auto &field : GetShardHeader()->GetFields()) { | |||||
| column_schema_id_[field.second] = field.first; | column_schema_id_[field.second] = field.first; | ||||
| } | } | ||||
| for (auto &col : columns) { | for (auto &col : columns) { | ||||
| @@ -671,7 +671,7 @@ std::pair<MSRStatus, std::vector<json>> ShardReader::GetLabels(int page_id, int | |||||
| json construct_json; | json construct_json; | ||||
| for (unsigned int j = 0; j < columns.size(); ++j) { | for (unsigned int j = 0; j < columns.size(); ++j) { | ||||
| // construct json "f1": value | // construct json "f1": value | ||||
| auto schema = shard_header_->get_schemas()[0]->GetSchema()["schema"]; | |||||
| auto schema = shard_header_->GetSchemas()[0]->GetSchema()["schema"]; | |||||
| // convert the string to base type by schema | // convert the string to base type by schema | ||||
| if (schema[columns[j]]["type"] == "int32") { | if (schema[columns[j]]["type"] == "int32") { | ||||
| @@ -719,9 +719,9 @@ int64_t ShardReader::GetNumClasses(const std::string &file_path, const std::stri | |||||
| return -1; | return -1; | ||||
| } | } | ||||
| auto header = std::make_shared<ShardHeader>(sh); | auto header = std::make_shared<ShardHeader>(sh); | ||||
| auto file_paths = header->get_shard_addresses(); | |||||
| auto file_paths = header->GetShardAddresses(); | |||||
| auto shard_count = file_paths.size(); | auto shard_count = file_paths.size(); | ||||
| auto index_fields = header->get_fields(); | |||||
| auto index_fields = header->GetFields(); | |||||
| std::map<std::string, int64_t> map_schema_id_fields; | std::map<std::string, int64_t> map_schema_id_fields; | ||||
| for (auto &field : index_fields) { | for (auto &field : index_fields) { | ||||
| @@ -799,7 +799,7 @@ MSRStatus ShardReader::Open(const std::string &file_path, int n_consumer, | |||||
| if (nlp_) { | if (nlp_) { | ||||
| selected_columns_ = selected_columns; | selected_columns_ = selected_columns; | ||||
| } else { | } else { | ||||
| vector<std::string> blob_fields = get_blob_fields().second; | |||||
| vector<std::string> blob_fields = GetBlobFields().second; | |||||
| for (unsigned int i = 0; i < selected_columns.size(); ++i) { | for (unsigned int i = 0; i < selected_columns.size(); ++i) { | ||||
| if (!std::any_of(blob_fields.begin(), blob_fields.end(), | if (!std::any_of(blob_fields.begin(), blob_fields.end(), | ||||
| [&selected_columns, i](std::string item) { return selected_columns[i] == item; })) { | [&selected_columns, i](std::string item) { return selected_columns[i] == item; })) { | ||||
| @@ -846,7 +846,7 @@ MSRStatus ShardReader::OpenPy(const std::string &file_path, const int &n_consume | |||||
| } | } | ||||
| // should remove blob field from selected_columns when call from python | // should remove blob field from selected_columns when call from python | ||||
| std::vector<std::string> columns(selected_columns); | std::vector<std::string> columns(selected_columns); | ||||
| auto blob_fields = get_blob_fields().second; | |||||
| auto blob_fields = GetBlobFields().second; | |||||
| for (auto &blob_field : blob_fields) { | for (auto &blob_field : blob_fields) { | ||||
| auto it = std::find(selected_columns.begin(), selected_columns.end(), blob_field); | auto it = std::find(selected_columns.begin(), selected_columns.end(), blob_field); | ||||
| if (it != selected_columns.end()) { | if (it != selected_columns.end()) { | ||||
| @@ -909,7 +909,7 @@ vector<std::string> ShardReader::GetAllColumns() { | |||||
| vector<std::string> columns; | vector<std::string> columns; | ||||
| if (nlp_) { | if (nlp_) { | ||||
| for (auto &c : selected_columns_) { | for (auto &c : selected_columns_) { | ||||
| for (auto &p : get_shard_header()->get_schemas()) { | |||||
| for (auto &p : GetShardHeader()->GetSchemas()) { | |||||
| auto schema = p->GetSchema()["schema"]; // make sure schema is not reference since error occurred in arm. | auto schema = p->GetSchema()["schema"]; // make sure schema is not reference since error occurred in arm. | ||||
| for (auto it = schema.begin(); it != schema.end(); ++it) { | for (auto it = schema.begin(); it != schema.end(); ++it) { | ||||
| if (it.key() == c) { | if (it.key() == c) { | ||||
| @@ -943,7 +943,7 @@ MSRStatus ShardReader::CreateTasksByCategory(const std::vector<std::tuple<int, i | |||||
| CheckIfColumnInIndex(columns); | CheckIfColumnInIndex(columns); | ||||
| auto category_op = std::dynamic_pointer_cast<ShardCategory>(op); | auto category_op = std::dynamic_pointer_cast<ShardCategory>(op); | ||||
| auto categories = category_op->get_categories(); | |||||
| auto categories = category_op->GetCategories(); | |||||
| int64_t num_elements = category_op->GetNumElements(); | int64_t num_elements = category_op->GetNumElements(); | ||||
| if (num_elements <= 0) { | if (num_elements <= 0) { | ||||
| MS_LOG(ERROR) << "Parameter num_element is not positive"; | MS_LOG(ERROR) << "Parameter num_element is not positive"; | ||||
| @@ -1104,7 +1104,7 @@ TASK_RETURN_CONTENT ShardReader::ConsumerOneTask(int task_id, uint32_t consumer_ | |||||
| } | } | ||||
| // Pick up task from task list | // Pick up task from task list | ||||
| auto task = tasks_.get_task_by_id(tasks_.permutation_[task_id]); | |||||
| auto task = tasks_.GetTaskByID(tasks_.permutation_[task_id]); | |||||
| auto shard_id = std::get<0>(std::get<0>(task)); | auto shard_id = std::get<0>(std::get<0>(task)); | ||||
| auto group_id = std::get<1>(std::get<0>(task)); | auto group_id = std::get<1>(std::get<0>(task)); | ||||
| @@ -1117,7 +1117,7 @@ TASK_RETURN_CONTENT ShardReader::ConsumerOneTask(int task_id, uint32_t consumer_ | |||||
| // Pack image list | // Pack image list | ||||
| std::vector<uint8_t> images(addr[1] - addr[0]); | std::vector<uint8_t> images(addr[1] - addr[0]); | ||||
| auto file_offset = header_size_ + page_size_ * (page->get_page_id()) + addr[0]; | |||||
| auto file_offset = header_size_ + page_size_ * (page->GetPageID()) + addr[0]; | |||||
| auto &io_seekg = file_streams_random_[consumer_id][shard_id]->seekg(file_offset, std::ios::beg); | auto &io_seekg = file_streams_random_[consumer_id][shard_id]->seekg(file_offset, std::ios::beg); | ||||
| if (!io_seekg.good() || io_seekg.fail() || io_seekg.bad()) { | if (!io_seekg.good() || io_seekg.fail() || io_seekg.bad()) { | ||||
| @@ -1139,7 +1139,7 @@ TASK_RETURN_CONTENT ShardReader::ConsumerOneTask(int task_id, uint32_t consumer_ | |||||
| if (selected_columns_.size() == 0) { | if (selected_columns_.size() == 0) { | ||||
| images_with_exact_columns = images; | images_with_exact_columns = images; | ||||
| } else { | } else { | ||||
| auto blob_fields = get_blob_fields(); | |||||
| auto blob_fields = GetBlobFields(); | |||||
| std::vector<uint32_t> ordered_selected_columns_index; | std::vector<uint32_t> ordered_selected_columns_index; | ||||
| uint32_t index = 0; | uint32_t index = 0; | ||||
| @@ -1272,7 +1272,7 @@ MSRStatus ShardReader::ConsumerByBlock(int consumer_id) { | |||||
| } | } | ||||
| // Pick up task from task list | // Pick up task from task list | ||||
| auto task = tasks_.get_task_by_id(tasks_.permutation_[task_id]); | |||||
| auto task = tasks_.GetTaskByID(tasks_.permutation_[task_id]); | |||||
| auto shard_id = std::get<0>(std::get<0>(task)); | auto shard_id = std::get<0>(std::get<0>(task)); | ||||
| auto group_id = std::get<1>(std::get<0>(task)); | auto group_id = std::get<1>(std::get<0>(task)); | ||||
| @@ -28,7 +28,7 @@ using mindspore::MsLogLevel::INFO; | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace mindrecord { | namespace mindrecord { | ||||
| ShardSegment::ShardSegment() { set_all_in_index(false); } | |||||
| ShardSegment::ShardSegment() { SetAllInIndex(false); } | |||||
| std::pair<MSRStatus, vector<std::string>> ShardSegment::GetCategoryFields() { | std::pair<MSRStatus, vector<std::string>> ShardSegment::GetCategoryFields() { | ||||
| // Skip if already populated | // Skip if already populated | ||||
| @@ -211,7 +211,7 @@ std::pair<MSRStatus, std::vector<uint8_t>> ShardSegment::PackImages(int group_id | |||||
| // Pack image list | // Pack image list | ||||
| std::vector<uint8_t> images(offset[1] - offset[0]); | std::vector<uint8_t> images(offset[1] - offset[0]); | ||||
| auto file_offset = header_size_ + page_size_ * (blob_page->get_page_id()) + offset[0]; | |||||
| auto file_offset = header_size_ + page_size_ * (blob_page->GetPageID()) + offset[0]; | |||||
| auto &io_seekg = file_streams_random_[0][shard_id]->seekg(file_offset, std::ios::beg); | auto &io_seekg = file_streams_random_[0][shard_id]->seekg(file_offset, std::ios::beg); | ||||
| if (!io_seekg.good() || io_seekg.fail() || io_seekg.bad()) { | if (!io_seekg.good() || io_seekg.fail() || io_seekg.bad()) { | ||||
| MS_LOG(ERROR) << "File seekg failed"; | MS_LOG(ERROR) << "File seekg failed"; | ||||
| @@ -363,21 +363,21 @@ std::pair<MSRStatus, std::vector<std::tuple<std::vector<uint8_t>, pybind11::obje | |||||
| return {SUCCESS, std::move(json_data)}; | return {SUCCESS, std::move(json_data)}; | ||||
| } | } | ||||
| std::pair<ShardType, std::vector<std::string>> ShardSegment::get_blob_fields() { | |||||
| std::pair<ShardType, std::vector<std::string>> ShardSegment::GetBlobFields() { | |||||
| std::vector<std::string> blob_fields; | std::vector<std::string> blob_fields; | ||||
| for (auto &p : get_shard_header()->get_schemas()) { | |||||
| for (auto &p : GetShardHeader()->GetSchemas()) { | |||||
| // assume one schema | // assume one schema | ||||
| const auto &fields = p->get_blob_fields(); | |||||
| const auto &fields = p->GetBlobFields(); | |||||
| blob_fields.assign(fields.begin(), fields.end()); | blob_fields.assign(fields.begin(), fields.end()); | ||||
| break; | break; | ||||
| } | } | ||||
| return std::make_pair(get_nlp_flag() ? kNLP : kCV, blob_fields); | |||||
| return std::make_pair(GetNlpFlag() ? kNLP : kCV, blob_fields); | |||||
| } | } | ||||
| std::tuple<std::vector<uint8_t>, json> ShardSegment::GetImageLabel(std::vector<uint8_t> images, json label) { | std::tuple<std::vector<uint8_t>, json> ShardSegment::GetImageLabel(std::vector<uint8_t> images, json label) { | ||||
| if (get_nlp_flag()) { | |||||
| if (GetNlpFlag()) { | |||||
| vector<std::string> columns; | vector<std::string> columns; | ||||
| for (auto &p : get_shard_header()->get_schemas()) { | |||||
| for (auto &p : GetShardHeader()->GetSchemas()) { | |||||
| auto schema = p->GetSchema()["schema"]; // make sure schema is not reference since error occurred in arm. | auto schema = p->GetSchema()["schema"]; // make sure schema is not reference since error occurred in arm. | ||||
| auto schema_items = schema.items(); | auto schema_items = schema.items(); | ||||
| using it_type = decltype(schema_items.begin()); | using it_type = decltype(schema_items.begin()); | ||||
| @@ -179,12 +179,12 @@ MSRStatus ShardWriter::OpenForAppend(const std::string &path) { | |||||
| return FAILED; | return FAILED; | ||||
| } | } | ||||
| shard_header_ = std::make_shared<ShardHeader>(sh); | shard_header_ = std::make_shared<ShardHeader>(sh); | ||||
| auto paths = shard_header_->get_shard_addresses(); | |||||
| MSRStatus ret = set_header_size(shard_header_->get_header_size()); | |||||
| auto paths = shard_header_->GetShardAddresses(); | |||||
| MSRStatus ret = SetHeaderSize(shard_header_->GetHeaderSize()); | |||||
| if (ret == FAILED) { | if (ret == FAILED) { | ||||
| return FAILED; | return FAILED; | ||||
| } | } | ||||
| ret = set_page_size(shard_header_->get_page_size()); | |||||
| ret = SetPageSize(shard_header_->GetPageSize()); | |||||
| if (ret == FAILED) { | if (ret == FAILED) { | ||||
| return FAILED; | return FAILED; | ||||
| } | } | ||||
| @@ -229,10 +229,10 @@ MSRStatus ShardWriter::SetShardHeader(std::shared_ptr<ShardHeader> header_data) | |||||
| } | } | ||||
| // set fields in mindrecord when empty | // set fields in mindrecord when empty | ||||
| std::vector<std::pair<uint64_t, std::string>> fields = header_data->get_fields(); | |||||
| std::vector<std::pair<uint64_t, std::string>> fields = header_data->GetFields(); | |||||
| if (fields.empty()) { | if (fields.empty()) { | ||||
| MS_LOG(DEBUG) << "Missing index fields by user, auto generate index fields."; | MS_LOG(DEBUG) << "Missing index fields by user, auto generate index fields."; | ||||
| std::vector<std::shared_ptr<Schema>> schemas = header_data->get_schemas(); | |||||
| std::vector<std::shared_ptr<Schema>> schemas = header_data->GetSchemas(); | |||||
| for (const auto &schema : schemas) { | for (const auto &schema : schemas) { | ||||
| json jsonSchema = schema->GetSchema()["schema"]; | json jsonSchema = schema->GetSchema()["schema"]; | ||||
| for (const auto &el : jsonSchema.items()) { | for (const auto &el : jsonSchema.items()) { | ||||
| @@ -241,7 +241,7 @@ MSRStatus ShardWriter::SetShardHeader(std::shared_ptr<ShardHeader> header_data) | |||||
| (el.value()["type"] == "int64" && el.value().find("shape") == el.value().end()) || | (el.value()["type"] == "int64" && el.value().find("shape") == el.value().end()) || | ||||
| (el.value()["type"] == "float32" && el.value().find("shape") == el.value().end()) || | (el.value()["type"] == "float32" && el.value().find("shape") == el.value().end()) || | ||||
| (el.value()["type"] == "float64" && el.value().find("shape") == el.value().end())) { | (el.value()["type"] == "float64" && el.value().find("shape") == el.value().end())) { | ||||
| fields.emplace_back(std::make_pair(schema->get_schema_id(), el.key())); | |||||
| fields.emplace_back(std::make_pair(schema->GetSchemaID(), el.key())); | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| @@ -256,12 +256,12 @@ MSRStatus ShardWriter::SetShardHeader(std::shared_ptr<ShardHeader> header_data) | |||||
| } | } | ||||
| shard_header_ = header_data; | shard_header_ = header_data; | ||||
| shard_header_->set_header_size(header_size_); | |||||
| shard_header_->set_page_size(page_size_); | |||||
| shard_header_->SetHeaderSize(header_size_); | |||||
| shard_header_->SetPageSize(page_size_); | |||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| MSRStatus ShardWriter::set_header_size(const uint64_t &header_size) { | |||||
| MSRStatus ShardWriter::SetHeaderSize(const uint64_t &header_size) { | |||||
| // header_size [16KB, 128MB] | // header_size [16KB, 128MB] | ||||
| if (header_size < kMinHeaderSize || header_size > kMaxHeaderSize) { | if (header_size < kMinHeaderSize || header_size > kMaxHeaderSize) { | ||||
| MS_LOG(ERROR) << "Header size should between 16KB and 128MB."; | MS_LOG(ERROR) << "Header size should between 16KB and 128MB."; | ||||
| @@ -276,7 +276,7 @@ MSRStatus ShardWriter::set_header_size(const uint64_t &header_size) { | |||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| MSRStatus ShardWriter::set_page_size(const uint64_t &page_size) { | |||||
| MSRStatus ShardWriter::SetPageSize(const uint64_t &page_size) { | |||||
| // PageSize [32KB, 256MB] | // PageSize [32KB, 256MB] | ||||
| if (page_size < kMinPageSize || page_size > kMaxPageSize) { | if (page_size < kMinPageSize || page_size > kMaxPageSize) { | ||||
| MS_LOG(ERROR) << "Page size should between 16KB and 256MB."; | MS_LOG(ERROR) << "Page size should between 16KB and 256MB."; | ||||
| @@ -398,7 +398,7 @@ MSRStatus ShardWriter::CheckData(const std::map<uint64_t, std::vector<json>> &ra | |||||
| return FAILED; | return FAILED; | ||||
| } | } | ||||
| json schema = result.first->GetSchema()["schema"]; | json schema = result.first->GetSchema()["schema"]; | ||||
| for (const auto &field : result.first->get_blob_fields()) { | |||||
| for (const auto &field : result.first->GetBlobFields()) { | |||||
| (void)schema.erase(field); | (void)schema.erase(field); | ||||
| } | } | ||||
| std::vector<json> sub_raw_data = rawdata_iter->second; | std::vector<json> sub_raw_data = rawdata_iter->second; | ||||
| @@ -456,7 +456,7 @@ std::tuple<MSRStatus, int, int> ShardWriter::ValidateRawData(std::map<uint64_t, | |||||
| MS_LOG(DEBUG) << "Schema count is " << schema_count_; | MS_LOG(DEBUG) << "Schema count is " << schema_count_; | ||||
| // Determine if the number of schemas is the same | // Determine if the number of schemas is the same | ||||
| if (shard_header_->get_schemas().size() != schema_count_) { | |||||
| if (shard_header_->GetSchemas().size() != schema_count_) { | |||||
| MS_LOG(ERROR) << "Data size is not equal with the schema size"; | MS_LOG(ERROR) << "Data size is not equal with the schema size"; | ||||
| return failed; | return failed; | ||||
| } | } | ||||
| @@ -475,9 +475,9 @@ std::tuple<MSRStatus, int, int> ShardWriter::ValidateRawData(std::map<uint64_t, | |||||
| } | } | ||||
| (void)schema_ids.insert(rawdata_iter->first); | (void)schema_ids.insert(rawdata_iter->first); | ||||
| } | } | ||||
| const std::vector<std::shared_ptr<Schema>> &schemas = shard_header_->get_schemas(); | |||||
| const std::vector<std::shared_ptr<Schema>> &schemas = shard_header_->GetSchemas(); | |||||
| if (std::any_of(schemas.begin(), schemas.end(), [schema_ids](const std::shared_ptr<Schema> &schema) { | if (std::any_of(schemas.begin(), schemas.end(), [schema_ids](const std::shared_ptr<Schema> &schema) { | ||||
| return schema_ids.find(schema->get_schema_id()) == schema_ids.end(); | |||||
| return schema_ids.find(schema->GetSchemaID()) == schema_ids.end(); | |||||
| })) { | })) { | ||||
| // There is not enough data which is not matching the number of schema | // There is not enough data which is not matching the number of schema | ||||
| MS_LOG(ERROR) << "Input rawdata schema id do not match real schema id."; | MS_LOG(ERROR) << "Input rawdata schema id do not match real schema id."; | ||||
| @@ -810,10 +810,10 @@ MSRStatus ShardWriter::CutRowGroup(int start_row, int end_row, const std::vector | |||||
| std::vector<std::pair<int, int>> &rows_in_group, | std::vector<std::pair<int, int>> &rows_in_group, | ||||
| const std::shared_ptr<Page> &last_raw_page, | const std::shared_ptr<Page> &last_raw_page, | ||||
| const std::shared_ptr<Page> &last_blob_page) { | const std::shared_ptr<Page> &last_blob_page) { | ||||
| auto n_byte_blob = last_blob_page ? last_blob_page->get_page_size() : 0; | |||||
| auto n_byte_blob = last_blob_page ? last_blob_page->GetPageSize() : 0; | |||||
| auto last_raw_page_size = last_raw_page ? last_raw_page->get_page_size() : 0; | |||||
| auto last_raw_offset = last_raw_page ? last_raw_page->get_last_row_group_id().second : 0; | |||||
| auto last_raw_page_size = last_raw_page ? last_raw_page->GetPageSize() : 0; | |||||
| auto last_raw_offset = last_raw_page ? last_raw_page->GetLastRowGroupID().second : 0; | |||||
| auto n_byte_raw = last_raw_page_size - last_raw_offset; | auto n_byte_raw = last_raw_page_size - last_raw_offset; | ||||
| int page_start_row = start_row; | int page_start_row = start_row; | ||||
| @@ -849,8 +849,8 @@ MSRStatus ShardWriter::AppendBlobPage(const int &shard_id, const std::vector<std | |||||
| if (blob_row.first == blob_row.second) return SUCCESS; | if (blob_row.first == blob_row.second) return SUCCESS; | ||||
| // Write disk | // Write disk | ||||
| auto page_id = last_blob_page->get_page_id(); | |||||
| auto bytes_page = last_blob_page->get_page_size(); | |||||
| auto page_id = last_blob_page->GetPageID(); | |||||
| auto bytes_page = last_blob_page->GetPageSize(); | |||||
| auto &io_seekp = file_streams_[shard_id]->seekp(page_size_ * page_id + header_size_ + bytes_page, std::ios::beg); | auto &io_seekp = file_streams_[shard_id]->seekp(page_size_ * page_id + header_size_ + bytes_page, std::ios::beg); | ||||
| if (!io_seekp.good() || io_seekp.fail() || io_seekp.bad()) { | if (!io_seekp.good() || io_seekp.fail() || io_seekp.bad()) { | ||||
| MS_LOG(ERROR) << "File seekp failed"; | MS_LOG(ERROR) << "File seekp failed"; | ||||
| @@ -862,9 +862,9 @@ MSRStatus ShardWriter::AppendBlobPage(const int &shard_id, const std::vector<std | |||||
| // Update last blob page | // Update last blob page | ||||
| bytes_page += std::accumulate(blob_data_size_.begin() + blob_row.first, blob_data_size_.begin() + blob_row.second, 0); | bytes_page += std::accumulate(blob_data_size_.begin() + blob_row.first, blob_data_size_.begin() + blob_row.second, 0); | ||||
| last_blob_page->set_page_size(bytes_page); | |||||
| uint64_t end_row = last_blob_page->get_end_row_id() + blob_row.second - blob_row.first; | |||||
| last_blob_page->set_end_row_id(end_row); | |||||
| last_blob_page->SetPageSize(bytes_page); | |||||
| uint64_t end_row = last_blob_page->GetEndRowID() + blob_row.second - blob_row.first; | |||||
| last_blob_page->SetEndRowID(end_row); | |||||
| (void)shard_header_->SetPage(last_blob_page); | (void)shard_header_->SetPage(last_blob_page); | ||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| @@ -873,8 +873,8 @@ MSRStatus ShardWriter::NewBlobPage(const int &shard_id, const std::vector<std::v | |||||
| const std::vector<std::pair<int, int>> &rows_in_group, | const std::vector<std::pair<int, int>> &rows_in_group, | ||||
| const std::shared_ptr<Page> &last_blob_page) { | const std::shared_ptr<Page> &last_blob_page) { | ||||
| auto page_id = shard_header_->GetLastPageId(shard_id); | auto page_id = shard_header_->GetLastPageId(shard_id); | ||||
| auto page_type_id = last_blob_page ? last_blob_page->get_page_type_id() : -1; | |||||
| auto current_row = last_blob_page ? last_blob_page->get_end_row_id() : 0; | |||||
| auto page_type_id = last_blob_page ? last_blob_page->GetPageTypeID() : -1; | |||||
| auto current_row = last_blob_page ? last_blob_page->GetEndRowID() : 0; | |||||
| // index(0) indicate appendBlobPage | // index(0) indicate appendBlobPage | ||||
| for (uint32_t i = 1; i < rows_in_group.size(); ++i) { | for (uint32_t i = 1; i < rows_in_group.size(); ++i) { | ||||
| auto blob_row = rows_in_group[i]; | auto blob_row = rows_in_group[i]; | ||||
| @@ -905,15 +905,15 @@ MSRStatus ShardWriter::ShiftRawPage(const int &shard_id, const std::vector<std:: | |||||
| std::shared_ptr<Page> &last_raw_page) { | std::shared_ptr<Page> &last_raw_page) { | ||||
| auto blob_row = rows_in_group[0]; | auto blob_row = rows_in_group[0]; | ||||
| if (blob_row.first == blob_row.second) return SUCCESS; | if (blob_row.first == blob_row.second) return SUCCESS; | ||||
| auto last_raw_page_size = last_raw_page ? last_raw_page->get_page_size() : 0; | |||||
| auto last_raw_page_size = last_raw_page ? last_raw_page->GetPageSize() : 0; | |||||
| if (std::accumulate(raw_data_size_.begin() + blob_row.first, raw_data_size_.begin() + blob_row.second, 0) + | if (std::accumulate(raw_data_size_.begin() + blob_row.first, raw_data_size_.begin() + blob_row.second, 0) + | ||||
| last_raw_page_size <= | last_raw_page_size <= | ||||
| page_size_) { | page_size_) { | ||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| auto page_id = shard_header_->GetLastPageId(shard_id); | auto page_id = shard_header_->GetLastPageId(shard_id); | ||||
| auto last_row_group_id_offset = last_raw_page->get_last_row_group_id().second; | |||||
| auto last_raw_page_id = last_raw_page->get_page_id(); | |||||
| auto last_row_group_id_offset = last_raw_page->GetLastRowGroupID().second; | |||||
| auto last_raw_page_id = last_raw_page->GetPageID(); | |||||
| auto shift_size = last_raw_page_size - last_row_group_id_offset; | auto shift_size = last_raw_page_size - last_row_group_id_offset; | ||||
| std::vector<uint8_t> buf(shift_size); | std::vector<uint8_t> buf(shift_size); | ||||
| @@ -956,10 +956,10 @@ MSRStatus ShardWriter::ShiftRawPage(const int &shard_id, const std::vector<std:: | |||||
| (void)shard_header_->SetPage(last_raw_page); | (void)shard_header_->SetPage(last_raw_page); | ||||
| // Refresh page info in header | // Refresh page info in header | ||||
| int row_group_id = last_raw_page->get_last_row_group_id().first + 1; | |||||
| int row_group_id = last_raw_page->GetLastRowGroupID().first + 1; | |||||
| std::vector<std::pair<int, uint64_t>> row_group_ids; | std::vector<std::pair<int, uint64_t>> row_group_ids; | ||||
| row_group_ids.emplace_back(row_group_id, 0); | row_group_ids.emplace_back(row_group_id, 0); | ||||
| int page_type_id = last_raw_page->get_page_id(); | |||||
| int page_type_id = last_raw_page->GetPageID(); | |||||
| auto page = Page(++page_id, shard_id, kPageTypeRaw, ++page_type_id, 0, 0, row_group_ids, shift_size); | auto page = Page(++page_id, shard_id, kPageTypeRaw, ++page_type_id, 0, 0, row_group_ids, shift_size); | ||||
| (void)shard_header_->AddPage(std::make_shared<Page>(page)); | (void)shard_header_->AddPage(std::make_shared<Page>(page)); | ||||
| @@ -971,7 +971,7 @@ MSRStatus ShardWriter::ShiftRawPage(const int &shard_id, const std::vector<std:: | |||||
| MSRStatus ShardWriter::WriteRawPage(const int &shard_id, const std::vector<std::pair<int, int>> &rows_in_group, | MSRStatus ShardWriter::WriteRawPage(const int &shard_id, const std::vector<std::pair<int, int>> &rows_in_group, | ||||
| std::shared_ptr<Page> &last_raw_page, | std::shared_ptr<Page> &last_raw_page, | ||||
| const std::vector<std::vector<uint8_t>> &bin_raw_data) { | const std::vector<std::vector<uint8_t>> &bin_raw_data) { | ||||
| int last_row_group_id = last_raw_page ? last_raw_page->get_last_row_group_id().first : -1; | |||||
| int last_row_group_id = last_raw_page ? last_raw_page->GetLastRowGroupID().first : -1; | |||||
| for (uint32_t i = 0; i < rows_in_group.size(); ++i) { | for (uint32_t i = 0; i < rows_in_group.size(); ++i) { | ||||
| const auto &blob_row = rows_in_group[i]; | const auto &blob_row = rows_in_group[i]; | ||||
| if (blob_row.first == blob_row.second) continue; | if (blob_row.first == blob_row.second) continue; | ||||
| @@ -979,7 +979,7 @@ MSRStatus ShardWriter::WriteRawPage(const int &shard_id, const std::vector<std:: | |||||
| std::accumulate(raw_data_size_.begin() + blob_row.first, raw_data_size_.begin() + blob_row.second, 0); | std::accumulate(raw_data_size_.begin() + blob_row.first, raw_data_size_.begin() + blob_row.second, 0); | ||||
| if (!last_raw_page) { | if (!last_raw_page) { | ||||
| EmptyRawPage(shard_id, last_raw_page); | EmptyRawPage(shard_id, last_raw_page); | ||||
| } else if (last_raw_page->get_page_size() + raw_size > page_size_) { | |||||
| } else if (last_raw_page->GetPageSize() + raw_size > page_size_) { | |||||
| (void)shard_header_->SetPage(last_raw_page); | (void)shard_header_->SetPage(last_raw_page); | ||||
| EmptyRawPage(shard_id, last_raw_page); | EmptyRawPage(shard_id, last_raw_page); | ||||
| } | } | ||||
| @@ -994,7 +994,7 @@ MSRStatus ShardWriter::WriteRawPage(const int &shard_id, const std::vector<std:: | |||||
| void ShardWriter::EmptyRawPage(const int &shard_id, std::shared_ptr<Page> &last_raw_page) { | void ShardWriter::EmptyRawPage(const int &shard_id, std::shared_ptr<Page> &last_raw_page) { | ||||
| auto row_group_ids = std::vector<std::pair<int, uint64_t>>(); | auto row_group_ids = std::vector<std::pair<int, uint64_t>>(); | ||||
| auto page_id = shard_header_->GetLastPageId(shard_id); | auto page_id = shard_header_->GetLastPageId(shard_id); | ||||
| auto page_type_id = last_raw_page ? last_raw_page->get_page_id() : -1; | |||||
| auto page_type_id = last_raw_page ? last_raw_page->GetPageID() : -1; | |||||
| auto page = Page(++page_id, shard_id, kPageTypeRaw, ++page_type_id, 0, 0, row_group_ids, 0); | auto page = Page(++page_id, shard_id, kPageTypeRaw, ++page_type_id, 0, 0, row_group_ids, 0); | ||||
| (void)shard_header_->AddPage(std::make_shared<Page>(page)); | (void)shard_header_->AddPage(std::make_shared<Page>(page)); | ||||
| SetLastRawPage(shard_id, last_raw_page); | SetLastRawPage(shard_id, last_raw_page); | ||||
| @@ -1003,9 +1003,9 @@ void ShardWriter::EmptyRawPage(const int &shard_id, std::shared_ptr<Page> &last_ | |||||
| MSRStatus ShardWriter::AppendRawPage(const int &shard_id, const std::vector<std::pair<int, int>> &rows_in_group, | MSRStatus ShardWriter::AppendRawPage(const int &shard_id, const std::vector<std::pair<int, int>> &rows_in_group, | ||||
| const int &chunk_id, int &last_row_group_id, std::shared_ptr<Page> last_raw_page, | const int &chunk_id, int &last_row_group_id, std::shared_ptr<Page> last_raw_page, | ||||
| const std::vector<std::vector<uint8_t>> &bin_raw_data) { | const std::vector<std::vector<uint8_t>> &bin_raw_data) { | ||||
| std::vector<std::pair<int, uint64_t>> row_group_ids = last_raw_page->get_row_group_ids(); | |||||
| auto last_raw_page_id = last_raw_page->get_page_id(); | |||||
| auto n_bytes = last_raw_page->get_page_size(); | |||||
| std::vector<std::pair<int, uint64_t>> row_group_ids = last_raw_page->GetRowGroupIds(); | |||||
| auto last_raw_page_id = last_raw_page->GetPageID(); | |||||
| auto n_bytes = last_raw_page->GetPageSize(); | |||||
| // previous raw data page | // previous raw data page | ||||
| auto &io_seekp = | auto &io_seekp = | ||||
| @@ -1022,8 +1022,8 @@ MSRStatus ShardWriter::AppendRawPage(const int &shard_id, const std::vector<std: | |||||
| (void)FlushRawChunk(file_streams_[shard_id], rows_in_group, chunk_id, bin_raw_data); | (void)FlushRawChunk(file_streams_[shard_id], rows_in_group, chunk_id, bin_raw_data); | ||||
| // Update previous raw data page | // Update previous raw data page | ||||
| last_raw_page->set_page_size(n_bytes); | |||||
| last_raw_page->set_row_group_ids(row_group_ids); | |||||
| last_raw_page->SetPageSize(n_bytes); | |||||
| last_raw_page->SetRowGroupIds(row_group_ids); | |||||
| (void)shard_header_->SetPage(last_raw_page); | (void)shard_header_->SetPage(last_raw_page); | ||||
| return SUCCESS; | return SUCCESS; | ||||
| @@ -34,7 +34,7 @@ ShardCategory::ShardCategory(const std::string &category_field, int64_t num_elem | |||||
| num_categories_(num_categories), | num_categories_(num_categories), | ||||
| replacement_(replacement) {} | replacement_(replacement) {} | ||||
| MSRStatus ShardCategory::execute(ShardTask &tasks) { return SUCCESS; } | |||||
| MSRStatus ShardCategory::Execute(ShardTask &tasks) { return SUCCESS; } | |||||
| int64_t ShardCategory::GetNumSamples(int64_t dataset_size, int64_t num_classes) { | int64_t ShardCategory::GetNumSamples(int64_t dataset_size, int64_t num_classes) { | ||||
| if (dataset_size == 0) return dataset_size; | if (dataset_size == 0) return dataset_size; | ||||
| @@ -343,7 +343,7 @@ std::vector<std::string> ShardHeader::SerializeHeader() { | |||||
| std::string ShardHeader::SerializeIndexFields() { | std::string ShardHeader::SerializeIndexFields() { | ||||
| json j; | json j; | ||||
| auto fields = index_->get_fields(); | |||||
| auto fields = index_->GetFields(); | |||||
| for (const auto &field : fields) { | for (const auto &field : fields) { | ||||
| j.push_back({{"schema_id", field.first}, {"index_field", field.second}}); | j.push_back({{"schema_id", field.first}, {"index_field", field.second}}); | ||||
| } | } | ||||
| @@ -365,7 +365,7 @@ std::vector<std::string> ShardHeader::SerializePage() { | |||||
| std::string ShardHeader::SerializeStatistics() { | std::string ShardHeader::SerializeStatistics() { | ||||
| json j; | json j; | ||||
| for (const auto &stats : statistics_) { | for (const auto &stats : statistics_) { | ||||
| j.emplace_back(stats->get_statistics()); | |||||
| j.emplace_back(stats->GetStatistics()); | |||||
| } | } | ||||
| return j.dump(); | return j.dump(); | ||||
| } | } | ||||
| @@ -398,8 +398,8 @@ MSRStatus ShardHeader::SetPage(const std::shared_ptr<Page> &new_page) { | |||||
| if (new_page == nullptr) { | if (new_page == nullptr) { | ||||
| return FAILED; | return FAILED; | ||||
| } | } | ||||
| int shard_id = new_page->get_shard_id(); | |||||
| int page_id = new_page->get_page_id(); | |||||
| int shard_id = new_page->GetShardID(); | |||||
| int page_id = new_page->GetPageID(); | |||||
| if (shard_id < static_cast<int>(pages_.size()) && page_id < static_cast<int>(pages_[shard_id].size())) { | if (shard_id < static_cast<int>(pages_.size()) && page_id < static_cast<int>(pages_[shard_id].size())) { | ||||
| pages_[shard_id][page_id] = new_page; | pages_[shard_id][page_id] = new_page; | ||||
| return SUCCESS; | return SUCCESS; | ||||
| @@ -412,8 +412,8 @@ MSRStatus ShardHeader::AddPage(const std::shared_ptr<Page> &new_page) { | |||||
| if (new_page == nullptr) { | if (new_page == nullptr) { | ||||
| return FAILED; | return FAILED; | ||||
| } | } | ||||
| int shard_id = new_page->get_shard_id(); | |||||
| int page_id = new_page->get_page_id(); | |||||
| int shard_id = new_page->GetShardID(); | |||||
| int page_id = new_page->GetPageID(); | |||||
| if (shard_id < static_cast<int>(pages_.size()) && page_id == static_cast<int>(pages_[shard_id].size())) { | if (shard_id < static_cast<int>(pages_.size()) && page_id == static_cast<int>(pages_[shard_id].size())) { | ||||
| pages_[shard_id].push_back(new_page); | pages_[shard_id].push_back(new_page); | ||||
| return SUCCESS; | return SUCCESS; | ||||
| @@ -435,8 +435,8 @@ int ShardHeader::GetLastPageIdByType(const int &shard_id, const std::string &pag | |||||
| } | } | ||||
| int last_page_id = -1; | int last_page_id = -1; | ||||
| for (uint64_t i = pages_[shard_id].size(); i >= 1; i--) { | for (uint64_t i = pages_[shard_id].size(); i >= 1; i--) { | ||||
| if (pages_[shard_id][i - 1]->get_page_type() == page_type) { | |||||
| last_page_id = pages_[shard_id][i - 1]->get_page_id(); | |||||
| if (pages_[shard_id][i - 1]->GetPageType() == page_type) { | |||||
| last_page_id = pages_[shard_id][i - 1]->GetPageID(); | |||||
| return last_page_id; | return last_page_id; | ||||
| } | } | ||||
| } | } | ||||
| @@ -451,7 +451,7 @@ const std::pair<MSRStatus, std::shared_ptr<Page>> ShardHeader::GetPageByGroupId( | |||||
| } | } | ||||
| for (uint64_t i = pages_[shard_id].size(); i >= 1; i--) { | for (uint64_t i = pages_[shard_id].size(); i >= 1; i--) { | ||||
| auto page = pages_[shard_id][i - 1]; | auto page = pages_[shard_id][i - 1]; | ||||
| if (page->get_page_type() == kPageTypeBlob && page->get_page_type_id() == group_id) { | |||||
| if (page->GetPageType() == kPageTypeBlob && page->GetPageTypeID() == group_id) { | |||||
| return {SUCCESS, page}; | return {SUCCESS, page}; | ||||
| } | } | ||||
| } | } | ||||
| @@ -470,10 +470,10 @@ int ShardHeader::AddSchema(std::shared_ptr<Schema> schema) { | |||||
| return -1; | return -1; | ||||
| } | } | ||||
| int64_t schema_id = schema->get_schema_id(); | |||||
| int64_t schema_id = schema->GetSchemaID(); | |||||
| if (schema_id == -1) { | if (schema_id == -1) { | ||||
| schema_id = schema_.size(); | schema_id = schema_.size(); | ||||
| schema->set_schema_id(schema_id); | |||||
| schema->SetSchemaID(schema_id); | |||||
| } | } | ||||
| schema_.push_back(schema); | schema_.push_back(schema); | ||||
| return schema_id; | return schema_id; | ||||
| @@ -481,10 +481,10 @@ int ShardHeader::AddSchema(std::shared_ptr<Schema> schema) { | |||||
| void ShardHeader::AddStatistic(std::shared_ptr<Statistics> statistic) { | void ShardHeader::AddStatistic(std::shared_ptr<Statistics> statistic) { | ||||
| if (statistic) { | if (statistic) { | ||||
| int64_t statistics_id = statistic->get_statistics_id(); | |||||
| int64_t statistics_id = statistic->GetStatisticsID(); | |||||
| if (statistics_id == -1) { | if (statistics_id == -1) { | ||||
| statistics_id = statistics_.size(); | statistics_id = statistics_.size(); | ||||
| statistic->set_statistics_id(statistics_id); | |||||
| statistic->SetStatisticsID(statistics_id); | |||||
| } | } | ||||
| statistics_.push_back(statistic); | statistics_.push_back(statistic); | ||||
| } | } | ||||
| @@ -527,13 +527,13 @@ MSRStatus ShardHeader::AddIndexFields(const std::vector<std::string> &fields) { | |||||
| return FAILED; | return FAILED; | ||||
| } | } | ||||
| if (get_schemas().empty()) { | |||||
| if (GetSchemas().empty()) { | |||||
| MS_LOG(ERROR) << "No schema is set"; | MS_LOG(ERROR) << "No schema is set"; | ||||
| return FAILED; | return FAILED; | ||||
| } | } | ||||
| for (const auto &schemaPtr : schema_) { | for (const auto &schemaPtr : schema_) { | ||||
| auto result = GetSchemaByID(schemaPtr->get_schema_id()); | |||||
| auto result = GetSchemaByID(schemaPtr->GetSchemaID()); | |||||
| if (result.second != SUCCESS) { | if (result.second != SUCCESS) { | ||||
| MS_LOG(ERROR) << "Could not get schema by id."; | MS_LOG(ERROR) << "Could not get schema by id."; | ||||
| return FAILED; | return FAILED; | ||||
| @@ -548,7 +548,7 @@ MSRStatus ShardHeader::AddIndexFields(const std::vector<std::string> &fields) { | |||||
| // checkout and add fields for each schema | // checkout and add fields for each schema | ||||
| std::set<std::string> field_set; | std::set<std::string> field_set; | ||||
| for (const auto &item : index->get_fields()) { | |||||
| for (const auto &item : index->GetFields()) { | |||||
| field_set.insert(item.second); | field_set.insert(item.second); | ||||
| } | } | ||||
| for (const auto &field : fields) { | for (const auto &field : fields) { | ||||
| @@ -564,7 +564,7 @@ MSRStatus ShardHeader::AddIndexFields(const std::vector<std::string> &fields) { | |||||
| field_set.insert(field); | field_set.insert(field); | ||||
| // add field into index | // add field into index | ||||
| index.get()->AddIndexField(schemaPtr->get_schema_id(), field); | |||||
| index.get()->AddIndexField(schemaPtr->GetSchemaID(), field); | |||||
| } | } | ||||
| } | } | ||||
| @@ -575,12 +575,12 @@ MSRStatus ShardHeader::AddIndexFields(const std::vector<std::string> &fields) { | |||||
| MSRStatus ShardHeader::GetAllSchemaID(std::set<uint64_t> &bucket_count) { | MSRStatus ShardHeader::GetAllSchemaID(std::set<uint64_t> &bucket_count) { | ||||
| // get all schema id | // get all schema id | ||||
| for (const auto &schema : schema_) { | for (const auto &schema : schema_) { | ||||
| auto bucket_it = bucket_count.find(schema->get_schema_id()); | |||||
| auto bucket_it = bucket_count.find(schema->GetSchemaID()); | |||||
| if (bucket_it != bucket_count.end()) { | if (bucket_it != bucket_count.end()) { | ||||
| MS_LOG(ERROR) << "Schema duplication"; | MS_LOG(ERROR) << "Schema duplication"; | ||||
| return FAILED; | return FAILED; | ||||
| } else { | } else { | ||||
| bucket_count.insert(schema->get_schema_id()); | |||||
| bucket_count.insert(schema->GetSchemaID()); | |||||
| } | } | ||||
| } | } | ||||
| return SUCCESS; | return SUCCESS; | ||||
| @@ -603,7 +603,7 @@ MSRStatus ShardHeader::AddIndexFields(std::vector<std::pair<uint64_t, std::strin | |||||
| // check and add fields for each schema | // check and add fields for each schema | ||||
| std::set<std::pair<uint64_t, std::string>> field_set; | std::set<std::pair<uint64_t, std::string>> field_set; | ||||
| for (const auto &item : index->get_fields()) { | |||||
| for (const auto &item : index->GetFields()) { | |||||
| field_set.insert(item); | field_set.insert(item); | ||||
| } | } | ||||
| for (const auto &field : fields) { | for (const auto &field : fields) { | ||||
| @@ -646,20 +646,20 @@ MSRStatus ShardHeader::AddIndexFields(std::vector<std::pair<uint64_t, std::strin | |||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| std::string ShardHeader::get_shard_address_by_id(int64_t shard_id) { | |||||
| std::string ShardHeader::GetShardAddressByID(int64_t shard_id) { | |||||
| if (shard_id >= shard_addresses_.size()) { | if (shard_id >= shard_addresses_.size()) { | ||||
| return ""; | return ""; | ||||
| } | } | ||||
| return shard_addresses_.at(shard_id); | return shard_addresses_.at(shard_id); | ||||
| } | } | ||||
| std::vector<std::shared_ptr<Schema>> ShardHeader::get_schemas() { return schema_; } | |||||
| std::vector<std::shared_ptr<Schema>> ShardHeader::GetSchemas() { return schema_; } | |||||
| std::vector<std::shared_ptr<Statistics>> ShardHeader::get_statistics() { return statistics_; } | |||||
| std::vector<std::shared_ptr<Statistics>> ShardHeader::GetStatistics() { return statistics_; } | |||||
| std::vector<std::pair<uint64_t, std::string>> ShardHeader::get_fields() { return index_->get_fields(); } | |||||
| std::vector<std::pair<uint64_t, std::string>> ShardHeader::GetFields() { return index_->GetFields(); } | |||||
| std::shared_ptr<Index> ShardHeader::get_index() { return index_; } | |||||
| std::shared_ptr<Index> ShardHeader::GetIndex() { return index_; } | |||||
| std::pair<std::shared_ptr<Schema>, MSRStatus> ShardHeader::GetSchemaByID(int64_t schema_id) { | std::pair<std::shared_ptr<Schema>, MSRStatus> ShardHeader::GetSchemaByID(int64_t schema_id) { | ||||
| int64_t schemaSize = schema_.size(); | int64_t schemaSize = schema_.size(); | ||||
| @@ -28,6 +28,6 @@ void Index::AddIndexField(const int64_t &schemaId, const std::string &field) { | |||||
| } | } | ||||
| // Get attribute list | // Get attribute list | ||||
| std::vector<std::pair<uint64_t, std::string>> Index::get_fields() { return fields_; } | |||||
| std::vector<std::pair<uint64_t, std::string>> Index::GetFields() { return fields_; } | |||||
| } // namespace mindrecord | } // namespace mindrecord | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -34,7 +34,7 @@ ShardPkSample::ShardPkSample(const std::string &category_field, int64_t num_elem | |||||
| shuffle_op_ = std::make_shared<ShardShuffle>(seed, kShuffleSample); // do shuffle and replacement | shuffle_op_ = std::make_shared<ShardShuffle>(seed, kShuffleSample); // do shuffle and replacement | ||||
| } | } | ||||
| MSRStatus ShardPkSample::suf_execute(ShardTask &tasks) { | |||||
| MSRStatus ShardPkSample::SufExecute(ShardTask &tasks) { | |||||
| if (shuffle_ == true) { | if (shuffle_ == true) { | ||||
| if (SUCCESS != (*shuffle_op_)(tasks)) { | if (SUCCESS != (*shuffle_op_)(tasks)) { | ||||
| return FAILED; | return FAILED; | ||||
| @@ -74,14 +74,14 @@ int64_t ShardSample::GetNumSamples(int64_t dataset_size, int64_t num_classes) { | |||||
| return -1; | return -1; | ||||
| } | } | ||||
| const std::pair<int, int> ShardSample::get_partitions() const { | |||||
| const std::pair<int, int> ShardSample::GetPartitions() const { | |||||
| if (numerator_ == 1 && denominator_ > 1) { | if (numerator_ == 1 && denominator_ > 1) { | ||||
| return std::pair<int, int>(denominator_, partition_id_); | return std::pair<int, int>(denominator_, partition_id_); | ||||
| } | } | ||||
| return std::pair<int, int>(-1, -1); | return std::pair<int, int>(-1, -1); | ||||
| } | } | ||||
| MSRStatus ShardSample::execute(ShardTask &tasks) { | |||||
| MSRStatus ShardSample::Execute(ShardTask &tasks) { | |||||
| int no_of_categories = static_cast<int>(tasks.categories); | int no_of_categories = static_cast<int>(tasks.categories); | ||||
| int total_no = static_cast<int>(tasks.Size()); | int total_no = static_cast<int>(tasks.Size()); | ||||
| @@ -114,11 +114,11 @@ MSRStatus ShardSample::execute(ShardTask &tasks) { | |||||
| if (sampler_type_ == kSubsetRandomSampler) { | if (sampler_type_ == kSubsetRandomSampler) { | ||||
| for (int i = 0; i < indices_.size(); ++i) { | for (int i = 0; i < indices_.size(); ++i) { | ||||
| int index = ((indices_[i] % total_no) + total_no) % total_no; | int index = ((indices_[i] % total_no) + total_no) % total_no; | ||||
| new_tasks.InsertTask(tasks.get_task_by_id(index)); // different mod result between c and python | |||||
| new_tasks.InsertTask(tasks.GetTaskByID(index)); // different mod result between c and python | |||||
| } | } | ||||
| } else { | } else { | ||||
| for (int i = partition_id_ * taking; i < (partition_id_ + 1) * taking; i++) { | for (int i = partition_id_ * taking; i < (partition_id_ + 1) * taking; i++) { | ||||
| new_tasks.InsertTask(tasks.get_task_by_id(i % total_no)); // rounding up. if overflow, go back to start | |||||
| new_tasks.InsertTask(tasks.GetTaskByID(i % total_no)); // rounding up. if overflow, go back to start | |||||
| } | } | ||||
| } | } | ||||
| std::swap(tasks, new_tasks); | std::swap(tasks, new_tasks); | ||||
| @@ -129,14 +129,14 @@ MSRStatus ShardSample::execute(ShardTask &tasks) { | |||||
| } | } | ||||
| total_no = static_cast<int>(tasks.permutation_.size()); | total_no = static_cast<int>(tasks.permutation_.size()); | ||||
| for (size_t i = partition_id_ * taking; i < (partition_id_ + 1) * taking; i++) { | for (size_t i = partition_id_ * taking; i < (partition_id_ + 1) * taking; i++) { | ||||
| new_tasks.InsertTask(tasks.get_task_by_id(tasks.permutation_[i % total_no])); | |||||
| new_tasks.InsertTask(tasks.GetTaskByID(tasks.permutation_[i % total_no])); | |||||
| } | } | ||||
| std::swap(tasks, new_tasks); | std::swap(tasks, new_tasks); | ||||
| } | } | ||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| MSRStatus ShardSample::suf_execute(ShardTask &tasks) { | |||||
| MSRStatus ShardSample::SufExecute(ShardTask &tasks) { | |||||
| if (sampler_type_ == kSubsetRandomSampler) { | if (sampler_type_ == kSubsetRandomSampler) { | ||||
| if (SUCCESS != (*shuffle_op_)(tasks)) { | if (SUCCESS != (*shuffle_op_)(tasks)) { | ||||
| return FAILED; | return FAILED; | ||||
| @@ -44,7 +44,7 @@ std::shared_ptr<Schema> Schema::Build(std::string desc, pybind11::handle schema) | |||||
| return Build(std::move(desc), schema_json); | return Build(std::move(desc), schema_json); | ||||
| } | } | ||||
| std::string Schema::get_desc() const { return desc_; } | |||||
| std::string Schema::GetDesc() const { return desc_; } | |||||
| json Schema::GetSchema() const { | json Schema::GetSchema() const { | ||||
| json str_schema; | json str_schema; | ||||
| @@ -60,11 +60,11 @@ pybind11::object Schema::GetSchemaForPython() const { | |||||
| return schema_py; | return schema_py; | ||||
| } | } | ||||
| void Schema::set_schema_id(int64_t id) { schema_id_ = id; } | |||||
| void Schema::SetSchemaID(int64_t id) { schema_id_ = id; } | |||||
| int64_t Schema::get_schema_id() const { return schema_id_; } | |||||
| int64_t Schema::GetSchemaID() const { return schema_id_; } | |||||
| std::vector<std::string> Schema::get_blob_fields() const { return blob_fields_; } | |||||
| std::vector<std::string> Schema::GetBlobFields() const { return blob_fields_; } | |||||
| std::vector<std::string> Schema::PopulateBlobFields(json schema) { | std::vector<std::string> Schema::PopulateBlobFields(json schema) { | ||||
| std::vector<std::string> blob_fields; | std::vector<std::string> blob_fields; | ||||
| @@ -155,7 +155,7 @@ bool Schema::Validate(json schema) { | |||||
| } | } | ||||
| bool Schema::operator==(const mindrecord::Schema &b) const { | bool Schema::operator==(const mindrecord::Schema &b) const { | ||||
| if (this->get_desc() != b.get_desc() || this->GetSchema() != b.GetSchema()) { | |||||
| if (this->GetDesc() != b.GetDesc() || this->GetSchema() != b.GetSchema()) { | |||||
| return false; | return false; | ||||
| } | } | ||||
| return true; | return true; | ||||
| @@ -23,7 +23,7 @@ namespace mindrecord { | |||||
| ShardShuffle::ShardShuffle(uint32_t seed, ShuffleType shuffle_type) | ShardShuffle::ShardShuffle(uint32_t seed, ShuffleType shuffle_type) | ||||
| : shuffle_seed_(seed), shuffle_type_(shuffle_type) {} | : shuffle_seed_(seed), shuffle_type_(shuffle_type) {} | ||||
| MSRStatus ShardShuffle::execute(ShardTask &tasks) { | |||||
| MSRStatus ShardShuffle::Execute(ShardTask &tasks) { | |||||
| if (tasks.categories < 1) { | if (tasks.categories < 1) { | ||||
| return FAILED; | return FAILED; | ||||
| } | } | ||||
| @@ -48,9 +48,9 @@ std::shared_ptr<Statistics> Statistics::Build(std::string desc, pybind11::handle | |||||
| return std::make_shared<Statistics>(object_statistics); | return std::make_shared<Statistics>(object_statistics); | ||||
| } | } | ||||
| std::string Statistics::get_desc() const { return desc_; } | |||||
| std::string Statistics::GetDesc() const { return desc_; } | |||||
| json Statistics::get_statistics() const { | |||||
| json Statistics::GetStatistics() const { | |||||
| json str_statistics; | json str_statistics; | ||||
| str_statistics["desc"] = desc_; | str_statistics["desc"] = desc_; | ||||
| str_statistics["statistics"] = statistics_; | str_statistics["statistics"] = statistics_; | ||||
| @@ -58,13 +58,13 @@ json Statistics::get_statistics() const { | |||||
| } | } | ||||
| pybind11::object Statistics::GetStatisticsForPython() const { | pybind11::object Statistics::GetStatisticsForPython() const { | ||||
| json str_statistics = Statistics::get_statistics(); | |||||
| json str_statistics = Statistics::GetStatistics(); | |||||
| return nlohmann::detail::FromJsonImpl(str_statistics); | return nlohmann::detail::FromJsonImpl(str_statistics); | ||||
| } | } | ||||
| void Statistics::set_statistics_id(int64_t id) { statistics_id_ = id; } | |||||
| void Statistics::SetStatisticsID(int64_t id) { statistics_id_ = id; } | |||||
| int64_t Statistics::get_statistics_id() const { return statistics_id_; } | |||||
| int64_t Statistics::GetStatisticsID() const { return statistics_id_; } | |||||
| bool Statistics::Validate(const json &statistics) { | bool Statistics::Validate(const json &statistics) { | ||||
| if (statistics.size() != kInt1) { | if (statistics.size() != kInt1) { | ||||
| @@ -103,7 +103,7 @@ bool Statistics::LevelRecursive(json level) { | |||||
| } | } | ||||
| bool Statistics::operator==(const Statistics &b) const { | bool Statistics::operator==(const Statistics &b) const { | ||||
| if (this->get_statistics() != b.get_statistics()) { | |||||
| if (this->GetStatistics() != b.GetStatistics()) { | |||||
| return false; | return false; | ||||
| } | } | ||||
| return true; | return true; | ||||
| @@ -59,12 +59,12 @@ uint32_t ShardTask::SizeOfRows() const { | |||||
| return nRows; | return nRows; | ||||
| } | } | ||||
| std::tuple<std::tuple<int, int>, std::vector<uint64_t>, json> &ShardTask::get_task_by_id(size_t id) { | |||||
| std::tuple<std::tuple<int, int>, std::vector<uint64_t>, json> &ShardTask::GetTaskByID(size_t id) { | |||||
| MS_ASSERT(id < task_list_.size()); | MS_ASSERT(id < task_list_.size()); | ||||
| return task_list_[id]; | return task_list_[id]; | ||||
| } | } | ||||
| std::tuple<std::tuple<int, int>, std::vector<uint64_t>, json> &ShardTask::get_random_task() { | |||||
| std::tuple<std::tuple<int, int>, std::vector<uint64_t>, json> &ShardTask::GetRandomTask() { | |||||
| std::random_device rd; | std::random_device rd; | ||||
| std::mt19937 gen(rd()); | std::mt19937 gen(rd()); | ||||
| std::uniform_int_distribution<> dis(0, task_list_.size() - 1); | std::uniform_int_distribution<> dis(0, task_list_.size() - 1); | ||||
| @@ -82,7 +82,7 @@ ShardTask ShardTask::Combine(std::vector<ShardTask> &category_tasks, bool replac | |||||
| } | } | ||||
| for (uint32_t task_no = 0; task_no < minTasks; task_no++) { | for (uint32_t task_no = 0; task_no < minTasks; task_no++) { | ||||
| for (uint32_t i = 0; i < total_categories; i++) { | for (uint32_t i = 0; i < total_categories; i++) { | ||||
| res.InsertTask(std::move(category_tasks[i].get_task_by_id(static_cast<int>(task_no)))); | |||||
| res.InsertTask(std::move(category_tasks[i].GetTaskByID(static_cast<int>(task_no)))); | |||||
| } | } | ||||
| } | } | ||||
| } else { | } else { | ||||
| @@ -95,7 +95,7 @@ ShardTask ShardTask::Combine(std::vector<ShardTask> &category_tasks, bool replac | |||||
| } | } | ||||
| for (uint32_t i = 0; i < total_categories; i++) { | for (uint32_t i = 0; i < total_categories; i++) { | ||||
| for (uint32_t j = 0; j < maxTasks; j++) { | for (uint32_t j = 0; j < maxTasks; j++) { | ||||
| res.InsertTask(category_tasks[i].get_random_task()); | |||||
| res.InsertTask(category_tasks[i].GetRandomTask()); | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| @@ -52,7 +52,7 @@ TEST_F(TestShard, TestShardSchemaPart) { | |||||
| std::shared_ptr<Schema> schema = Schema::Build(desc, j); | std::shared_ptr<Schema> schema = Schema::Build(desc, j); | ||||
| ASSERT_TRUE(schema != nullptr); | ASSERT_TRUE(schema != nullptr); | ||||
| MS_LOG(INFO) << "schema description: " << schema->get_desc() << ", schema: " << | |||||
| MS_LOG(INFO) << "schema description: " << schema->GetDesc() << ", schema: " << | |||||
| common::SafeCStr(schema->GetSchema().dump()); | common::SafeCStr(schema->GetSchema().dump()); | ||||
| for (int i = 1; i <= 4; i++) { | for (int i = 1; i <= 4; i++) { | ||||
| string filename = std::string("./imagenet.shard0") + std::to_string(i); | string filename = std::string("./imagenet.shard0") + std::to_string(i); | ||||
| @@ -71,8 +71,8 @@ TEST_F(TestShard, TestStatisticPart) { | |||||
| nlohmann::json statistic_json = json::parse(kStatistics[2]); | nlohmann::json statistic_json = json::parse(kStatistics[2]); | ||||
| std::shared_ptr<Statistics> statistics = Statistics::Build(desc, statistic_json); | std::shared_ptr<Statistics> statistics = Statistics::Build(desc, statistic_json); | ||||
| ASSERT_TRUE(statistics != nullptr); | ASSERT_TRUE(statistics != nullptr); | ||||
| MS_LOG(INFO) << "test get_desc(), result: " << statistics->get_desc(); | |||||
| MS_LOG(INFO) << "test get_statistics, result: " << statistics->get_statistics().dump(); | |||||
| MS_LOG(INFO) << "test get_desc(), result: " << statistics->GetDesc(); | |||||
| MS_LOG(INFO) << "test get_statistics, result: " << statistics->GetStatistics().dump(); | |||||
| std::string desc2 = "axis"; | std::string desc2 = "axis"; | ||||
| nlohmann::json statistic_json2 = R"({})"; | nlohmann::json statistic_json2 = R"({})"; | ||||
| @@ -111,13 +111,13 @@ TEST_F(TestShard, TestShardHeaderPart) { | |||||
| ASSERT_EQ(res, 0); | ASSERT_EQ(res, 0); | ||||
| header_data.AddStatistic(statistics1); | header_data.AddStatistic(statistics1); | ||||
| std::vector<Schema> re_schemas; | std::vector<Schema> re_schemas; | ||||
| for (auto &schema_ptr : header_data.get_schemas()) { | |||||
| for (auto &schema_ptr : header_data.GetSchemas()) { | |||||
| re_schemas.push_back(*schema_ptr); | re_schemas.push_back(*schema_ptr); | ||||
| } | } | ||||
| ASSERT_EQ(re_schemas, validate_schema); | ASSERT_EQ(re_schemas, validate_schema); | ||||
| std::vector<Statistics> re_statistics; | std::vector<Statistics> re_statistics; | ||||
| for (auto &statistic : header_data.get_statistics()) { | |||||
| for (auto &statistic : header_data.GetStatistics()) { | |||||
| re_statistics.push_back(*statistic); | re_statistics.push_back(*statistic); | ||||
| } | } | ||||
| ASSERT_EQ(re_statistics, validate_statistics); | ASSERT_EQ(re_statistics, validate_statistics); | ||||
| @@ -129,7 +129,7 @@ TEST_F(TestShard, TestShardHeaderPart) { | |||||
| std::pair<uint64_t, std::string> pair1(0, "name"); | std::pair<uint64_t, std::string> pair1(0, "name"); | ||||
| fields.push_back(pair1); | fields.push_back(pair1); | ||||
| ASSERT_TRUE(header_data.AddIndexFields(fields) == SUCCESS); | ASSERT_TRUE(header_data.AddIndexFields(fields) == SUCCESS); | ||||
| std::vector<std::pair<uint64_t, std::string>> resFields = header_data.get_fields(); | |||||
| std::vector<std::pair<uint64_t, std::string>> resFields = header_data.GetFields(); | |||||
| ASSERT_EQ(resFields, fields); | ASSERT_EQ(resFields, fields); | ||||
| } | } | ||||
| @@ -70,7 +70,7 @@ TEST_F(TestShardHeader, AddIndexFields) { | |||||
| int schema_id1 = header_data.AddSchema(schema1); | int schema_id1 = header_data.AddSchema(schema1); | ||||
| int schema_id2 = header_data.AddSchema(schema2); | int schema_id2 = header_data.AddSchema(schema2); | ||||
| ASSERT_EQ(schema_id2, -1); | ASSERT_EQ(schema_id2, -1); | ||||
| ASSERT_EQ(header_data.get_schemas().size(), 1); | |||||
| ASSERT_EQ(header_data.GetSchemas().size(), 1); | |||||
| // check out fields | // check out fields | ||||
| std::vector<std::pair<uint64_t, std::string>> fields; | std::vector<std::pair<uint64_t, std::string>> fields; | ||||
| @@ -81,35 +81,35 @@ TEST_F(TestShardHeader, AddIndexFields) { | |||||
| fields.push_back(index_field2); | fields.push_back(index_field2); | ||||
| MSRStatus res = header_data.AddIndexFields(fields); | MSRStatus res = header_data.AddIndexFields(fields); | ||||
| ASSERT_EQ(res, SUCCESS); | ASSERT_EQ(res, SUCCESS); | ||||
| ASSERT_EQ(header_data.get_fields().size(), 2); | |||||
| ASSERT_EQ(header_data.GetFields().size(), 2); | |||||
| fields.clear(); | fields.clear(); | ||||
| std::pair<uint64_t, std::string> index_field3(schema_id1, "name"); | std::pair<uint64_t, std::string> index_field3(schema_id1, "name"); | ||||
| fields.push_back(index_field3); | fields.push_back(index_field3); | ||||
| res = header_data.AddIndexFields(fields); | res = header_data.AddIndexFields(fields); | ||||
| ASSERT_EQ(res, FAILED); | ASSERT_EQ(res, FAILED); | ||||
| ASSERT_EQ(header_data.get_fields().size(), 2); | |||||
| ASSERT_EQ(header_data.GetFields().size(), 2); | |||||
| fields.clear(); | fields.clear(); | ||||
| std::pair<uint64_t, std::string> index_field4(schema_id1, "names"); | std::pair<uint64_t, std::string> index_field4(schema_id1, "names"); | ||||
| fields.push_back(index_field4); | fields.push_back(index_field4); | ||||
| res = header_data.AddIndexFields(fields); | res = header_data.AddIndexFields(fields); | ||||
| ASSERT_EQ(res, FAILED); | ASSERT_EQ(res, FAILED); | ||||
| ASSERT_EQ(header_data.get_fields().size(), 2); | |||||
| ASSERT_EQ(header_data.GetFields().size(), 2); | |||||
| fields.clear(); | fields.clear(); | ||||
| std::pair<uint64_t, std::string> index_field5(schema_id1 + 1, "name"); | std::pair<uint64_t, std::string> index_field5(schema_id1 + 1, "name"); | ||||
| fields.push_back(index_field5); | fields.push_back(index_field5); | ||||
| res = header_data.AddIndexFields(fields); | res = header_data.AddIndexFields(fields); | ||||
| ASSERT_EQ(res, FAILED); | ASSERT_EQ(res, FAILED); | ||||
| ASSERT_EQ(header_data.get_fields().size(), 2); | |||||
| ASSERT_EQ(header_data.GetFields().size(), 2); | |||||
| fields.clear(); | fields.clear(); | ||||
| std::pair<uint64_t, std::string> index_field6(schema_id1, "label"); | std::pair<uint64_t, std::string> index_field6(schema_id1, "label"); | ||||
| fields.push_back(index_field6); | fields.push_back(index_field6); | ||||
| res = header_data.AddIndexFields(fields); | res = header_data.AddIndexFields(fields); | ||||
| ASSERT_EQ(res, FAILED); | ASSERT_EQ(res, FAILED); | ||||
| ASSERT_EQ(header_data.get_fields().size(), 2); | |||||
| ASSERT_EQ(header_data.GetFields().size(), 2); | |||||
| std::string desc_new = "this is a test1"; | std::string desc_new = "this is a test1"; | ||||
| json schemaContent_new = R"({"name": {"type": "string"}, | json schemaContent_new = R"({"name": {"type": "string"}, | ||||
| @@ -121,7 +121,7 @@ TEST_F(TestShardHeader, AddIndexFields) { | |||||
| mindrecord::ShardHeader header_data_new; | mindrecord::ShardHeader header_data_new; | ||||
| header_data_new.AddSchema(schema_new); | header_data_new.AddSchema(schema_new); | ||||
| ASSERT_EQ(header_data_new.get_schemas().size(), 1); | |||||
| ASSERT_EQ(header_data_new.GetSchemas().size(), 1); | |||||
| // test add fields | // test add fields | ||||
| std::vector<std::string> single_fields; | std::vector<std::string> single_fields; | ||||
| @@ -131,25 +131,25 @@ TEST_F(TestShardHeader, AddIndexFields) { | |||||
| single_fields.push_back("box"); | single_fields.push_back("box"); | ||||
| res = header_data_new.AddIndexFields(single_fields); | res = header_data_new.AddIndexFields(single_fields); | ||||
| ASSERT_EQ(res, FAILED); | ASSERT_EQ(res, FAILED); | ||||
| ASSERT_EQ(header_data_new.get_fields().size(), 1); | |||||
| ASSERT_EQ(header_data_new.GetFields().size(), 1); | |||||
| single_fields.push_back("name"); | single_fields.push_back("name"); | ||||
| single_fields.push_back("box"); | single_fields.push_back("box"); | ||||
| res = header_data_new.AddIndexFields(single_fields); | res = header_data_new.AddIndexFields(single_fields); | ||||
| ASSERT_EQ(res, FAILED); | ASSERT_EQ(res, FAILED); | ||||
| ASSERT_EQ(header_data_new.get_fields().size(), 1); | |||||
| ASSERT_EQ(header_data_new.GetFields().size(), 1); | |||||
| single_fields.clear(); | single_fields.clear(); | ||||
| single_fields.push_back("names"); | single_fields.push_back("names"); | ||||
| res = header_data_new.AddIndexFields(single_fields); | res = header_data_new.AddIndexFields(single_fields); | ||||
| ASSERT_EQ(res, FAILED); | ASSERT_EQ(res, FAILED); | ||||
| ASSERT_EQ(header_data_new.get_fields().size(), 1); | |||||
| ASSERT_EQ(header_data_new.GetFields().size(), 1); | |||||
| single_fields.clear(); | single_fields.clear(); | ||||
| single_fields.push_back("box"); | single_fields.push_back("box"); | ||||
| res = header_data_new.AddIndexFields(single_fields); | res = header_data_new.AddIndexFields(single_fields); | ||||
| ASSERT_EQ(res, SUCCESS); | ASSERT_EQ(res, SUCCESS); | ||||
| ASSERT_EQ(header_data_new.get_fields().size(), 2); | |||||
| ASSERT_EQ(header_data_new.GetFields().size(), 2); | |||||
| } | } | ||||
| } // namespace mindrecord | } // namespace mindrecord | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -139,7 +139,7 @@ TEST_F(TestShardOperator, TestShardSamplePartition) { | |||||
| const int kPar = 2; | const int kPar = 2; | ||||
| std::vector<std::shared_ptr<ShardOperator>> ops; | std::vector<std::shared_ptr<ShardOperator>> ops; | ||||
| ops.push_back(std::make_shared<ShardSample>(kNum, kDen, kPar)); | ops.push_back(std::make_shared<ShardSample>(kNum, kDen, kPar)); | ||||
| auto partitions = std::dynamic_pointer_cast<ShardSample>(ops[0])->get_partitions(); | |||||
| auto partitions = std::dynamic_pointer_cast<ShardSample>(ops[0])->GetPartitions(); | |||||
| ASSERT_TRUE(partitions.first == 4); | ASSERT_TRUE(partitions.first == 4); | ||||
| ASSERT_TRUE(partitions.second == 2); | ASSERT_TRUE(partitions.second == 2); | ||||
| @@ -57,15 +57,15 @@ TEST_F(TestShardPage, TestBasic) { | |||||
| Page page = | Page page = | ||||
| Page(kGoldenPageId, kGoldenShardId, kGoldenType, kGoldenTypeId, kGoldenStart, kGoldenEnd, golden_row_group, kGoldenSize); | Page(kGoldenPageId, kGoldenShardId, kGoldenType, kGoldenTypeId, kGoldenStart, kGoldenEnd, golden_row_group, kGoldenSize); | ||||
| EXPECT_EQ(kGoldenPageId, page.get_page_id()); | |||||
| EXPECT_EQ(kGoldenShardId, page.get_shard_id()); | |||||
| EXPECT_EQ(kGoldenTypeId, page.get_page_type_id()); | |||||
| ASSERT_TRUE(kGoldenType == page.get_page_type()); | |||||
| EXPECT_EQ(kGoldenSize, page.get_page_size()); | |||||
| EXPECT_EQ(kGoldenStart, page.get_start_row_id()); | |||||
| EXPECT_EQ(kGoldenEnd, page.get_end_row_id()); | |||||
| ASSERT_TRUE(std::make_pair(4, kOffset) == page.get_last_row_group_id()); | |||||
| ASSERT_TRUE(golden_row_group == page.get_row_group_ids()); | |||||
| EXPECT_EQ(kGoldenPageId, page.GetPageID()); | |||||
| EXPECT_EQ(kGoldenShardId, page.GetShardID()); | |||||
| EXPECT_EQ(kGoldenTypeId, page.GetPageTypeID()); | |||||
| ASSERT_TRUE(kGoldenType == page.GetPageType()); | |||||
| EXPECT_EQ(kGoldenSize, page.GetPageSize()); | |||||
| EXPECT_EQ(kGoldenStart, page.GetStartRowID()); | |||||
| EXPECT_EQ(kGoldenEnd, page.GetEndRowID()); | |||||
| ASSERT_TRUE(std::make_pair(4, kOffset) == page.GetLastRowGroupID()); | |||||
| ASSERT_TRUE(golden_row_group == page.GetRowGroupIds()); | |||||
| } | } | ||||
| TEST_F(TestShardPage, TestSetter) { | TEST_F(TestShardPage, TestSetter) { | ||||
| @@ -86,43 +86,43 @@ TEST_F(TestShardPage, TestSetter) { | |||||
| Page page = | Page page = | ||||
| Page(kGoldenPageId, kGoldenShardId, kGoldenType, kGoldenTypeId, kGoldenStart, kGoldenEnd, golden_row_group, kGoldenSize); | Page(kGoldenPageId, kGoldenShardId, kGoldenType, kGoldenTypeId, kGoldenStart, kGoldenEnd, golden_row_group, kGoldenSize); | ||||
| EXPECT_EQ(kGoldenPageId, page.get_page_id()); | |||||
| EXPECT_EQ(kGoldenShardId, page.get_shard_id()); | |||||
| EXPECT_EQ(kGoldenTypeId, page.get_page_type_id()); | |||||
| ASSERT_TRUE(kGoldenType == page.get_page_type()); | |||||
| EXPECT_EQ(kGoldenSize, page.get_page_size()); | |||||
| EXPECT_EQ(kGoldenStart, page.get_start_row_id()); | |||||
| EXPECT_EQ(kGoldenEnd, page.get_end_row_id()); | |||||
| ASSERT_TRUE(std::make_pair(4, kOffset1) == page.get_last_row_group_id()); | |||||
| ASSERT_TRUE(golden_row_group == page.get_row_group_ids()); | |||||
| EXPECT_EQ(kGoldenPageId, page.GetPageID()); | |||||
| EXPECT_EQ(kGoldenShardId, page.GetShardID()); | |||||
| EXPECT_EQ(kGoldenTypeId, page.GetPageTypeID()); | |||||
| ASSERT_TRUE(kGoldenType == page.GetPageType()); | |||||
| EXPECT_EQ(kGoldenSize, page.GetPageSize()); | |||||
| EXPECT_EQ(kGoldenStart, page.GetStartRowID()); | |||||
| EXPECT_EQ(kGoldenEnd, page.GetEndRowID()); | |||||
| ASSERT_TRUE(std::make_pair(4, kOffset1) == page.GetLastRowGroupID()); | |||||
| ASSERT_TRUE(golden_row_group == page.GetRowGroupIds()); | |||||
| const int kNewEnd = 33; | const int kNewEnd = 33; | ||||
| const int kNewSize = 300; | const int kNewSize = 300; | ||||
| std::vector<std::pair<int, uint64_t>> new_row_group = {{0, 100}, {100, 200}, {200, 3000}}; | std::vector<std::pair<int, uint64_t>> new_row_group = {{0, 100}, {100, 200}, {200, 3000}}; | ||||
| page.set_end_row_id(kNewEnd); | |||||
| page.set_page_size(kNewSize); | |||||
| page.set_row_group_ids(new_row_group); | |||||
| EXPECT_EQ(kGoldenPageId, page.get_page_id()); | |||||
| EXPECT_EQ(kGoldenShardId, page.get_shard_id()); | |||||
| EXPECT_EQ(kGoldenTypeId, page.get_page_type_id()); | |||||
| ASSERT_TRUE(kGoldenType == page.get_page_type()); | |||||
| EXPECT_EQ(kNewSize, page.get_page_size()); | |||||
| EXPECT_EQ(kGoldenStart, page.get_start_row_id()); | |||||
| EXPECT_EQ(kNewEnd, page.get_end_row_id()); | |||||
| ASSERT_TRUE(std::make_pair(200, kOffset2) == page.get_last_row_group_id()); | |||||
| ASSERT_TRUE(new_row_group == page.get_row_group_ids()); | |||||
| page.SetEndRowID(kNewEnd); | |||||
| page.SetPageSize(kNewSize); | |||||
| page.SetRowGroupIds(new_row_group); | |||||
| EXPECT_EQ(kGoldenPageId, page.GetPageID()); | |||||
| EXPECT_EQ(kGoldenShardId, page.GetShardID()); | |||||
| EXPECT_EQ(kGoldenTypeId, page.GetPageTypeID()); | |||||
| ASSERT_TRUE(kGoldenType == page.GetPageType()); | |||||
| EXPECT_EQ(kNewSize, page.GetPageSize()); | |||||
| EXPECT_EQ(kGoldenStart, page.GetStartRowID()); | |||||
| EXPECT_EQ(kNewEnd, page.GetEndRowID()); | |||||
| ASSERT_TRUE(std::make_pair(200, kOffset2) == page.GetLastRowGroupID()); | |||||
| ASSERT_TRUE(new_row_group == page.GetRowGroupIds()); | |||||
| page.DeleteLastGroupId(); | page.DeleteLastGroupId(); | ||||
| EXPECT_EQ(kGoldenPageId, page.get_page_id()); | |||||
| EXPECT_EQ(kGoldenShardId, page.get_shard_id()); | |||||
| EXPECT_EQ(kGoldenTypeId, page.get_page_type_id()); | |||||
| ASSERT_TRUE(kGoldenType == page.get_page_type()); | |||||
| EXPECT_EQ(3000, page.get_page_size()); | |||||
| EXPECT_EQ(kGoldenStart, page.get_start_row_id()); | |||||
| EXPECT_EQ(kNewEnd, page.get_end_row_id()); | |||||
| ASSERT_TRUE(std::make_pair(100, kOffset3) == page.get_last_row_group_id()); | |||||
| EXPECT_EQ(kGoldenPageId, page.GetPageID()); | |||||
| EXPECT_EQ(kGoldenShardId, page.GetShardID()); | |||||
| EXPECT_EQ(kGoldenTypeId, page.GetPageTypeID()); | |||||
| ASSERT_TRUE(kGoldenType == page.GetPageType()); | |||||
| EXPECT_EQ(3000, page.GetPageSize()); | |||||
| EXPECT_EQ(kGoldenStart, page.GetStartRowID()); | |||||
| EXPECT_EQ(kNewEnd, page.GetEndRowID()); | |||||
| ASSERT_TRUE(std::make_pair(100, kOffset3) == page.GetLastRowGroupID()); | |||||
| new_row_group.pop_back(); | new_row_group.pop_back(); | ||||
| ASSERT_TRUE(new_row_group == page.get_row_group_ids()); | |||||
| ASSERT_TRUE(new_row_group == page.GetRowGroupIds()); | |||||
| } | } | ||||
| TEST_F(TestShardPage, TestJson) { | TEST_F(TestShardPage, TestJson) { | ||||
| @@ -107,15 +107,15 @@ TEST_F(TestShardSchema, TestFunction) { | |||||
| std::shared_ptr<Schema> schema = Schema::Build(desc, schema_content); | std::shared_ptr<Schema> schema = Schema::Build(desc, schema_content); | ||||
| ASSERT_NE(schema, nullptr); | ASSERT_NE(schema, nullptr); | ||||
| ASSERT_EQ(schema->get_desc(), desc); | |||||
| ASSERT_EQ(schema->GetDesc(), desc); | |||||
| json schema_json = schema->GetSchema(); | json schema_json = schema->GetSchema(); | ||||
| ASSERT_EQ(schema_json["desc"], desc); | ASSERT_EQ(schema_json["desc"], desc); | ||||
| ASSERT_EQ(schema_json["schema"], schema_content); | ASSERT_EQ(schema_json["schema"], schema_content); | ||||
| ASSERT_EQ(schema->get_schema_id(), -1); | |||||
| schema->set_schema_id(2); | |||||
| ASSERT_EQ(schema->get_schema_id(), 2); | |||||
| ASSERT_EQ(schema->GetSchemaID(), -1); | |||||
| schema->SetSchemaID(2); | |||||
| ASSERT_EQ(schema->GetSchemaID(), 2); | |||||
| } | } | ||||
| TEST_F(TestStatistics, StatisticPart) { | TEST_F(TestStatistics, StatisticPart) { | ||||
| @@ -137,8 +137,8 @@ TEST_F(TestStatistics, StatisticPart) { | |||||
| ASSERT_NE(statistics, nullptr); | ASSERT_NE(statistics, nullptr); | ||||
| MS_LOG(INFO) << "test get_desc(), result: " << statistics->get_desc(); | |||||
| MS_LOG(INFO) << "test get_statistics, result: " << statistics->get_statistics().dump(); | |||||
| MS_LOG(INFO) << "test GetDesc(), result: " << statistics->GetDesc(); | |||||
| MS_LOG(INFO) << "test GetStatistics, result: " << statistics->GetStatistics().dump(); | |||||
| statistic_json["test"] = "test"; | statistic_json["test"] = "test"; | ||||
| statistics = Statistics::Build(desc, statistic_json); | statistics = Statistics::Build(desc, statistic_json); | ||||
| @@ -194,8 +194,8 @@ TEST_F(TestShardWriter, TestShardWriterShiftRawPage) { | |||||
| fw.Open(file_names); | fw.Open(file_names); | ||||
| uint64_t header_size = 1 << 14; | uint64_t header_size = 1 << 14; | ||||
| uint64_t page_size = 1 << 15; | uint64_t page_size = 1 << 15; | ||||
| fw.set_header_size(header_size); | |||||
| fw.set_page_size(page_size); | |||||
| fw.SetHeaderSize(header_size); | |||||
| fw.SetPageSize(page_size); | |||||
| // set shardHeader | // set shardHeader | ||||
| fw.SetShardHeader(std::make_shared<mindrecord::ShardHeader>(header_data)); | fw.SetShardHeader(std::make_shared<mindrecord::ShardHeader>(header_data)); | ||||
| @@ -331,8 +331,8 @@ TEST_F(TestShardWriter, TestShardWriterTrial) { | |||||
| fw.Open(file_names); | fw.Open(file_names); | ||||
| uint64_t header_size = 1 << 14; | uint64_t header_size = 1 << 14; | ||||
| uint64_t page_size = 1 << 17; | uint64_t page_size = 1 << 17; | ||||
| fw.set_header_size(header_size); | |||||
| fw.set_page_size(page_size); | |||||
| fw.SetHeaderSize(header_size); | |||||
| fw.SetPageSize(page_size); | |||||
| // set shardHeader | // set shardHeader | ||||
| fw.SetShardHeader(std::make_shared<mindrecord::ShardHeader>(header_data)); | fw.SetShardHeader(std::make_shared<mindrecord::ShardHeader>(header_data)); | ||||
| @@ -466,8 +466,8 @@ TEST_F(TestShardWriter, TestShardWriterTrialNoFields) { | |||||
| fw.Open(file_names); | fw.Open(file_names); | ||||
| uint64_t header_size = 1 << 14; | uint64_t header_size = 1 << 14; | ||||
| uint64_t page_size = 1 << 17; | uint64_t page_size = 1 << 17; | ||||
| fw.set_header_size(header_size); | |||||
| fw.set_page_size(page_size); | |||||
| fw.SetHeaderSize(header_size); | |||||
| fw.SetPageSize(page_size); | |||||
| // set shardHeader | // set shardHeader | ||||
| fw.SetShardHeader(std::make_shared<mindrecord::ShardHeader>(header_data)); | fw.SetShardHeader(std::make_shared<mindrecord::ShardHeader>(header_data)); | ||||
| @@ -567,8 +567,8 @@ TEST_F(TestShardWriter, DataCheck) { | |||||
| fw.Open(file_names); | fw.Open(file_names); | ||||
| uint64_t header_size = 1 << 14; | uint64_t header_size = 1 << 14; | ||||
| uint64_t page_size = 1 << 17; | uint64_t page_size = 1 << 17; | ||||
| fw.set_header_size(header_size); | |||||
| fw.set_page_size(page_size); | |||||
| fw.SetHeaderSize(header_size); | |||||
| fw.SetPageSize(page_size); | |||||
| // set shardHeader | // set shardHeader | ||||
| fw.SetShardHeader(std::make_shared<mindrecord::ShardHeader>(header_data)); | fw.SetShardHeader(std::make_shared<mindrecord::ShardHeader>(header_data)); | ||||
| @@ -668,8 +668,8 @@ TEST_F(TestShardWriter, AllRawDataWrong) { | |||||
| fw.Open(file_names); | fw.Open(file_names); | ||||
| uint64_t header_size = 1 << 14; | uint64_t header_size = 1 << 14; | ||||
| uint64_t page_size = 1 << 17; | uint64_t page_size = 1 << 17; | ||||
| fw.set_header_size(header_size); | |||||
| fw.set_page_size(page_size); | |||||
| fw.SetHeaderSize(header_size); | |||||
| fw.SetPageSize(page_size); | |||||
| // set shardHeader | // set shardHeader | ||||
| fw.SetShardHeader(std::make_shared<mindrecord::ShardHeader>(header_data)); | fw.SetShardHeader(std::make_shared<mindrecord::ShardHeader>(header_data)); | ||||