| @@ -48,12 +48,12 @@ PYBIND_REGISTER( | |||||
| ShardPkSample, 1, ([](const py::module *m) { | ShardPkSample, 1, ([](const py::module *m) { | ||||
| (void)py::class_<mindrecord::ShardPkSample, mindrecord::ShardOperator, std::shared_ptr<mindrecord::ShardPkSample>>( | (void)py::class_<mindrecord::ShardPkSample, mindrecord::ShardOperator, std::shared_ptr<mindrecord::ShardPkSample>>( | ||||
| *m, "MindrecordPkSampler") | *m, "MindrecordPkSampler") | ||||
| .def(py::init([](int64_t kVal, std::string kColumn, bool shuffle) { | |||||
| .def(py::init([](int64_t kVal, std::string kColumn, bool shuffle, int64_t num_samples) { | |||||
| if (shuffle == true) { | if (shuffle == true) { | ||||
| return std::make_shared<mindrecord::ShardPkSample>(kColumn, kVal, std::numeric_limits<int64_t>::max(), | return std::make_shared<mindrecord::ShardPkSample>(kColumn, kVal, std::numeric_limits<int64_t>::max(), | ||||
| GetSeed()); | |||||
| GetSeed(), num_samples); | |||||
| } else { | } else { | ||||
| return std::make_shared<mindrecord::ShardPkSample>(kColumn, kVal); | |||||
| return std::make_shared<mindrecord::ShardPkSample>(kColumn, kVal, num_samples); | |||||
| } | } | ||||
| })); | })); | ||||
| })); | })); | ||||
| @@ -29,19 +29,23 @@ namespace mindspore { | |||||
| namespace mindrecord { | namespace mindrecord { | ||||
| class ShardPkSample : public ShardCategory { | class ShardPkSample : public ShardCategory { | ||||
| public: | public: | ||||
| ShardPkSample(const std::string &category_field, int64_t num_elements); | |||||
| ShardPkSample(const std::string &category_field, int64_t num_elements, int64_t num_samples); | |||||
| ShardPkSample(const std::string &category_field, int64_t num_elements, int64_t num_categories); | |||||
| ShardPkSample(const std::string &category_field, int64_t num_elements, int64_t num_categories, int64_t num_samples); | |||||
| ShardPkSample(const std::string &category_field, int64_t num_elements, int64_t num_categories, uint32_t seed); | |||||
| ShardPkSample(const std::string &category_field, int64_t num_elements, int64_t num_categories, uint32_t seed, | |||||
| int64_t num_samples); | |||||
| ~ShardPkSample() override{}; | ~ShardPkSample() override{}; | ||||
| MSRStatus SufExecute(ShardTask &tasks) override; | MSRStatus SufExecute(ShardTask &tasks) override; | ||||
| int64_t GetNumSamples() const { return num_samples_; } | |||||
| private: | private: | ||||
| bool shuffle_; | bool shuffle_; | ||||
| std::shared_ptr<ShardShuffle> shuffle_op_; | std::shared_ptr<ShardShuffle> shuffle_op_; | ||||
| int64_t num_samples_; | |||||
| }; | }; | ||||
| } // namespace mindrecord | } // namespace mindrecord | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -49,6 +49,7 @@ | |||||
| #include "minddata/mindrecord/include/shard_error.h" | #include "minddata/mindrecord/include/shard_error.h" | ||||
| #include "minddata/mindrecord/include/shard_index_generator.h" | #include "minddata/mindrecord/include/shard_index_generator.h" | ||||
| #include "minddata/mindrecord/include/shard_operator.h" | #include "minddata/mindrecord/include/shard_operator.h" | ||||
| #include "minddata/mindrecord/include/shard_pk_sample.h" | |||||
| #include "minddata/mindrecord/include/shard_reader.h" | #include "minddata/mindrecord/include/shard_reader.h" | ||||
| #include "minddata/mindrecord/include/shard_sample.h" | #include "minddata/mindrecord/include/shard_sample.h" | ||||
| #include "minddata/mindrecord/include/shard_shuffle.h" | #include "minddata/mindrecord/include/shard_shuffle.h" | ||||
| @@ -53,7 +53,8 @@ class ShardTask { | |||||
| std::tuple<TaskType, std::tuple<int, int>, std::vector<uint64_t>, json> &GetRandomTask(); | std::tuple<TaskType, std::tuple<int, int>, std::vector<uint64_t>, json> &GetRandomTask(); | ||||
| static ShardTask Combine(std::vector<ShardTask> &category_tasks, bool replacement, int64_t num_elements); | |||||
| static ShardTask Combine(std::vector<ShardTask> &category_tasks, bool replacement, int64_t num_elements, | |||||
| int64_t num_samples); | |||||
| uint32_t categories; | uint32_t categories; | ||||
| @@ -827,6 +827,12 @@ MSRStatus ShardReader::CountTotalRows(const std::vector<std::string> &file_paths | |||||
| std::string category_field = category_op->GetCategoryField(); | std::string category_field = category_op->GetCategoryField(); | ||||
| auto num_classes = GetNumClasses(category_field); | auto num_classes = GetNumClasses(category_field); | ||||
| num_samples = category_op->GetNumSamples(num_samples, num_classes); | num_samples = category_op->GetNumSamples(num_samples, num_classes); | ||||
| if (std::dynamic_pointer_cast<ShardPkSample>(op)) { | |||||
| auto tmp = std::dynamic_pointer_cast<ShardPkSample>(op)->GetNumSamples(); | |||||
| if (tmp != 0) { | |||||
| num_samples = std::min(num_samples, tmp); | |||||
| } | |||||
| } | |||||
| } else if (std::dynamic_pointer_cast<ShardSample>(op)) { | } else if (std::dynamic_pointer_cast<ShardSample>(op)) { | ||||
| if (std::dynamic_pointer_cast<ShardDistributedSample>(op)) { | if (std::dynamic_pointer_cast<ShardDistributedSample>(op)) { | ||||
| auto sampler_op = std::dynamic_pointer_cast<ShardDistributedSample>(op); | auto sampler_op = std::dynamic_pointer_cast<ShardDistributedSample>(op); | ||||
| @@ -958,6 +964,14 @@ MSRStatus ShardReader::CreateTasksByCategory(const std::vector<std::tuple<int, i | |||||
| auto category_op = std::dynamic_pointer_cast<ShardCategory>(op); | auto category_op = std::dynamic_pointer_cast<ShardCategory>(op); | ||||
| auto categories = category_op->GetCategories(); | auto categories = category_op->GetCategories(); | ||||
| int64_t num_elements = category_op->GetNumElements(); | int64_t num_elements = category_op->GetNumElements(); | ||||
| int64_t num_samples = 0; | |||||
| if (std::dynamic_pointer_cast<ShardPkSample>(op)) { | |||||
| num_samples = std::dynamic_pointer_cast<ShardPkSample>(op)->GetNumSamples(); | |||||
| if (num_samples < 0) { | |||||
| MS_LOG(ERROR) << "Parameter num_samples is not positive or zero"; | |||||
| return FAILED; | |||||
| } | |||||
| } | |||||
| if (num_elements <= 0) { | if (num_elements <= 0) { | ||||
| MS_LOG(ERROR) << "Parameter num_element is not positive"; | MS_LOG(ERROR) << "Parameter num_element is not positive"; | ||||
| return FAILED; | return FAILED; | ||||
| @@ -1006,7 +1020,7 @@ MSRStatus ShardReader::CreateTasksByCategory(const std::vector<std::tuple<int, i | |||||
| } | } | ||||
| MS_LOG(INFO) << "Category #" << categoryNo << " has " << categoryTasks[categoryNo].Size() << " tasks"; | MS_LOG(INFO) << "Category #" << categoryNo << " has " << categoryTasks[categoryNo].Size() << " tasks"; | ||||
| } | } | ||||
| tasks_ = ShardTask::Combine(categoryTasks, category_op->GetReplacement(), num_elements); | |||||
| tasks_ = ShardTask::Combine(categoryTasks, category_op->GetReplacement(), num_elements, num_samples); | |||||
| if (SUCCESS != (*category_op)(tasks_)) { | if (SUCCESS != (*category_op)(tasks_)) { | ||||
| return FAILED; | return FAILED; | ||||
| } | } | ||||
| @@ -22,15 +22,18 @@ using mindspore::MsLogLevel::ERROR; | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace mindrecord { | namespace mindrecord { | ||||
| ShardPkSample::ShardPkSample(const std::string &category_field, int64_t num_elements) | |||||
| : ShardCategory(category_field, num_elements, std::numeric_limits<int64_t>::max(), true), shuffle_(false) {} | |||||
| ShardPkSample::ShardPkSample(const std::string &category_field, int64_t num_elements, int64_t num_samples) | |||||
| : ShardCategory(category_field, num_elements, std::numeric_limits<int64_t>::max(), true), | |||||
| shuffle_(false), | |||||
| num_samples_(num_samples) {} | |||||
| ShardPkSample::ShardPkSample(const std::string &category_field, int64_t num_elements, int64_t num_categories) | |||||
| : ShardCategory(category_field, num_elements, num_categories, true), shuffle_(false) {} | |||||
| ShardPkSample::ShardPkSample(const std::string &category_field, int64_t num_elements, int64_t num_categories, | |||||
| int64_t num_samples) | |||||
| : ShardCategory(category_field, num_elements, num_categories, true), shuffle_(false), num_samples_(num_samples) {} | |||||
| ShardPkSample::ShardPkSample(const std::string &category_field, int64_t num_elements, int64_t num_categories, | ShardPkSample::ShardPkSample(const std::string &category_field, int64_t num_elements, int64_t num_categories, | ||||
| uint32_t seed) | |||||
| : ShardCategory(category_field, num_elements, num_categories, true), shuffle_(true) { | |||||
| uint32_t seed, int64_t num_samples) | |||||
| : ShardCategory(category_field, num_elements, num_categories, true), shuffle_(true), num_samples_(num_samples) { | |||||
| shuffle_op_ = std::make_shared<ShardShuffle>(seed, kShuffleSample); // do shuffle and replacement | shuffle_op_ = std::make_shared<ShardShuffle>(seed, kShuffleSample); // do shuffle and replacement | ||||
| } | } | ||||
| @@ -86,7 +86,8 @@ std::tuple<TaskType, std::tuple<int, int>, std::vector<uint64_t>, json> &ShardTa | |||||
| return task_list_[dis(gen)]; | return task_list_[dis(gen)]; | ||||
| } | } | ||||
| ShardTask ShardTask::Combine(std::vector<ShardTask> &category_tasks, bool replacement, int64_t num_elements) { | |||||
| ShardTask ShardTask::Combine(std::vector<ShardTask> &category_tasks, bool replacement, int64_t num_elements, | |||||
| int64_t num_samples) { | |||||
| ShardTask res; | ShardTask res; | ||||
| if (category_tasks.empty()) return res; | if (category_tasks.empty()) return res; | ||||
| auto total_categories = category_tasks.size(); | auto total_categories = category_tasks.size(); | ||||
| @@ -96,9 +97,12 @@ ShardTask ShardTask::Combine(std::vector<ShardTask> &category_tasks, bool replac | |||||
| for (uint32_t i = 1; i < total_categories; i++) { | for (uint32_t i = 1; i < total_categories; i++) { | ||||
| minTasks = std::min(minTasks, category_tasks[i].Size()); | minTasks = std::min(minTasks, category_tasks[i].Size()); | ||||
| } | } | ||||
| int64_t count = 0; | |||||
| for (uint32_t task_no = 0; task_no < minTasks; task_no++) { | for (uint32_t task_no = 0; task_no < minTasks; task_no++) { | ||||
| for (uint32_t i = 0; i < total_categories; i++) { | for (uint32_t i = 0; i < total_categories; i++) { | ||||
| if (num_samples != 0 && count == num_samples) break; | |||||
| res.InsertTask(std::move(category_tasks[i].GetTaskByID(static_cast<int>(task_no)))); | res.InsertTask(std::move(category_tasks[i].GetTaskByID(static_cast<int>(task_no)))); | ||||
| count++; | |||||
| } | } | ||||
| } | } | ||||
| } else { | } else { | ||||
| @@ -109,9 +113,12 @@ ShardTask ShardTask::Combine(std::vector<ShardTask> &category_tasks, bool replac | |||||
| if (num_elements != std::numeric_limits<int64_t>::max()) { | if (num_elements != std::numeric_limits<int64_t>::max()) { | ||||
| maxTasks = static_cast<decltype(maxTasks)>(num_elements); | maxTasks = static_cast<decltype(maxTasks)>(num_elements); | ||||
| } | } | ||||
| int64_t count = 0; | |||||
| for (uint32_t i = 0; i < total_categories; i++) { | for (uint32_t i = 0; i < total_categories; i++) { | ||||
| for (uint32_t j = 0; j < maxTasks; j++) { | for (uint32_t j = 0; j < maxTasks; j++) { | ||||
| if (num_samples != 0 && count == num_samples) break; | |||||
| res.InsertTask(category_tasks[i].GetRandomTask()); | res.InsertTask(category_tasks[i].GetRandomTask()); | ||||
| count++; | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| @@ -359,7 +359,8 @@ class PKSampler(BuiltinSampler): | |||||
| if not self.class_column or not isinstance(self.class_column, str): | if not self.class_column or not isinstance(self.class_column, str): | ||||
| raise ValueError("class_column should be a not empty string value, \ | raise ValueError("class_column should be a not empty string value, \ | ||||
| but got class_column={}".format(class_column)) | but got class_column={}".format(class_column)) | ||||
| c_sampler = cde.MindrecordPkSampler(self.num_val, self.class_column, self.shuffle) | |||||
| num_samples = self.num_samples if self.num_samples is not None else 0 | |||||
| c_sampler = cde.MindrecordPkSampler(self.num_val, self.class_column, self.shuffle, num_samples) | |||||
| c_child_sampler = self.create_child_for_minddataset() | c_child_sampler = self.create_child_for_minddataset() | ||||
| c_sampler.add_child(c_child_sampler) | c_sampler.add_child(c_child_sampler) | ||||
| return c_sampler | return c_sampler | ||||
| @@ -104,7 +104,7 @@ class TFRecordToMR: | |||||
| Args: | Args: | ||||
| source (str): the TFRecord file to be transformed. | source (str): the TFRecord file to be transformed. | ||||
| destination (str): the MindRecord file path to tranform into. | destination (str): the MindRecord file path to tranform into. | ||||
| feature_dict (dict): a dictionary than states the feature type, i.e. | |||||
| feature_dict (dict): a dictionary that states the feature type, i.e. | |||||
| feature_dict = {"xxxx": tf.io.FixedLenFeature([], tf.string), \ | feature_dict = {"xxxx": tf.io.FixedLenFeature([], tf.string), \ | ||||
| "yyyy": tf.io.FixedLenFeature([], tf.int64)} | "yyyy": tf.io.FixedLenFeature([], tf.int64)} | ||||
| @@ -162,7 +162,7 @@ TEST_F(TestShardOperator, TestShardPkSamplerBasic) { | |||||
| auto column_list = std::vector<std::string>{"file_name", "label"}; | auto column_list = std::vector<std::string>{"file_name", "label"}; | ||||
| std::vector<std::shared_ptr<ShardOperator>> ops; | std::vector<std::shared_ptr<ShardOperator>> ops; | ||||
| ops.push_back(std::make_shared<ShardPkSample>("label", 2)); | |||||
| ops.push_back(std::make_shared<ShardPkSample>("label", 2, 0)); | |||||
| ShardReader dataset; | ShardReader dataset; | ||||
| dataset.Open({file_name},true, 4, column_list, ops); | dataset.Open({file_name},true, 4, column_list, ops); | ||||
| @@ -187,7 +187,7 @@ TEST_F(TestShardOperator, TestShardPkSamplerNumClass) { | |||||
| auto column_list = std::vector<std::string>{"file_name", "label"}; | auto column_list = std::vector<std::string>{"file_name", "label"}; | ||||
| std::vector<std::shared_ptr<ShardOperator>> ops; | std::vector<std::shared_ptr<ShardOperator>> ops; | ||||
| ops.push_back(std::make_shared<ShardPkSample>("label", 2, 3, 0)); | |||||
| ops.push_back(std::make_shared<ShardPkSample>("label", 2, 3, 0, 0)); | |||||
| ShardReader dataset; | ShardReader dataset; | ||||
| dataset.Open({file_name},true, 4, column_list, ops); | dataset.Open({file_name},true, 4, column_list, ops); | ||||
| @@ -204,7 +204,7 @@ TEST_F(TestShardOperator, TestShardPkSamplerNumClass) { | |||||
| } | } | ||||
| dataset.Finish(); | dataset.Finish(); | ||||
| ASSERT_TRUE(i == 6); | ASSERT_TRUE(i == 6); | ||||
| } // namespace mindrecord | |||||
| } | |||||
| TEST_F(TestShardOperator, TestShardCategory) { | TEST_F(TestShardOperator, TestShardCategory) { | ||||
| MS_LOG(INFO) << common::SafeCStr(FormatInfo("Test read imageNet")); | MS_LOG(INFO) << common::SafeCStr(FormatInfo("Test read imageNet")); | ||||
| @@ -101,7 +101,6 @@ def test_cv_minddataset_pk_sample_basic(add_and_remove_cv_file): | |||||
| "-------------- item[label]: {} ----------------------------".format(item["label"])) | "-------------- item[label]: {} ----------------------------".format(item["label"])) | ||||
| num_iter += 1 | num_iter += 1 | ||||
| def test_cv_minddataset_pk_sample_shuffle(add_and_remove_cv_file): | def test_cv_minddataset_pk_sample_shuffle(add_and_remove_cv_file): | ||||
| """tutorial for cv minderdataset.""" | """tutorial for cv minderdataset.""" | ||||
| columns_list = ["data", "file_name", "label"] | columns_list = ["data", "file_name", "label"] | ||||
| @@ -120,9 +119,51 @@ def test_cv_minddataset_pk_sample_shuffle(add_and_remove_cv_file): | |||||
| logger.info( | logger.info( | ||||
| "-------------- item[label]: {} ----------------------------".format(item["label"])) | "-------------- item[label]: {} ----------------------------".format(item["label"])) | ||||
| num_iter += 1 | num_iter += 1 | ||||
| assert num_iter == 9 | |||||
| def test_cv_minddataset_pk_sample_shuffle_1(add_and_remove_cv_file): | |||||
| """tutorial for cv minderdataset.""" | |||||
| columns_list = ["data", "file_name", "label"] | |||||
| num_readers = 4 | |||||
| sampler = ds.PKSampler(3, None, True, 'label', 5) | |||||
| data_set = ds.MindDataset(CV_FILE_NAME + "0", columns_list, num_readers, | |||||
| sampler=sampler) | |||||
| assert data_set.get_dataset_size() == 5 | |||||
| num_iter = 0 | |||||
| for item in data_set.create_dict_iterator(): | |||||
| logger.info( | |||||
| "-------------- cv reader basic: {} ------------------------".format(num_iter)) | |||||
| logger.info("-------------- item[file_name]: \ | |||||
| {}------------------------".format(to_str(item["file_name"]))) | |||||
| logger.info( | |||||
| "-------------- item[label]: {} ----------------------------".format(item["label"])) | |||||
| num_iter += 1 | |||||
| assert num_iter == 5 | |||||
| def test_cv_minddataset_pk_sample_shuffle_2(add_and_remove_cv_file): | |||||
| """tutorial for cv minderdataset.""" | |||||
| columns_list = ["data", "file_name", "label"] | |||||
| num_readers = 4 | |||||
| sampler = ds.PKSampler(3, None, True, 'label', 10) | |||||
| data_set = ds.MindDataset(CV_FILE_NAME + "0", columns_list, num_readers, | |||||
| sampler=sampler) | |||||
| assert data_set.get_dataset_size() == 9 | |||||
| num_iter = 0 | |||||
| for item in data_set.create_dict_iterator(): | |||||
| logger.info( | |||||
| "-------------- cv reader basic: {} ------------------------".format(num_iter)) | |||||
| logger.info("-------------- item[file_name]: \ | |||||
| {}------------------------".format(to_str(item["file_name"]))) | |||||
| logger.info( | |||||
| "-------------- item[label]: {} ----------------------------".format(item["label"])) | |||||
| num_iter += 1 | |||||
| assert num_iter == 9 | |||||
| def test_cv_minddataset_pk_sample_out_of_range(add_and_remove_cv_file): | |||||
| def test_cv_minddataset_pk_sample_out_of_range_0(add_and_remove_cv_file): | |||||
| """tutorial for cv minderdataset.""" | """tutorial for cv minderdataset.""" | ||||
| columns_list = ["data", "file_name", "label"] | columns_list = ["data", "file_name", "label"] | ||||
| num_readers = 4 | num_readers = 4 | ||||
| @@ -139,6 +180,45 @@ def test_cv_minddataset_pk_sample_out_of_range(add_and_remove_cv_file): | |||||
| logger.info( | logger.info( | ||||
| "-------------- item[label]: {} ----------------------------".format(item["label"])) | "-------------- item[label]: {} ----------------------------".format(item["label"])) | ||||
| num_iter += 1 | num_iter += 1 | ||||
| assert num_iter == 15 | |||||
| def test_cv_minddataset_pk_sample_out_of_range_1(add_and_remove_cv_file): | |||||
| """tutorial for cv minderdataset.""" | |||||
| columns_list = ["data", "file_name", "label"] | |||||
| num_readers = 4 | |||||
| sampler = ds.PKSampler(5, None, True, 'label', 20) | |||||
| data_set = ds.MindDataset(CV_FILE_NAME + "0", columns_list, num_readers, | |||||
| sampler=sampler) | |||||
| assert data_set.get_dataset_size() == 15 | |||||
| num_iter = 0 | |||||
| for item in data_set.create_dict_iterator(): | |||||
| logger.info( | |||||
| "-------------- cv reader basic: {} ------------------------".format(num_iter)) | |||||
| logger.info("-------------- item[file_name]: \ | |||||
| {}------------------------".format(to_str(item["file_name"]))) | |||||
| logger.info( | |||||
| "-------------- item[label]: {} ----------------------------".format(item["label"])) | |||||
| num_iter += 1 | |||||
| assert num_iter == 15 | |||||
| def test_cv_minddataset_pk_sample_out_of_range_2(add_and_remove_cv_file): | |||||
| """tutorial for cv minderdataset.""" | |||||
| columns_list = ["data", "file_name", "label"] | |||||
| num_readers = 4 | |||||
| sampler = ds.PKSampler(5, None, True, 'label', 10) | |||||
| data_set = ds.MindDataset(CV_FILE_NAME + "0", columns_list, num_readers, | |||||
| sampler=sampler) | |||||
| assert data_set.get_dataset_size() == 10 | |||||
| num_iter = 0 | |||||
| for item in data_set.create_dict_iterator(): | |||||
| logger.info( | |||||
| "-------------- cv reader basic: {} ------------------------".format(num_iter)) | |||||
| logger.info("-------------- item[file_name]: \ | |||||
| {}------------------------".format(to_str(item["file_name"]))) | |||||
| logger.info( | |||||
| "-------------- item[label]: {} ----------------------------".format(item["label"])) | |||||
| num_iter += 1 | |||||
| assert num_iter == 10 | |||||
| def test_cv_minddataset_subset_random_sample_basic(add_and_remove_cv_file): | def test_cv_minddataset_subset_random_sample_basic(add_and_remove_cv_file): | ||||