Browse Source

!969 [MD] fix bug in pk sampler of minddataset

Merge pull request !969 from liyong126/fix_mindrecord_pk_sampler
tags/v0.3.0-alpha
mindspore-ci-bot Gitee 5 years ago
parent
commit
16930c562d
4 changed files with 37 additions and 5 deletions
  1. +7
    -3
      mindspore/ccsrc/mindrecord/io/shard_reader.cc
  2. +2
    -2
      mindspore/dataset/engine/datasets.py
  3. +14
    -0
      tests/ut/python/dataset/test_minddataset_exception.py
  4. +14
    -0
      tests/ut/python/dataset/test_minddataset_sampler.py

+ 7
- 3
mindspore/ccsrc/mindrecord/io/shard_reader.cc View File

@@ -316,11 +316,15 @@ MSRStatus ShardReader::ReadAllRowsInShard(int shard_id, const std::string &sql,
}

MSRStatus ShardReader::GetAllClasses(const std::string &category_field, std::set<std::string> &categories) {
if (column_schema_id_.find(category_field) == column_schema_id_.end()) {
MS_LOG(ERROR) << "Field " << category_field << " does not exist.";
std::map<std::string, uint64_t> index_columns;
for (auto &field : get_shard_header()->get_fields()) {
index_columns[field.second] = field.first;
}
if (index_columns.find(category_field) == index_columns.end()) {
MS_LOG(ERROR) << "Index field " << category_field << " does not exist.";
return FAILED;
}
auto ret = ShardIndexGenerator::GenerateFieldName(std::make_pair(column_schema_id_[category_field], category_field));
auto ret = ShardIndexGenerator::GenerateFieldName(std::make_pair(index_columns[category_field], category_field));
if (SUCCESS != ret.first) {
return FAILED;
}


+ 2
- 2
mindspore/dataset/engine/datasets.py View File

@@ -2224,8 +2224,8 @@ class MindDataset(SourceDataset):
if block_reader is True and sampler is not None:
raise ValueError("block reader not allowed true when use sampler")

if shuffle is True and sampler is not None:
raise ValueError("shuffle not allowed true when use sampler")
if shuffle is not None and sampler is not None:
raise ValueError("shuffle not allowed when use sampler")

if block_reader is False and sampler is None:
self.global_shuffle = not bool(shuffle is False)


+ 14
- 0
tests/ut/python/dataset/test_minddataset_exception.py View File

@@ -97,3 +97,17 @@ def test_cv_minddataset_pk_sample_error_class_column():
os.remove(CV_FILE_NAME)
os.remove("{}.db".format(CV_FILE_NAME))

def test_cv_minddataset_pk_sample_exclusive_shuffle():
create_cv_mindrecord(1)
columns_list = ["data", "file_name", "label"]
num_readers = 4
sampler = ds.PKSampler(2)
with pytest.raises(Exception, match="shuffle not allowed when use sampler"):
data_set = ds.MindDataset(CV_FILE_NAME, columns_list, num_readers,
sampler=sampler, shuffle=False)
num_iter = 0
for item in data_set.create_dict_iterator():
num_iter += 1
os.remove(CV_FILE_NAME)
os.remove("{}.db".format(CV_FILE_NAME))


+ 14
- 0
tests/ut/python/dataset/test_minddataset_sampler.py View File

@@ -60,7 +60,21 @@ def add_and_remove_cv_file():
os.remove("{}".format(x))
os.remove("{}.db".format(x))

def test_cv_minddataset_pk_sample_no_column(add_and_remove_cv_file):
"""tutorial for cv minderdataset."""
num_readers = 4
sampler = ds.PKSampler(2)
data_set = ds.MindDataset(CV_FILE_NAME + "0", None, num_readers,
sampler=sampler)

assert data_set.get_dataset_size() == 6
num_iter = 0
for item in data_set.create_dict_iterator():
logger.info("-------------- cv reader basic: {} ------------------------".format(num_iter))
logger.info("-------------- item[file_name]: \
{}------------------------".format("".join([chr(x) for x in item["file_name"]])))
logger.info("-------------- item[label]: {} ----------------------------".format(item["label"]))
num_iter += 1
def test_cv_minddataset_pk_sample_basic(add_and_remove_cv_file):
"""tutorial for cv minderdataset."""
columns_list = ["data", "file_name", "label"]


Loading…
Cancel
Save