| @@ -48,12 +48,12 @@ PYBIND_REGISTER( | |||
| ShardPkSample, 1, ([](const py::module *m) { | |||
| (void)py::class_<mindrecord::ShardPkSample, mindrecord::ShardOperator, std::shared_ptr<mindrecord::ShardPkSample>>( | |||
| *m, "MindrecordPkSampler") | |||
| .def(py::init([](int64_t kVal, std::string kColumn, bool shuffle) { | |||
| .def(py::init([](int64_t kVal, std::string kColumn, bool shuffle, int64_t num_samples) { | |||
| if (shuffle == true) { | |||
| return std::make_shared<mindrecord::ShardPkSample>(kColumn, kVal, std::numeric_limits<int64_t>::max(), | |||
| GetSeed()); | |||
| GetSeed(), num_samples); | |||
| } else { | |||
| return std::make_shared<mindrecord::ShardPkSample>(kColumn, kVal); | |||
| return std::make_shared<mindrecord::ShardPkSample>(kColumn, kVal, num_samples); | |||
| } | |||
| })); | |||
| })); | |||
| @@ -29,19 +29,23 @@ namespace mindspore { | |||
| namespace mindrecord { | |||
| class ShardPkSample : public ShardCategory { | |||
| public: | |||
| ShardPkSample(const std::string &category_field, int64_t num_elements); | |||
| ShardPkSample(const std::string &category_field, int64_t num_elements, int64_t num_samples); | |||
| ShardPkSample(const std::string &category_field, int64_t num_elements, int64_t num_categories); | |||
| ShardPkSample(const std::string &category_field, int64_t num_elements, int64_t num_categories, int64_t num_samples); | |||
| ShardPkSample(const std::string &category_field, int64_t num_elements, int64_t num_categories, uint32_t seed); | |||
| ShardPkSample(const std::string &category_field, int64_t num_elements, int64_t num_categories, uint32_t seed, | |||
| int64_t num_samples); | |||
| ~ShardPkSample() override{}; | |||
| MSRStatus SufExecute(ShardTask &tasks) override; | |||
| int64_t GetNumSamples() const { return num_samples_; } | |||
| private: | |||
| bool shuffle_; | |||
| std::shared_ptr<ShardShuffle> shuffle_op_; | |||
| int64_t num_samples_; | |||
| }; | |||
| } // namespace mindrecord | |||
| } // namespace mindspore | |||
| @@ -49,6 +49,7 @@ | |||
| #include "minddata/mindrecord/include/shard_error.h" | |||
| #include "minddata/mindrecord/include/shard_index_generator.h" | |||
| #include "minddata/mindrecord/include/shard_operator.h" | |||
| #include "minddata/mindrecord/include/shard_pk_sample.h" | |||
| #include "minddata/mindrecord/include/shard_reader.h" | |||
| #include "minddata/mindrecord/include/shard_sample.h" | |||
| #include "minddata/mindrecord/include/shard_shuffle.h" | |||
| @@ -53,7 +53,8 @@ class ShardTask { | |||
| std::tuple<TaskType, std::tuple<int, int>, std::vector<uint64_t>, json> &GetRandomTask(); | |||
| static ShardTask Combine(std::vector<ShardTask> &category_tasks, bool replacement, int64_t num_elements); | |||
| static ShardTask Combine(std::vector<ShardTask> &category_tasks, bool replacement, int64_t num_elements, | |||
| int64_t num_samples); | |||
| uint32_t categories; | |||
| @@ -827,6 +827,12 @@ MSRStatus ShardReader::CountTotalRows(const std::vector<std::string> &file_paths | |||
| std::string category_field = category_op->GetCategoryField(); | |||
| auto num_classes = GetNumClasses(category_field); | |||
| num_samples = category_op->GetNumSamples(num_samples, num_classes); | |||
| if (std::dynamic_pointer_cast<ShardPkSample>(op)) { | |||
| auto tmp = std::dynamic_pointer_cast<ShardPkSample>(op)->GetNumSamples(); | |||
| if (tmp != 0) { | |||
| num_samples = std::min(num_samples, tmp); | |||
| } | |||
| } | |||
| } else if (std::dynamic_pointer_cast<ShardSample>(op)) { | |||
| if (std::dynamic_pointer_cast<ShardDistributedSample>(op)) { | |||
| auto sampler_op = std::dynamic_pointer_cast<ShardDistributedSample>(op); | |||
| @@ -958,6 +964,14 @@ MSRStatus ShardReader::CreateTasksByCategory(const std::vector<std::tuple<int, i | |||
| auto category_op = std::dynamic_pointer_cast<ShardCategory>(op); | |||
| auto categories = category_op->GetCategories(); | |||
| int64_t num_elements = category_op->GetNumElements(); | |||
| int64_t num_samples = 0; | |||
| if (std::dynamic_pointer_cast<ShardPkSample>(op)) { | |||
| num_samples = std::dynamic_pointer_cast<ShardPkSample>(op)->GetNumSamples(); | |||
| if (num_samples < 0) { | |||
| MS_LOG(ERROR) << "Parameter num_samples is not positive or zero"; | |||
| return FAILED; | |||
| } | |||
| } | |||
| if (num_elements <= 0) { | |||
| MS_LOG(ERROR) << "Parameter num_element is not positive"; | |||
| return FAILED; | |||
| @@ -1006,7 +1020,7 @@ MSRStatus ShardReader::CreateTasksByCategory(const std::vector<std::tuple<int, i | |||
| } | |||
| MS_LOG(INFO) << "Category #" << categoryNo << " has " << categoryTasks[categoryNo].Size() << " tasks"; | |||
| } | |||
| tasks_ = ShardTask::Combine(categoryTasks, category_op->GetReplacement(), num_elements); | |||
| tasks_ = ShardTask::Combine(categoryTasks, category_op->GetReplacement(), num_elements, num_samples); | |||
| if (SUCCESS != (*category_op)(tasks_)) { | |||
| return FAILED; | |||
| } | |||
| @@ -22,15 +22,18 @@ using mindspore::MsLogLevel::ERROR; | |||
| namespace mindspore { | |||
| namespace mindrecord { | |||
| ShardPkSample::ShardPkSample(const std::string &category_field, int64_t num_elements) | |||
| : ShardCategory(category_field, num_elements, std::numeric_limits<int64_t>::max(), true), shuffle_(false) {} | |||
| ShardPkSample::ShardPkSample(const std::string &category_field, int64_t num_elements, int64_t num_samples) | |||
| : ShardCategory(category_field, num_elements, std::numeric_limits<int64_t>::max(), true), | |||
| shuffle_(false), | |||
| num_samples_(num_samples) {} | |||
| ShardPkSample::ShardPkSample(const std::string &category_field, int64_t num_elements, int64_t num_categories) | |||
| : ShardCategory(category_field, num_elements, num_categories, true), shuffle_(false) {} | |||
| ShardPkSample::ShardPkSample(const std::string &category_field, int64_t num_elements, int64_t num_categories, | |||
| int64_t num_samples) | |||
| : ShardCategory(category_field, num_elements, num_categories, true), shuffle_(false), num_samples_(num_samples) {} | |||
| ShardPkSample::ShardPkSample(const std::string &category_field, int64_t num_elements, int64_t num_categories, | |||
| uint32_t seed) | |||
| : ShardCategory(category_field, num_elements, num_categories, true), shuffle_(true) { | |||
| uint32_t seed, int64_t num_samples) | |||
| : ShardCategory(category_field, num_elements, num_categories, true), shuffle_(true), num_samples_(num_samples) { | |||
| shuffle_op_ = std::make_shared<ShardShuffle>(seed, kShuffleSample); // do shuffle and replacement | |||
| } | |||
| @@ -86,7 +86,8 @@ std::tuple<TaskType, std::tuple<int, int>, std::vector<uint64_t>, json> &ShardTa | |||
| return task_list_[dis(gen)]; | |||
| } | |||
| ShardTask ShardTask::Combine(std::vector<ShardTask> &category_tasks, bool replacement, int64_t num_elements) { | |||
| ShardTask ShardTask::Combine(std::vector<ShardTask> &category_tasks, bool replacement, int64_t num_elements, | |||
| int64_t num_samples) { | |||
| ShardTask res; | |||
| if (category_tasks.empty()) return res; | |||
| auto total_categories = category_tasks.size(); | |||
| @@ -96,9 +97,12 @@ ShardTask ShardTask::Combine(std::vector<ShardTask> &category_tasks, bool replac | |||
| for (uint32_t i = 1; i < total_categories; i++) { | |||
| minTasks = std::min(minTasks, category_tasks[i].Size()); | |||
| } | |||
| int64_t count = 0; | |||
| for (uint32_t task_no = 0; task_no < minTasks; task_no++) { | |||
| for (uint32_t i = 0; i < total_categories; i++) { | |||
| if (num_samples != 0 && count == num_samples) break; | |||
| res.InsertTask(std::move(category_tasks[i].GetTaskByID(static_cast<int>(task_no)))); | |||
| count++; | |||
| } | |||
| } | |||
| } else { | |||
| @@ -109,9 +113,12 @@ ShardTask ShardTask::Combine(std::vector<ShardTask> &category_tasks, bool replac | |||
| if (num_elements != std::numeric_limits<int64_t>::max()) { | |||
| maxTasks = static_cast<decltype(maxTasks)>(num_elements); | |||
| } | |||
| int64_t count = 0; | |||
| for (uint32_t i = 0; i < total_categories; i++) { | |||
| for (uint32_t j = 0; j < maxTasks; j++) { | |||
| if (num_samples != 0 && count == num_samples) break; | |||
| res.InsertTask(category_tasks[i].GetRandomTask()); | |||
| count++; | |||
| } | |||
| } | |||
| } | |||
| @@ -359,7 +359,8 @@ class PKSampler(BuiltinSampler): | |||
| if not self.class_column or not isinstance(self.class_column, str): | |||
| raise ValueError("class_column should be a not empty string value, \ | |||
| but got class_column={}".format(class_column)) | |||
| c_sampler = cde.MindrecordPkSampler(self.num_val, self.class_column, self.shuffle) | |||
| num_samples = self.num_samples if self.num_samples is not None else 0 | |||
| c_sampler = cde.MindrecordPkSampler(self.num_val, self.class_column, self.shuffle, num_samples) | |||
| c_child_sampler = self.create_child_for_minddataset() | |||
| c_sampler.add_child(c_child_sampler) | |||
| return c_sampler | |||
| @@ -104,7 +104,7 @@ class TFRecordToMR: | |||
| Args: | |||
| source (str): the TFRecord file to be transformed. | |||
| destination (str): the MindRecord file path to tranform into. | |||
| feature_dict (dict): a dictionary than states the feature type, i.e. | |||
| feature_dict (dict): a dictionary that states the feature type, i.e. | |||
| feature_dict = {"xxxx": tf.io.FixedLenFeature([], tf.string), \ | |||
| "yyyy": tf.io.FixedLenFeature([], tf.int64)} | |||
| @@ -162,7 +162,7 @@ TEST_F(TestShardOperator, TestShardPkSamplerBasic) { | |||
| auto column_list = std::vector<std::string>{"file_name", "label"}; | |||
| std::vector<std::shared_ptr<ShardOperator>> ops; | |||
| ops.push_back(std::make_shared<ShardPkSample>("label", 2)); | |||
| ops.push_back(std::make_shared<ShardPkSample>("label", 2, 0)); | |||
| ShardReader dataset; | |||
| dataset.Open({file_name},true, 4, column_list, ops); | |||
| @@ -187,7 +187,7 @@ TEST_F(TestShardOperator, TestShardPkSamplerNumClass) { | |||
| auto column_list = std::vector<std::string>{"file_name", "label"}; | |||
| std::vector<std::shared_ptr<ShardOperator>> ops; | |||
| ops.push_back(std::make_shared<ShardPkSample>("label", 2, 3, 0)); | |||
| ops.push_back(std::make_shared<ShardPkSample>("label", 2, 3, 0, 0)); | |||
| ShardReader dataset; | |||
| dataset.Open({file_name},true, 4, column_list, ops); | |||
| @@ -204,7 +204,7 @@ TEST_F(TestShardOperator, TestShardPkSamplerNumClass) { | |||
| } | |||
| dataset.Finish(); | |||
| ASSERT_TRUE(i == 6); | |||
| } // namespace mindrecord | |||
| } | |||
| TEST_F(TestShardOperator, TestShardCategory) { | |||
| MS_LOG(INFO) << common::SafeCStr(FormatInfo("Test read imageNet")); | |||
| @@ -101,7 +101,6 @@ def test_cv_minddataset_pk_sample_basic(add_and_remove_cv_file): | |||
| "-------------- item[label]: {} ----------------------------".format(item["label"])) | |||
| num_iter += 1 | |||
| def test_cv_minddataset_pk_sample_shuffle(add_and_remove_cv_file): | |||
| """tutorial for cv minderdataset.""" | |||
| columns_list = ["data", "file_name", "label"] | |||
| @@ -120,9 +119,51 @@ def test_cv_minddataset_pk_sample_shuffle(add_and_remove_cv_file): | |||
| logger.info( | |||
| "-------------- item[label]: {} ----------------------------".format(item["label"])) | |||
| num_iter += 1 | |||
| assert num_iter == 9 | |||
| def test_cv_minddataset_pk_sample_shuffle_1(add_and_remove_cv_file): | |||
| """tutorial for cv minderdataset.""" | |||
| columns_list = ["data", "file_name", "label"] | |||
| num_readers = 4 | |||
| sampler = ds.PKSampler(3, None, True, 'label', 5) | |||
| data_set = ds.MindDataset(CV_FILE_NAME + "0", columns_list, num_readers, | |||
| sampler=sampler) | |||
| assert data_set.get_dataset_size() == 5 | |||
| num_iter = 0 | |||
| for item in data_set.create_dict_iterator(): | |||
| logger.info( | |||
| "-------------- cv reader basic: {} ------------------------".format(num_iter)) | |||
| logger.info("-------------- item[file_name]: \ | |||
| {}------------------------".format(to_str(item["file_name"]))) | |||
| logger.info( | |||
| "-------------- item[label]: {} ----------------------------".format(item["label"])) | |||
| num_iter += 1 | |||
| assert num_iter == 5 | |||
| def test_cv_minddataset_pk_sample_shuffle_2(add_and_remove_cv_file): | |||
| """tutorial for cv minderdataset.""" | |||
| columns_list = ["data", "file_name", "label"] | |||
| num_readers = 4 | |||
| sampler = ds.PKSampler(3, None, True, 'label', 10) | |||
| data_set = ds.MindDataset(CV_FILE_NAME + "0", columns_list, num_readers, | |||
| sampler=sampler) | |||
| assert data_set.get_dataset_size() == 9 | |||
| num_iter = 0 | |||
| for item in data_set.create_dict_iterator(): | |||
| logger.info( | |||
| "-------------- cv reader basic: {} ------------------------".format(num_iter)) | |||
| logger.info("-------------- item[file_name]: \ | |||
| {}------------------------".format(to_str(item["file_name"]))) | |||
| logger.info( | |||
| "-------------- item[label]: {} ----------------------------".format(item["label"])) | |||
| num_iter += 1 | |||
| assert num_iter == 9 | |||
| def test_cv_minddataset_pk_sample_out_of_range(add_and_remove_cv_file): | |||
| def test_cv_minddataset_pk_sample_out_of_range_0(add_and_remove_cv_file): | |||
| """tutorial for cv minderdataset.""" | |||
| columns_list = ["data", "file_name", "label"] | |||
| num_readers = 4 | |||
| @@ -139,6 +180,45 @@ def test_cv_minddataset_pk_sample_out_of_range(add_and_remove_cv_file): | |||
| logger.info( | |||
| "-------------- item[label]: {} ----------------------------".format(item["label"])) | |||
| num_iter += 1 | |||
| assert num_iter == 15 | |||
| def test_cv_minddataset_pk_sample_out_of_range_1(add_and_remove_cv_file): | |||
| """tutorial for cv minderdataset.""" | |||
| columns_list = ["data", "file_name", "label"] | |||
| num_readers = 4 | |||
| sampler = ds.PKSampler(5, None, True, 'label', 20) | |||
| data_set = ds.MindDataset(CV_FILE_NAME + "0", columns_list, num_readers, | |||
| sampler=sampler) | |||
| assert data_set.get_dataset_size() == 15 | |||
| num_iter = 0 | |||
| for item in data_set.create_dict_iterator(): | |||
| logger.info( | |||
| "-------------- cv reader basic: {} ------------------------".format(num_iter)) | |||
| logger.info("-------------- item[file_name]: \ | |||
| {}------------------------".format(to_str(item["file_name"]))) | |||
| logger.info( | |||
| "-------------- item[label]: {} ----------------------------".format(item["label"])) | |||
| num_iter += 1 | |||
| assert num_iter == 15 | |||
| def test_cv_minddataset_pk_sample_out_of_range_2(add_and_remove_cv_file): | |||
| """tutorial for cv minderdataset.""" | |||
| columns_list = ["data", "file_name", "label"] | |||
| num_readers = 4 | |||
| sampler = ds.PKSampler(5, None, True, 'label', 10) | |||
| data_set = ds.MindDataset(CV_FILE_NAME + "0", columns_list, num_readers, | |||
| sampler=sampler) | |||
| assert data_set.get_dataset_size() == 10 | |||
| num_iter = 0 | |||
| for item in data_set.create_dict_iterator(): | |||
| logger.info( | |||
| "-------------- cv reader basic: {} ------------------------".format(num_iter)) | |||
| logger.info("-------------- item[file_name]: \ | |||
| {}------------------------".format(to_str(item["file_name"]))) | |||
| logger.info( | |||
| "-------------- item[label]: {} ----------------------------".format(item["label"])) | |||
| num_iter += 1 | |||
| assert num_iter == 10 | |||
| def test_cv_minddataset_subset_random_sample_basic(add_and_remove_cv_file): | |||