| @@ -60,6 +60,7 @@ | |||||
| #include "dataset/kernels/data/to_float16_op.h" | #include "dataset/kernels/data/to_float16_op.h" | ||||
| #include "dataset/util/random.h" | #include "dataset/util/random.h" | ||||
| #include "mindrecord/include/shard_operator.h" | #include "mindrecord/include/shard_operator.h" | ||||
| #include "mindrecord/include/shard_pk_sample.h" | |||||
| #include "mindrecord/include/shard_sample.h" | #include "mindrecord/include/shard_sample.h" | ||||
| #include "pybind11/pybind11.h" | #include "pybind11/pybind11.h" | ||||
| #include "pybind11/stl.h" | #include "pybind11/stl.h" | ||||
| @@ -152,9 +153,14 @@ void bindDatasetOps(py::module *m) { | |||||
| }); | }); | ||||
| (void)py::class_<MindRecordOp, DatasetOp, std::shared_ptr<MindRecordOp>>(*m, "MindRecordOp") | (void)py::class_<MindRecordOp, DatasetOp, std::shared_ptr<MindRecordOp>>(*m, "MindRecordOp") | ||||
| .def_static("get_num_rows", [](const std::string &path) { | |||||
| .def_static("get_num_rows", [](const std::string &path, const py::object &sampler) { | |||||
| int64_t count = 0; | int64_t count = 0; | ||||
| THROW_IF_ERROR(MindRecordOp::CountTotalRows(path, &count)); | |||||
| std::shared_ptr<mindrecord::ShardOperator> op; | |||||
| if (py::hasattr(sampler, "_create_for_minddataset")) { | |||||
| auto create = sampler.attr("_create_for_minddataset"); | |||||
| op = create().cast<std::shared_ptr<mindrecord::ShardOperator>>(); | |||||
| } | |||||
| THROW_IF_ERROR(MindRecordOp::CountTotalRows(path, op, &count)); | |||||
| return count; | return count; | ||||
| }); | }); | ||||
| @@ -435,6 +441,16 @@ void bindSamplerOps(py::module *m) { | |||||
| (void)py::class_<mindrecord::ShardSample, mindrecord::ShardOperator, std::shared_ptr<mindrecord::ShardSample>>( | (void)py::class_<mindrecord::ShardSample, mindrecord::ShardOperator, std::shared_ptr<mindrecord::ShardSample>>( | ||||
| *m, "MindrecordSubsetRandomSampler") | *m, "MindrecordSubsetRandomSampler") | ||||
| .def(py::init<std::vector<int64_t>, uint32_t>(), py::arg("indices"), py::arg("seed") = GetSeed()); | .def(py::init<std::vector<int64_t>, uint32_t>(), py::arg("indices"), py::arg("seed") = GetSeed()); | ||||
| (void)py::class_<mindrecord::ShardPkSample, mindrecord::ShardOperator, std::shared_ptr<mindrecord::ShardPkSample>>( | |||||
| *m, "MindrecordPkSampler") | |||||
| .def(py::init([](int64_t kVal, bool shuffle) { | |||||
| if (shuffle == true) { | |||||
| return std::make_shared<mindrecord::ShardPkSample>("label", kVal, std::numeric_limits<int64_t>::max(), | |||||
| GetSeed()); | |||||
| } else { | |||||
| return std::make_shared<mindrecord::ShardPkSample>("label", kVal); | |||||
| } | |||||
| })); | |||||
| (void)py::class_<WeightedRandomSampler, Sampler, std::shared_ptr<WeightedRandomSampler>>(*m, "WeightedRandomSampler") | (void)py::class_<WeightedRandomSampler, Sampler, std::shared_ptr<WeightedRandomSampler>>(*m, "WeightedRandomSampler") | ||||
| .def(py::init<std::vector<double>, int64_t, bool>(), py::arg("weights"), py::arg("numSamples"), | .def(py::init<std::vector<double>, int64_t, bool>(), py::arg("weights"), py::arg("numSamples"), | ||||
| @@ -655,9 +655,10 @@ Status MindRecordOp::LaunchThreadAndInitOp() { | |||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| Status MindRecordOp::CountTotalRows(const std::string dataset_path, int64_t *count) { | |||||
| Status MindRecordOp::CountTotalRows(const std::string dataset_path, const std::shared_ptr<ShardOperator> &op, | |||||
| int64_t *count) { | |||||
| std::unique_ptr<ShardReader> shard_reader = std::make_unique<ShardReader>(); | std::unique_ptr<ShardReader> shard_reader = std::make_unique<ShardReader>(); | ||||
| MSRStatus rc = shard_reader->CountTotalRows(dataset_path, count); | |||||
| MSRStatus rc = shard_reader->CountTotalRows(dataset_path, op, count); | |||||
| if (rc == MSRStatus::FAILED) { | if (rc == MSRStatus::FAILED) { | ||||
| RETURN_STATUS_UNEXPECTED("MindRecordOp count total rows failed."); | RETURN_STATUS_UNEXPECTED("MindRecordOp count total rows failed."); | ||||
| } | } | ||||
| @@ -171,7 +171,8 @@ class MindRecordOp : public ParallelOp { | |||||
| int32_t num_rows() const { return num_rows_; } | int32_t num_rows() const { return num_rows_; } | ||||
| // Getter method | // Getter method | ||||
| static Status CountTotalRows(const std::string dataset_path, int64_t *count); | |||||
| static Status CountTotalRows(const std::string dataset_path, const std::shared_ptr<ShardOperator> &op, | |||||
| int64_t *count); | |||||
| // Getter method | // Getter method | ||||
| int32_t rows_per_buffer() const { return rows_per_buffer_; } | int32_t rows_per_buffer() const { return rows_per_buffer_; } | ||||
| @@ -72,6 +72,8 @@ enum ShardType { | |||||
| enum SamplerType { kCustomTopNSampler, kCustomTopPercentSampler, kSubsetRandomSampler, kPKSampler }; | enum SamplerType { kCustomTopNSampler, kCustomTopPercentSampler, kSubsetRandomSampler, kPKSampler }; | ||||
| enum ShuffleType { kShuffleCategory, kShuffleSample }; | |||||
| const double kEpsilon = 1e-7; | const double kEpsilon = 1e-7; | ||||
| const int kThreadNumber = 14; | const int kThreadNumber = 14; | ||||
| @@ -17,6 +17,8 @@ | |||||
| #ifndef MINDRECORD_INCLUDE_SHARD_CATEGORY_H_ | #ifndef MINDRECORD_INCLUDE_SHARD_CATEGORY_H_ | ||||
| #define MINDRECORD_INCLUDE_SHARD_CATEGORY_H_ | #define MINDRECORD_INCLUDE_SHARD_CATEGORY_H_ | ||||
| #include <algorithm> | |||||
| #include <limits> | |||||
| #include <string> | #include <string> | ||||
| #include <utility> | #include <utility> | ||||
| #include <vector> | #include <vector> | ||||
| @@ -26,16 +28,34 @@ namespace mindspore { | |||||
| namespace mindrecord { | namespace mindrecord { | ||||
| class ShardCategory : public ShardOperator { | class ShardCategory : public ShardOperator { | ||||
| public: | public: | ||||
| explicit ShardCategory(const std::vector<std::pair<std::string, std::string>> &categories); | |||||
| explicit ShardCategory(const std::vector<std::pair<std::string, std::string>> &categories, | |||||
| int64_t num_elements = std::numeric_limits<int64_t>::max(), bool replacement = false); | |||||
| ShardCategory(const std::string &category_field, int64_t num_elements, | |||||
| int64_t num_categories = std::numeric_limits<int64_t>::max(), bool replacement = false); | |||||
| ~ShardCategory() override{}; | ~ShardCategory() override{}; | ||||
| const std::vector<std::pair<std::string, std::string>> &get_categories() const; | |||||
| const std::vector<std::pair<std::string, std::string>> &get_categories() const { return categories_; } | |||||
| const std::string GetCategoryField() const { return category_field_; } | |||||
| int64_t GetNumElements() const { return num_elements_; } | |||||
| int64_t GetNumCategories() const { return num_categories_; } | |||||
| bool GetReplacement() const { return replacement_; } | |||||
| MSRStatus execute(ShardTask &tasks) override; | MSRStatus execute(ShardTask &tasks) override; | ||||
| int64_t GetNumSamples(int64_t dataset_size, int64_t num_classes) override; | |||||
| private: | private: | ||||
| std::vector<std::pair<std::string, std::string>> categories_; | std::vector<std::pair<std::string, std::string>> categories_; | ||||
| std::string category_field_; | |||||
| int64_t num_elements_; | |||||
| int64_t num_categories_; | |||||
| bool replacement_; | |||||
| }; | }; | ||||
| } // namespace mindrecord | } // namespace mindrecord | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -43,6 +43,8 @@ class ShardOperator { | |||||
| virtual MSRStatus execute(ShardTask &tasks) = 0; | virtual MSRStatus execute(ShardTask &tasks) = 0; | ||||
| virtual MSRStatus suf_execute(ShardTask &tasks) { return SUCCESS; } | virtual MSRStatus suf_execute(ShardTask &tasks) { return SUCCESS; } | ||||
| virtual int64_t GetNumSamples(int64_t dataset_size, int64_t num_classes) { return -1; } | |||||
| }; | }; | ||||
| } // namespace mindrecord | } // namespace mindrecord | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -0,0 +1,49 @@ | |||||
| /** | |||||
| * Copyright 2019 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDRECORD_INCLUDE_SHARD_PK_SAMPLE_H_ | |||||
| #define MINDRECORD_INCLUDE_SHARD_PK_SAMPLE_H_ | |||||
| #include <memory> | |||||
| #include <string> | |||||
| #include <utility> | |||||
| #include <vector> | |||||
| #include "mindrecord/include/shard_operator.h" | |||||
| #include "mindrecord/include/shard_shuffle.h" | |||||
| #include "mindrecord/include/shard_category.h" | |||||
| namespace mindspore { | |||||
| namespace mindrecord { | |||||
| class ShardPkSample : public ShardCategory { | |||||
| public: | |||||
| ShardPkSample(const std::string &category_field, int64_t num_elements); | |||||
| ShardPkSample(const std::string &category_field, int64_t num_elements, int64_t num_categories); | |||||
| ShardPkSample(const std::string &category_field, int64_t num_elements, int64_t num_categories, uint32_t seed); | |||||
| ~ShardPkSample() override{}; | |||||
| MSRStatus suf_execute(ShardTask &tasks) override; | |||||
| private: | |||||
| bool shuffle_; | |||||
| std::shared_ptr<ShardShuffle> shuffle_op_; | |||||
| }; | |||||
| } // namespace mindrecord | |||||
| } // namespace mindspore | |||||
| #endif // MINDRECORD_INCLUDE_SHARD_PK_SAMPLE_H_ | |||||
| @@ -115,9 +115,10 @@ class ShardReader { | |||||
| /// \brief get the number of rows in database | /// \brief get the number of rows in database | ||||
| /// \param[in] file_path the path of ONE file, any file in dataset is fine | /// \param[in] file_path the path of ONE file, any file in dataset is fine | ||||
| /// \param[in] op smart pointer refer to ShardCategory or ShardSample object | |||||
| /// \param[out] count # of rows | /// \param[out] count # of rows | ||||
| /// \return MSRStatus the status of MSRStatus | /// \return MSRStatus the status of MSRStatus | ||||
| MSRStatus CountTotalRows(const std::string &file_path, int64_t *count); | |||||
| MSRStatus CountTotalRows(const std::string &file_path, const std::shared_ptr<ShardOperator> &op, int64_t *count); | |||||
| /// \brief shuffle task with incremental seed | /// \brief shuffle task with incremental seed | ||||
| /// \return void | /// \return void | ||||
| @@ -197,6 +198,9 @@ class ShardReader { | |||||
| /// \brief get NLP flag | /// \brief get NLP flag | ||||
| bool get_nlp_flag(); | bool get_nlp_flag(); | ||||
| /// \brief get all classes | |||||
| MSRStatus GetAllClasses(const std::string &category_field, std::set<std::string> &categories); | |||||
| protected: | protected: | ||||
| /// \brief sqlite call back function | /// \brief sqlite call back function | ||||
| static int SelectCallback(void *p_data, int num_fields, char **p_fields, char **p_col_names); | static int SelectCallback(void *p_data, int num_fields, char **p_fields, char **p_col_names); | ||||
| @@ -249,8 +253,8 @@ class ShardReader { | |||||
| const std::vector<std::shared_ptr<ShardOperator>> &operators); | const std::vector<std::shared_ptr<ShardOperator>> &operators); | ||||
| /// \brief create category-applied task list | /// \brief create category-applied task list | ||||
| int CreateTasksByCategory(const std::vector<std::tuple<int, int, int, uint64_t>> &row_group_summary, | |||||
| const std::vector<std::shared_ptr<ShardOperator>> &operators); | |||||
| MSRStatus CreateTasksByCategory(const std::vector<std::tuple<int, int, int, uint64_t>> &row_group_summary, | |||||
| const std::shared_ptr<ShardOperator> &op); | |||||
| /// \brief create task list in row-reader mode | /// \brief create task list in row-reader mode | ||||
| MSRStatus CreateTasksByRow(const std::vector<std::tuple<int, int, int, uint64_t>> &row_group_summary, | MSRStatus CreateTasksByRow(const std::vector<std::tuple<int, int, int, uint64_t>> &row_group_summary, | ||||
| @@ -284,6 +288,12 @@ class ShardReader { | |||||
| MSRStatus ReadBlob(const int &shard_id, const uint64_t &page_offset, const int &page_length, const int &buf_id); | MSRStatus ReadBlob(const int &shard_id, const uint64_t &page_offset, const int &page_length, const int &buf_id); | ||||
| /// \brief get classes in one shard | |||||
| void GetClassesInShard(sqlite3 *db, int shard_id, const std::string sql, std::set<std::string> &categories); | |||||
| /// \brief get number of classes | |||||
| int64_t GetNumClasses(const std::string &file_path, const std::string &category_field); | |||||
| protected: | protected: | ||||
| uint64_t header_size_; // header size | uint64_t header_size_; // header size | ||||
| uint64_t page_size_; // page size | uint64_t page_size_; // page size | ||||
| @@ -41,8 +41,11 @@ class ShardSample : public ShardOperator { | |||||
| const std::pair<int, int> get_partitions() const; | const std::pair<int, int> get_partitions() const; | ||||
| MSRStatus execute(ShardTask &tasks) override; | MSRStatus execute(ShardTask &tasks) override; | ||||
| MSRStatus suf_execute(ShardTask &tasks) override; | MSRStatus suf_execute(ShardTask &tasks) override; | ||||
| int64_t GetNumSamples(int64_t dataset_size, int64_t num_classes) override; | |||||
| private: | private: | ||||
| int numerator_; | int numerator_; | ||||
| int denominator_; | int denominator_; | ||||
| @@ -24,7 +24,7 @@ namespace mindspore { | |||||
| namespace mindrecord { | namespace mindrecord { | ||||
| class ShardShuffle : public ShardOperator { | class ShardShuffle : public ShardOperator { | ||||
| public: | public: | ||||
| explicit ShardShuffle(uint32_t seed = 0); | |||||
| explicit ShardShuffle(uint32_t seed = 0, ShuffleType shuffle_type = kShuffleCategory); | |||||
| ~ShardShuffle() override{}; | ~ShardShuffle() override{}; | ||||
| @@ -32,6 +32,7 @@ class ShardShuffle : public ShardOperator { | |||||
| private: | private: | ||||
| uint32_t shuffle_seed_; | uint32_t shuffle_seed_; | ||||
| ShuffleType shuffle_type_; | |||||
| }; | }; | ||||
| } // namespace mindrecord | } // namespace mindrecord | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -41,7 +41,9 @@ class ShardTask { | |||||
| std::tuple<std::tuple<int, int>, std::vector<uint64_t>, json> &get_task_by_id(size_t id); | std::tuple<std::tuple<int, int>, std::vector<uint64_t>, json> &get_task_by_id(size_t id); | ||||
| static ShardTask Combine(std::vector<ShardTask> &category_tasks); | |||||
| std::tuple<std::tuple<int, int>, std::vector<uint64_t>, json> &get_random_task(); | |||||
| static ShardTask Combine(std::vector<ShardTask> &category_tasks, bool replacement, int64_t num_elements); | |||||
| uint32_t categories = 1; | uint32_t categories = 1; | ||||
| @@ -315,6 +315,43 @@ MSRStatus ShardReader::ReadAllRowsInShard(int shard_id, const std::string &sql, | |||||
| return ConvertLabelToJson(labels, fs, offsets, shard_id, columns, column_values); | return ConvertLabelToJson(labels, fs, offsets, shard_id, columns, column_values); | ||||
| } | } | ||||
| MSRStatus ShardReader::GetAllClasses(const std::string &category_field, std::set<std::string> &categories) { | |||||
| auto ret = ShardIndexGenerator::GenerateFieldName(std::make_pair(column_schema_id_[category_field], category_field)); | |||||
| if (SUCCESS != ret.first) { | |||||
| return FAILED; | |||||
| } | |||||
| std::string sql = "SELECT DISTINCT " + ret.second + " FROM INDEXES"; | |||||
| std::vector<std::thread> threads = std::vector<std::thread>(shard_count_); | |||||
| for (int x = 0; x < shard_count_; x++) { | |||||
| threads[x] = std::thread(&ShardReader::GetClassesInShard, this, database_paths_[x], x, sql, std::ref(categories)); | |||||
| } | |||||
| for (int x = 0; x < shard_count_; x++) { | |||||
| threads[x].join(); | |||||
| } | |||||
| return SUCCESS; | |||||
| } | |||||
| void ShardReader::GetClassesInShard(sqlite3 *db, int shard_id, const std::string sql, | |||||
| std::set<std::string> &categories) { | |||||
| if (nullptr == db) { | |||||
| return; | |||||
| } | |||||
| std::vector<std::vector<std::string>> columns; | |||||
| char *errmsg = nullptr; | |||||
| int ret = sqlite3_exec(db, common::SafeCStr(sql), SelectCallback, &columns, &errmsg); | |||||
| if (ret != SQLITE_OK) { | |||||
| sqlite3_free(errmsg); | |||||
| sqlite3_close(db); | |||||
| MS_LOG(ERROR) << "Error in select sql statement, sql:" << common::SafeCStr(sql) << ", error: " << errmsg; | |||||
| return; | |||||
| } | |||||
| MS_LOG(INFO) << "Get" << static_cast<int>(columns.size()) << " records from shard " << shard_id << " index."; | |||||
| for (int i = 0; i < static_cast<int>(columns.size()); ++i) { | |||||
| categories.emplace(columns[i][0]); | |||||
| } | |||||
| } | |||||
| ROW_GROUPS ShardReader::ReadAllRowGroup(std::vector<std::string> &columns) { | ROW_GROUPS ShardReader::ReadAllRowGroup(std::vector<std::string> &columns) { | ||||
| std::string fields = "ROW_GROUP_ID, PAGE_OFFSET_BLOB, PAGE_OFFSET_BLOB_END"; | std::string fields = "ROW_GROUP_ID, PAGE_OFFSET_BLOB, PAGE_OFFSET_BLOB_END"; | ||||
| std::vector<std::vector<std::vector<uint64_t>>> offsets(shard_count_, std::vector<std::vector<uint64_t>>{}); | std::vector<std::vector<std::vector<uint64_t>>> offsets(shard_count_, std::vector<std::vector<uint64_t>>{}); | ||||
| @@ -667,11 +704,64 @@ MSRStatus ShardReader::Finish() { | |||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| MSRStatus ShardReader::CountTotalRows(const std::string &file_path, int64_t *count) { | |||||
| int64_t ShardReader::GetNumClasses(const std::string &file_path, const std::string &category_field) { | |||||
| ShardHeader sh = ShardHeader(); | |||||
| if (sh.Build(file_path) == FAILED) { | |||||
| return -1; | |||||
| } | |||||
| auto header = std::make_shared<ShardHeader>(sh); | |||||
| auto file_paths = header->get_shard_addresses(); | |||||
| auto shard_count = file_paths.size(); | |||||
| auto index_fields = header->get_fields(); | |||||
| std::map<std::string, int64_t> map_schema_id_fields; | |||||
| for (auto &field : index_fields) { | |||||
| map_schema_id_fields[field.second] = field.first; | |||||
| } | |||||
| auto ret = | |||||
| ShardIndexGenerator::GenerateFieldName(std::make_pair(map_schema_id_fields[category_field], category_field)); | |||||
| if (SUCCESS != ret.first) { | |||||
| return -1; | |||||
| } | |||||
| std::string sql = "SELECT DISTINCT " + ret.second + " FROM INDEXES"; | |||||
| std::vector<std::thread> threads = std::vector<std::thread>(shard_count); | |||||
| std::set<std::string> categories; | |||||
| for (int x = 0; x < shard_count; x++) { | |||||
| sqlite3 *db = nullptr; | |||||
| int rc = sqlite3_open_v2(common::SafeCStr(file_paths[x] + ".db"), &db, SQLITE_OPEN_READONLY, nullptr); | |||||
| if (SQLITE_OK != rc) { | |||||
| MS_LOG(ERROR) << "Can't open database, error: " << sqlite3_errmsg(db); | |||||
| return -1; | |||||
| } | |||||
| threads[x] = std::thread(&ShardReader::GetClassesInShard, this, db, x, sql, std::ref(categories)); | |||||
| } | |||||
| for (int x = 0; x < shard_count; x++) { | |||||
| threads[x].join(); | |||||
| } | |||||
| return categories.size(); | |||||
| } | |||||
| MSRStatus ShardReader::CountTotalRows(const std::string &file_path, const std::shared_ptr<ShardOperator> &op, | |||||
| int64_t *count) { | |||||
| if (Init(file_path) == FAILED) { | if (Init(file_path) == FAILED) { | ||||
| return FAILED; | return FAILED; | ||||
| } | } | ||||
| *count = num_rows_; | |||||
| int64_t num_samples = num_rows_; | |||||
| if (std::dynamic_pointer_cast<ShardCategory>(op)) { | |||||
| auto category_op = std::dynamic_pointer_cast<ShardCategory>(op); | |||||
| std::string category_field = category_op->GetCategoryField(); | |||||
| auto num_classes = GetNumClasses(file_path, category_field); | |||||
| num_samples = category_op->GetNumSamples(num_rows_, num_classes); | |||||
| } else if (std::dynamic_pointer_cast<ShardSample>(op)) { | |||||
| num_samples = op->GetNumSamples(num_rows_, 0); | |||||
| } else { | |||||
| } | |||||
| if (-1 == num_samples) { | |||||
| MS_LOG(ERROR) << "Failed to get dataset size."; | |||||
| return FAILED; | |||||
| } | |||||
| *count = num_samples; | |||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| @@ -793,6 +883,8 @@ MSRStatus ShardReader::Launch(bool isSimpleReader) { | |||||
| thread_set_[x] = std::thread(&ShardReader::ConsumerByRow, this, x); | thread_set_[x] = std::thread(&ShardReader::ConsumerByRow, this, x); | ||||
| } | } | ||||
| } | } | ||||
| MS_LOG(INFO) << "Launch read thread successfully."; | |||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| @@ -828,44 +920,67 @@ MSRStatus ShardReader::CreateTasksByBlock(const std::vector<std::tuple<int, int, | |||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| int ShardReader::CreateTasksByCategory(const std::vector<std::tuple<int, int, int, uint64_t>> &row_group_summary, | |||||
| const std::vector<std::shared_ptr<ShardOperator>> &operators) { | |||||
| MSRStatus ShardReader::CreateTasksByCategory(const std::vector<std::tuple<int, int, int, uint64_t>> &row_group_summary, | |||||
| const std::shared_ptr<ShardOperator> &op) { | |||||
| vector<std::string> columns = GetAllColumns(); | vector<std::string> columns = GetAllColumns(); | ||||
| CheckIfColumnInIndex(columns); | CheckIfColumnInIndex(columns); | ||||
| int category_operator = -1; | |||||
| for (uint32_t i = 0; i < operators.size(); ++i) { | |||||
| const auto &op = operators[i]; | |||||
| if (std::dynamic_pointer_cast<ShardCategory>(op)) category_operator = static_cast<int>(i); | |||||
| auto category_op = std::dynamic_pointer_cast<ShardCategory>(op); | |||||
| auto categories = category_op->get_categories(); | |||||
| int64_t num_elements = category_op->GetNumElements(); | |||||
| if (num_elements <= 0) { | |||||
| MS_LOG(ERROR) << "Parameter num_element is not positive"; | |||||
| return FAILED; | |||||
| } | |||||
| if (categories.empty() == true) { | |||||
| std::string category_field = category_op->GetCategoryField(); | |||||
| int64_t num_categories = category_op->GetNumCategories(); | |||||
| if (num_categories <= 0) { | |||||
| MS_LOG(ERROR) << "Parameter num_categories is not positive"; | |||||
| return FAILED; | |||||
| } | |||||
| std::set<std::string> categories_set; | |||||
| auto ret = GetAllClasses(category_field, categories_set); | |||||
| if (SUCCESS != ret) { | |||||
| return FAILED; | |||||
| } | |||||
| int i = 0; | |||||
| for (auto it = categories_set.begin(); it != categories_set.end() && i < num_categories; ++it) { | |||||
| categories.emplace_back(category_field, *it); | |||||
| i++; | |||||
| } | |||||
| } | } | ||||
| if (category_operator == -1) return category_operator; | |||||
| auto categories = std::dynamic_pointer_cast<ShardCategory>(operators[category_operator])->get_categories(); | |||||
| // Generate task list, a task will create a batch | // Generate task list, a task will create a batch | ||||
| std::vector<ShardTask> categoryTasks(categories.size()); | std::vector<ShardTask> categoryTasks(categories.size()); | ||||
| for (uint32_t categoryNo = 0; categoryNo < categories.size(); ++categoryNo) { | for (uint32_t categoryNo = 0; categoryNo < categories.size(); ++categoryNo) { | ||||
| int category_index = 0; | |||||
| for (const auto &rg : row_group_summary) { | for (const auto &rg : row_group_summary) { | ||||
| if (category_index >= num_elements) break; | |||||
| auto shard_id = std::get<0>(rg); | auto shard_id = std::get<0>(rg); | ||||
| auto group_id = std::get<1>(rg); | auto group_id = std::get<1>(rg); | ||||
| auto details = ReadRowGroupCriteria(group_id, shard_id, categories[categoryNo], columns); | auto details = ReadRowGroupCriteria(group_id, shard_id, categories[categoryNo], columns); | ||||
| if (SUCCESS != std::get<0>(details)) { | if (SUCCESS != std::get<0>(details)) { | ||||
| return -2; | |||||
| return FAILED; | |||||
| } | } | ||||
| auto offsets = std::get<4>(details); | auto offsets = std::get<4>(details); | ||||
| auto number_of_rows = offsets.size(); | auto number_of_rows = offsets.size(); | ||||
| for (uint32_t iStart = 0; iStart < number_of_rows; iStart += 1) { | for (uint32_t iStart = 0; iStart < number_of_rows; iStart += 1) { | ||||
| categoryTasks[categoryNo].InsertTask(shard_id, group_id, std::get<4>(details)[iStart], | |||||
| std::get<5>(details)[iStart]); | |||||
| if (category_index < num_elements) { | |||||
| categoryTasks[categoryNo].InsertTask(shard_id, group_id, std::get<4>(details)[iStart], | |||||
| std::get<5>(details)[iStart]); | |||||
| category_index++; | |||||
| } | |||||
| } | } | ||||
| } | } | ||||
| MS_LOG(INFO) << "Category #" << categoryNo << " has " << categoryTasks[categoryNo].Size() << " tasks"; | MS_LOG(INFO) << "Category #" << categoryNo << " has " << categoryTasks[categoryNo].Size() << " tasks"; | ||||
| } | } | ||||
| tasks_ = ShardTask::Combine(categoryTasks); | |||||
| return category_operator; | |||||
| tasks_ = ShardTask::Combine(categoryTasks, category_op->GetReplacement(), num_elements); | |||||
| if (SUCCESS != (*category_op)(tasks_)) { | |||||
| return FAILED; | |||||
| } | |||||
| return SUCCESS; | |||||
| } | } | ||||
| MSRStatus ShardReader::CreateTasksByRow(const std::vector<std::tuple<int, int, int, uint64_t>> &row_group_summary, | MSRStatus ShardReader::CreateTasksByRow(const std::vector<std::tuple<int, int, int, uint64_t>> &row_group_summary, | ||||
| @@ -896,14 +1011,26 @@ MSRStatus ShardReader::CreateTasksByRow(const std::vector<std::tuple<int, int, i | |||||
| MSRStatus ShardReader::CreateTasks(const std::vector<std::tuple<int, int, int, uint64_t>> &row_group_summary, | MSRStatus ShardReader::CreateTasks(const std::vector<std::tuple<int, int, int, uint64_t>> &row_group_summary, | ||||
| const std::vector<std::shared_ptr<ShardOperator>> &operators) { | const std::vector<std::shared_ptr<ShardOperator>> &operators) { | ||||
| if (block_reader_) { | if (block_reader_) { | ||||
| CreateTasksByBlock(row_group_summary, operators); | |||||
| if (SUCCESS != CreateTasksByBlock(row_group_summary, operators)) { | |||||
| return FAILED; | |||||
| } | |||||
| } else { | } else { | ||||
| int category_operator = CreateTasksByCategory(row_group_summary, operators); | |||||
| if (category_operator == -1) { | |||||
| CreateTasksByRow(row_group_summary, operators); | |||||
| int category_operator = -1; | |||||
| for (uint32_t i = 0; i < operators.size(); ++i) { | |||||
| const auto &op = operators[i]; | |||||
| if (std::dynamic_pointer_cast<ShardCategory>(op)) { | |||||
| category_operator = static_cast<int>(i); | |||||
| break; | |||||
| } | |||||
| } | } | ||||
| if (category_operator == -2) { | |||||
| return FAILED; | |||||
| if (-1 == category_operator) { | |||||
| if (SUCCESS != CreateTasksByRow(row_group_summary, operators)) { | |||||
| return FAILED; | |||||
| } | |||||
| } else { | |||||
| if (SUCCESS != CreateTasksByCategory(row_group_summary, operators[category_operator])) { | |||||
| return FAILED; | |||||
| } | |||||
| } | } | ||||
| } | } | ||||
| @@ -18,11 +18,30 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace mindrecord { | namespace mindrecord { | ||||
| ShardCategory::ShardCategory(const std::vector<std::pair<std::string, std::string>> &categories) | |||||
| : categories_(categories) {} | |||||
| ShardCategory::ShardCategory(const std::vector<std::pair<std::string, std::string>> &categories, int64_t num_elements, | |||||
| bool replacement) | |||||
| : categories_(categories), | |||||
| category_field_(""), | |||||
| num_elements_(num_elements), | |||||
| num_categories_(0), | |||||
| replacement_(replacement) {} | |||||
| const std::vector<std::pair<std::string, std::string>> &ShardCategory::get_categories() const { return categories_; } | |||||
| ShardCategory::ShardCategory(const std::string &category_field, int64_t num_elements, int64_t num_categories, | |||||
| bool replacement) | |||||
| : categories_({}), | |||||
| category_field_(category_field), | |||||
| num_elements_(num_elements), | |||||
| num_categories_(num_categories), | |||||
| replacement_(replacement) {} | |||||
| MSRStatus ShardCategory::execute(ShardTask &tasks) { return SUCCESS; } | MSRStatus ShardCategory::execute(ShardTask &tasks) { return SUCCESS; } | ||||
| int64_t ShardCategory::GetNumSamples(int64_t dataset_size, int64_t num_classes) { | |||||
| if (dataset_size == 0) return dataset_size; | |||||
| if (dataset_size > 0 && num_categories_ > 0 && num_elements_ > 0) { | |||||
| return std::min(num_categories_, num_classes) * num_elements_; | |||||
| } | |||||
| return -1; | |||||
| } | |||||
| } // namespace mindrecord | } // namespace mindrecord | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -0,0 +1,46 @@ | |||||
| /** | |||||
| * Copyright 2019 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "mindrecord/include/shard_pk_sample.h" | |||||
| using mindspore::LogStream; | |||||
| using mindspore::ExceptionType::NoExceptionType; | |||||
| using mindspore::MsLogLevel::ERROR; | |||||
| namespace mindspore { | |||||
| namespace mindrecord { | |||||
| ShardPkSample::ShardPkSample(const std::string &category_field, int64_t num_elements) | |||||
| : ShardCategory(category_field, num_elements, std::numeric_limits<int64_t>::max(), true), shuffle_(false) {} | |||||
| ShardPkSample::ShardPkSample(const std::string &category_field, int64_t num_elements, int64_t num_categories) | |||||
| : ShardCategory(category_field, num_elements, num_categories, true), shuffle_(false) {} | |||||
| ShardPkSample::ShardPkSample(const std::string &category_field, int64_t num_elements, int64_t num_categories, | |||||
| uint32_t seed) | |||||
| : ShardCategory(category_field, num_elements, num_categories, true), shuffle_(true) { | |||||
| shuffle_op_ = std::make_shared<ShardShuffle>(seed, kShuffleSample); // do shuffle and replacement | |||||
| } | |||||
| MSRStatus ShardPkSample::suf_execute(ShardTask &tasks) { | |||||
| if (shuffle_ == true) { | |||||
| if (SUCCESS != (*shuffle_op_)(tasks)) { | |||||
| return FAILED; | |||||
| } | |||||
| } | |||||
| return SUCCESS; | |||||
| } | |||||
| } // namespace mindrecord | |||||
| } // namespace mindspore | |||||
| @@ -56,6 +56,24 @@ ShardSample::ShardSample(const std::vector<int64_t> &indices, uint32_t seed) | |||||
| shuffle_op_ = std::make_shared<ShardShuffle>(seed); | shuffle_op_ = std::make_shared<ShardShuffle>(seed); | ||||
| } | } | ||||
| int64_t ShardSample::GetNumSamples(int64_t dataset_size, int64_t num_classes) { | |||||
| if (sampler_type_ == kCustomTopNSampler) { | |||||
| return no_of_samples_; | |||||
| } | |||||
| if (sampler_type_ == kCustomTopPercentSampler) { | |||||
| if (dataset_size % denominator_ == 0) { | |||||
| return dataset_size / denominator_ * numerator_; | |||||
| } else { | |||||
| return dataset_size / denominator_ * numerator_ + 1; | |||||
| } | |||||
| } | |||||
| if (sampler_type_ == kSubsetRandomSampler) { | |||||
| return indices_.size(); | |||||
| } | |||||
| return -1; | |||||
| } | |||||
| const std::pair<int, int> ShardSample::get_partitions() const { | const std::pair<int, int> ShardSample::get_partitions() const { | ||||
| if (numerator_ == 1 && denominator_ > 1) { | if (numerator_ == 1 && denominator_ > 1) { | ||||
| return std::pair<int, int>(denominator_, partition_id_); | return std::pair<int, int>(denominator_, partition_id_); | ||||
| @@ -20,25 +20,33 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace mindrecord { | namespace mindrecord { | ||||
| ShardShuffle::ShardShuffle(uint32_t seed) : shuffle_seed_(seed) {} | |||||
| ShardShuffle::ShardShuffle(uint32_t seed, ShuffleType shuffle_type) | |||||
| : shuffle_seed_(seed), shuffle_type_(shuffle_type) {} | |||||
| MSRStatus ShardShuffle::execute(ShardTask &tasks) { | MSRStatus ShardShuffle::execute(ShardTask &tasks) { | ||||
| if (tasks.categories < 1) { | if (tasks.categories < 1) { | ||||
| return FAILED; | return FAILED; | ||||
| } | } | ||||
| uint32_t individual_size = tasks.Size() / tasks.categories; | |||||
| std::vector<std::vector<int>> new_permutations(tasks.categories, std::vector<int>(individual_size)); | |||||
| for (uint32_t i = 0; i < tasks.categories; i++) { | |||||
| for (uint32_t j = 0; j < individual_size; j++) new_permutations[i][j] = static_cast<int>(j); | |||||
| std::shuffle(new_permutations[i].begin(), new_permutations[i].end(), std::default_random_engine(shuffle_seed_)); | |||||
| } | |||||
| shuffle_seed_++; | |||||
| tasks.permutation_.clear(); | |||||
| for (uint32_t j = 0; j < individual_size; j++) { | |||||
| if (shuffle_type_ == kShuffleSample) { | |||||
| if (tasks.permutation_.empty() == true) { | |||||
| tasks.MakePerm(); | |||||
| } | |||||
| std::shuffle(tasks.permutation_.begin(), tasks.permutation_.end(), std::default_random_engine(shuffle_seed_)); | |||||
| } else { // shuffle unit like: (a1, b1, c1),(a2, b2, c2),..., (an, bn, cn) | |||||
| uint32_t individual_size = tasks.Size() / tasks.categories; | |||||
| std::vector<std::vector<int>> new_permutations(tasks.categories, std::vector<int>(individual_size)); | |||||
| for (uint32_t i = 0; i < tasks.categories; i++) { | for (uint32_t i = 0; i < tasks.categories; i++) { | ||||
| tasks.permutation_.push_back(new_permutations[i][j] * static_cast<int>(tasks.categories) + static_cast<int>(i)); | |||||
| for (uint32_t j = 0; j < individual_size; j++) new_permutations[i][j] = static_cast<int>(j); | |||||
| std::shuffle(new_permutations[i].begin(), new_permutations[i].end(), std::default_random_engine(shuffle_seed_)); | |||||
| } | |||||
| tasks.permutation_.clear(); | |||||
| for (uint32_t j = 0; j < individual_size; j++) { | |||||
| for (uint32_t i = 0; i < tasks.categories; i++) { | |||||
| tasks.permutation_.push_back(new_permutations[i][j] * static_cast<int>(tasks.categories) + static_cast<int>(i)); | |||||
| } | |||||
| } | } | ||||
| } | } | ||||
| shuffle_seed_++; | |||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| } // namespace mindrecord | } // namespace mindrecord | ||||
| @@ -35,8 +35,6 @@ void ShardTask::InsertTask(int shard_id, int group_id, const std::vector<uint64_ | |||||
| MS_LOG(DEBUG) << "Into insert task, shard_id: " << shard_id << ", group_id: " << group_id | MS_LOG(DEBUG) << "Into insert task, shard_id: " << shard_id << ", group_id: " << group_id | ||||
| << ", label: " << label.dump() << ", size of task_list_: " << task_list_.size() << "."; | << ", label: " << label.dump() << ", size of task_list_: " << task_list_.size() << "."; | ||||
| task_list_.emplace_back(std::make_tuple(shard_id, group_id), offset, label); | task_list_.emplace_back(std::make_tuple(shard_id, group_id), offset, label); | ||||
| MS_LOG(DEBUG) << "Out of insert task, shard_id: " << shard_id << ", group_id: " << group_id | |||||
| << ", label: " << label.dump() << ", size of task_list_: " << task_list_.size() << "."; | |||||
| } | } | ||||
| void ShardTask::InsertTask(std::tuple<std::tuple<int, int>, std::vector<uint64_t>, json> task) { | void ShardTask::InsertTask(std::tuple<std::tuple<int, int>, std::vector<uint64_t>, json> task) { | ||||
| @@ -44,9 +42,6 @@ void ShardTask::InsertTask(std::tuple<std::tuple<int, int>, std::vector<uint64_t | |||||
| << ", group_id: " << std::get<1>(std::get<0>(task)) << ", label: " << std::get<2>(task).dump() | << ", group_id: " << std::get<1>(std::get<0>(task)) << ", label: " << std::get<2>(task).dump() | ||||
| << ", size of task_list_: " << task_list_.size() << "."; | << ", size of task_list_: " << task_list_.size() << "."; | ||||
| task_list_.push_back(std::move(task)); | task_list_.push_back(std::move(task)); | ||||
| MS_LOG(DEBUG) << "Out of insert task, shard_id: " << std::get<0>(std::get<0>(task)) | |||||
| << ", group_id: " << std::get<1>(std::get<0>(task)) << ", label: " << std::get<2>(task).dump() | |||||
| << ", size of task_list_: " << task_list_.size() << "."; | |||||
| } | } | ||||
| void ShardTask::PopBack() { task_list_.pop_back(); } | void ShardTask::PopBack() { task_list_.pop_back(); } | ||||
| @@ -69,18 +64,39 @@ std::tuple<std::tuple<int, int>, std::vector<uint64_t>, json> &ShardTask::get_ta | |||||
| return task_list_[id]; | return task_list_[id]; | ||||
| } | } | ||||
| ShardTask ShardTask::Combine(std::vector<ShardTask> &category_tasks) { | |||||
| std::tuple<std::tuple<int, int>, std::vector<uint64_t>, json> &ShardTask::get_random_task() { | |||||
| std::random_device rd; | |||||
| std::mt19937 gen(rd()); | |||||
| std::uniform_int_distribution<> dis(0, task_list_.size() - 1); | |||||
| return task_list_[dis(gen)]; | |||||
| } | |||||
| ShardTask ShardTask::Combine(std::vector<ShardTask> &category_tasks, bool replacement, int64_t num_elements) { | |||||
| ShardTask res; | ShardTask res; | ||||
| if (category_tasks.empty()) return res; | if (category_tasks.empty()) return res; | ||||
| auto total_categories = category_tasks.size(); | auto total_categories = category_tasks.size(); | ||||
| res.categories = static_cast<uint32_t>(total_categories); | res.categories = static_cast<uint32_t>(total_categories); | ||||
| auto minTasks = category_tasks[0].Size(); | |||||
| for (uint32_t i = 1; i < total_categories; i++) { | |||||
| minTasks = std::min(minTasks, category_tasks[i].Size()); | |||||
| } | |||||
| for (uint32_t task_no = 0; task_no < minTasks; task_no++) { | |||||
| if (replacement == false) { | |||||
| auto minTasks = category_tasks[0].Size(); | |||||
| for (uint32_t i = 1; i < total_categories; i++) { | |||||
| minTasks = std::min(minTasks, category_tasks[i].Size()); | |||||
| } | |||||
| for (uint32_t task_no = 0; task_no < minTasks; task_no++) { | |||||
| for (uint32_t i = 0; i < total_categories; i++) { | |||||
| res.InsertTask(std::move(category_tasks[i].get_task_by_id(static_cast<int>(task_no)))); | |||||
| } | |||||
| } | |||||
| } else { | |||||
| auto maxTasks = category_tasks[0].Size(); | |||||
| for (uint32_t i = 1; i < total_categories; i++) { | |||||
| maxTasks = std::max(maxTasks, category_tasks[i].Size()); | |||||
| } | |||||
| if (num_elements != std::numeric_limits<int64_t>::max()) { | |||||
| maxTasks = static_cast<decltype(maxTasks)>(num_elements); | |||||
| } | |||||
| for (uint32_t i = 0; i < total_categories; i++) { | for (uint32_t i = 0; i < total_categories; i++) { | ||||
| res.InsertTask(std::move(category_tasks[i].get_task_by_id(static_cast<int>(task_no)))); | |||||
| for (uint32_t j = 0; j < maxTasks; j++) { | |||||
| res.InsertTask(category_tasks[i].get_random_task()); | |||||
| } | |||||
| } | } | ||||
| } | } | ||||
| return res; | return res; | ||||
| @@ -1882,7 +1882,8 @@ class MindDataset(SourceDataset): | |||||
| block_reader (bool, optional): Whether read data by block mode (default=False). | block_reader (bool, optional): Whether read data by block mode (default=False). | ||||
| sampler (Sampler, optional): Object used to choose samples from the | sampler (Sampler, optional): Object used to choose samples from the | ||||
| dataset (default=None, sampler is exclusive | dataset (default=None, sampler is exclusive | ||||
| with shuffle and block_reader). Support list: SubsetRandomSampler. | |||||
| with shuffle and block_reader). Support list: SubsetRandomSampler, | |||||
| PkSampler | |||||
| Raises: | Raises: | ||||
| ValueError: If num_shards is specified but shard_id is None. | ValueError: If num_shards is specified but shard_id is None. | ||||
| @@ -1915,8 +1916,10 @@ class MindDataset(SourceDataset): | |||||
| if block_reader is True: | if block_reader is True: | ||||
| logger.warning("WARN: global shuffle is not used.") | logger.warning("WARN: global shuffle is not used.") | ||||
| if sampler is not None and isinstance(sampler, samplers.SubsetRandomSampler) is False: | |||||
| raise ValueError("the sampler is not supported yet.") | |||||
| if sampler is not None: | |||||
| if isinstance(sampler, samplers.SubsetRandomSampler) is False and \ | |||||
| isinstance(sampler, samplers.PKSampler) is False: | |||||
| raise ValueError("the sampler is not supported yet.") | |||||
| # sampler exclusive | # sampler exclusive | ||||
| if block_reader is True and sampler is not None: | if block_reader is True and sampler is not None: | ||||
| @@ -1952,7 +1955,7 @@ class MindDataset(SourceDataset): | |||||
| Number, number of batches. | Number, number of batches. | ||||
| """ | """ | ||||
| num_rows = MindRecordOp.get_num_rows(self.dataset_file) | |||||
| num_rows = MindRecordOp.get_num_rows(self.dataset_file, self.sampler) | |||||
| if self.partitions is not None and self.partitions[0] > 0: | if self.partitions is not None and self.partitions[0] > 0: | ||||
| if num_rows % self.partitions[0] == 0: | if num_rows % self.partitions[0] == 0: | ||||
| num_rows = num_rows // self.partitions[0] | num_rows = num_rows // self.partitions[0] | ||||
| @@ -184,6 +184,8 @@ class PKSampler(BuiltinSampler): | |||||
| def create(self): | def create(self): | ||||
| return cde.PKSampler(self.num_val, self.shuffle) | return cde.PKSampler(self.num_val, self.shuffle) | ||||
| def _create_for_minddataset(self): | |||||
| return cde.MindrecordPkSampler(self.num_val, self.shuffle) | |||||
| class RandomSampler(BuiltinSampler): | class RandomSampler(BuiltinSampler): | ||||
| """ | """ | ||||
| @@ -25,6 +25,7 @@ | |||||
| #include "gtest/gtest.h" | #include "gtest/gtest.h" | ||||
| #include "utils/log_adapter.h" | #include "utils/log_adapter.h" | ||||
| #include "mindrecord/include/shard_category.h" | #include "mindrecord/include/shard_category.h" | ||||
| #include "mindrecord/include/shard_pk_sample.h" | |||||
| #include "mindrecord/include/shard_reader.h" | #include "mindrecord/include/shard_reader.h" | ||||
| #include "mindrecord/include/shard_sample.h" | #include "mindrecord/include/shard_sample.h" | ||||
| #include "mindrecord/include/shard_shuffle.h" | #include "mindrecord/include/shard_shuffle.h" | ||||
| @@ -146,6 +147,57 @@ TEST_F(TestShardOperator, TestShardSamplePartition) { | |||||
| ASSERT_TRUE(i <= 10); | ASSERT_TRUE(i <= 10); | ||||
| } | } | ||||
| TEST_F(TestShardOperator, TestShardPkSamplerBasic) { | |||||
| MS_LOG(INFO) << common::SafeCStr(FormatInfo("Test pk sampler")); | |||||
| std::string file_name = "./imagenet.shard01"; | |||||
| auto column_list = std::vector<std::string>{"file_name", "label"}; | |||||
| std::vector<std::shared_ptr<ShardOperator>> ops; | |||||
| ops.push_back(std::make_shared<ShardPkSample>("label", 2)); | |||||
| ShardReader dataset; | |||||
| dataset.Open(file_name, 4, column_list, ops); | |||||
| dataset.Launch(); | |||||
| int i = 0; | |||||
| while (true) { | |||||
| auto x = dataset.GetNext(); | |||||
| if (x.empty()) break; | |||||
| std::cout << "index: " << i << ", filename: " << common::SafeCStr((std::get<1>(x[0]))["file_name"]) | |||||
| << ", label: " << common::SafeCStr((std::get<1>(x[0]))["label"].dump()) << std::endl; | |||||
| i++; | |||||
| } | |||||
| dataset.Finish(); | |||||
| ASSERT_TRUE(i == 20); | |||||
| } // namespace mindrecord | |||||
| TEST_F(TestShardOperator, TestShardPkSamplerNumClass) { | |||||
| MS_LOG(INFO) << common::SafeCStr(FormatInfo("Test pk sampler")); | |||||
| std::string file_name = "./imagenet.shard01"; | |||||
| auto column_list = std::vector<std::string>{"file_name", "label"}; | |||||
| std::vector<std::shared_ptr<ShardOperator>> ops; | |||||
| ops.push_back(std::make_shared<ShardPkSample>("label", 2, 3, 0)); | |||||
| ShardReader dataset; | |||||
| dataset.Open(file_name, 4, column_list, ops); | |||||
| dataset.Launch(); | |||||
| int i = 0; | |||||
| while (true) { | |||||
| auto x = dataset.GetNext(); | |||||
| if (x.empty()) break; | |||||
| std::cout << "index: " << i << ", filename: " << common::SafeCStr((std::get<1>(x[0]))["file_name"]) | |||||
| << ", label: " << common::SafeCStr((std::get<1>(x[0]))["label"].dump()) << std::endl; | |||||
| i++; | |||||
| } | |||||
| dataset.Finish(); | |||||
| ASSERT_TRUE(i == 6); | |||||
| } // namespace mindrecord | |||||
| TEST_F(TestShardOperator, TestShardCategory) { | TEST_F(TestShardOperator, TestShardCategory) { | ||||
| MS_LOG(INFO) << common::SafeCStr(FormatInfo("Test read imageNet")); | MS_LOG(INFO) << common::SafeCStr(FormatInfo("Test read imageNet")); | ||||
| @@ -0,0 +1,10 @@ | |||||
| image_00001.jpg,164 | |||||
| image_00002.jpg,164 | |||||
| image_00003.jpg,164 | |||||
| image_00004.jpg,599 | |||||
| image_00005.jpg,599 | |||||
| image_00006.jpg,599 | |||||
| image_00007.jpg,13 | |||||
| image_00008.jpg,13 | |||||
| image_00009.jpg,13 | |||||
| image_00010.jpg,13 | |||||
| @@ -46,7 +46,7 @@ def add_and_remove_cv_file(): | |||||
| if os.path.exists("{}.db".format(x)): | if os.path.exists("{}.db".format(x)): | ||||
| os.remove("{}.db".format(x)) | os.remove("{}.db".format(x)) | ||||
| writer = FileWriter(CV_FILE_NAME, FILES_NUM) | writer = FileWriter(CV_FILE_NAME, FILES_NUM) | ||||
| data = get_data(CV_DIR_NAME) | |||||
| data = get_data(CV_DIR_NAME, True) | |||||
| cv_schema_json = {"id": {"type": "int32"}, | cv_schema_json = {"id": {"type": "int32"}, | ||||
| "file_name": {"type": "string"}, | "file_name": {"type": "string"}, | ||||
| "label": {"type": "int32"}, | "label": {"type": "int32"}, | ||||
| @@ -61,6 +61,59 @@ def add_and_remove_cv_file(): | |||||
| os.remove("{}.db".format(x)) | os.remove("{}.db".format(x)) | ||||
| def test_cv_minddataset_pk_sample_basic(add_and_remove_cv_file): | |||||
| """tutorial for cv minderdataset.""" | |||||
| columns_list = ["data", "file_name", "label"] | |||||
| num_readers = 4 | |||||
| sampler = ds.PKSampler(2) | |||||
| data_set = ds.MindDataset(CV_FILE_NAME + "0", columns_list, num_readers, | |||||
| sampler=sampler) | |||||
| assert data_set.get_dataset_size() == 6 | |||||
| num_iter = 0 | |||||
| for item in data_set.create_dict_iterator(): | |||||
| logger.info("-------------- cv reader basic: {} ------------------------".format(num_iter)) | |||||
| logger.info("-------------- item[file_name]: \ | |||||
| {}------------------------".format("".join([chr(x) for x in item["file_name"]]))) | |||||
| logger.info("-------------- item[label]: {} ----------------------------".format(item["label"])) | |||||
| num_iter += 1 | |||||
| def test_cv_minddataset_pk_sample_shuffle(add_and_remove_cv_file): | |||||
| """tutorial for cv minderdataset.""" | |||||
| columns_list = ["data", "file_name", "label"] | |||||
| num_readers = 4 | |||||
| sampler = ds.PKSampler(3, None, True) | |||||
| data_set = ds.MindDataset(CV_FILE_NAME + "0", columns_list, num_readers, | |||||
| sampler=sampler) | |||||
| assert data_set.get_dataset_size() == 9 | |||||
| num_iter = 0 | |||||
| for item in data_set.create_dict_iterator(): | |||||
| logger.info("-------------- cv reader basic: {} ------------------------".format(num_iter)) | |||||
| logger.info("-------------- item[file_name]: \ | |||||
| {}------------------------".format("".join([chr(x) for x in item["file_name"]]))) | |||||
| logger.info("-------------- item[label]: {} ----------------------------".format(item["label"])) | |||||
| num_iter += 1 | |||||
| def test_cv_minddataset_pk_sample_out_of_range(add_and_remove_cv_file): | |||||
| """tutorial for cv minderdataset.""" | |||||
| columns_list = ["data", "file_name", "label"] | |||||
| num_readers = 4 | |||||
| sampler = ds.PKSampler(5, None, True) | |||||
| data_set = ds.MindDataset(CV_FILE_NAME + "0", columns_list, num_readers, | |||||
| sampler=sampler) | |||||
| assert data_set.get_dataset_size() == 15 | |||||
| num_iter = 0 | |||||
| for item in data_set.create_dict_iterator(): | |||||
| logger.info("-------------- cv reader basic: {} ------------------------".format(num_iter)) | |||||
| logger.info("-------------- item[file_name]: \ | |||||
| {}------------------------".format("".join([chr(x) for x in item["file_name"]]))) | |||||
| logger.info("-------------- item[label]: {} ----------------------------".format(item["label"])) | |||||
| num_iter += 1 | |||||
| def test_cv_minddataset_subset_random_sample_basic(add_and_remove_cv_file): | def test_cv_minddataset_subset_random_sample_basic(add_and_remove_cv_file): | ||||
| """tutorial for cv minderdataset.""" | """tutorial for cv minderdataset.""" | ||||
| columns_list = ["data", "file_name", "label"] | columns_list = ["data", "file_name", "label"] | ||||
| @@ -69,8 +122,7 @@ def test_cv_minddataset_subset_random_sample_basic(add_and_remove_cv_file): | |||||
| sampler = ds.SubsetRandomSampler(indices) | sampler = ds.SubsetRandomSampler(indices) | ||||
| data_set = ds.MindDataset(CV_FILE_NAME + "0", columns_list, num_readers, | data_set = ds.MindDataset(CV_FILE_NAME + "0", columns_list, num_readers, | ||||
| sampler=sampler) | sampler=sampler) | ||||
| data = get_data(CV_DIR_NAME) | |||||
| assert data_set.get_dataset_size() == 10 | |||||
| assert data_set.get_dataset_size() == 5 | |||||
| num_iter = 0 | num_iter = 0 | ||||
| for item in data_set.create_dict_iterator(): | for item in data_set.create_dict_iterator(): | ||||
| logger.info( | logger.info( | ||||
| @@ -93,8 +145,7 @@ def test_cv_minddataset_subset_random_sample_replica(add_and_remove_cv_file): | |||||
| sampler = ds.SubsetRandomSampler(indices) | sampler = ds.SubsetRandomSampler(indices) | ||||
| data_set = ds.MindDataset(CV_FILE_NAME + "0", columns_list, num_readers, | data_set = ds.MindDataset(CV_FILE_NAME + "0", columns_list, num_readers, | ||||
| sampler=sampler) | sampler=sampler) | ||||
| data = get_data(CV_DIR_NAME) | |||||
| assert data_set.get_dataset_size() == 10 | |||||
| assert data_set.get_dataset_size() == 6 | |||||
| num_iter = 0 | num_iter = 0 | ||||
| for item in data_set.create_dict_iterator(): | for item in data_set.create_dict_iterator(): | ||||
| logger.info( | logger.info( | ||||
| @@ -117,8 +168,7 @@ def test_cv_minddataset_subset_random_sample_empty(add_and_remove_cv_file): | |||||
| sampler = ds.SubsetRandomSampler(indices) | sampler = ds.SubsetRandomSampler(indices) | ||||
| data_set = ds.MindDataset(CV_FILE_NAME + "0", columns_list, num_readers, | data_set = ds.MindDataset(CV_FILE_NAME + "0", columns_list, num_readers, | ||||
| sampler=sampler) | sampler=sampler) | ||||
| data = get_data(CV_DIR_NAME) | |||||
| assert data_set.get_dataset_size() == 10 | |||||
| assert data_set.get_dataset_size() == 0 | |||||
| num_iter = 0 | num_iter = 0 | ||||
| for item in data_set.create_dict_iterator(): | for item in data_set.create_dict_iterator(): | ||||
| logger.info( | logger.info( | ||||
| @@ -133,7 +183,7 @@ def test_cv_minddataset_subset_random_sample_empty(add_and_remove_cv_file): | |||||
| assert num_iter == 0 | assert num_iter == 0 | ||||
| def test_cv_minddataset_subset_random_sample_out_range(add_and_remove_cv_file): | |||||
| def test_cv_minddataset_subset_random_sample_out_of_range(add_and_remove_cv_file): | |||||
| """tutorial for cv minderdataset.""" | """tutorial for cv minderdataset.""" | ||||
| columns_list = ["data", "file_name", "label"] | columns_list = ["data", "file_name", "label"] | ||||
| num_readers = 4 | num_readers = 4 | ||||
| @@ -141,8 +191,7 @@ def test_cv_minddataset_subset_random_sample_out_range(add_and_remove_cv_file): | |||||
| sampler = ds.SubsetRandomSampler(indices) | sampler = ds.SubsetRandomSampler(indices) | ||||
| data_set = ds.MindDataset(CV_FILE_NAME + "0", columns_list, num_readers, | data_set = ds.MindDataset(CV_FILE_NAME + "0", columns_list, num_readers, | ||||
| sampler=sampler) | sampler=sampler) | ||||
| data = get_data(CV_DIR_NAME) | |||||
| assert data_set.get_dataset_size() == 10 | |||||
| assert data_set.get_dataset_size() == 5 | |||||
| num_iter = 0 | num_iter = 0 | ||||
| for item in data_set.create_dict_iterator(): | for item in data_set.create_dict_iterator(): | ||||
| logger.info( | logger.info( | ||||
| @@ -165,8 +214,7 @@ def test_cv_minddataset_subset_random_sample_negative(add_and_remove_cv_file): | |||||
| sampler = ds.SubsetRandomSampler(indices) | sampler = ds.SubsetRandomSampler(indices) | ||||
| data_set = ds.MindDataset(CV_FILE_NAME + "0", columns_list, num_readers, | data_set = ds.MindDataset(CV_FILE_NAME + "0", columns_list, num_readers, | ||||
| sampler=sampler) | sampler=sampler) | ||||
| data = get_data(CV_DIR_NAME) | |||||
| assert data_set.get_dataset_size() == 10 | |||||
| assert data_set.get_dataset_size() == 5 | |||||
| num_iter = 0 | num_iter = 0 | ||||
| for item in data_set.create_dict_iterator(): | for item in data_set.create_dict_iterator(): | ||||
| logger.info( | logger.info( | ||||
| @@ -181,7 +229,7 @@ def test_cv_minddataset_subset_random_sample_negative(add_and_remove_cv_file): | |||||
| assert num_iter == 5 | assert num_iter == 5 | ||||
| def get_data(dir_name): | |||||
| def get_data(dir_name, sampler=False): | |||||
| """ | """ | ||||
| usage: get data from imagenet dataset | usage: get data from imagenet dataset | ||||
| params: | params: | ||||
| @@ -191,7 +239,10 @@ def get_data(dir_name): | |||||
| if not os.path.isdir(dir_name): | if not os.path.isdir(dir_name): | ||||
| raise IOError("Directory {} not exists".format(dir_name)) | raise IOError("Directory {} not exists".format(dir_name)) | ||||
| img_dir = os.path.join(dir_name, "images") | img_dir = os.path.join(dir_name, "images") | ||||
| ann_file = os.path.join(dir_name, "annotation.txt") | |||||
| if sampler: | |||||
| ann_file = os.path.join(dir_name, "annotation_sampler.txt") | |||||
| else: | |||||
| ann_file = os.path.join(dir_name, "annotation.txt") | |||||
| with open(ann_file, "r") as file_reader: | with open(ann_file, "r") as file_reader: | ||||
| lines = file_reader.readlines() | lines = file_reader.readlines() | ||||
| @@ -243,7 +243,7 @@ def test_minddataset(add_and_remove_cv_file): | |||||
| assert ds1_json == ds2_json | assert ds1_json == ds2_json | ||||
| data = get_data(CV_DIR_NAME) | data = get_data(CV_DIR_NAME) | ||||
| assert data_set.get_dataset_size() == 10 | |||||
| assert data_set.get_dataset_size() == 5 | |||||
| num_iter = 0 | num_iter = 0 | ||||
| for item in data_set.create_dict_iterator(): | for item in data_set.create_dict_iterator(): | ||||
| num_iter += 1 | num_iter += 1 | ||||