Merge pull request !969 from liyong126/fix_mindrecord_pk_samplertags/v0.3.0-alpha
| @@ -316,11 +316,15 @@ MSRStatus ShardReader::ReadAllRowsInShard(int shard_id, const std::string &sql, | |||||
| } | } | ||||
| MSRStatus ShardReader::GetAllClasses(const std::string &category_field, std::set<std::string> &categories) { | MSRStatus ShardReader::GetAllClasses(const std::string &category_field, std::set<std::string> &categories) { | ||||
| if (column_schema_id_.find(category_field) == column_schema_id_.end()) { | |||||
| MS_LOG(ERROR) << "Field " << category_field << " does not exist."; | |||||
| std::map<std::string, uint64_t> index_columns; | |||||
| for (auto &field : get_shard_header()->get_fields()) { | |||||
| index_columns[field.second] = field.first; | |||||
| } | |||||
| if (index_columns.find(category_field) == index_columns.end()) { | |||||
| MS_LOG(ERROR) << "Index field " << category_field << " does not exist."; | |||||
| return FAILED; | return FAILED; | ||||
| } | } | ||||
| auto ret = ShardIndexGenerator::GenerateFieldName(std::make_pair(column_schema_id_[category_field], category_field)); | |||||
| auto ret = ShardIndexGenerator::GenerateFieldName(std::make_pair(index_columns[category_field], category_field)); | |||||
| if (SUCCESS != ret.first) { | if (SUCCESS != ret.first) { | ||||
| return FAILED; | return FAILED; | ||||
| } | } | ||||
| @@ -2224,8 +2224,8 @@ class MindDataset(SourceDataset): | |||||
| if block_reader is True and sampler is not None: | if block_reader is True and sampler is not None: | ||||
| raise ValueError("block reader not allowed true when use sampler") | raise ValueError("block reader not allowed true when use sampler") | ||||
| if shuffle is True and sampler is not None: | |||||
| raise ValueError("shuffle not allowed true when use sampler") | |||||
| if shuffle is not None and sampler is not None: | |||||
| raise ValueError("shuffle not allowed when use sampler") | |||||
| if block_reader is False and sampler is None: | if block_reader is False and sampler is None: | ||||
| self.global_shuffle = not bool(shuffle is False) | self.global_shuffle = not bool(shuffle is False) | ||||
| @@ -97,3 +97,17 @@ def test_cv_minddataset_pk_sample_error_class_column(): | |||||
| os.remove(CV_FILE_NAME) | os.remove(CV_FILE_NAME) | ||||
| os.remove("{}.db".format(CV_FILE_NAME)) | os.remove("{}.db".format(CV_FILE_NAME)) | ||||
| def test_cv_minddataset_pk_sample_exclusive_shuffle(): | |||||
| create_cv_mindrecord(1) | |||||
| columns_list = ["data", "file_name", "label"] | |||||
| num_readers = 4 | |||||
| sampler = ds.PKSampler(2) | |||||
| with pytest.raises(Exception, match="shuffle not allowed when use sampler"): | |||||
| data_set = ds.MindDataset(CV_FILE_NAME, columns_list, num_readers, | |||||
| sampler=sampler, shuffle=False) | |||||
| num_iter = 0 | |||||
| for item in data_set.create_dict_iterator(): | |||||
| num_iter += 1 | |||||
| os.remove(CV_FILE_NAME) | |||||
| os.remove("{}.db".format(CV_FILE_NAME)) | |||||
| @@ -60,7 +60,21 @@ def add_and_remove_cv_file(): | |||||
| os.remove("{}".format(x)) | os.remove("{}".format(x)) | ||||
| os.remove("{}.db".format(x)) | os.remove("{}.db".format(x)) | ||||
| def test_cv_minddataset_pk_sample_no_column(add_and_remove_cv_file): | |||||
| """tutorial for cv minderdataset.""" | |||||
| num_readers = 4 | |||||
| sampler = ds.PKSampler(2) | |||||
| data_set = ds.MindDataset(CV_FILE_NAME + "0", None, num_readers, | |||||
| sampler=sampler) | |||||
| assert data_set.get_dataset_size() == 6 | |||||
| num_iter = 0 | |||||
| for item in data_set.create_dict_iterator(): | |||||
| logger.info("-------------- cv reader basic: {} ------------------------".format(num_iter)) | |||||
| logger.info("-------------- item[file_name]: \ | |||||
| {}------------------------".format("".join([chr(x) for x in item["file_name"]]))) | |||||
| logger.info("-------------- item[label]: {} ----------------------------".format(item["label"])) | |||||
| num_iter += 1 | |||||
| def test_cv_minddataset_pk_sample_basic(add_and_remove_cv_file): | def test_cv_minddataset_pk_sample_basic(add_and_remove_cv_file): | ||||
| """tutorial for cv minderdataset.""" | """tutorial for cv minderdataset.""" | ||||
| columns_list = ["data", "file_name", "label"] | columns_list = ["data", "file_name", "label"] | ||||