Merge pull request !603 from liyong126/update_pk_samplertags/v0.3.0-alpha
| @@ -435,12 +435,12 @@ void bindSamplerOps(py::module *m) { | |||||
| .def(py::init<std::vector<int64_t>, uint32_t>(), py::arg("indices"), py::arg("seed") = GetSeed()); | .def(py::init<std::vector<int64_t>, uint32_t>(), py::arg("indices"), py::arg("seed") = GetSeed()); | ||||
| (void)py::class_<mindrecord::ShardPkSample, mindrecord::ShardOperator, std::shared_ptr<mindrecord::ShardPkSample>>( | (void)py::class_<mindrecord::ShardPkSample, mindrecord::ShardOperator, std::shared_ptr<mindrecord::ShardPkSample>>( | ||||
| *m, "MindrecordPkSampler") | *m, "MindrecordPkSampler") | ||||
| .def(py::init([](int64_t kVal, bool shuffle) { | |||||
| .def(py::init([](int64_t kVal, std::string kColumn, bool shuffle) { | |||||
| if (shuffle == true) { | if (shuffle == true) { | ||||
| return std::make_shared<mindrecord::ShardPkSample>("label", kVal, std::numeric_limits<int64_t>::max(), | |||||
| return std::make_shared<mindrecord::ShardPkSample>(kColumn, kVal, std::numeric_limits<int64_t>::max(), | |||||
| GetSeed()); | GetSeed()); | ||||
| } else { | } else { | ||||
| return std::make_shared<mindrecord::ShardPkSample>("label", kVal); | |||||
| return std::make_shared<mindrecord::ShardPkSample>(kColumn, kVal); | |||||
| } | } | ||||
| })); | })); | ||||
| @@ -316,6 +316,10 @@ MSRStatus ShardReader::ReadAllRowsInShard(int shard_id, const std::string &sql, | |||||
| } | } | ||||
| MSRStatus ShardReader::GetAllClasses(const std::string &category_field, std::set<std::string> &categories) { | MSRStatus ShardReader::GetAllClasses(const std::string &category_field, std::set<std::string> &categories) { | ||||
| if (column_schema_id_.find(category_field) == column_schema_id_.end()) { | |||||
| MS_LOG(ERROR) << "Field " << category_field << " does not exist."; | |||||
| return FAILED; | |||||
| } | |||||
| auto ret = ShardIndexGenerator::GenerateFieldName(std::make_pair(column_schema_id_[category_field], category_field)); | auto ret = ShardIndexGenerator::GenerateFieldName(std::make_pair(column_schema_id_[category_field], category_field)); | ||||
| if (SUCCESS != ret.first) { | if (SUCCESS != ret.first) { | ||||
| return FAILED; | return FAILED; | ||||
| @@ -719,6 +723,11 @@ int64_t ShardReader::GetNumClasses(const std::string &file_path, const std::stri | |||||
| for (auto &field : index_fields) { | for (auto &field : index_fields) { | ||||
| map_schema_id_fields[field.second] = field.first; | map_schema_id_fields[field.second] = field.first; | ||||
| } | } | ||||
| if (map_schema_id_fields.find(category_field) == map_schema_id_fields.end()) { | |||||
| MS_LOG(ERROR) << "Field " << category_field << " does not exist."; | |||||
| return -1; | |||||
| } | |||||
| auto ret = | auto ret = | ||||
| ShardIndexGenerator::GenerateFieldName(std::make_pair(map_schema_id_fields[category_field], category_field)); | ShardIndexGenerator::GenerateFieldName(std::make_pair(map_schema_id_fields[category_field], category_field)); | ||||
| if (SUCCESS != ret.first) { | if (SUCCESS != ret.first) { | ||||
| @@ -38,7 +38,7 @@ MSRStatus ShardCategory::execute(ShardTask &tasks) { return SUCCESS; } | |||||
| int64_t ShardCategory::GetNumSamples(int64_t dataset_size, int64_t num_classes) { | int64_t ShardCategory::GetNumSamples(int64_t dataset_size, int64_t num_classes) { | ||||
| if (dataset_size == 0) return dataset_size; | if (dataset_size == 0) return dataset_size; | ||||
| if (dataset_size > 0 && num_categories_ > 0 && num_elements_ > 0) { | |||||
| if (dataset_size > 0 && num_classes > 0 && num_categories_ > 0 && num_elements_ > 0) { | |||||
| return std::min(num_categories_, num_classes) * num_elements_; | return std::min(num_categories_, num_classes) * num_elements_; | ||||
| } | } | ||||
| return -1; | return -1; | ||||
| @@ -152,6 +152,7 @@ class PKSampler(BuiltinSampler): | |||||
| num_val (int): Number of elements to sample for each class. | num_val (int): Number of elements to sample for each class. | ||||
| num_class (int, optional): Number of classes to sample (default=None, all classes). | num_class (int, optional): Number of classes to sample (default=None, all classes). | ||||
| shuffle (bool, optional): If true, the class IDs are shuffled (default=False). | shuffle (bool, optional): If true, the class IDs are shuffled (default=False). | ||||
| class_column (str, optional): Name of column to classify dataset(default='label'), for MindDataset. | |||||
| Examples: | Examples: | ||||
| >>> import mindspore.dataset as ds | >>> import mindspore.dataset as ds | ||||
| @@ -168,7 +169,7 @@ class PKSampler(BuiltinSampler): | |||||
| ValueError: If shuffle is not boolean. | ValueError: If shuffle is not boolean. | ||||
| """ | """ | ||||
| def __init__(self, num_val, num_class=None, shuffle=False): | |||||
| def __init__(self, num_val, num_class=None, shuffle=False, class_column='label'): | |||||
| if num_val <= 0: | if num_val <= 0: | ||||
| raise ValueError("num_val should be a positive integer value, but got num_val={}".format(num_val)) | raise ValueError("num_val should be a positive integer value, but got num_val={}".format(num_val)) | ||||
| @@ -180,12 +181,16 @@ class PKSampler(BuiltinSampler): | |||||
| self.num_val = num_val | self.num_val = num_val | ||||
| self.shuffle = shuffle | self.shuffle = shuffle | ||||
| self.class_column = class_column # work for minddataset | |||||
| def create(self): | def create(self): | ||||
| return cde.PKSampler(self.num_val, self.shuffle) | return cde.PKSampler(self.num_val, self.shuffle) | ||||
| def _create_for_minddataset(self): | def _create_for_minddataset(self): | ||||
| return cde.MindrecordPkSampler(self.num_val, self.shuffle) | |||||
| if not self.class_column or not isinstance(self.class_column, str): | |||||
| raise ValueError("class_column should be a not empty string value, \ | |||||
| but got class_column={}".format(class_column)) | |||||
| return cde.MindrecordPkSampler(self.num_val, self.class_column, self.shuffle) | |||||
| class RandomSampler(BuiltinSampler): | class RandomSampler(BuiltinSampler): | ||||
| """ | """ | ||||
| @@ -82,3 +82,18 @@ def test_minddataset_lack_db(): | |||||
| num_iter += 1 | num_iter += 1 | ||||
| assert num_iter == 0 | assert num_iter == 0 | ||||
| os.remove(CV_FILE_NAME) | os.remove(CV_FILE_NAME) | ||||
| def test_cv_minddataset_pk_sample_error_class_column(): | |||||
| create_cv_mindrecord(1) | |||||
| columns_list = ["data", "file_name", "label"] | |||||
| num_readers = 4 | |||||
| sampler = ds.PKSampler(5, None, True, 'no_exsit_column') | |||||
| with pytest.raises(Exception, match="MindRecordOp launch failed"): | |||||
| data_set = ds.MindDataset(CV_FILE_NAME, columns_list, num_readers, sampler=sampler) | |||||
| num_iter = 0 | |||||
| for item in data_set.create_dict_iterator(): | |||||
| num_iter += 1 | |||||
| os.remove(CV_FILE_NAME) | |||||
| os.remove("{}.db".format(CV_FILE_NAME)) | |||||