Browse Source

activate num_samples in distributed samplers

tags/v0.6.0-beta
liyong 5 years ago
parent
commit
e2ea1fa0df
8 changed files with 108 additions and 11 deletions
  1. +1
    -1
      mindspore/ccsrc/minddata/dataset/api/python_bindings.cc
  2. +3
    -2
      mindspore/ccsrc/minddata/mindrecord/include/shard_distributed_sample.h
  3. +1
    -1
      mindspore/ccsrc/minddata/mindrecord/include/shard_sample.h
  4. +5
    -4
      mindspore/ccsrc/minddata/mindrecord/meta/shard_distributed_sample.cc
  5. +8
    -2
      mindspore/ccsrc/minddata/mindrecord/meta/shard_sample.cc
  6. +3
    -1
      mindspore/dataset/engine/samplers.py
  7. +66
    -0
      tests/ut/python/dataset/test_minddataset.py
  8. +21
    -0
      tests/ut/python/dataset/test_minddataset_exception.py

+ 1
- 1
mindspore/ccsrc/minddata/dataset/api/python_bindings.cc View File

@@ -784,7 +784,7 @@ void bindSamplerOps(py::module *m) {

(void)py::class_<mindrecord::ShardDistributedSample, mindrecord::ShardSample,
std::shared_ptr<mindrecord::ShardDistributedSample>>(*m, "MindrecordDistributedSampler")
.def(py::init<int64_t, int64_t, bool, uint32_t>());
.def(py::init<int64_t, int64_t, bool, uint32_t, int64_t>());

(void)py::class_<mindrecord::ShardShuffle, mindrecord::ShardOperator, std::shared_ptr<mindrecord::ShardShuffle>>(
*m, "MindrecordRandomSampler")


+ 3
- 2
mindspore/ccsrc/minddata/mindrecord/include/shard_distributed_sample.h View File

@@ -29,9 +29,10 @@ namespace mindspore {
namespace mindrecord {
class ShardDistributedSample : public ShardSample {
public:
ShardDistributedSample(int num_shards, int shard_id, int no_of_padded_samples, bool shuffle, uint32_t seed);
ShardDistributedSample(int num_shards, int shard_id, int no_of_padded_samples, bool shuffle, uint32_t seed,
int no_of_samples = 0);

ShardDistributedSample(int num_shards, int shard_id, bool shuffle, uint32_t seed);
ShardDistributedSample(int num_shards, int shard_id, bool shuffle, uint32_t seed, int no_of_samples = 0);

void SetNumPaddedSamples(int no_of_padded_samples) { no_of_padded_samples_ = no_of_padded_samples; }



+ 1
- 1
mindspore/ccsrc/minddata/mindrecord/include/shard_sample.h View File

@@ -32,7 +32,7 @@ class ShardSample : public ShardOperator {

ShardSample(int num, int den);

ShardSample(int num, int den, int par);
ShardSample(int num, int den, int par, int no_of_samples = 0);

ShardSample(const std::vector<int64_t> &indices, uint32_t seed);



+ 5
- 4
mindspore/ccsrc/minddata/mindrecord/meta/shard_distributed_sample.cc View File

@@ -23,16 +23,17 @@ using mindspore::MsLogLevel::ERROR;
namespace mindspore {
namespace mindrecord {
ShardDistributedSample::ShardDistributedSample(int num_shards, int shard_id, int no_of_padded_samples, bool shuffle,
uint32_t seed)
: ShardSample(1, num_shards, shard_id),
uint32_t seed, int no_of_samples)
: ShardSample(1, num_shards, shard_id, no_of_samples),
shuffle_(shuffle),
no_of_padded_samples_(no_of_padded_samples),
first_epoch_(true) {
shuffle_op_ = std::make_shared<ShardShuffle>(seed, kShuffleSample);
}

ShardDistributedSample::ShardDistributedSample(int num_shards, int shard_id, bool shuffle, uint32_t seed)
: ShardDistributedSample(num_shards, shard_id, 0, shuffle, seed) {}
ShardDistributedSample::ShardDistributedSample(int num_shards, int shard_id, bool shuffle, uint32_t seed,
int no_of_samples)
: ShardDistributedSample(num_shards, shard_id, 0, shuffle, seed, no_of_samples) {}

int64_t ShardDistributedSample::GetNumSamples(int64_t dataset_size, int64_t num_classes) {
if (no_of_padded_samples_ <= 0) {


+ 8
- 2
mindspore/ccsrc/minddata/mindrecord/meta/shard_sample.cc View File

@@ -38,11 +38,11 @@ ShardSample::ShardSample(int num, int den)
indices_({}),
sampler_type_(kCustomTopPercentSampler) {}

ShardSample::ShardSample(int num, int den, int par)
ShardSample::ShardSample(int num, int den, int par, int no_of_samples)
: numerator_(num),
denominator_(den),
partition_id_(par),
no_of_samples_(0),
no_of_samples_(no_of_samples),
indices_({}),
sampler_type_(kCustomTopPercentSampler) {}

@@ -110,8 +110,11 @@ MSRStatus ShardSample::Execute(ShardTask &tasks) {
new_tasks.InsertTask(tasks.GetTaskByID(index)); // different mod result between c and python
}
} else {
int count = 0;
for (int i = partition_id_ * taking; i < (partition_id_ + 1) * taking; i++) {
if (no_of_samples_ != 0 && count == no_of_samples_) break;
new_tasks.InsertTask(tasks.GetTaskByID(i % total_no)); // rounding up. if overflow, go back to start
count++;
}
}
std::swap(tasks, new_tasks);
@@ -121,8 +124,11 @@ MSRStatus ShardSample::Execute(ShardTask &tasks) {
return FAILED;
}
total_no = static_cast<int>(tasks.permutation_.size());
int count = 0;
for (size_t i = partition_id_ * taking; i < (partition_id_ + 1) * taking; i++) {
if (no_of_samples_ != 0 && count == no_of_samples_) break;
new_tasks.InsertTask(tasks.GetTaskByID(tasks.permutation_[i % total_no]));
count++;
}
std::swap(tasks, new_tasks);
}


+ 3
- 1
mindspore/dataset/engine/samplers.py View File

@@ -270,7 +270,9 @@ class DistributedSampler(BuiltinSampler):
return c_sampler

def create_for_minddataset(self):
c_sampler = cde.MindrecordDistributedSampler(self.num_shards, self.shard_id, self.shuffle, self.seed)
num_samples = self.num_samples if self.num_samples is not None else 0
c_sampler = cde.MindrecordDistributedSampler(self.num_shards, self.shard_id, self.shuffle,
self.seed, num_samples)
c_child_sampler = self.create_child_for_minddataset()
c_sampler.add_child(c_child_sampler)
return c_sampler


+ 66
- 0
tests/ut/python/dataset/test_minddataset.py View File

@@ -238,6 +238,72 @@ def test_cv_minddataset_partition_tutorial(add_and_remove_cv_file):
assert partitions(5) == 2
assert partitions(9) == 2

def test_cv_minddataset_partition_num_samples_0(add_and_remove_cv_file):
"""tutorial for cv minddataset."""
columns_list = ["data", "file_name", "label"]
num_readers = 4

def partitions(num_shards):
for partition_id in range(num_shards):
data_set = ds.MindDataset(CV_FILE_NAME + "0", columns_list, num_readers,
num_shards=num_shards,
shard_id=partition_id, num_samples=1)
num_iter = 0
for item in data_set.create_dict_iterator():
logger.info("-------------- partition : {} ------------------------".format(partition_id))
logger.info("-------------- item[file_name]: {}-----------------------".format(item["file_name"]))
logger.info("-------------- item[label]: {} -----------------------".format(item["label"]))
num_iter += 1
return num_iter

assert partitions(4) == 1
assert partitions(5) == 1
assert partitions(9) == 1

def test_cv_minddataset_partition_num_samples_1(add_and_remove_cv_file):
"""tutorial for cv minddataset."""
columns_list = ["data", "file_name", "label"]
num_readers = 4

def partitions(num_shards):
for partition_id in range(num_shards):
data_set = ds.MindDataset(CV_FILE_NAME + "0", columns_list, num_readers,
num_shards=num_shards,
shard_id=partition_id, num_samples=2)
num_iter = 0
for item in data_set.create_dict_iterator():
logger.info("-------------- partition : {} ------------------------".format(partition_id))
logger.info("-------------- item[file_name]: {}-----------------------".format(item["file_name"]))
logger.info("-------------- item[label]: {} -----------------------".format(item["label"]))
num_iter += 1
return num_iter

assert partitions(4) == 2
assert partitions(5) == 2
assert partitions(9) == 2

def test_cv_minddataset_partition_num_samples_2(add_and_remove_cv_file):
"""tutorial for cv minddataset."""
columns_list = ["data", "file_name", "label"]
num_readers = 4

def partitions(num_shards):
for partition_id in range(num_shards):
data_set = ds.MindDataset(CV_FILE_NAME + "0", columns_list, num_readers,
num_shards=num_shards,
shard_id=partition_id, num_samples=3)
num_iter = 0
for item in data_set.create_dict_iterator():
logger.info("-------------- partition : {} ------------------------".format(partition_id))
logger.info("-------------- item[file_name]: {}-----------------------".format(item["file_name"]))
logger.info("-------------- item[label]: {} -----------------------".format(item["label"]))
num_iter += 1
return num_iter

assert partitions(4) == 3
assert partitions(5) == 2
assert partitions(9) == 2


def test_cv_minddataset_partition_tutorial_check_shuffle_result(add_and_remove_cv_file):
"""tutorial for cv minddataset."""


+ 21
- 0
tests/ut/python/dataset/test_minddataset_exception.py View File

@@ -228,3 +228,24 @@ def test_minddataset_shard_id_bigger_than_num_shard():

os.remove(CV_FILE_NAME)
os.remove("{}.db".format(CV_FILE_NAME))

def test_cv_minddataset_partition_num_samples_equals_0():
"""tutorial for cv minddataset."""
create_cv_mindrecord(1)
columns_list = ["data", "label"]
num_readers = 4

def partitions(num_shards):
for partition_id in range(num_shards):
data_set = ds.MindDataset(CV_FILE_NAME, columns_list, num_readers,
num_shards=num_shards,
shard_id=partition_id, num_samples=0)
num_iter = 0
for _ in data_set.create_dict_iterator():
num_iter += 1
with pytest.raises(Exception) as error_info:
partitions(5)
assert 'num_samples should be a positive integer value, but got num_samples=0' in str(error_info)

os.remove(CV_FILE_NAME)
os.remove("{}.db".format(CV_FILE_NAME))

Loading…
Cancel
Save