| @@ -784,7 +784,7 @@ void bindSamplerOps(py::module *m) { | |||
| (void)py::class_<mindrecord::ShardDistributedSample, mindrecord::ShardSample, | |||
| std::shared_ptr<mindrecord::ShardDistributedSample>>(*m, "MindrecordDistributedSampler") | |||
| .def(py::init<int64_t, int64_t, bool, uint32_t>()); | |||
| .def(py::init<int64_t, int64_t, bool, uint32_t, int64_t>()); | |||
| (void)py::class_<mindrecord::ShardShuffle, mindrecord::ShardOperator, std::shared_ptr<mindrecord::ShardShuffle>>( | |||
| *m, "MindrecordRandomSampler") | |||
| @@ -29,9 +29,10 @@ namespace mindspore { | |||
| namespace mindrecord { | |||
| class ShardDistributedSample : public ShardSample { | |||
| public: | |||
| ShardDistributedSample(int num_shards, int shard_id, int no_of_padded_samples, bool shuffle, uint32_t seed); | |||
| ShardDistributedSample(int num_shards, int shard_id, int no_of_padded_samples, bool shuffle, uint32_t seed, | |||
| int no_of_samples = 0); | |||
| ShardDistributedSample(int num_shards, int shard_id, bool shuffle, uint32_t seed); | |||
| ShardDistributedSample(int num_shards, int shard_id, bool shuffle, uint32_t seed, int no_of_samples = 0); | |||
| void SetNumPaddedSamples(int no_of_padded_samples) { no_of_padded_samples_ = no_of_padded_samples; } | |||
| @@ -32,7 +32,7 @@ class ShardSample : public ShardOperator { | |||
| ShardSample(int num, int den); | |||
| ShardSample(int num, int den, int par); | |||
| ShardSample(int num, int den, int par, int no_of_samples = 0); | |||
| ShardSample(const std::vector<int64_t> &indices, uint32_t seed); | |||
| @@ -23,16 +23,17 @@ using mindspore::MsLogLevel::ERROR; | |||
| namespace mindspore { | |||
| namespace mindrecord { | |||
| ShardDistributedSample::ShardDistributedSample(int num_shards, int shard_id, int no_of_padded_samples, bool shuffle, | |||
| uint32_t seed) | |||
| : ShardSample(1, num_shards, shard_id), | |||
| uint32_t seed, int no_of_samples) | |||
| : ShardSample(1, num_shards, shard_id, no_of_samples), | |||
| shuffle_(shuffle), | |||
| no_of_padded_samples_(no_of_padded_samples), | |||
| first_epoch_(true) { | |||
| shuffle_op_ = std::make_shared<ShardShuffle>(seed, kShuffleSample); | |||
| } | |||
| ShardDistributedSample::ShardDistributedSample(int num_shards, int shard_id, bool shuffle, uint32_t seed) | |||
| : ShardDistributedSample(num_shards, shard_id, 0, shuffle, seed) {} | |||
| ShardDistributedSample::ShardDistributedSample(int num_shards, int shard_id, bool shuffle, uint32_t seed, | |||
| int no_of_samples) | |||
| : ShardDistributedSample(num_shards, shard_id, 0, shuffle, seed, no_of_samples) {} | |||
| int64_t ShardDistributedSample::GetNumSamples(int64_t dataset_size, int64_t num_classes) { | |||
| if (no_of_padded_samples_ <= 0) { | |||
| @@ -38,11 +38,11 @@ ShardSample::ShardSample(int num, int den) | |||
| indices_({}), | |||
| sampler_type_(kCustomTopPercentSampler) {} | |||
| ShardSample::ShardSample(int num, int den, int par) | |||
| ShardSample::ShardSample(int num, int den, int par, int no_of_samples) | |||
| : numerator_(num), | |||
| denominator_(den), | |||
| partition_id_(par), | |||
| no_of_samples_(0), | |||
| no_of_samples_(no_of_samples), | |||
| indices_({}), | |||
| sampler_type_(kCustomTopPercentSampler) {} | |||
| @@ -110,8 +110,11 @@ MSRStatus ShardSample::Execute(ShardTask &tasks) { | |||
| new_tasks.InsertTask(tasks.GetTaskByID(index)); // different mod result between c and python | |||
| } | |||
| } else { | |||
| int count = 0; | |||
| for (int i = partition_id_ * taking; i < (partition_id_ + 1) * taking; i++) { | |||
| if (no_of_samples_ != 0 && count == no_of_samples_) break; | |||
| new_tasks.InsertTask(tasks.GetTaskByID(i % total_no)); // rounding up. if overflow, go back to start | |||
| count++; | |||
| } | |||
| } | |||
| std::swap(tasks, new_tasks); | |||
| @@ -121,8 +124,11 @@ MSRStatus ShardSample::Execute(ShardTask &tasks) { | |||
| return FAILED; | |||
| } | |||
| total_no = static_cast<int>(tasks.permutation_.size()); | |||
| int count = 0; | |||
| for (size_t i = partition_id_ * taking; i < (partition_id_ + 1) * taking; i++) { | |||
| if (no_of_samples_ != 0 && count == no_of_samples_) break; | |||
| new_tasks.InsertTask(tasks.GetTaskByID(tasks.permutation_[i % total_no])); | |||
| count++; | |||
| } | |||
| std::swap(tasks, new_tasks); | |||
| } | |||
| @@ -270,7 +270,9 @@ class DistributedSampler(BuiltinSampler): | |||
| return c_sampler | |||
| def create_for_minddataset(self): | |||
| c_sampler = cde.MindrecordDistributedSampler(self.num_shards, self.shard_id, self.shuffle, self.seed) | |||
| num_samples = self.num_samples if self.num_samples is not None else 0 | |||
| c_sampler = cde.MindrecordDistributedSampler(self.num_shards, self.shard_id, self.shuffle, | |||
| self.seed, num_samples) | |||
| c_child_sampler = self.create_child_for_minddataset() | |||
| c_sampler.add_child(c_child_sampler) | |||
| return c_sampler | |||
| @@ -238,6 +238,72 @@ def test_cv_minddataset_partition_tutorial(add_and_remove_cv_file): | |||
| assert partitions(5) == 2 | |||
| assert partitions(9) == 2 | |||
| def test_cv_minddataset_partition_num_samples_0(add_and_remove_cv_file): | |||
| """tutorial for cv minddataset.""" | |||
| columns_list = ["data", "file_name", "label"] | |||
| num_readers = 4 | |||
| def partitions(num_shards): | |||
| for partition_id in range(num_shards): | |||
| data_set = ds.MindDataset(CV_FILE_NAME + "0", columns_list, num_readers, | |||
| num_shards=num_shards, | |||
| shard_id=partition_id, num_samples=1) | |||
| num_iter = 0 | |||
| for item in data_set.create_dict_iterator(): | |||
| logger.info("-------------- partition : {} ------------------------".format(partition_id)) | |||
| logger.info("-------------- item[file_name]: {}-----------------------".format(item["file_name"])) | |||
| logger.info("-------------- item[label]: {} -----------------------".format(item["label"])) | |||
| num_iter += 1 | |||
| return num_iter | |||
| assert partitions(4) == 1 | |||
| assert partitions(5) == 1 | |||
| assert partitions(9) == 1 | |||
| def test_cv_minddataset_partition_num_samples_1(add_and_remove_cv_file): | |||
| """tutorial for cv minddataset.""" | |||
| columns_list = ["data", "file_name", "label"] | |||
| num_readers = 4 | |||
| def partitions(num_shards): | |||
| for partition_id in range(num_shards): | |||
| data_set = ds.MindDataset(CV_FILE_NAME + "0", columns_list, num_readers, | |||
| num_shards=num_shards, | |||
| shard_id=partition_id, num_samples=2) | |||
| num_iter = 0 | |||
| for item in data_set.create_dict_iterator(): | |||
| logger.info("-------------- partition : {} ------------------------".format(partition_id)) | |||
| logger.info("-------------- item[file_name]: {}-----------------------".format(item["file_name"])) | |||
| logger.info("-------------- item[label]: {} -----------------------".format(item["label"])) | |||
| num_iter += 1 | |||
| return num_iter | |||
| assert partitions(4) == 2 | |||
| assert partitions(5) == 2 | |||
| assert partitions(9) == 2 | |||
| def test_cv_minddataset_partition_num_samples_2(add_and_remove_cv_file): | |||
| """tutorial for cv minddataset.""" | |||
| columns_list = ["data", "file_name", "label"] | |||
| num_readers = 4 | |||
| def partitions(num_shards): | |||
| for partition_id in range(num_shards): | |||
| data_set = ds.MindDataset(CV_FILE_NAME + "0", columns_list, num_readers, | |||
| num_shards=num_shards, | |||
| shard_id=partition_id, num_samples=3) | |||
| num_iter = 0 | |||
| for item in data_set.create_dict_iterator(): | |||
| logger.info("-------------- partition : {} ------------------------".format(partition_id)) | |||
| logger.info("-------------- item[file_name]: {}-----------------------".format(item["file_name"])) | |||
| logger.info("-------------- item[label]: {} -----------------------".format(item["label"])) | |||
| num_iter += 1 | |||
| return num_iter | |||
| assert partitions(4) == 3 | |||
| assert partitions(5) == 2 | |||
| assert partitions(9) == 2 | |||
| def test_cv_minddataset_partition_tutorial_check_shuffle_result(add_and_remove_cv_file): | |||
| """tutorial for cv minddataset.""" | |||
| @@ -228,3 +228,24 @@ def test_minddataset_shard_id_bigger_than_num_shard(): | |||
| os.remove(CV_FILE_NAME) | |||
| os.remove("{}.db".format(CV_FILE_NAME)) | |||
| def test_cv_minddataset_partition_num_samples_equals_0(): | |||
| """tutorial for cv minddataset.""" | |||
| create_cv_mindrecord(1) | |||
| columns_list = ["data", "label"] | |||
| num_readers = 4 | |||
| def partitions(num_shards): | |||
| for partition_id in range(num_shards): | |||
| data_set = ds.MindDataset(CV_FILE_NAME, columns_list, num_readers, | |||
| num_shards=num_shards, | |||
| shard_id=partition_id, num_samples=0) | |||
| num_iter = 0 | |||
| for _ in data_set.create_dict_iterator(): | |||
| num_iter += 1 | |||
| with pytest.raises(Exception) as error_info: | |||
| partitions(5) | |||
| assert 'num_samples should be a positive integer value, but got num_samples=0' in str(error_info) | |||
| os.remove(CV_FILE_NAME) | |||
| os.remove("{}.db".format(CV_FILE_NAME)) | |||