Browse Source

!6300 [MD] fix get_dataset_size bug when set num_samples in DistributedSampler

Merge pull request !6300 from liyong126/fix_md_num_samples
tags/v1.0.0
mindspore-ci-bot Gitee 5 years ago
parent
commit
bef2118253
2 changed files with 32 additions and 12 deletions
  1. +4
    -2
      mindspore/ccsrc/minddata/mindrecord/meta/shard_distributed_sample.cc
  2. +28
    -10
      tests/ut/python/dataset/test_minddataset.py

+ 4
- 2
mindspore/ccsrc/minddata/mindrecord/meta/shard_distributed_sample.cc View File

@@ -37,11 +37,13 @@ ShardDistributedSample::ShardDistributedSample(int num_shards, int shard_id, boo

int64_t ShardDistributedSample::GetNumSamples(int64_t dataset_size, int64_t num_classes) {
if (no_of_padded_samples_ <= 0) {
int64_t res = 0;
if (dataset_size % denominator_ == 0) {
return dataset_size / denominator_ * numerator_;
res = dataset_size / denominator_ * numerator_;
} else {
return dataset_size / denominator_ * numerator_ + 1;
res = dataset_size / denominator_ * numerator_ + 1;
}
return no_of_samples_ == 0 ? res : std::min(static_cast<int64_t>(no_of_samples_), res);
} else {
auto padded_size = dataset_size + no_of_padded_samples_;
if (padded_size % denominator_ == 0) {


+ 28
- 10
tests/ut/python/dataset/test_minddataset.py View File

@@ -278,6 +278,8 @@ def test_cv_minddataset_partition_num_samples_0(add_and_remove_cv_file):
data_set = ds.MindDataset(CV_FILE_NAME + "0", columns_list, num_readers,
num_shards=num_shards,
shard_id=partition_id, num_samples=1)

assert data_set.get_dataset_size() == 1
num_iter = 0
for item in data_set.create_dict_iterator(num_epochs=1, output_numpy=True):
logger.info("-------------- partition : {} ------------------------".format(partition_id))
@@ -301,6 +303,8 @@ def test_cv_minddataset_partition_num_samples_1(add_and_remove_cv_file):
data_set = ds.MindDataset(CV_FILE_NAME + "0", columns_list, num_readers,
num_shards=num_shards,
shard_id=partition_id, num_samples=2)

assert data_set.get_dataset_size() == 2
num_iter = 0
for item in data_set.create_dict_iterator(num_epochs=1, output_numpy=True):
logger.info("-------------- partition : {} ------------------------".format(partition_id))
@@ -319,11 +323,13 @@ def test_cv_minddataset_partition_num_samples_2(add_and_remove_cv_file):
columns_list = ["data", "file_name", "label"]
num_readers = 4

def partitions(num_shards):
def partitions(num_shards, expect):
for partition_id in range(num_shards):
data_set = ds.MindDataset(CV_FILE_NAME + "0", columns_list, num_readers,
num_shards=num_shards,
shard_id=partition_id, num_samples=3)

assert data_set.get_dataset_size() == expect
num_iter = 0
for item in data_set.create_dict_iterator(num_epochs=1, output_numpy=True):
logger.info("-------------- partition : {} ------------------------".format(partition_id))
@@ -332,10 +338,25 @@ def test_cv_minddataset_partition_num_samples_2(add_and_remove_cv_file):
num_iter += 1
return num_iter

assert partitions(4) == 3
assert partitions(5) == 2
assert partitions(9) == 2
assert partitions(4, 3) == 3
assert partitions(5, 2) == 2
assert partitions(9, 2) == 2

def test_cv_minddataset_partition_num_samples_3(add_and_remove_cv_file):
"""tutorial for cv minddataset."""
columns_list = ["data", "file_name", "label"]
num_readers = 4

data_set = ds.MindDataset(CV_FILE_NAME + "0", columns_list, num_readers, num_shards=1, shard_id=0, num_samples=5)

assert data_set.get_dataset_size() == 5
num_iter = 0
for item in data_set.create_dict_iterator(num_epochs=1, output_numpy=True):
logger.info("-------------- item[file_name]: {}-----------------------".format(item["file_name"]))
logger.info("-------------- item[label]: {} -----------------------".format(item["label"]))
num_iter += 1

assert num_iter == 5

def test_cv_minddataset_partition_tutorial_check_shuffle_result(add_and_remove_cv_file):
"""tutorial for cv minddataset."""
@@ -841,13 +862,10 @@ def test_cv_minddataset_reader_basic_tutorial_5_epoch_with_batch(add_and_remove_

# define map operations
decode_op = vision.Decode()
resize_op = vision.Resize(
(resize_height, resize_width), ds.transforms.vision.Inter.LINEAR)
resize_op = vision.Resize((resize_height, resize_width))

data_set = data_set.map(
input_columns=["data"], operations=decode_op, num_parallel_workers=4)
data_set = data_set.map(
input_columns=["data"], operations=resize_op, num_parallel_workers=4)
data_set = data_set.map(input_columns=["data"], operations=decode_op, num_parallel_workers=4)
data_set = data_set.map(input_columns=["data"], operations=resize_op, num_parallel_workers=4)

data_set = data_set.batch(2)
assert data_set.get_dataset_size() == 5


Loading…
Cancel
Save