From cf352d190f686bea16421e0a0103cef34acf49ba Mon Sep 17 00:00:00 2001 From: jonyguo Date: Fri, 8 May 2020 14:40:01 +0800 Subject: [PATCH] format func name for mindrecord --- .../engine/datasetops/source/mindrecord_op.cc | 8 +- .../ccsrc/mindrecord/common/shard_pybind.cc | 28 +++---- .../ccsrc/mindrecord/include/shard_category.h | 4 +- .../ccsrc/mindrecord/include/shard_header.h | 30 +++---- .../ccsrc/mindrecord/include/shard_index.h | 2 +- .../ccsrc/mindrecord/include/shard_operator.h | 12 +-- .../ccsrc/mindrecord/include/shard_page.h | 24 +++--- .../mindrecord/include/shard_pk_sample.h | 2 +- .../ccsrc/mindrecord/include/shard_reader.h | 12 +-- .../ccsrc/mindrecord/include/shard_sample.h | 6 +- .../ccsrc/mindrecord/include/shard_schema.h | 8 +- .../ccsrc/mindrecord/include/shard_segment.h | 2 +- .../ccsrc/mindrecord/include/shard_shuffle.h | 2 +- .../mindrecord/include/shard_statistics.h | 8 +- .../ccsrc/mindrecord/include/shard_task.h | 4 +- .../ccsrc/mindrecord/include/shard_writer.h | 4 +- .../mindrecord/io/shard_index_generator.cc | 44 +++++----- mindspore/ccsrc/mindrecord/io/shard_reader.cc | 80 +++++++++---------- .../ccsrc/mindrecord/io/shard_segment.cc | 16 ++-- mindspore/ccsrc/mindrecord/io/shard_writer.cc | 74 ++++++++--------- .../ccsrc/mindrecord/meta/shard_category.cc | 2 +- .../ccsrc/mindrecord/meta/shard_header.cc | 50 ++++++------ .../ccsrc/mindrecord/meta/shard_index.cc | 2 +- .../ccsrc/mindrecord/meta/shard_pk_sample.cc | 2 +- .../ccsrc/mindrecord/meta/shard_sample.cc | 12 +-- .../ccsrc/mindrecord/meta/shard_schema.cc | 10 +-- .../ccsrc/mindrecord/meta/shard_shuffle.cc | 2 +- .../ccsrc/mindrecord/meta/shard_statistics.cc | 12 +-- mindspore/ccsrc/mindrecord/meta/shard_task.cc | 8 +- tests/ut/cpp/mindrecord/ut_shard.cc | 12 +-- .../ut/cpp/mindrecord/ut_shard_header_test.cc | 22 ++--- .../cpp/mindrecord/ut_shard_operator_test.cc | 2 +- tests/ut/cpp/mindrecord/ut_shard_page_test.cc | 78 +++++++++--------- .../ut/cpp/mindrecord/ut_shard_schema_test.cc | 12 +-- .../ut/cpp/mindrecord/ut_shard_writer_test.cc | 20 ++--- 35 files changed, 308 insertions(+), 308 deletions(-) diff --git a/mindspore/ccsrc/dataset/engine/datasetops/source/mindrecord_op.cc b/mindspore/ccsrc/dataset/engine/datasetops/source/mindrecord_op.cc index 9458ca6307..1f34a1d373 100644 --- a/mindspore/ccsrc/dataset/engine/datasetops/source/mindrecord_op.cc +++ b/mindspore/ccsrc/dataset/engine/datasetops/source/mindrecord_op.cc @@ -108,7 +108,7 @@ Status MindRecordOp::Init() { data_schema_ = std::make_unique(); - std::vector> schema_vec = shard_reader_->get_shard_header()->get_schemas(); + std::vector> schema_vec = shard_reader_->GetShardHeader()->GetSchemas(); // check whether schema exists, if so use the first one CHECK_FAIL_RETURN_UNEXPECTED(!schema_vec.empty(), "No schema found"); mindrecord::json mr_schema = schema_vec[0]->GetSchema()["schema"]; @@ -155,7 +155,7 @@ Status MindRecordOp::Init() { column_name_mapping_[columns_to_load_[i]] = i; } - num_rows_ = shard_reader_->get_num_rows(); + num_rows_ = shard_reader_->GetNumRows(); // Compute how many buffers we would need to accomplish rowsPerBuffer buffers_needed_ = (num_rows_ + rows_per_buffer_ - 1) / rows_per_buffer_; RETURN_IF_NOT_OK(SetColumnsBlob()); @@ -164,7 +164,7 @@ Status MindRecordOp::Init() { } Status MindRecordOp::SetColumnsBlob() { - columns_blob_ = shard_reader_->get_blob_fields().second; + columns_blob_ = shard_reader_->GetBlobFields().second; // get the exactly blob fields by columns_to_load_ std::vector columns_blob_exact; @@ -600,7 +600,7 @@ Status MindRecordOp::FetchBlockBuffer(const int32_t &buffer_id) { // Main logic, Register Queue with TaskGroup, launch all threads and do the functor's work Status MindRecordOp::operator()() { RETURN_IF_NOT_OK(LaunchThreadAndInitOp()); - num_rows_ = shard_reader_->get_num_rows(); + num_rows_ = shard_reader_->GetNumRows(); buffers_needed_ = num_rows_ / rows_per_buffer_; if (num_rows_ % rows_per_buffer_ != 0) { diff --git a/mindspore/ccsrc/mindrecord/common/shard_pybind.cc b/mindspore/ccsrc/mindrecord/common/shard_pybind.cc index 8718e9b871..0eb9ac14b2 100644 --- a/mindspore/ccsrc/mindrecord/common/shard_pybind.cc +++ b/mindspore/ccsrc/mindrecord/common/shard_pybind.cc @@ -39,18 +39,18 @@ namespace mindrecord { void BindSchema(py::module *m) { (void)py::class_>(*m, "Schema", py::module_local()) .def_static("build", (std::shared_ptr(*)(std::string, py::handle)) & Schema::Build) - .def("get_desc", &Schema::get_desc) + .def("get_desc", &Schema::GetDesc) .def("get_schema_content", (py::object(Schema::*)()) & Schema::GetSchemaForPython) - .def("get_blob_fields", &Schema::get_blob_fields) - .def("get_schema_id", &Schema::get_schema_id); + .def("get_blob_fields", &Schema::GetBlobFields) + .def("get_schema_id", &Schema::GetSchemaID); } void BindStatistics(const py::module *m) { (void)py::class_>(*m, "Statistics", py::module_local()) .def_static("build", (std::shared_ptr(*)(std::string, py::handle)) & Statistics::Build) - .def("get_desc", &Statistics::get_desc) + .def("get_desc", &Statistics::GetDesc) .def("get_statistics", (py::object(Statistics::*)()) & Statistics::GetStatisticsForPython) - .def("get_statistics_id", &Statistics::get_statistics_id); + .def("get_statistics_id", &Statistics::GetStatisticsID); } void BindShardHeader(const py::module *m) { @@ -60,9 +60,9 @@ void BindShardHeader(const py::module *m) { .def("add_statistics", &ShardHeader::AddStatistic) .def("add_index_fields", (MSRStatus(ShardHeader::*)(const std::vector &)) & ShardHeader::AddIndexFields) - .def("get_meta", &ShardHeader::get_schemas) - .def("get_statistics", &ShardHeader::get_statistics) - .def("get_fields", &ShardHeader::get_fields) + .def("get_meta", &ShardHeader::GetSchemas) + .def("get_statistics", &ShardHeader::GetStatistics) + .def("get_fields", &ShardHeader::GetFields) .def("get_schema_by_id", &ShardHeader::GetSchemaByID) .def("get_statistic_by_id", &ShardHeader::GetStatisticByID); } @@ -72,8 +72,8 @@ void BindShardWriter(py::module *m) { .def(py::init<>()) .def("open", &ShardWriter::Open) .def("open_for_append", &ShardWriter::OpenForAppend) - .def("set_header_size", &ShardWriter::set_header_size) - .def("set_page_size", &ShardWriter::set_page_size) + .def("set_header_size", &ShardWriter::SetHeaderSize) + .def("set_page_size", &ShardWriter::SetPageSize) .def("set_shard_header", &ShardWriter::SetShardHeader) .def("write_raw_data", (MSRStatus(ShardWriter::*)(std::map> &, vector> &, bool, bool)) & @@ -88,8 +88,8 @@ void BindShardReader(const py::module *m) { const std::vector> &)) & ShardReader::OpenPy) .def("launch", &ShardReader::Launch) - .def("get_header", &ShardReader::get_shard_header) - .def("get_blob_fields", &ShardReader::get_blob_fields) + .def("get_header", &ShardReader::GetShardHeader) + .def("get_blob_fields", &ShardReader::GetBlobFields) .def("get_next", (std::vector, pybind11::object>>(ShardReader::*)()) & ShardReader::GetNextPy) .def("finish", &ShardReader::Finish) @@ -119,9 +119,9 @@ void BindShardSegment(py::module *m) { .def("read_at_page_by_name", (std::pair, pybind11::object>>>( ShardSegment::*)(std::string, int64_t, int64_t)) & ShardSegment::ReadAtPageByNamePy) - .def("get_header", &ShardSegment::get_shard_header) + .def("get_header", &ShardSegment::GetShardHeader) .def("get_blob_fields", - (std::pair>(ShardSegment::*)()) & ShardSegment::get_blob_fields); + (std::pair>(ShardSegment::*)()) & ShardSegment::GetBlobFields); } void BindGlobalParams(py::module *m) { diff --git a/mindspore/ccsrc/mindrecord/include/shard_category.h b/mindspore/ccsrc/mindrecord/include/shard_category.h index b2fe18fbac..618a91b1d8 100644 --- a/mindspore/ccsrc/mindrecord/include/shard_category.h +++ b/mindspore/ccsrc/mindrecord/include/shard_category.h @@ -36,7 +36,7 @@ class ShardCategory : public ShardOperator { ~ShardCategory() override{}; - const std::vector> &get_categories() const { return categories_; } + const std::vector> &GetCategories() const { return categories_; } const std::string GetCategoryField() const { return category_field_; } @@ -46,7 +46,7 @@ class ShardCategory : public ShardOperator { bool GetReplacement() const { return replacement_; } - MSRStatus execute(ShardTask &tasks) override; + MSRStatus Execute(ShardTask &tasks) override; int64_t GetNumSamples(int64_t dataset_size, int64_t num_classes) override; diff --git a/mindspore/ccsrc/mindrecord/include/shard_header.h b/mindspore/ccsrc/mindrecord/include/shard_header.h index 70cfcdb6b7..d2c2ef0a2d 100644 --- a/mindspore/ccsrc/mindrecord/include/shard_header.h +++ b/mindspore/ccsrc/mindrecord/include/shard_header.h @@ -58,19 +58,19 @@ class ShardHeader { /// \brief get the schema /// \return the schema - std::vector> get_schemas(); + std::vector> GetSchemas(); /// \brief get Statistics /// \return the Statistic - std::vector> get_statistics(); + std::vector> GetStatistics(); /// \brief get the fields of the index /// \return the fields of the index - std::vector> get_fields(); + std::vector> GetFields(); /// \brief get the index /// \return the index - std::shared_ptr get_index(); + std::shared_ptr GetIndex(); /// \brief get the schema by schemaid /// \param[in] schemaId the id of schema needs to be got @@ -80,7 +80,7 @@ class ShardHeader { /// \brief get the filepath to shard by shardID /// \param[in] shardID the id of shard which filepath needs to be obtained /// \return the filepath obtained by shardID - std::string get_shard_address_by_id(int64_t shard_id); + std::string GetShardAddressByID(int64_t shard_id); /// \brief get the statistic by statistic id /// \param[in] statisticId the id of statistic needs to be get @@ -89,7 +89,7 @@ class ShardHeader { MSRStatus InitByFiles(const std::vector &file_paths); - void set_index(Index index) { index_ = std::make_shared(index); } + void SetIndex(Index index) { index_ = std::make_shared(index); } std::pair, MSRStatus> GetPage(const int &shard_id, const int &page_id); @@ -103,21 +103,21 @@ class ShardHeader { const std::pair> GetPageByGroupId(const int &group_id, const int &shard_id); - std::vector get_shard_addresses() const { return shard_addresses_; } + std::vector GetShardAddresses() const { return shard_addresses_; } - int get_shard_count() const { return shard_count_; } + int GetShardCount() const { return shard_count_; } - int get_schema_count() const { return schema_.size(); } + int GetSchemaCount() const { return schema_.size(); } - uint64_t get_header_size() const { return header_size_; } + uint64_t GetHeaderSize() const { return header_size_; } - uint64_t get_page_size() const { return page_size_; } + uint64_t GetPageSize() const { return page_size_; } - void set_header_size(const uint64_t &header_size) { header_size_ = header_size; } + void SetHeaderSize(const uint64_t &header_size) { header_size_ = header_size; } - void set_page_size(const uint64_t &page_size) { page_size_ = page_size; } + void SetPageSize(const uint64_t &page_size) { page_size_ = page_size; } - const string get_version() { return version_; } + const string GetVersion() { return version_; } std::vector SerializeHeader(); @@ -132,7 +132,7 @@ class ShardHeader { /// \param[in] the shard data real path /// \param[in] the headers which readed from the shard data /// \return SUCCESS/FAILED - MSRStatus get_headers(const vector &real_addresses, std::vector &headers); + MSRStatus GetHeaders(const vector &real_addresses, std::vector &headers); MSRStatus ValidateField(const std::vector &field_name, json schema, const uint64_t &schema_id); diff --git a/mindspore/ccsrc/mindrecord/include/shard_index.h b/mindspore/ccsrc/mindrecord/include/shard_index.h index 6d4bc36457..d430c5bdcf 100644 --- a/mindspore/ccsrc/mindrecord/include/shard_index.h +++ b/mindspore/ccsrc/mindrecord/include/shard_index.h @@ -52,7 +52,7 @@ class Index { /// \brief get stored fields /// \return fields stored - std::vector > get_fields(); + std::vector > GetFields(); private: std::vector > fields_; diff --git a/mindspore/ccsrc/mindrecord/include/shard_operator.h b/mindspore/ccsrc/mindrecord/include/shard_operator.h index 7476660a70..59c77074a1 100644 --- a/mindspore/ccsrc/mindrecord/include/shard_operator.h +++ b/mindspore/ccsrc/mindrecord/include/shard_operator.h @@ -26,23 +26,23 @@ class ShardOperator { virtual ~ShardOperator() = default; MSRStatus operator()(ShardTask &tasks) { - if (SUCCESS != this->pre_execute(tasks)) { + if (SUCCESS != this->PreExecute(tasks)) { return FAILED; } - if (SUCCESS != this->execute(tasks)) { + if (SUCCESS != this->Execute(tasks)) { return FAILED; } - if (SUCCESS != this->suf_execute(tasks)) { + if (SUCCESS != this->SufExecute(tasks)) { return FAILED; } return SUCCESS; } - virtual MSRStatus pre_execute(ShardTask &tasks) { return SUCCESS; } + virtual MSRStatus PreExecute(ShardTask &tasks) { return SUCCESS; } - virtual MSRStatus execute(ShardTask &tasks) = 0; + virtual MSRStatus Execute(ShardTask &tasks) = 0; - virtual MSRStatus suf_execute(ShardTask &tasks) { return SUCCESS; } + virtual MSRStatus SufExecute(ShardTask &tasks) { return SUCCESS; } virtual int64_t GetNumSamples(int64_t dataset_size, int64_t num_classes) { return -1; } }; diff --git a/mindspore/ccsrc/mindrecord/include/shard_page.h b/mindspore/ccsrc/mindrecord/include/shard_page.h index 8b7a5244bd..c22acd8d2c 100644 --- a/mindspore/ccsrc/mindrecord/include/shard_page.h +++ b/mindspore/ccsrc/mindrecord/include/shard_page.h @@ -53,29 +53,29 @@ class Page { /// \return the json format of the page and its description json GetPage() const; - int get_page_id() const { return page_id_; } + int GetPageID() const { return page_id_; } - int get_shard_id() const { return shard_id_; } + int GetShardID() const { return shard_id_; } - int get_page_type_id() const { return page_type_id_; } + int GetPageTypeID() const { return page_type_id_; } - std::string get_page_type() const { return page_type_; } + std::string GetPageType() const { return page_type_; } - uint64_t get_page_size() const { return page_size_; } + uint64_t GetPageSize() const { return page_size_; } - uint64_t get_start_row_id() const { return start_row_id_; } + uint64_t GetStartRowID() const { return start_row_id_; } - uint64_t get_end_row_id() const { return end_row_id_; } + uint64_t GetEndRowID() const { return end_row_id_; } - void set_end_row_id(const uint64_t &end_row_id) { end_row_id_ = end_row_id; } + void SetEndRowID(const uint64_t &end_row_id) { end_row_id_ = end_row_id; } - void set_page_size(const uint64_t &page_size) { page_size_ = page_size; } + void SetPageSize(const uint64_t &page_size) { page_size_ = page_size; } - std::pair get_last_row_group_id() const { return row_group_ids_.back(); } + std::pair GetLastRowGroupID() const { return row_group_ids_.back(); } - std::vector> get_row_group_ids() const { return row_group_ids_; } + std::vector> GetRowGroupIds() const { return row_group_ids_; } - void set_row_group_ids(const std::vector> &last_row_group_ids) { + void SetRowGroupIds(const std::vector> &last_row_group_ids) { row_group_ids_ = last_row_group_ids; } diff --git a/mindspore/ccsrc/mindrecord/include/shard_pk_sample.h b/mindspore/ccsrc/mindrecord/include/shard_pk_sample.h index df3888dad4..4f1a1c307a 100644 --- a/mindspore/ccsrc/mindrecord/include/shard_pk_sample.h +++ b/mindspore/ccsrc/mindrecord/include/shard_pk_sample.h @@ -37,7 +37,7 @@ class ShardPkSample : public ShardCategory { ~ShardPkSample() override{}; - MSRStatus suf_execute(ShardTask &tasks) override; + MSRStatus SufExecute(ShardTask &tasks) override; private: bool shuffle_; diff --git a/mindspore/ccsrc/mindrecord/include/shard_reader.h b/mindspore/ccsrc/mindrecord/include/shard_reader.h index 6b90275cfc..840f5a1b48 100644 --- a/mindspore/ccsrc/mindrecord/include/shard_reader.h +++ b/mindspore/ccsrc/mindrecord/include/shard_reader.h @@ -107,11 +107,11 @@ class ShardReader { /// \brief aim to get the meta data /// \return the metadata - std::shared_ptr get_shard_header() const; + std::shared_ptr GetShardHeader() const; /// \brief get the number of shards /// \return # of shards - int get_shard_count() const; + int GetShardCount() const; /// \brief get the number of rows in database /// \param[in] file_path the path of ONE file, any file in dataset is fine @@ -126,7 +126,7 @@ class ShardReader { /// \brief get the number of rows in database /// \return # of rows - int get_num_rows() const; + int GetNumRows() const; /// \brief Read the summary of row groups /// \return the tuple of 4 elements @@ -185,7 +185,7 @@ class ShardReader { /// \brief get blob filed list /// \return blob field list - std::pair> get_blob_fields(); + std::pair> GetBlobFields(); /// \brief reset reader /// \return null @@ -193,10 +193,10 @@ class ShardReader { /// \brief set flag of all-in-index /// \return null - void set_all_in_index(bool all_in_index) { all_in_index_ = all_in_index; } + void SetAllInIndex(bool all_in_index) { all_in_index_ = all_in_index; } /// \brief get NLP flag - bool get_nlp_flag(); + bool GetNlpFlag(); /// \brief get all classes MSRStatus GetAllClasses(const std::string &category_field, std::set &categories); diff --git a/mindspore/ccsrc/mindrecord/include/shard_sample.h b/mindspore/ccsrc/mindrecord/include/shard_sample.h index b16fc5cc4f..7905f328f9 100644 --- a/mindspore/ccsrc/mindrecord/include/shard_sample.h +++ b/mindspore/ccsrc/mindrecord/include/shard_sample.h @@ -38,11 +38,11 @@ class ShardSample : public ShardOperator { ~ShardSample() override{}; - const std::pair get_partitions() const; + const std::pair GetPartitions() const; - MSRStatus execute(ShardTask &tasks) override; + MSRStatus Execute(ShardTask &tasks) override; - MSRStatus suf_execute(ShardTask &tasks) override; + MSRStatus SufExecute(ShardTask &tasks) override; int64_t GetNumSamples(int64_t dataset_size, int64_t num_classes) override; diff --git a/mindspore/ccsrc/mindrecord/include/shard_schema.h b/mindspore/ccsrc/mindrecord/include/shard_schema.h index ee0222ec8e..4ef134bde2 100644 --- a/mindspore/ccsrc/mindrecord/include/shard_schema.h +++ b/mindspore/ccsrc/mindrecord/include/shard_schema.h @@ -51,7 +51,7 @@ class Schema { /// \brief get the schema and its description /// \return the json format of the schema and its description - std::string get_desc() const; + std::string GetDesc() const; /// \brief get the schema and its description /// \return the json format of the schema and its description @@ -63,15 +63,15 @@ class Schema { /// set the schema id /// \param[in] id the id need to be set - void set_schema_id(int64_t id); + void SetSchemaID(int64_t id); /// get the schema id /// \return the int64 schema id - int64_t get_schema_id() const; + int64_t GetSchemaID() const; /// get the blob fields /// \return the vector blob fields - std::vector get_blob_fields() const; + std::vector GetBlobFields() const; private: Schema() = default; diff --git a/mindspore/ccsrc/mindrecord/include/shard_segment.h b/mindspore/ccsrc/mindrecord/include/shard_segment.h index c13ea5dccf..9ffb7aee88 100644 --- a/mindspore/ccsrc/mindrecord/include/shard_segment.h +++ b/mindspore/ccsrc/mindrecord/include/shard_segment.h @@ -81,7 +81,7 @@ class ShardSegment : public ShardReader { std::pair, pybind11::object>>> ReadAtPageByNamePy( std::string category_name, int64_t page_no, int64_t n_rows_of_page); - std::pair> get_blob_fields(); + std::pair> GetBlobFields(); private: std::pair>> WrapCategoryInfo(); diff --git a/mindspore/ccsrc/mindrecord/include/shard_shuffle.h b/mindspore/ccsrc/mindrecord/include/shard_shuffle.h index 027a5ad527..a9c54e6239 100644 --- a/mindspore/ccsrc/mindrecord/include/shard_shuffle.h +++ b/mindspore/ccsrc/mindrecord/include/shard_shuffle.h @@ -28,7 +28,7 @@ class ShardShuffle : public ShardOperator { ~ShardShuffle() override{}; - MSRStatus execute(ShardTask &tasks) override; + MSRStatus Execute(ShardTask &tasks) override; private: uint32_t shuffle_seed_; diff --git a/mindspore/ccsrc/mindrecord/include/shard_statistics.h b/mindspore/ccsrc/mindrecord/include/shard_statistics.h index 44956332e1..7fc2f968cd 100644 --- a/mindspore/ccsrc/mindrecord/include/shard_statistics.h +++ b/mindspore/ccsrc/mindrecord/include/shard_statistics.h @@ -53,11 +53,11 @@ class Statistics { /// \brief get the description /// \return the description - std::string get_desc() const; + std::string GetDesc() const; /// \brief get the statistic /// \return json format of the statistic - json get_statistics() const; + json GetStatistics() const; /// \brief get the statistic for python /// \return the python object of statistics @@ -66,11 +66,11 @@ class Statistics { /// \brief decode the bson statistics to json /// \param[in] encodedStatistics the bson type of statistics /// \return json type of statistic - void set_statistics_id(int64_t id); + void SetStatisticsID(int64_t id); /// \brief get the statistics id /// \return the int64 statistics id - int64_t get_statistics_id() const; + int64_t GetStatisticsID() const; private: /// \brief validate the statistic diff --git a/mindspore/ccsrc/mindrecord/include/shard_task.h b/mindspore/ccsrc/mindrecord/include/shard_task.h index b276b5150f..d48c25c9cd 100644 --- a/mindspore/ccsrc/mindrecord/include/shard_task.h +++ b/mindspore/ccsrc/mindrecord/include/shard_task.h @@ -39,9 +39,9 @@ class ShardTask { uint32_t SizeOfRows() const; - std::tuple, std::vector, json> &get_task_by_id(size_t id); + std::tuple, std::vector, json> &GetTaskByID(size_t id); - std::tuple, std::vector, json> &get_random_task(); + std::tuple, std::vector, json> &GetRandomTask(); static ShardTask Combine(std::vector &category_tasks, bool replacement, int64_t num_elements); diff --git a/mindspore/ccsrc/mindrecord/include/shard_writer.h b/mindspore/ccsrc/mindrecord/include/shard_writer.h index 78a434fc97..4679814287 100644 --- a/mindspore/ccsrc/mindrecord/include/shard_writer.h +++ b/mindspore/ccsrc/mindrecord/include/shard_writer.h @@ -69,12 +69,12 @@ class ShardWriter { /// \brief Set file size /// \param[in] header_size the size of header, only (1< ShardIndexGenerator::GetValueByField(const str } // schema does not contain the field - auto schema = shard_header_.get_schemas()[0]->GetSchema()["schema"]; + auto schema = shard_header_.GetSchemas()[0]->GetSchema()["schema"]; if (schema.find(field) == schema.end()) { MS_LOG(ERROR) << "The field " << field << " is not found in schema " << schema; return {FAILED, ""}; @@ -203,7 +203,7 @@ MSRStatus ShardIndexGenerator::CreateShardNameTable(sqlite3 *db, const std::stri } std::pair ShardIndexGenerator::CreateDatabase(int shard_no) { - std::string shard_address = shard_header_.get_shard_address_by_id(shard_no); + std::string shard_address = shard_header_.GetShardAddressByID(shard_no); if (shard_address.empty()) { MS_LOG(ERROR) << "Shard address is null, shard no: " << shard_no; return {FAILED, nullptr}; @@ -357,12 +357,12 @@ MSRStatus ShardIndexGenerator::BindParameterExecuteSQL( MSRStatus ShardIndexGenerator::AddBlobPageInfo(std::vector> &row_data, const std::shared_ptr cur_blob_page, uint64_t &cur_blob_page_offset, std::fstream &in) { - row_data.emplace_back(":PAGE_ID_BLOB", "INTEGER", std::to_string(cur_blob_page->get_page_id())); + row_data.emplace_back(":PAGE_ID_BLOB", "INTEGER", std::to_string(cur_blob_page->GetPageID())); // blob data start row_data.emplace_back(":PAGE_OFFSET_BLOB", "INTEGER", std::to_string(cur_blob_page_offset)); auto &io_seekg_blob = - in.seekg(page_size_ * cur_blob_page->get_page_id() + header_size_ + cur_blob_page_offset, std::ios::beg); + in.seekg(page_size_ * cur_blob_page->GetPageID() + header_size_ + cur_blob_page_offset, std::ios::beg); if (!io_seekg_blob.good() || io_seekg_blob.fail() || io_seekg_blob.bad()) { MS_LOG(ERROR) << "File seekg failed"; in.close(); @@ -405,7 +405,7 @@ ROW_DATA ShardIndexGenerator::GenerateRowData(int shard_no, const std::map cur_raw_page = shard_header_.GetPage(shard_no, raw_page_id).first; // related blob page - vector> row_group_list = cur_raw_page->get_row_group_ids(); + vector> row_group_list = cur_raw_page->GetRowGroupIds(); // pair: row_group id, offset in raw data page for (pair blob_ids : row_group_list) { @@ -415,18 +415,18 @@ ROW_DATA ShardIndexGenerator::GenerateRowData(int shard_no, const std::map(blob_ids.second); uint64_t cur_blob_page_offset = 0; - for (unsigned int i = cur_blob_page->get_start_row_id(); i < cur_blob_page->get_end_row_id(); ++i) { + for (unsigned int i = cur_blob_page->GetStartRowID(); i < cur_blob_page->GetEndRowID(); ++i) { std::vector> row_data; row_data.emplace_back(":ROW_ID", "INTEGER", std::to_string(i)); - row_data.emplace_back(":ROW_GROUP_ID", "INTEGER", std::to_string(cur_blob_page->get_page_type_id())); - row_data.emplace_back(":PAGE_ID_RAW", "INTEGER", std::to_string(cur_raw_page->get_page_id())); + row_data.emplace_back(":ROW_GROUP_ID", "INTEGER", std::to_string(cur_blob_page->GetPageTypeID())); + row_data.emplace_back(":PAGE_ID_RAW", "INTEGER", std::to_string(cur_raw_page->GetPageID())); // raw data start row_data.emplace_back(":PAGE_OFFSET_RAW", "INTEGER", std::to_string(cur_raw_page_offset)); // calculate raw data end auto &io_seekg = - in.seekg(page_size_ * (cur_raw_page->get_page_id()) + header_size_ + cur_raw_page_offset, std::ios::beg); + in.seekg(page_size_ * (cur_raw_page->GetPageID()) + header_size_ + cur_raw_page_offset, std::ios::beg); if (!io_seekg.good() || io_seekg.fail() || io_seekg.bad()) { MS_LOG(ERROR) << "File seekg failed"; in.close(); @@ -473,7 +473,7 @@ ROW_DATA ShardIndexGenerator::GenerateRowData(int shard_no, const std::map &schema_detail) { std::vector> fields; // index fields - std::vector> index_fields = shard_header_.get_fields(); + std::vector> index_fields = shard_header_.GetFields(); for (const auto &field : index_fields) { if (field.first >= schema_detail.size()) { return {FAILED, {}}; @@ -504,7 +504,7 @@ MSRStatus ShardIndexGenerator::ExecuteTransaction(const int &shard_no, const std const std::vector &raw_page_ids, const std::map &blob_id_to_page_id) { // Add index data to database - std::string shard_address = shard_header_.get_shard_address_by_id(shard_no); + std::string shard_address = shard_header_.GetShardAddressByID(shard_no); if (shard_address.empty()) { MS_LOG(ERROR) << "Shard address is null"; return FAILED; @@ -546,12 +546,12 @@ MSRStatus ShardIndexGenerator::ExecuteTransaction(const int &shard_no, const std } MSRStatus ShardIndexGenerator::WriteToDatabase() { - fields_ = shard_header_.get_fields(); - page_size_ = shard_header_.get_page_size(); - header_size_ = shard_header_.get_header_size(); - schema_count_ = shard_header_.get_schema_count(); - if (shard_header_.get_shard_count() > kMaxShardCount) { - MS_LOG(ERROR) << "num shards: " << shard_header_.get_shard_count() << " exceeds max count:" << kMaxSchemaCount; + fields_ = shard_header_.GetFields(); + page_size_ = shard_header_.GetPageSize(); + header_size_ = shard_header_.GetHeaderSize(); + schema_count_ = shard_header_.GetSchemaCount(); + if (shard_header_.GetShardCount() > kMaxShardCount) { + MS_LOG(ERROR) << "num shards: " << shard_header_.GetShardCount() << " exceeds max count:" << kMaxSchemaCount; return FAILED; } task_ = 0; // set two atomic vars to initial value @@ -559,7 +559,7 @@ MSRStatus ShardIndexGenerator::WriteToDatabase() { // spawn half the physical threads or total number of shards whichever is smaller const unsigned int num_workers = - std::min(std::thread::hardware_concurrency() / 2 + 1, static_cast(shard_header_.get_shard_count())); + std::min(std::thread::hardware_concurrency() / 2 + 1, static_cast(shard_header_.GetShardCount())); std::vector threads; threads.reserve(num_workers); @@ -576,7 +576,7 @@ MSRStatus ShardIndexGenerator::WriteToDatabase() { void ShardIndexGenerator::DatabaseWriter() { int shard_no = task_++; - while (shard_no < shard_header_.get_shard_count()) { + while (shard_no < shard_header_.GetShardCount()) { auto db = CreateDatabase(shard_no); if (db.first != SUCCESS || db.second == nullptr || write_success_ == false) { write_success_ = false; @@ -592,10 +592,10 @@ void ShardIndexGenerator::DatabaseWriter() { std::vector raw_page_ids; for (uint64_t i = 0; i < total_pages; ++i) { std::shared_ptr cur_page = shard_header_.GetPage(shard_no, i).first; - if (cur_page->get_page_type() == "RAW_DATA") { + if (cur_page->GetPageType() == "RAW_DATA") { raw_page_ids.push_back(i); - } else if (cur_page->get_page_type() == "BLOB_DATA") { - blob_id_to_page_id[cur_page->get_page_type_id()] = i; + } else if (cur_page->GetPageType() == "BLOB_DATA") { + blob_id_to_page_id[cur_page->GetPageTypeID()] = i; } } diff --git a/mindspore/ccsrc/mindrecord/io/shard_reader.cc b/mindspore/ccsrc/mindrecord/io/shard_reader.cc index 69e14510e8..bd0394ac42 100644 --- a/mindspore/ccsrc/mindrecord/io/shard_reader.cc +++ b/mindspore/ccsrc/mindrecord/io/shard_reader.cc @@ -56,9 +56,9 @@ MSRStatus ShardReader::Init(const std::string &file_path) { return FAILED; } shard_header_ = std::make_shared(sh); - header_size_ = shard_header_->get_header_size(); - page_size_ = shard_header_->get_page_size(); - file_paths_ = shard_header_->get_shard_addresses(); + header_size_ = shard_header_->GetHeaderSize(); + page_size_ = shard_header_->GetPageSize(); + file_paths_ = shard_header_->GetShardAddresses(); for (const auto &file : file_paths_) { sqlite3 *db = nullptr; @@ -105,7 +105,7 @@ MSRStatus ShardReader::Init(const std::string &file_path) { MSRStatus ShardReader::CheckColumnList(const std::vector &selected_columns) { vector inSchema(selected_columns.size(), 0); - for (auto &p : get_shard_header()->get_schemas()) { + for (auto &p : GetShardHeader()->GetSchemas()) { auto schema = p->GetSchema()["schema"]; for (unsigned int i = 0; i < selected_columns.size(); ++i) { if (schema.find(selected_columns[i]) != schema.end()) { @@ -183,15 +183,15 @@ void ShardReader::Close() { FileStreamsOperator(); } -std::shared_ptr ShardReader::get_shard_header() const { return shard_header_; } +std::shared_ptr ShardReader::GetShardHeader() const { return shard_header_; } -int ShardReader::get_shard_count() const { return shard_header_->get_shard_count(); } +int ShardReader::GetShardCount() const { return shard_header_->GetShardCount(); } -int ShardReader::get_num_rows() const { return num_rows_; } +int ShardReader::GetNumRows() const { return num_rows_; } std::vector> ShardReader::ReadRowGroupSummary() { std::vector> row_group_summary; - int shard_count = shard_header_->get_shard_count(); + int shard_count = shard_header_->GetShardCount(); if (shard_count <= 0) { return row_group_summary; } @@ -205,13 +205,13 @@ std::vector> ShardReader::ReadRowGroupSummar for (uint64_t page_id = 0; page_id <= last_page_id; ++page_id) { const auto &page_t = shard_header_->GetPage(shard_id, page_id); const auto &page = page_t.first; - if (page->get_page_type() != kPageTypeBlob) continue; - uint64_t start_row_id = page->get_start_row_id(); - if (start_row_id > page->get_end_row_id()) { + if (page->GetPageType() != kPageTypeBlob) continue; + uint64_t start_row_id = page->GetStartRowID(); + if (start_row_id > page->GetEndRowID()) { return std::vector>(); } - uint64_t number_of_rows = page->get_end_row_id() - start_row_id; - row_group_summary.emplace_back(shard_id, page->get_page_type_id(), start_row_id, number_of_rows); + uint64_t number_of_rows = page->GetEndRowID() - start_row_id; + row_group_summary.emplace_back(shard_id, page->GetPageTypeID(), start_row_id, number_of_rows); } } } @@ -265,7 +265,7 @@ MSRStatus ShardReader::ConvertLabelToJson(const std::vectorget_schemas()[0]->GetSchema()["schema"]; + auto schema = shard_header_->GetSchemas()[0]->GetSchema()["schema"]; // convert the string to base type by schema if (schema[columns[j]]["type"] == "int32") { @@ -317,7 +317,7 @@ MSRStatus ShardReader::ReadAllRowsInShard(int shard_id, const std::string &sql, MSRStatus ShardReader::GetAllClasses(const std::string &category_field, std::set &categories) { std::map index_columns; - for (auto &field : get_shard_header()->get_fields()) { + for (auto &field : GetShardHeader()->GetFields()) { index_columns[field.second] = field.first; } if (index_columns.find(category_field) == index_columns.end()) { @@ -400,11 +400,11 @@ ROW_GROUP_BRIEF ShardReader::ReadRowGroupBrief(int group_id, int shard_id, const } const std::shared_ptr &page = ret.second; std::string file_name = file_paths_[shard_id]; - uint64_t page_length = page->get_page_size(); - uint64_t page_offset = page_size_ * page->get_page_id() + header_size_; - std::vector> image_offset = GetImageOffset(page->get_page_id(), shard_id); + uint64_t page_length = page->GetPageSize(); + uint64_t page_offset = page_size_ * page->GetPageID() + header_size_; + std::vector> image_offset = GetImageOffset(page->GetPageID(), shard_id); - auto status_labels = GetLabels(page->get_page_id(), shard_id, columns); + auto status_labels = GetLabels(page->GetPageID(), shard_id, columns); if (status_labels.first != SUCCESS) { return std::make_tuple(FAILED, "", 0, 0, std::vector>(), std::vector()); } @@ -426,11 +426,11 @@ ROW_GROUP_BRIEF ShardReader::ReadRowGroupCriteria(int group_id, int shard_id, } const std::shared_ptr &page = ret.second; std::string file_name = file_paths_[shard_id]; - uint64_t page_length = page->get_page_size(); - uint64_t page_offset = page_size_ * page->get_page_id() + header_size_; - std::vector> image_offset = GetImageOffset(page->get_page_id(), shard_id, criteria); + uint64_t page_length = page->GetPageSize(); + uint64_t page_offset = page_size_ * page->GetPageID() + header_size_; + std::vector> image_offset = GetImageOffset(page->GetPageID(), shard_id, criteria); - auto status_labels = GetLabels(page->get_page_id(), shard_id, columns, criteria); + auto status_labels = GetLabels(page->GetPageID(), shard_id, columns, criteria); if (status_labels.first != SUCCESS) { return std::make_tuple(FAILED, "", 0, 0, std::vector>(), std::vector()); } @@ -458,7 +458,7 @@ std::vector> ShardReader::GetImageOffset(int page_id, int // whether use index search if (!criteria.first.empty()) { - auto schema = shard_header_->get_schemas()[0]->GetSchema(); + auto schema = shard_header_->GetSchemas()[0]->GetSchema(); // not number field should add '' in sql if (kNumberFieldTypeSet.find(schema["schema"][criteria.first]["type"]) != kNumberFieldTypeSet.end()) { @@ -497,13 +497,13 @@ void ShardReader::CheckNlp() { return; } -bool ShardReader::get_nlp_flag() { return nlp_; } +bool ShardReader::GetNlpFlag() { return nlp_; } -std::pair> ShardReader::get_blob_fields() { +std::pair> ShardReader::GetBlobFields() { std::vector blob_fields; - for (auto &p : get_shard_header()->get_schemas()) { + for (auto &p : GetShardHeader()->GetSchemas()) { // assume one schema - const auto &fields = p->get_blob_fields(); + const auto &fields = p->GetBlobFields(); blob_fields.assign(fields.begin(), fields.end()); break; } @@ -516,7 +516,7 @@ void ShardReader::CheckIfColumnInIndex(const std::vector &columns) all_in_index_ = false; return; } - for (auto &field : get_shard_header()->get_fields()) { + for (auto &field : GetShardHeader()->GetFields()) { column_schema_id_[field.second] = field.first; } for (auto &col : columns) { @@ -671,7 +671,7 @@ std::pair> ShardReader::GetLabels(int page_id, int json construct_json; for (unsigned int j = 0; j < columns.size(); ++j) { // construct json "f1": value - auto schema = shard_header_->get_schemas()[0]->GetSchema()["schema"]; + auto schema = shard_header_->GetSchemas()[0]->GetSchema()["schema"]; // convert the string to base type by schema if (schema[columns[j]]["type"] == "int32") { @@ -719,9 +719,9 @@ int64_t ShardReader::GetNumClasses(const std::string &file_path, const std::stri return -1; } auto header = std::make_shared(sh); - auto file_paths = header->get_shard_addresses(); + auto file_paths = header->GetShardAddresses(); auto shard_count = file_paths.size(); - auto index_fields = header->get_fields(); + auto index_fields = header->GetFields(); std::map map_schema_id_fields; for (auto &field : index_fields) { @@ -799,7 +799,7 @@ MSRStatus ShardReader::Open(const std::string &file_path, int n_consumer, if (nlp_) { selected_columns_ = selected_columns; } else { - vector blob_fields = get_blob_fields().second; + vector blob_fields = GetBlobFields().second; for (unsigned int i = 0; i < selected_columns.size(); ++i) { if (!std::any_of(blob_fields.begin(), blob_fields.end(), [&selected_columns, i](std::string item) { return selected_columns[i] == item; })) { @@ -846,7 +846,7 @@ MSRStatus ShardReader::OpenPy(const std::string &file_path, const int &n_consume } // should remove blob field from selected_columns when call from python std::vector columns(selected_columns); - auto blob_fields = get_blob_fields().second; + auto blob_fields = GetBlobFields().second; for (auto &blob_field : blob_fields) { auto it = std::find(selected_columns.begin(), selected_columns.end(), blob_field); if (it != selected_columns.end()) { @@ -909,7 +909,7 @@ vector ShardReader::GetAllColumns() { vector columns; if (nlp_) { for (auto &c : selected_columns_) { - for (auto &p : get_shard_header()->get_schemas()) { + for (auto &p : GetShardHeader()->GetSchemas()) { auto schema = p->GetSchema()["schema"]; // make sure schema is not reference since error occurred in arm. for (auto it = schema.begin(); it != schema.end(); ++it) { if (it.key() == c) { @@ -943,7 +943,7 @@ MSRStatus ShardReader::CreateTasksByCategory(const std::vector(op); - auto categories = category_op->get_categories(); + auto categories = category_op->GetCategories(); int64_t num_elements = category_op->GetNumElements(); if (num_elements <= 0) { MS_LOG(ERROR) << "Parameter num_element is not positive"; @@ -1104,7 +1104,7 @@ TASK_RETURN_CONTENT ShardReader::ConsumerOneTask(int task_id, uint32_t consumer_ } // Pick up task from task list - auto task = tasks_.get_task_by_id(tasks_.permutation_[task_id]); + auto task = tasks_.GetTaskByID(tasks_.permutation_[task_id]); auto shard_id = std::get<0>(std::get<0>(task)); auto group_id = std::get<1>(std::get<0>(task)); @@ -1117,7 +1117,7 @@ TASK_RETURN_CONTENT ShardReader::ConsumerOneTask(int task_id, uint32_t consumer_ // Pack image list std::vector images(addr[1] - addr[0]); - auto file_offset = header_size_ + page_size_ * (page->get_page_id()) + addr[0]; + auto file_offset = header_size_ + page_size_ * (page->GetPageID()) + addr[0]; auto &io_seekg = file_streams_random_[consumer_id][shard_id]->seekg(file_offset, std::ios::beg); if (!io_seekg.good() || io_seekg.fail() || io_seekg.bad()) { @@ -1139,7 +1139,7 @@ TASK_RETURN_CONTENT ShardReader::ConsumerOneTask(int task_id, uint32_t consumer_ if (selected_columns_.size() == 0) { images_with_exact_columns = images; } else { - auto blob_fields = get_blob_fields(); + auto blob_fields = GetBlobFields(); std::vector ordered_selected_columns_index; uint32_t index = 0; @@ -1272,7 +1272,7 @@ MSRStatus ShardReader::ConsumerByBlock(int consumer_id) { } // Pick up task from task list - auto task = tasks_.get_task_by_id(tasks_.permutation_[task_id]); + auto task = tasks_.GetTaskByID(tasks_.permutation_[task_id]); auto shard_id = std::get<0>(std::get<0>(task)); auto group_id = std::get<1>(std::get<0>(task)); diff --git a/mindspore/ccsrc/mindrecord/io/shard_segment.cc b/mindspore/ccsrc/mindrecord/io/shard_segment.cc index e015831d6b..d6536996ba 100644 --- a/mindspore/ccsrc/mindrecord/io/shard_segment.cc +++ b/mindspore/ccsrc/mindrecord/io/shard_segment.cc @@ -28,7 +28,7 @@ using mindspore::MsLogLevel::INFO; namespace mindspore { namespace mindrecord { -ShardSegment::ShardSegment() { set_all_in_index(false); } +ShardSegment::ShardSegment() { SetAllInIndex(false); } std::pair> ShardSegment::GetCategoryFields() { // Skip if already populated @@ -211,7 +211,7 @@ std::pair> ShardSegment::PackImages(int group_id // Pack image list std::vector images(offset[1] - offset[0]); - auto file_offset = header_size_ + page_size_ * (blob_page->get_page_id()) + offset[0]; + auto file_offset = header_size_ + page_size_ * (blob_page->GetPageID()) + offset[0]; auto &io_seekg = file_streams_random_[0][shard_id]->seekg(file_offset, std::ios::beg); if (!io_seekg.good() || io_seekg.fail() || io_seekg.bad()) { MS_LOG(ERROR) << "File seekg failed"; @@ -363,21 +363,21 @@ std::pair, pybind11::obje return {SUCCESS, std::move(json_data)}; } -std::pair> ShardSegment::get_blob_fields() { +std::pair> ShardSegment::GetBlobFields() { std::vector blob_fields; - for (auto &p : get_shard_header()->get_schemas()) { + for (auto &p : GetShardHeader()->GetSchemas()) { // assume one schema - const auto &fields = p->get_blob_fields(); + const auto &fields = p->GetBlobFields(); blob_fields.assign(fields.begin(), fields.end()); break; } - return std::make_pair(get_nlp_flag() ? kNLP : kCV, blob_fields); + return std::make_pair(GetNlpFlag() ? kNLP : kCV, blob_fields); } std::tuple, json> ShardSegment::GetImageLabel(std::vector images, json label) { - if (get_nlp_flag()) { + if (GetNlpFlag()) { vector columns; - for (auto &p : get_shard_header()->get_schemas()) { + for (auto &p : GetShardHeader()->GetSchemas()) { auto schema = p->GetSchema()["schema"]; // make sure schema is not reference since error occurred in arm. auto schema_items = schema.items(); using it_type = decltype(schema_items.begin()); diff --git a/mindspore/ccsrc/mindrecord/io/shard_writer.cc b/mindspore/ccsrc/mindrecord/io/shard_writer.cc index 2fb5db5503..4a33bfddb3 100644 --- a/mindspore/ccsrc/mindrecord/io/shard_writer.cc +++ b/mindspore/ccsrc/mindrecord/io/shard_writer.cc @@ -179,12 +179,12 @@ MSRStatus ShardWriter::OpenForAppend(const std::string &path) { return FAILED; } shard_header_ = std::make_shared(sh); - auto paths = shard_header_->get_shard_addresses(); - MSRStatus ret = set_header_size(shard_header_->get_header_size()); + auto paths = shard_header_->GetShardAddresses(); + MSRStatus ret = SetHeaderSize(shard_header_->GetHeaderSize()); if (ret == FAILED) { return FAILED; } - ret = set_page_size(shard_header_->get_page_size()); + ret = SetPageSize(shard_header_->GetPageSize()); if (ret == FAILED) { return FAILED; } @@ -229,10 +229,10 @@ MSRStatus ShardWriter::SetShardHeader(std::shared_ptr header_data) } // set fields in mindrecord when empty - std::vector> fields = header_data->get_fields(); + std::vector> fields = header_data->GetFields(); if (fields.empty()) { MS_LOG(DEBUG) << "Missing index fields by user, auto generate index fields."; - std::vector> schemas = header_data->get_schemas(); + std::vector> schemas = header_data->GetSchemas(); for (const auto &schema : schemas) { json jsonSchema = schema->GetSchema()["schema"]; for (const auto &el : jsonSchema.items()) { @@ -241,7 +241,7 @@ MSRStatus ShardWriter::SetShardHeader(std::shared_ptr header_data) (el.value()["type"] == "int64" && el.value().find("shape") == el.value().end()) || (el.value()["type"] == "float32" && el.value().find("shape") == el.value().end()) || (el.value()["type"] == "float64" && el.value().find("shape") == el.value().end())) { - fields.emplace_back(std::make_pair(schema->get_schema_id(), el.key())); + fields.emplace_back(std::make_pair(schema->GetSchemaID(), el.key())); } } } @@ -256,12 +256,12 @@ MSRStatus ShardWriter::SetShardHeader(std::shared_ptr header_data) } shard_header_ = header_data; - shard_header_->set_header_size(header_size_); - shard_header_->set_page_size(page_size_); + shard_header_->SetHeaderSize(header_size_); + shard_header_->SetPageSize(page_size_); return SUCCESS; } -MSRStatus ShardWriter::set_header_size(const uint64_t &header_size) { +MSRStatus ShardWriter::SetHeaderSize(const uint64_t &header_size) { // header_size [16KB, 128MB] if (header_size < kMinHeaderSize || header_size > kMaxHeaderSize) { MS_LOG(ERROR) << "Header size should between 16KB and 128MB."; @@ -276,7 +276,7 @@ MSRStatus ShardWriter::set_header_size(const uint64_t &header_size) { return SUCCESS; } -MSRStatus ShardWriter::set_page_size(const uint64_t &page_size) { +MSRStatus ShardWriter::SetPageSize(const uint64_t &page_size) { // PageSize [32KB, 256MB] if (page_size < kMinPageSize || page_size > kMaxPageSize) { MS_LOG(ERROR) << "Page size should between 16KB and 256MB."; @@ -398,7 +398,7 @@ MSRStatus ShardWriter::CheckData(const std::map> &ra return FAILED; } json schema = result.first->GetSchema()["schema"]; - for (const auto &field : result.first->get_blob_fields()) { + for (const auto &field : result.first->GetBlobFields()) { (void)schema.erase(field); } std::vector sub_raw_data = rawdata_iter->second; @@ -456,7 +456,7 @@ std::tuple ShardWriter::ValidateRawData(std::mapget_schemas().size() != schema_count_) { + if (shard_header_->GetSchemas().size() != schema_count_) { MS_LOG(ERROR) << "Data size is not equal with the schema size"; return failed; } @@ -475,9 +475,9 @@ std::tuple ShardWriter::ValidateRawData(std::mapfirst); } - const std::vector> &schemas = shard_header_->get_schemas(); + const std::vector> &schemas = shard_header_->GetSchemas(); if (std::any_of(schemas.begin(), schemas.end(), [schema_ids](const std::shared_ptr &schema) { - return schema_ids.find(schema->get_schema_id()) == schema_ids.end(); + return schema_ids.find(schema->GetSchemaID()) == schema_ids.end(); })) { // There is not enough data which is not matching the number of schema MS_LOG(ERROR) << "Input rawdata schema id do not match real schema id."; @@ -810,10 +810,10 @@ MSRStatus ShardWriter::CutRowGroup(int start_row, int end_row, const std::vector std::vector> &rows_in_group, const std::shared_ptr &last_raw_page, const std::shared_ptr &last_blob_page) { - auto n_byte_blob = last_blob_page ? last_blob_page->get_page_size() : 0; + auto n_byte_blob = last_blob_page ? last_blob_page->GetPageSize() : 0; - auto last_raw_page_size = last_raw_page ? last_raw_page->get_page_size() : 0; - auto last_raw_offset = last_raw_page ? last_raw_page->get_last_row_group_id().second : 0; + auto last_raw_page_size = last_raw_page ? last_raw_page->GetPageSize() : 0; + auto last_raw_offset = last_raw_page ? last_raw_page->GetLastRowGroupID().second : 0; auto n_byte_raw = last_raw_page_size - last_raw_offset; int page_start_row = start_row; @@ -849,8 +849,8 @@ MSRStatus ShardWriter::AppendBlobPage(const int &shard_id, const std::vectorget_page_id(); - auto bytes_page = last_blob_page->get_page_size(); + auto page_id = last_blob_page->GetPageID(); + auto bytes_page = last_blob_page->GetPageSize(); auto &io_seekp = file_streams_[shard_id]->seekp(page_size_ * page_id + header_size_ + bytes_page, std::ios::beg); if (!io_seekp.good() || io_seekp.fail() || io_seekp.bad()) { MS_LOG(ERROR) << "File seekp failed"; @@ -862,9 +862,9 @@ MSRStatus ShardWriter::AppendBlobPage(const int &shard_id, const std::vectorset_page_size(bytes_page); - uint64_t end_row = last_blob_page->get_end_row_id() + blob_row.second - blob_row.first; - last_blob_page->set_end_row_id(end_row); + last_blob_page->SetPageSize(bytes_page); + uint64_t end_row = last_blob_page->GetEndRowID() + blob_row.second - blob_row.first; + last_blob_page->SetEndRowID(end_row); (void)shard_header_->SetPage(last_blob_page); return SUCCESS; } @@ -873,8 +873,8 @@ MSRStatus ShardWriter::NewBlobPage(const int &shard_id, const std::vector> &rows_in_group, const std::shared_ptr &last_blob_page) { auto page_id = shard_header_->GetLastPageId(shard_id); - auto page_type_id = last_blob_page ? last_blob_page->get_page_type_id() : -1; - auto current_row = last_blob_page ? last_blob_page->get_end_row_id() : 0; + auto page_type_id = last_blob_page ? last_blob_page->GetPageTypeID() : -1; + auto current_row = last_blob_page ? last_blob_page->GetEndRowID() : 0; // index(0) indicate appendBlobPage for (uint32_t i = 1; i < rows_in_group.size(); ++i) { auto blob_row = rows_in_group[i]; @@ -905,15 +905,15 @@ MSRStatus ShardWriter::ShiftRawPage(const int &shard_id, const std::vector &last_raw_page) { auto blob_row = rows_in_group[0]; if (blob_row.first == blob_row.second) return SUCCESS; - auto last_raw_page_size = last_raw_page ? last_raw_page->get_page_size() : 0; + auto last_raw_page_size = last_raw_page ? last_raw_page->GetPageSize() : 0; if (std::accumulate(raw_data_size_.begin() + blob_row.first, raw_data_size_.begin() + blob_row.second, 0) + last_raw_page_size <= page_size_) { return SUCCESS; } auto page_id = shard_header_->GetLastPageId(shard_id); - auto last_row_group_id_offset = last_raw_page->get_last_row_group_id().second; - auto last_raw_page_id = last_raw_page->get_page_id(); + auto last_row_group_id_offset = last_raw_page->GetLastRowGroupID().second; + auto last_raw_page_id = last_raw_page->GetPageID(); auto shift_size = last_raw_page_size - last_row_group_id_offset; std::vector buf(shift_size); @@ -956,10 +956,10 @@ MSRStatus ShardWriter::ShiftRawPage(const int &shard_id, const std::vectorSetPage(last_raw_page); // Refresh page info in header - int row_group_id = last_raw_page->get_last_row_group_id().first + 1; + int row_group_id = last_raw_page->GetLastRowGroupID().first + 1; std::vector> row_group_ids; row_group_ids.emplace_back(row_group_id, 0); - int page_type_id = last_raw_page->get_page_id(); + int page_type_id = last_raw_page->GetPageID(); auto page = Page(++page_id, shard_id, kPageTypeRaw, ++page_type_id, 0, 0, row_group_ids, shift_size); (void)shard_header_->AddPage(std::make_shared(page)); @@ -971,7 +971,7 @@ MSRStatus ShardWriter::ShiftRawPage(const int &shard_id, const std::vector> &rows_in_group, std::shared_ptr &last_raw_page, const std::vector> &bin_raw_data) { - int last_row_group_id = last_raw_page ? last_raw_page->get_last_row_group_id().first : -1; + int last_row_group_id = last_raw_page ? last_raw_page->GetLastRowGroupID().first : -1; for (uint32_t i = 0; i < rows_in_group.size(); ++i) { const auto &blob_row = rows_in_group[i]; if (blob_row.first == blob_row.second) continue; @@ -979,7 +979,7 @@ MSRStatus ShardWriter::WriteRawPage(const int &shard_id, const std::vectorget_page_size() + raw_size > page_size_) { + } else if (last_raw_page->GetPageSize() + raw_size > page_size_) { (void)shard_header_->SetPage(last_raw_page); EmptyRawPage(shard_id, last_raw_page); } @@ -994,7 +994,7 @@ MSRStatus ShardWriter::WriteRawPage(const int &shard_id, const std::vector &last_raw_page) { auto row_group_ids = std::vector>(); auto page_id = shard_header_->GetLastPageId(shard_id); - auto page_type_id = last_raw_page ? last_raw_page->get_page_id() : -1; + auto page_type_id = last_raw_page ? last_raw_page->GetPageID() : -1; auto page = Page(++page_id, shard_id, kPageTypeRaw, ++page_type_id, 0, 0, row_group_ids, 0); (void)shard_header_->AddPage(std::make_shared(page)); SetLastRawPage(shard_id, last_raw_page); @@ -1003,9 +1003,9 @@ void ShardWriter::EmptyRawPage(const int &shard_id, std::shared_ptr &last_ MSRStatus ShardWriter::AppendRawPage(const int &shard_id, const std::vector> &rows_in_group, const int &chunk_id, int &last_row_group_id, std::shared_ptr last_raw_page, const std::vector> &bin_raw_data) { - std::vector> row_group_ids = last_raw_page->get_row_group_ids(); - auto last_raw_page_id = last_raw_page->get_page_id(); - auto n_bytes = last_raw_page->get_page_size(); + std::vector> row_group_ids = last_raw_page->GetRowGroupIds(); + auto last_raw_page_id = last_raw_page->GetPageID(); + auto n_bytes = last_raw_page->GetPageSize(); // previous raw data page auto &io_seekp = @@ -1022,8 +1022,8 @@ MSRStatus ShardWriter::AppendRawPage(const int &shard_id, const std::vectorset_page_size(n_bytes); - last_raw_page->set_row_group_ids(row_group_ids); + last_raw_page->SetPageSize(n_bytes); + last_raw_page->SetRowGroupIds(row_group_ids); (void)shard_header_->SetPage(last_raw_page); return SUCCESS; diff --git a/mindspore/ccsrc/mindrecord/meta/shard_category.cc b/mindspore/ccsrc/mindrecord/meta/shard_category.cc index 2a9c2c0966..dfca92a08c 100644 --- a/mindspore/ccsrc/mindrecord/meta/shard_category.cc +++ b/mindspore/ccsrc/mindrecord/meta/shard_category.cc @@ -34,7 +34,7 @@ ShardCategory::ShardCategory(const std::string &category_field, int64_t num_elem num_categories_(num_categories), replacement_(replacement) {} -MSRStatus ShardCategory::execute(ShardTask &tasks) { return SUCCESS; } +MSRStatus ShardCategory::Execute(ShardTask &tasks) { return SUCCESS; } int64_t ShardCategory::GetNumSamples(int64_t dataset_size, int64_t num_classes) { if (dataset_size == 0) return dataset_size; diff --git a/mindspore/ccsrc/mindrecord/meta/shard_header.cc b/mindspore/ccsrc/mindrecord/meta/shard_header.cc index 26008e3ca9..8db2c6b7c9 100644 --- a/mindspore/ccsrc/mindrecord/meta/shard_header.cc +++ b/mindspore/ccsrc/mindrecord/meta/shard_header.cc @@ -343,7 +343,7 @@ std::vector ShardHeader::SerializeHeader() { std::string ShardHeader::SerializeIndexFields() { json j; - auto fields = index_->get_fields(); + auto fields = index_->GetFields(); for (const auto &field : fields) { j.push_back({{"schema_id", field.first}, {"index_field", field.second}}); } @@ -365,7 +365,7 @@ std::vector ShardHeader::SerializePage() { std::string ShardHeader::SerializeStatistics() { json j; for (const auto &stats : statistics_) { - j.emplace_back(stats->get_statistics()); + j.emplace_back(stats->GetStatistics()); } return j.dump(); } @@ -398,8 +398,8 @@ MSRStatus ShardHeader::SetPage(const std::shared_ptr &new_page) { if (new_page == nullptr) { return FAILED; } - int shard_id = new_page->get_shard_id(); - int page_id = new_page->get_page_id(); + int shard_id = new_page->GetShardID(); + int page_id = new_page->GetPageID(); if (shard_id < static_cast(pages_.size()) && page_id < static_cast(pages_[shard_id].size())) { pages_[shard_id][page_id] = new_page; return SUCCESS; @@ -412,8 +412,8 @@ MSRStatus ShardHeader::AddPage(const std::shared_ptr &new_page) { if (new_page == nullptr) { return FAILED; } - int shard_id = new_page->get_shard_id(); - int page_id = new_page->get_page_id(); + int shard_id = new_page->GetShardID(); + int page_id = new_page->GetPageID(); if (shard_id < static_cast(pages_.size()) && page_id == static_cast(pages_[shard_id].size())) { pages_[shard_id].push_back(new_page); return SUCCESS; @@ -435,8 +435,8 @@ int ShardHeader::GetLastPageIdByType(const int &shard_id, const std::string &pag } int last_page_id = -1; for (uint64_t i = pages_[shard_id].size(); i >= 1; i--) { - if (pages_[shard_id][i - 1]->get_page_type() == page_type) { - last_page_id = pages_[shard_id][i - 1]->get_page_id(); + if (pages_[shard_id][i - 1]->GetPageType() == page_type) { + last_page_id = pages_[shard_id][i - 1]->GetPageID(); return last_page_id; } } @@ -451,7 +451,7 @@ const std::pair> ShardHeader::GetPageByGroupId( } for (uint64_t i = pages_[shard_id].size(); i >= 1; i--) { auto page = pages_[shard_id][i - 1]; - if (page->get_page_type() == kPageTypeBlob && page->get_page_type_id() == group_id) { + if (page->GetPageType() == kPageTypeBlob && page->GetPageTypeID() == group_id) { return {SUCCESS, page}; } } @@ -470,10 +470,10 @@ int ShardHeader::AddSchema(std::shared_ptr schema) { return -1; } - int64_t schema_id = schema->get_schema_id(); + int64_t schema_id = schema->GetSchemaID(); if (schema_id == -1) { schema_id = schema_.size(); - schema->set_schema_id(schema_id); + schema->SetSchemaID(schema_id); } schema_.push_back(schema); return schema_id; @@ -481,10 +481,10 @@ int ShardHeader::AddSchema(std::shared_ptr schema) { void ShardHeader::AddStatistic(std::shared_ptr statistic) { if (statistic) { - int64_t statistics_id = statistic->get_statistics_id(); + int64_t statistics_id = statistic->GetStatisticsID(); if (statistics_id == -1) { statistics_id = statistics_.size(); - statistic->set_statistics_id(statistics_id); + statistic->SetStatisticsID(statistics_id); } statistics_.push_back(statistic); } @@ -527,13 +527,13 @@ MSRStatus ShardHeader::AddIndexFields(const std::vector &fields) { return FAILED; } - if (get_schemas().empty()) { + if (GetSchemas().empty()) { MS_LOG(ERROR) << "No schema is set"; return FAILED; } for (const auto &schemaPtr : schema_) { - auto result = GetSchemaByID(schemaPtr->get_schema_id()); + auto result = GetSchemaByID(schemaPtr->GetSchemaID()); if (result.second != SUCCESS) { MS_LOG(ERROR) << "Could not get schema by id."; return FAILED; @@ -548,7 +548,7 @@ MSRStatus ShardHeader::AddIndexFields(const std::vector &fields) { // checkout and add fields for each schema std::set field_set; - for (const auto &item : index->get_fields()) { + for (const auto &item : index->GetFields()) { field_set.insert(item.second); } for (const auto &field : fields) { @@ -564,7 +564,7 @@ MSRStatus ShardHeader::AddIndexFields(const std::vector &fields) { field_set.insert(field); // add field into index - index.get()->AddIndexField(schemaPtr->get_schema_id(), field); + index.get()->AddIndexField(schemaPtr->GetSchemaID(), field); } } @@ -575,12 +575,12 @@ MSRStatus ShardHeader::AddIndexFields(const std::vector &fields) { MSRStatus ShardHeader::GetAllSchemaID(std::set &bucket_count) { // get all schema id for (const auto &schema : schema_) { - auto bucket_it = bucket_count.find(schema->get_schema_id()); + auto bucket_it = bucket_count.find(schema->GetSchemaID()); if (bucket_it != bucket_count.end()) { MS_LOG(ERROR) << "Schema duplication"; return FAILED; } else { - bucket_count.insert(schema->get_schema_id()); + bucket_count.insert(schema->GetSchemaID()); } } return SUCCESS; @@ -603,7 +603,7 @@ MSRStatus ShardHeader::AddIndexFields(std::vector> field_set; - for (const auto &item : index->get_fields()) { + for (const auto &item : index->GetFields()) { field_set.insert(item); } for (const auto &field : fields) { @@ -646,20 +646,20 @@ MSRStatus ShardHeader::AddIndexFields(std::vector= shard_addresses_.size()) { return ""; } return shard_addresses_.at(shard_id); } -std::vector> ShardHeader::get_schemas() { return schema_; } +std::vector> ShardHeader::GetSchemas() { return schema_; } -std::vector> ShardHeader::get_statistics() { return statistics_; } +std::vector> ShardHeader::GetStatistics() { return statistics_; } -std::vector> ShardHeader::get_fields() { return index_->get_fields(); } +std::vector> ShardHeader::GetFields() { return index_->GetFields(); } -std::shared_ptr ShardHeader::get_index() { return index_; } +std::shared_ptr ShardHeader::GetIndex() { return index_; } std::pair, MSRStatus> ShardHeader::GetSchemaByID(int64_t schema_id) { int64_t schemaSize = schema_.size(); diff --git a/mindspore/ccsrc/mindrecord/meta/shard_index.cc b/mindspore/ccsrc/mindrecord/meta/shard_index.cc index ddb85fc66e..8b7a3c0342 100644 --- a/mindspore/ccsrc/mindrecord/meta/shard_index.cc +++ b/mindspore/ccsrc/mindrecord/meta/shard_index.cc @@ -28,6 +28,6 @@ void Index::AddIndexField(const int64_t &schemaId, const std::string &field) { } // Get attribute list -std::vector> Index::get_fields() { return fields_; } +std::vector> Index::GetFields() { return fields_; } } // namespace mindrecord } // namespace mindspore diff --git a/mindspore/ccsrc/mindrecord/meta/shard_pk_sample.cc b/mindspore/ccsrc/mindrecord/meta/shard_pk_sample.cc index 8e2e892e63..fac2fec708 100644 --- a/mindspore/ccsrc/mindrecord/meta/shard_pk_sample.cc +++ b/mindspore/ccsrc/mindrecord/meta/shard_pk_sample.cc @@ -34,7 +34,7 @@ ShardPkSample::ShardPkSample(const std::string &category_field, int64_t num_elem shuffle_op_ = std::make_shared(seed, kShuffleSample); // do shuffle and replacement } -MSRStatus ShardPkSample::suf_execute(ShardTask &tasks) { +MSRStatus ShardPkSample::SufExecute(ShardTask &tasks) { if (shuffle_ == true) { if (SUCCESS != (*shuffle_op_)(tasks)) { return FAILED; diff --git a/mindspore/ccsrc/mindrecord/meta/shard_sample.cc b/mindspore/ccsrc/mindrecord/meta/shard_sample.cc index a9cfce0d01..d7842a11a3 100644 --- a/mindspore/ccsrc/mindrecord/meta/shard_sample.cc +++ b/mindspore/ccsrc/mindrecord/meta/shard_sample.cc @@ -74,14 +74,14 @@ int64_t ShardSample::GetNumSamples(int64_t dataset_size, int64_t num_classes) { return -1; } -const std::pair ShardSample::get_partitions() const { +const std::pair ShardSample::GetPartitions() const { if (numerator_ == 1 && denominator_ > 1) { return std::pair(denominator_, partition_id_); } return std::pair(-1, -1); } -MSRStatus ShardSample::execute(ShardTask &tasks) { +MSRStatus ShardSample::Execute(ShardTask &tasks) { int no_of_categories = static_cast(tasks.categories); int total_no = static_cast(tasks.Size()); @@ -114,11 +114,11 @@ MSRStatus ShardSample::execute(ShardTask &tasks) { if (sampler_type_ == kSubsetRandomSampler) { for (int i = 0; i < indices_.size(); ++i) { int index = ((indices_[i] % total_no) + total_no) % total_no; - new_tasks.InsertTask(tasks.get_task_by_id(index)); // different mod result between c and python + new_tasks.InsertTask(tasks.GetTaskByID(index)); // different mod result between c and python } } else { for (int i = partition_id_ * taking; i < (partition_id_ + 1) * taking; i++) { - new_tasks.InsertTask(tasks.get_task_by_id(i % total_no)); // rounding up. if overflow, go back to start + new_tasks.InsertTask(tasks.GetTaskByID(i % total_no)); // rounding up. if overflow, go back to start } } std::swap(tasks, new_tasks); @@ -129,14 +129,14 @@ MSRStatus ShardSample::execute(ShardTask &tasks) { } total_no = static_cast(tasks.permutation_.size()); for (size_t i = partition_id_ * taking; i < (partition_id_ + 1) * taking; i++) { - new_tasks.InsertTask(tasks.get_task_by_id(tasks.permutation_[i % total_no])); + new_tasks.InsertTask(tasks.GetTaskByID(tasks.permutation_[i % total_no])); } std::swap(tasks, new_tasks); } return SUCCESS; } -MSRStatus ShardSample::suf_execute(ShardTask &tasks) { +MSRStatus ShardSample::SufExecute(ShardTask &tasks) { if (sampler_type_ == kSubsetRandomSampler) { if (SUCCESS != (*shuffle_op_)(tasks)) { return FAILED; diff --git a/mindspore/ccsrc/mindrecord/meta/shard_schema.cc b/mindspore/ccsrc/mindrecord/meta/shard_schema.cc index 0c2550e2dc..ee0f5afa4a 100644 --- a/mindspore/ccsrc/mindrecord/meta/shard_schema.cc +++ b/mindspore/ccsrc/mindrecord/meta/shard_schema.cc @@ -44,7 +44,7 @@ std::shared_ptr Schema::Build(std::string desc, pybind11::handle schema) return Build(std::move(desc), schema_json); } -std::string Schema::get_desc() const { return desc_; } +std::string Schema::GetDesc() const { return desc_; } json Schema::GetSchema() const { json str_schema; @@ -60,11 +60,11 @@ pybind11::object Schema::GetSchemaForPython() const { return schema_py; } -void Schema::set_schema_id(int64_t id) { schema_id_ = id; } +void Schema::SetSchemaID(int64_t id) { schema_id_ = id; } -int64_t Schema::get_schema_id() const { return schema_id_; } +int64_t Schema::GetSchemaID() const { return schema_id_; } -std::vector Schema::get_blob_fields() const { return blob_fields_; } +std::vector Schema::GetBlobFields() const { return blob_fields_; } std::vector Schema::PopulateBlobFields(json schema) { std::vector blob_fields; @@ -155,7 +155,7 @@ bool Schema::Validate(json schema) { } bool Schema::operator==(const mindrecord::Schema &b) const { - if (this->get_desc() != b.get_desc() || this->GetSchema() != b.GetSchema()) { + if (this->GetDesc() != b.GetDesc() || this->GetSchema() != b.GetSchema()) { return false; } return true; diff --git a/mindspore/ccsrc/mindrecord/meta/shard_shuffle.cc b/mindspore/ccsrc/mindrecord/meta/shard_shuffle.cc index 757dcb7b74..d33400ef38 100644 --- a/mindspore/ccsrc/mindrecord/meta/shard_shuffle.cc +++ b/mindspore/ccsrc/mindrecord/meta/shard_shuffle.cc @@ -23,7 +23,7 @@ namespace mindrecord { ShardShuffle::ShardShuffle(uint32_t seed, ShuffleType shuffle_type) : shuffle_seed_(seed), shuffle_type_(shuffle_type) {} -MSRStatus ShardShuffle::execute(ShardTask &tasks) { +MSRStatus ShardShuffle::Execute(ShardTask &tasks) { if (tasks.categories < 1) { return FAILED; } diff --git a/mindspore/ccsrc/mindrecord/meta/shard_statistics.cc b/mindspore/ccsrc/mindrecord/meta/shard_statistics.cc index deaf0b1874..ca36c50863 100644 --- a/mindspore/ccsrc/mindrecord/meta/shard_statistics.cc +++ b/mindspore/ccsrc/mindrecord/meta/shard_statistics.cc @@ -48,9 +48,9 @@ std::shared_ptr Statistics::Build(std::string desc, pybind11::handle return std::make_shared(object_statistics); } -std::string Statistics::get_desc() const { return desc_; } +std::string Statistics::GetDesc() const { return desc_; } -json Statistics::get_statistics() const { +json Statistics::GetStatistics() const { json str_statistics; str_statistics["desc"] = desc_; str_statistics["statistics"] = statistics_; @@ -58,13 +58,13 @@ json Statistics::get_statistics() const { } pybind11::object Statistics::GetStatisticsForPython() const { - json str_statistics = Statistics::get_statistics(); + json str_statistics = Statistics::GetStatistics(); return nlohmann::detail::FromJsonImpl(str_statistics); } -void Statistics::set_statistics_id(int64_t id) { statistics_id_ = id; } +void Statistics::SetStatisticsID(int64_t id) { statistics_id_ = id; } -int64_t Statistics::get_statistics_id() const { return statistics_id_; } +int64_t Statistics::GetStatisticsID() const { return statistics_id_; } bool Statistics::Validate(const json &statistics) { if (statistics.size() != kInt1) { @@ -103,7 +103,7 @@ bool Statistics::LevelRecursive(json level) { } bool Statistics::operator==(const Statistics &b) const { - if (this->get_statistics() != b.get_statistics()) { + if (this->GetStatistics() != b.GetStatistics()) { return false; } return true; diff --git a/mindspore/ccsrc/mindrecord/meta/shard_task.cc b/mindspore/ccsrc/mindrecord/meta/shard_task.cc index be566d1601..3abc725a7b 100644 --- a/mindspore/ccsrc/mindrecord/meta/shard_task.cc +++ b/mindspore/ccsrc/mindrecord/meta/shard_task.cc @@ -59,12 +59,12 @@ uint32_t ShardTask::SizeOfRows() const { return nRows; } -std::tuple, std::vector, json> &ShardTask::get_task_by_id(size_t id) { +std::tuple, std::vector, json> &ShardTask::GetTaskByID(size_t id) { MS_ASSERT(id < task_list_.size()); return task_list_[id]; } -std::tuple, std::vector, json> &ShardTask::get_random_task() { +std::tuple, std::vector, json> &ShardTask::GetRandomTask() { std::random_device rd; std::mt19937 gen(rd()); std::uniform_int_distribution<> dis(0, task_list_.size() - 1); @@ -82,7 +82,7 @@ ShardTask ShardTask::Combine(std::vector &category_tasks, bool replac } for (uint32_t task_no = 0; task_no < minTasks; task_no++) { for (uint32_t i = 0; i < total_categories; i++) { - res.InsertTask(std::move(category_tasks[i].get_task_by_id(static_cast(task_no)))); + res.InsertTask(std::move(category_tasks[i].GetTaskByID(static_cast(task_no)))); } } } else { @@ -95,7 +95,7 @@ ShardTask ShardTask::Combine(std::vector &category_tasks, bool replac } for (uint32_t i = 0; i < total_categories; i++) { for (uint32_t j = 0; j < maxTasks; j++) { - res.InsertTask(category_tasks[i].get_random_task()); + res.InsertTask(category_tasks[i].GetRandomTask()); } } } diff --git a/tests/ut/cpp/mindrecord/ut_shard.cc b/tests/ut/cpp/mindrecord/ut_shard.cc index 994ff1b859..b8c229e82f 100644 --- a/tests/ut/cpp/mindrecord/ut_shard.cc +++ b/tests/ut/cpp/mindrecord/ut_shard.cc @@ -52,7 +52,7 @@ TEST_F(TestShard, TestShardSchemaPart) { std::shared_ptr schema = Schema::Build(desc, j); ASSERT_TRUE(schema != nullptr); - MS_LOG(INFO) << "schema description: " << schema->get_desc() << ", schema: " << + MS_LOG(INFO) << "schema description: " << schema->GetDesc() << ", schema: " << common::SafeCStr(schema->GetSchema().dump()); for (int i = 1; i <= 4; i++) { string filename = std::string("./imagenet.shard0") + std::to_string(i); @@ -71,8 +71,8 @@ TEST_F(TestShard, TestStatisticPart) { nlohmann::json statistic_json = json::parse(kStatistics[2]); std::shared_ptr statistics = Statistics::Build(desc, statistic_json); ASSERT_TRUE(statistics != nullptr); - MS_LOG(INFO) << "test get_desc(), result: " << statistics->get_desc(); - MS_LOG(INFO) << "test get_statistics, result: " << statistics->get_statistics().dump(); + MS_LOG(INFO) << "test get_desc(), result: " << statistics->GetDesc(); + MS_LOG(INFO) << "test get_statistics, result: " << statistics->GetStatistics().dump(); std::string desc2 = "axis"; nlohmann::json statistic_json2 = R"({})"; @@ -111,13 +111,13 @@ TEST_F(TestShard, TestShardHeaderPart) { ASSERT_EQ(res, 0); header_data.AddStatistic(statistics1); std::vector re_schemas; - for (auto &schema_ptr : header_data.get_schemas()) { + for (auto &schema_ptr : header_data.GetSchemas()) { re_schemas.push_back(*schema_ptr); } ASSERT_EQ(re_schemas, validate_schema); std::vector re_statistics; - for (auto &statistic : header_data.get_statistics()) { + for (auto &statistic : header_data.GetStatistics()) { re_statistics.push_back(*statistic); } ASSERT_EQ(re_statistics, validate_statistics); @@ -129,7 +129,7 @@ TEST_F(TestShard, TestShardHeaderPart) { std::pair pair1(0, "name"); fields.push_back(pair1); ASSERT_TRUE(header_data.AddIndexFields(fields) == SUCCESS); - std::vector> resFields = header_data.get_fields(); + std::vector> resFields = header_data.GetFields(); ASSERT_EQ(resFields, fields); } diff --git a/tests/ut/cpp/mindrecord/ut_shard_header_test.cc b/tests/ut/cpp/mindrecord/ut_shard_header_test.cc index ce5f40c10c..cea71c34b7 100644 --- a/tests/ut/cpp/mindrecord/ut_shard_header_test.cc +++ b/tests/ut/cpp/mindrecord/ut_shard_header_test.cc @@ -70,7 +70,7 @@ TEST_F(TestShardHeader, AddIndexFields) { int schema_id1 = header_data.AddSchema(schema1); int schema_id2 = header_data.AddSchema(schema2); ASSERT_EQ(schema_id2, -1); - ASSERT_EQ(header_data.get_schemas().size(), 1); + ASSERT_EQ(header_data.GetSchemas().size(), 1); // check out fields std::vector> fields; @@ -81,35 +81,35 @@ TEST_F(TestShardHeader, AddIndexFields) { fields.push_back(index_field2); MSRStatus res = header_data.AddIndexFields(fields); ASSERT_EQ(res, SUCCESS); - ASSERT_EQ(header_data.get_fields().size(), 2); + ASSERT_EQ(header_data.GetFields().size(), 2); fields.clear(); std::pair index_field3(schema_id1, "name"); fields.push_back(index_field3); res = header_data.AddIndexFields(fields); ASSERT_EQ(res, FAILED); - ASSERT_EQ(header_data.get_fields().size(), 2); + ASSERT_EQ(header_data.GetFields().size(), 2); fields.clear(); std::pair index_field4(schema_id1, "names"); fields.push_back(index_field4); res = header_data.AddIndexFields(fields); ASSERT_EQ(res, FAILED); - ASSERT_EQ(header_data.get_fields().size(), 2); + ASSERT_EQ(header_data.GetFields().size(), 2); fields.clear(); std::pair index_field5(schema_id1 + 1, "name"); fields.push_back(index_field5); res = header_data.AddIndexFields(fields); ASSERT_EQ(res, FAILED); - ASSERT_EQ(header_data.get_fields().size(), 2); + ASSERT_EQ(header_data.GetFields().size(), 2); fields.clear(); std::pair index_field6(schema_id1, "label"); fields.push_back(index_field6); res = header_data.AddIndexFields(fields); ASSERT_EQ(res, FAILED); - ASSERT_EQ(header_data.get_fields().size(), 2); + ASSERT_EQ(header_data.GetFields().size(), 2); std::string desc_new = "this is a test1"; json schemaContent_new = R"({"name": {"type": "string"}, @@ -121,7 +121,7 @@ TEST_F(TestShardHeader, AddIndexFields) { mindrecord::ShardHeader header_data_new; header_data_new.AddSchema(schema_new); - ASSERT_EQ(header_data_new.get_schemas().size(), 1); + ASSERT_EQ(header_data_new.GetSchemas().size(), 1); // test add fields std::vector single_fields; @@ -131,25 +131,25 @@ TEST_F(TestShardHeader, AddIndexFields) { single_fields.push_back("box"); res = header_data_new.AddIndexFields(single_fields); ASSERT_EQ(res, FAILED); - ASSERT_EQ(header_data_new.get_fields().size(), 1); + ASSERT_EQ(header_data_new.GetFields().size(), 1); single_fields.push_back("name"); single_fields.push_back("box"); res = header_data_new.AddIndexFields(single_fields); ASSERT_EQ(res, FAILED); - ASSERT_EQ(header_data_new.get_fields().size(), 1); + ASSERT_EQ(header_data_new.GetFields().size(), 1); single_fields.clear(); single_fields.push_back("names"); res = header_data_new.AddIndexFields(single_fields); ASSERT_EQ(res, FAILED); - ASSERT_EQ(header_data_new.get_fields().size(), 1); + ASSERT_EQ(header_data_new.GetFields().size(), 1); single_fields.clear(); single_fields.push_back("box"); res = header_data_new.AddIndexFields(single_fields); ASSERT_EQ(res, SUCCESS); - ASSERT_EQ(header_data_new.get_fields().size(), 2); + ASSERT_EQ(header_data_new.GetFields().size(), 2); } } // namespace mindrecord } // namespace mindspore diff --git a/tests/ut/cpp/mindrecord/ut_shard_operator_test.cc b/tests/ut/cpp/mindrecord/ut_shard_operator_test.cc index 9c177d7a40..6fc5ccbbe5 100644 --- a/tests/ut/cpp/mindrecord/ut_shard_operator_test.cc +++ b/tests/ut/cpp/mindrecord/ut_shard_operator_test.cc @@ -139,7 +139,7 @@ TEST_F(TestShardOperator, TestShardSamplePartition) { const int kPar = 2; std::vector> ops; ops.push_back(std::make_shared(kNum, kDen, kPar)); - auto partitions = std::dynamic_pointer_cast(ops[0])->get_partitions(); + auto partitions = std::dynamic_pointer_cast(ops[0])->GetPartitions(); ASSERT_TRUE(partitions.first == 4); ASSERT_TRUE(partitions.second == 2); diff --git a/tests/ut/cpp/mindrecord/ut_shard_page_test.cc b/tests/ut/cpp/mindrecord/ut_shard_page_test.cc index f06b987a89..dabd3d819f 100644 --- a/tests/ut/cpp/mindrecord/ut_shard_page_test.cc +++ b/tests/ut/cpp/mindrecord/ut_shard_page_test.cc @@ -57,15 +57,15 @@ TEST_F(TestShardPage, TestBasic) { Page page = Page(kGoldenPageId, kGoldenShardId, kGoldenType, kGoldenTypeId, kGoldenStart, kGoldenEnd, golden_row_group, kGoldenSize); - EXPECT_EQ(kGoldenPageId, page.get_page_id()); - EXPECT_EQ(kGoldenShardId, page.get_shard_id()); - EXPECT_EQ(kGoldenTypeId, page.get_page_type_id()); - ASSERT_TRUE(kGoldenType == page.get_page_type()); - EXPECT_EQ(kGoldenSize, page.get_page_size()); - EXPECT_EQ(kGoldenStart, page.get_start_row_id()); - EXPECT_EQ(kGoldenEnd, page.get_end_row_id()); - ASSERT_TRUE(std::make_pair(4, kOffset) == page.get_last_row_group_id()); - ASSERT_TRUE(golden_row_group == page.get_row_group_ids()); + EXPECT_EQ(kGoldenPageId, page.GetPageID()); + EXPECT_EQ(kGoldenShardId, page.GetShardID()); + EXPECT_EQ(kGoldenTypeId, page.GetPageTypeID()); + ASSERT_TRUE(kGoldenType == page.GetPageType()); + EXPECT_EQ(kGoldenSize, page.GetPageSize()); + EXPECT_EQ(kGoldenStart, page.GetStartRowID()); + EXPECT_EQ(kGoldenEnd, page.GetEndRowID()); + ASSERT_TRUE(std::make_pair(4, kOffset) == page.GetLastRowGroupID()); + ASSERT_TRUE(golden_row_group == page.GetRowGroupIds()); } TEST_F(TestShardPage, TestSetter) { @@ -86,43 +86,43 @@ TEST_F(TestShardPage, TestSetter) { Page page = Page(kGoldenPageId, kGoldenShardId, kGoldenType, kGoldenTypeId, kGoldenStart, kGoldenEnd, golden_row_group, kGoldenSize); - EXPECT_EQ(kGoldenPageId, page.get_page_id()); - EXPECT_EQ(kGoldenShardId, page.get_shard_id()); - EXPECT_EQ(kGoldenTypeId, page.get_page_type_id()); - ASSERT_TRUE(kGoldenType == page.get_page_type()); - EXPECT_EQ(kGoldenSize, page.get_page_size()); - EXPECT_EQ(kGoldenStart, page.get_start_row_id()); - EXPECT_EQ(kGoldenEnd, page.get_end_row_id()); - ASSERT_TRUE(std::make_pair(4, kOffset1) == page.get_last_row_group_id()); - ASSERT_TRUE(golden_row_group == page.get_row_group_ids()); + EXPECT_EQ(kGoldenPageId, page.GetPageID()); + EXPECT_EQ(kGoldenShardId, page.GetShardID()); + EXPECT_EQ(kGoldenTypeId, page.GetPageTypeID()); + ASSERT_TRUE(kGoldenType == page.GetPageType()); + EXPECT_EQ(kGoldenSize, page.GetPageSize()); + EXPECT_EQ(kGoldenStart, page.GetStartRowID()); + EXPECT_EQ(kGoldenEnd, page.GetEndRowID()); + ASSERT_TRUE(std::make_pair(4, kOffset1) == page.GetLastRowGroupID()); + ASSERT_TRUE(golden_row_group == page.GetRowGroupIds()); const int kNewEnd = 33; const int kNewSize = 300; std::vector> new_row_group = {{0, 100}, {100, 200}, {200, 3000}}; - page.set_end_row_id(kNewEnd); - page.set_page_size(kNewSize); - page.set_row_group_ids(new_row_group); - EXPECT_EQ(kGoldenPageId, page.get_page_id()); - EXPECT_EQ(kGoldenShardId, page.get_shard_id()); - EXPECT_EQ(kGoldenTypeId, page.get_page_type_id()); - ASSERT_TRUE(kGoldenType == page.get_page_type()); - EXPECT_EQ(kNewSize, page.get_page_size()); - EXPECT_EQ(kGoldenStart, page.get_start_row_id()); - EXPECT_EQ(kNewEnd, page.get_end_row_id()); - ASSERT_TRUE(std::make_pair(200, kOffset2) == page.get_last_row_group_id()); - ASSERT_TRUE(new_row_group == page.get_row_group_ids()); + page.SetEndRowID(kNewEnd); + page.SetPageSize(kNewSize); + page.SetRowGroupIds(new_row_group); + EXPECT_EQ(kGoldenPageId, page.GetPageID()); + EXPECT_EQ(kGoldenShardId, page.GetShardID()); + EXPECT_EQ(kGoldenTypeId, page.GetPageTypeID()); + ASSERT_TRUE(kGoldenType == page.GetPageType()); + EXPECT_EQ(kNewSize, page.GetPageSize()); + EXPECT_EQ(kGoldenStart, page.GetStartRowID()); + EXPECT_EQ(kNewEnd, page.GetEndRowID()); + ASSERT_TRUE(std::make_pair(200, kOffset2) == page.GetLastRowGroupID()); + ASSERT_TRUE(new_row_group == page.GetRowGroupIds()); page.DeleteLastGroupId(); - EXPECT_EQ(kGoldenPageId, page.get_page_id()); - EXPECT_EQ(kGoldenShardId, page.get_shard_id()); - EXPECT_EQ(kGoldenTypeId, page.get_page_type_id()); - ASSERT_TRUE(kGoldenType == page.get_page_type()); - EXPECT_EQ(3000, page.get_page_size()); - EXPECT_EQ(kGoldenStart, page.get_start_row_id()); - EXPECT_EQ(kNewEnd, page.get_end_row_id()); - ASSERT_TRUE(std::make_pair(100, kOffset3) == page.get_last_row_group_id()); + EXPECT_EQ(kGoldenPageId, page.GetPageID()); + EXPECT_EQ(kGoldenShardId, page.GetShardID()); + EXPECT_EQ(kGoldenTypeId, page.GetPageTypeID()); + ASSERT_TRUE(kGoldenType == page.GetPageType()); + EXPECT_EQ(3000, page.GetPageSize()); + EXPECT_EQ(kGoldenStart, page.GetStartRowID()); + EXPECT_EQ(kNewEnd, page.GetEndRowID()); + ASSERT_TRUE(std::make_pair(100, kOffset3) == page.GetLastRowGroupID()); new_row_group.pop_back(); - ASSERT_TRUE(new_row_group == page.get_row_group_ids()); + ASSERT_TRUE(new_row_group == page.GetRowGroupIds()); } TEST_F(TestShardPage, TestJson) { diff --git a/tests/ut/cpp/mindrecord/ut_shard_schema_test.cc b/tests/ut/cpp/mindrecord/ut_shard_schema_test.cc index cce000fd28..8d9654a5ef 100644 --- a/tests/ut/cpp/mindrecord/ut_shard_schema_test.cc +++ b/tests/ut/cpp/mindrecord/ut_shard_schema_test.cc @@ -107,15 +107,15 @@ TEST_F(TestShardSchema, TestFunction) { std::shared_ptr schema = Schema::Build(desc, schema_content); ASSERT_NE(schema, nullptr); - ASSERT_EQ(schema->get_desc(), desc); + ASSERT_EQ(schema->GetDesc(), desc); json schema_json = schema->GetSchema(); ASSERT_EQ(schema_json["desc"], desc); ASSERT_EQ(schema_json["schema"], schema_content); - ASSERT_EQ(schema->get_schema_id(), -1); - schema->set_schema_id(2); - ASSERT_EQ(schema->get_schema_id(), 2); + ASSERT_EQ(schema->GetSchemaID(), -1); + schema->SetSchemaID(2); + ASSERT_EQ(schema->GetSchemaID(), 2); } TEST_F(TestStatistics, StatisticPart) { @@ -137,8 +137,8 @@ TEST_F(TestStatistics, StatisticPart) { ASSERT_NE(statistics, nullptr); - MS_LOG(INFO) << "test get_desc(), result: " << statistics->get_desc(); - MS_LOG(INFO) << "test get_statistics, result: " << statistics->get_statistics().dump(); + MS_LOG(INFO) << "test GetDesc(), result: " << statistics->GetDesc(); + MS_LOG(INFO) << "test GetStatistics, result: " << statistics->GetStatistics().dump(); statistic_json["test"] = "test"; statistics = Statistics::Build(desc, statistic_json); diff --git a/tests/ut/cpp/mindrecord/ut_shard_writer_test.cc b/tests/ut/cpp/mindrecord/ut_shard_writer_test.cc index 3fa248c2e0..71da456e7c 100644 --- a/tests/ut/cpp/mindrecord/ut_shard_writer_test.cc +++ b/tests/ut/cpp/mindrecord/ut_shard_writer_test.cc @@ -194,8 +194,8 @@ TEST_F(TestShardWriter, TestShardWriterShiftRawPage) { fw.Open(file_names); uint64_t header_size = 1 << 14; uint64_t page_size = 1 << 15; - fw.set_header_size(header_size); - fw.set_page_size(page_size); + fw.SetHeaderSize(header_size); + fw.SetPageSize(page_size); // set shardHeader fw.SetShardHeader(std::make_shared(header_data)); @@ -331,8 +331,8 @@ TEST_F(TestShardWriter, TestShardWriterTrial) { fw.Open(file_names); uint64_t header_size = 1 << 14; uint64_t page_size = 1 << 17; - fw.set_header_size(header_size); - fw.set_page_size(page_size); + fw.SetHeaderSize(header_size); + fw.SetPageSize(page_size); // set shardHeader fw.SetShardHeader(std::make_shared(header_data)); @@ -466,8 +466,8 @@ TEST_F(TestShardWriter, TestShardWriterTrialNoFields) { fw.Open(file_names); uint64_t header_size = 1 << 14; uint64_t page_size = 1 << 17; - fw.set_header_size(header_size); - fw.set_page_size(page_size); + fw.SetHeaderSize(header_size); + fw.SetPageSize(page_size); // set shardHeader fw.SetShardHeader(std::make_shared(header_data)); @@ -567,8 +567,8 @@ TEST_F(TestShardWriter, DataCheck) { fw.Open(file_names); uint64_t header_size = 1 << 14; uint64_t page_size = 1 << 17; - fw.set_header_size(header_size); - fw.set_page_size(page_size); + fw.SetHeaderSize(header_size); + fw.SetPageSize(page_size); // set shardHeader fw.SetShardHeader(std::make_shared(header_data)); @@ -668,8 +668,8 @@ TEST_F(TestShardWriter, AllRawDataWrong) { fw.Open(file_names); uint64_t header_size = 1 << 14; uint64_t page_size = 1 << 17; - fw.set_header_size(header_size); - fw.set_page_size(page_size); + fw.SetHeaderSize(header_size); + fw.SetPageSize(page_size); // set shardHeader fw.SetShardHeader(std::make_shared(header_data));