Merge pull request !1808 from Jamie/numsamplestags/v0.5.0-beta
| @@ -856,9 +856,7 @@ Status DEPipeline::ParseImageFolderOp(const py::dict &args, std::shared_ptr<Data | |||
| std::string key = py::str(arg.first); | |||
| py::handle value = arg.second; | |||
| if (!value.is_none()) { | |||
| if (key == "num_samples") { | |||
| (void)builder->SetNumSamples(ToInt(value)); | |||
| } else if (key == "num_parallel_workers") { | |||
| if (key == "num_parallel_workers") { | |||
| (void)builder->SetNumWorkers(ToInt(value)); | |||
| } else if (key == "sampler") { | |||
| auto create = py::reinterpret_borrow<py::object>(value).attr("create"); | |||
| @@ -893,9 +891,7 @@ Status DEPipeline::ParseManifestOp(const py::dict &args, std::shared_ptr<Dataset | |||
| std::string key = py::str(arg.first); | |||
| py::handle value = arg.second; | |||
| if (!value.is_none()) { | |||
| if (key == "num_samples") { | |||
| (void)builder->SetNumSamples(ToInt(value)); | |||
| } else if (key == "num_parallel_workers") { | |||
| if (key == "num_parallel_workers") { | |||
| (void)builder->SetNumWorkers(ToInt(value)); | |||
| } else if (key == "sampler") { | |||
| auto create = py::reinterpret_borrow<py::object>(value).attr("create"); | |||
| @@ -930,9 +926,7 @@ Status DEPipeline::ParseVOCOp(const py::dict &args, std::shared_ptr<DatasetOp> * | |||
| std::string key = py::str(arg.first); | |||
| py::handle value = arg.second; | |||
| if (!value.is_none()) { | |||
| if (key == "num_samples") { | |||
| (void)builder->SetNumSamples(ToInt(value)); | |||
| } else if (key == "num_parallel_workers") { | |||
| if (key == "num_parallel_workers") { | |||
| (void)builder->SetNumWorkers(ToInt(value)); | |||
| } else if (key == "sampler") { | |||
| auto create = py::reinterpret_borrow<py::object>(value).attr("create"); | |||
| @@ -966,9 +960,7 @@ Status DEPipeline::ParseCifar10Op(const py::dict &args, std::shared_ptr<DatasetO | |||
| std::string key = py::str(arg.first); | |||
| py::handle value = arg.second; | |||
| if (!value.is_none()) { | |||
| if (key == "num_samples") { | |||
| (void)builder->SetNumSamples(ToInt(value)); | |||
| } else if (key == "num_parallel_workers") { | |||
| if (key == "num_parallel_workers") { | |||
| (void)builder->SetNumWorkers(ToInt(value)); | |||
| } else if (key == "sampler") { | |||
| auto create = py::reinterpret_borrow<py::object>(value).attr("create"); | |||
| @@ -1001,9 +993,7 @@ Status DEPipeline::ParseCifar100Op(const py::dict &args, std::shared_ptr<Dataset | |||
| std::string key = py::str(arg.first); | |||
| py::handle value = arg.second; | |||
| if (!value.is_none()) { | |||
| if (key == "num_samples") { | |||
| (void)builder->SetNumSamples(ToInt(value)); | |||
| } else if (key == "num_parallel_workers") { | |||
| if (key == "num_parallel_workers") { | |||
| (void)builder->SetNumWorkers(ToInt(value)); | |||
| } else if (key == "sampler") { | |||
| auto create = py::reinterpret_borrow<py::object>(value).attr("create"); | |||
| @@ -1039,10 +1029,12 @@ Status DEPipeline::ParseRandomDataOp(const py::dict &args, std::shared_ptr<Datas | |||
| (void)builder.SetNumWorkers(ToInt(value)); | |||
| } else if (key == "schema_file_path" || key == "schema_json_string") { | |||
| schema_exists = true; | |||
| } else if (key == "num_samples") { | |||
| (void)builder.SetTotalRows(ToInt(value)); | |||
| } else if (key == "columns_list") { | |||
| columns_to_load = ToStringVector(value); | |||
| } else if (key == "num_samples") { | |||
| // This is not sampling here. The random data op needs to know how much data to | |||
| // generate. It does not currently support sampling. | |||
| (void)builder.SetTotalRows(ToInt(value)); | |||
| } | |||
| } | |||
| if (schema_exists) { | |||
| @@ -1077,9 +1069,7 @@ Status DEPipeline::ParseMnistOp(const py::dict &args, std::shared_ptr<DatasetOp> | |||
| std::string key = py::str(arg.first); | |||
| py::handle value = arg.second; | |||
| if (!value.is_none()) { | |||
| if (key == "num_samples") { | |||
| (void)builder->SetNumSamples(ToInt(value)); | |||
| } else if (key == "num_parallel_workers") { | |||
| if (key == "num_parallel_workers") { | |||
| (void)builder->SetNumWorkers(ToInt(value)); | |||
| } else if (key == "sampler") { | |||
| auto create = py::reinterpret_borrow<py::object>(value).attr("create"); | |||
| @@ -1121,8 +1111,6 @@ Status DEPipeline::ParseCelebAOp(const py::dict &args, std::shared_ptr<DatasetOp | |||
| (void)builder->SetDecode(ToBool(value)); | |||
| } else if (key == "extensions") { | |||
| (void)builder->SetExtensions(ToStringSet(value)); | |||
| } else if (key == "num_samples") { | |||
| (void)builder->SetNumSamples(ToInt(value)); | |||
| } else if (key == "dataset_type") { | |||
| (void)builder->SetDatasetType(ToString(value)); | |||
| } | |||
| @@ -1153,7 +1141,7 @@ Status DEPipeline::ParseTextFileOp(const py::dict &args, std::shared_ptr<Dataset | |||
| } else if (key == "shuffle_files") { | |||
| (void)builder->SetShuffleFiles(ToBool(value)); | |||
| } else if (key == "num_samples") { | |||
| (void)builder->SetNumSamples(ToInt(value)); | |||
| (void)builder->SetTotalRows(ToInt(value)); | |||
| } else if (key == "num_shards") { | |||
| (void)builder->SetNumDevices(ToInt(value)); | |||
| } else if (key == "shard_id") { | |||
| @@ -49,7 +49,6 @@ | |||
| #include "dataset/engine/datasetops/source/sampler/pk_sampler.h" | |||
| #include "dataset/engine/datasetops/source/sampler/random_sampler.h" | |||
| #include "dataset/engine/datasetops/source/sampler/sequential_sampler.h" | |||
| #include "dataset/engine/datasetops/source/sampler/subset_sampler.h" | |||
| #include "dataset/engine/datasetops/source/sampler/subset_random_sampler.h" | |||
| #include "dataset/engine/datasetops/source/sampler/weighted_random_sampler.h" | |||
| #include "dataset/engine/datasetops/source/sampler/python_sampler.h" | |||
| @@ -143,17 +142,16 @@ void bindDatasetOps(py::module *m) { | |||
| }); | |||
| (void)py::class_<CifarOp, DatasetOp, std::shared_ptr<CifarOp>>(*m, "CifarOp") | |||
| .def_static("get_num_rows", [](const std::string &dir, int64_t numSamples, bool isCifar10) { | |||
| .def_static("get_num_rows", [](const std::string &dir, bool isCifar10) { | |||
| int64_t count = 0; | |||
| THROW_IF_ERROR(CifarOp::CountTotalRows(dir, numSamples, isCifar10, &count)); | |||
| THROW_IF_ERROR(CifarOp::CountTotalRows(dir, isCifar10, &count)); | |||
| return count; | |||
| }); | |||
| (void)py::class_<ImageFolderOp, DatasetOp, std::shared_ptr<ImageFolderOp>>(*m, "ImageFolderOp") | |||
| .def_static("get_num_rows_and_classes", [](const std::string &path, int64_t numSamples) { | |||
| .def_static("get_num_rows_and_classes", [](const std::string &path) { | |||
| int64_t count = 0, num_classes = 0; | |||
| THROW_IF_ERROR( | |||
| ImageFolderOp::CountRowsAndClasses(path, numSamples, std::set<std::string>{}, &count, &num_classes)); | |||
| THROW_IF_ERROR(ImageFolderOp::CountRowsAndClasses(path, std::set<std::string>{}, &count, &num_classes)); | |||
| return py::make_tuple(count, num_classes); | |||
| }); | |||
| @@ -172,22 +170,21 @@ void bindDatasetOps(py::module *m) { | |||
| (void)py::class_<ManifestOp, DatasetOp, std::shared_ptr<ManifestOp>>(*m, "ManifestOp") | |||
| .def_static("get_num_rows_and_classes", | |||
| [](const std::string &file, int64_t numSamples, const py::dict &dict, const std::string &usage) { | |||
| [](const std::string &file, const py::dict &dict, const std::string &usage) { | |||
| int64_t count = 0, num_classes = 0; | |||
| THROW_IF_ERROR(ManifestOp::CountTotalRows(file, numSamples, dict, usage, &count, &num_classes)); | |||
| THROW_IF_ERROR(ManifestOp::CountTotalRows(file, dict, usage, &count, &num_classes)); | |||
| return py::make_tuple(count, num_classes); | |||
| }) | |||
| .def_static("get_class_indexing", | |||
| [](const std::string &file, int64_t numSamples, const py::dict &dict, const std::string &usage) { | |||
| std::map<std::string, int32_t> output_class_indexing; | |||
| THROW_IF_ERROR(ManifestOp::GetClassIndexing(file, numSamples, dict, usage, &output_class_indexing)); | |||
| return output_class_indexing; | |||
| }); | |||
| .def_static("get_class_indexing", [](const std::string &file, const py::dict &dict, const std::string &usage) { | |||
| std::map<std::string, int32_t> output_class_indexing; | |||
| THROW_IF_ERROR(ManifestOp::GetClassIndexing(file, dict, usage, &output_class_indexing)); | |||
| return output_class_indexing; | |||
| }); | |||
| (void)py::class_<MnistOp, DatasetOp, std::shared_ptr<MnistOp>>(*m, "MnistOp") | |||
| .def_static("get_num_rows", [](const std::string &dir, int64_t numSamples) { | |||
| .def_static("get_num_rows", [](const std::string &dir) { | |||
| int64_t count = 0; | |||
| THROW_IF_ERROR(MnistOp::CountTotalRows(dir, numSamples, &count)); | |||
| THROW_IF_ERROR(MnistOp::CountTotalRows(dir, &count)); | |||
| return count; | |||
| }); | |||
| @@ -206,13 +203,13 @@ void bindDatasetOps(py::module *m) { | |||
| [](const std::string &dir, const std::string &task_type, const std::string &task_mode, | |||
| const py::dict &dict, int64_t numSamples) { | |||
| int64_t count = 0; | |||
| THROW_IF_ERROR(VOCOp::CountTotalRows(dir, task_type, task_mode, dict, numSamples, &count)); | |||
| THROW_IF_ERROR(VOCOp::CountTotalRows(dir, task_type, task_mode, dict, &count)); | |||
| return count; | |||
| }) | |||
| .def_static("get_class_indexing", [](const std::string &dir, const std::string &task_type, | |||
| const std::string &task_mode, const py::dict &dict, int64_t numSamples) { | |||
| const std::string &task_mode, const py::dict &dict) { | |||
| std::map<std::string, int32_t> output_class_indexing; | |||
| THROW_IF_ERROR(VOCOp::GetClassIndexing(dir, task_type, task_mode, dict, numSamples, &output_class_indexing)); | |||
| THROW_IF_ERROR(VOCOp::GetClassIndexing(dir, task_type, task_mode, dict, &output_class_indexing)); | |||
| return output_class_indexing; | |||
| }); | |||
| } | |||
| @@ -452,25 +449,19 @@ void bindSamplerOps(py::module *m) { | |||
| (void)py::class_<mindrecord::ShardOperator, std::shared_ptr<mindrecord::ShardOperator>>(*m, "ShardOperator"); | |||
| (void)py::class_<DistributedSampler, Sampler, std::shared_ptr<DistributedSampler>>(*m, "DistributedSampler") | |||
| .def(py::init<int64_t, int64_t, bool, uint32_t>(), py::arg("numDev"), py::arg("devId"), py::arg("shuffle"), | |||
| py::arg("seed")); | |||
| .def(py::init<int64_t, int64_t, int64_t, bool, uint32_t>()); | |||
| (void)py::class_<PKSampler, Sampler, std::shared_ptr<PKSampler>>(*m, "PKSampler") | |||
| .def(py::init<int64_t, bool>(), py::arg("kVal"), py::arg("shuffle")); | |||
| .def(py::init<int64_t, int64_t, bool>()); | |||
| (void)py::class_<RandomSampler, Sampler, std::shared_ptr<RandomSampler>>(*m, "RandomSampler") | |||
| .def(py::init<bool, bool, int64_t>(), py::arg("replacement"), py::arg("reshuffle_each_epoch"), | |||
| py::arg("num_samples")) | |||
| .def(py::init<bool, bool>(), py::arg("replacement"), py::arg("reshuffle_each_epoch")); | |||
| .def(py::init<int64_t, bool, bool>()); | |||
| (void)py::class_<SequentialSampler, Sampler, std::shared_ptr<SequentialSampler>>(*m, "SequentialSampler") | |||
| .def(py::init<>()); | |||
| (void)py::class_<SubsetSampler, Sampler, std::shared_ptr<SubsetSampler>>(*m, "SubsetSampler") | |||
| .def(py::init<int64_t, int64_t>(), py::arg("start_index"), py::arg("subset_size")); | |||
| .def(py::init<int64_t, int64_t>()); | |||
| (void)py::class_<SubsetRandomSampler, Sampler, std::shared_ptr<SubsetRandomSampler>>(*m, "SubsetRandomSampler") | |||
| .def(py::init<std::vector<int64_t>>(), py::arg("indices")); | |||
| .def(py::init<int64_t, std::vector<int64_t>>()); | |||
| (void)py::class_<mindrecord::ShardSample, mindrecord::ShardOperator, std::shared_ptr<mindrecord::ShardSample>>( | |||
| *m, "MindrecordSubsetRandomSampler") | |||
| @@ -487,11 +478,10 @@ void bindSamplerOps(py::module *m) { | |||
| })); | |||
| (void)py::class_<WeightedRandomSampler, Sampler, std::shared_ptr<WeightedRandomSampler>>(*m, "WeightedRandomSampler") | |||
| .def(py::init<std::vector<double>, int64_t, bool>(), py::arg("weights"), py::arg("numSamples"), | |||
| py::arg("replacement")); | |||
| .def(py::init<int64_t, std::vector<double>, bool>()); | |||
| (void)py::class_<PythonSampler, Sampler, std::shared_ptr<PythonSampler>>(*m, "PythonSampler") | |||
| .def(py::init<py::object>(), py::arg("pySampler")); | |||
| .def(py::init<int64_t, py::object>()); | |||
| } | |||
| void bindInfoObjects(py::module *m) { | |||
| @@ -26,7 +26,7 @@ | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| CelebAOp::Builder::Builder() : builder_decode_(false), builder_sampler_(nullptr), builder_num_samples_(0) { | |||
| CelebAOp::Builder::Builder() : builder_decode_(false), builder_sampler_(nullptr) { | |||
| std::shared_ptr<ConfigManager> cfg = GlobalContext::config_manager(); | |||
| builder_num_workers_ = cfg->num_parallel_workers(); | |||
| builder_rows_per_buffer_ = cfg->rows_per_buffer(); | |||
| @@ -38,7 +38,9 @@ Status CelebAOp::Builder::Build(std::shared_ptr<CelebAOp> *op) { | |||
| MS_LOG(DEBUG) << "Celeba dataset type is " << builder_dataset_type_.c_str() << "."; | |||
| RETURN_IF_NOT_OK(SanityCheck()); | |||
| if (builder_sampler_ == nullptr) { | |||
| builder_sampler_ = std::make_shared<SequentialSampler>(); | |||
| int64_t num_samples = 0; | |||
| int64_t start_index = 0; | |||
| builder_sampler_ = std::make_shared<SequentialSampler>(start_index, num_samples); | |||
| } | |||
| builder_schema_ = std::make_unique<DataSchema>(); | |||
| @@ -47,10 +49,9 @@ Status CelebAOp::Builder::Build(std::shared_ptr<CelebAOp> *op) { | |||
| // label is like this:0 1 0 0 1...... | |||
| RETURN_IF_NOT_OK( | |||
| builder_schema_->AddColumn(ColDescriptor("attr", DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 1))); | |||
| *op = | |||
| std::make_shared<CelebAOp>(builder_num_workers_, builder_rows_per_buffer_, builder_dir_, builder_op_connector_size_, | |||
| builder_decode_, builder_dataset_type_, builder_extensions_, std::move(builder_schema_), | |||
| std::move(builder_sampler_), builder_num_samples_); | |||
| *op = std::make_shared<CelebAOp>(builder_num_workers_, builder_rows_per_buffer_, builder_dir_, | |||
| builder_op_connector_size_, builder_decode_, builder_dataset_type_, | |||
| builder_extensions_, std::move(builder_schema_), std::move(builder_sampler_)); | |||
| if (*op == nullptr) { | |||
| return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, "CelebAOp is null"); | |||
| } | |||
| @@ -68,7 +69,7 @@ Status CelebAOp::Builder::SanityCheck() { | |||
| CelebAOp::CelebAOp(int32_t num_workers, int32_t rows_per_buffer, const std::string &dir, int32_t queue_size, | |||
| bool decode, const std::string &dataset_type, const std::set<std::string> &exts, | |||
| std::unique_ptr<DataSchema> schema, std::shared_ptr<Sampler> sampler, int64_t num_samples) | |||
| std::unique_ptr<DataSchema> schema, std::shared_ptr<Sampler> sampler) | |||
| : ParallelOp(num_workers, queue_size), | |||
| rows_per_buffer_(rows_per_buffer), | |||
| folder_path_(dir), | |||
| @@ -77,8 +78,6 @@ CelebAOp::CelebAOp(int32_t num_workers, int32_t rows_per_buffer, const std::stri | |||
| data_schema_(std::move(schema)), | |||
| sampler_(std::move(sampler)), | |||
| num_rows_in_attr_file_(0), | |||
| num_rows_exact_(0), | |||
| num_samples_(num_samples), | |||
| dataset_type_(dataset_type) { | |||
| // Set the column name map (base class field) | |||
| for (int32_t index = 0; index < data_schema_->NumColumns(); index++) { | |||
| @@ -202,13 +201,6 @@ Status CelebAOp::ParseImageAttrInfo() { | |||
| RETURN_IF_NOT_OK(attr_info_queue_->PopFront(&image_infos)); | |||
| while (!image_infos.empty() && needMoreData) { | |||
| for (uint32_t index = 0; index < image_infos.size(); index++) { | |||
| if (num_samples_ != 0 && image_labels_vec_.size() >= num_samples_) { | |||
| MS_LOG(WARNING) << "Image number(" << image_labels_vec_.size() << " is more than" | |||
| << " rows num eval attr file(" << num_rows_in_attr_file_ << ") or num samples(" << num_samples_ | |||
| << ")."; | |||
| needMoreData = false; | |||
| break; | |||
| } | |||
| std::string image_info = image_infos[index]; | |||
| std::vector<std::string> split = Split(image_info); | |||
| std::pair<std::string, std::vector<int32_t>> image_labels; | |||
| @@ -239,14 +231,13 @@ Status CelebAOp::ParseImageAttrInfo() { | |||
| RETURN_IF_NOT_OK(attr_info_queue_->PopFront(&image_infos)); | |||
| } | |||
| num_rows_exact_ = image_labels_vec_.size(); | |||
| num_samples_ = (num_samples_ == 0 || num_samples_ > num_rows_exact_) ? num_rows_exact_ : num_samples_; | |||
| if (num_rows_exact_ == 0) { | |||
| num_rows_ = image_labels_vec_.size(); | |||
| if (num_rows_ == 0) { | |||
| RETURN_STATUS_UNEXPECTED( | |||
| "There is no valid data matching the dataset API CelebADataset.Please check file path or dataset API " | |||
| "validation first."); | |||
| } | |||
| MS_LOG(DEBUG) << "Celeba dataset rows number is " << num_rows_exact_ << "."; | |||
| MS_LOG(DEBUG) << "Celeba dataset rows number is " << num_rows_ << "."; | |||
| return Status::OK(); | |||
| } | |||
| @@ -268,28 +259,6 @@ std::vector<std::string> CelebAOp::Split(const std::string &line) { | |||
| return split; | |||
| } | |||
| // Derived from RandomAccessOp | |||
| Status CelebAOp::GetNumSamples(int64_t *num) const { | |||
| if (num == nullptr || num_samples_ == 0) { | |||
| RETURN_STATUS_UNEXPECTED( | |||
| "There is no valid data matching the dataset API CelebADataset.Please check file path or dataset API " | |||
| "validation first."); | |||
| } | |||
| (*num) = num_samples_; | |||
| return Status::OK(); | |||
| } | |||
| Status CelebAOp::GetNumRowsInDataset(int64_t *num) const { | |||
| if (num == nullptr || num_rows_exact_ == 0) { | |||
| RETURN_STATUS_UNEXPECTED( | |||
| "There is no valid data matching the dataset API CelebADataset.Please check file path or dataset API " | |||
| "validation first."); | |||
| } | |||
| *num = num_rows_exact_; | |||
| return Status::OK(); | |||
| } | |||
| // Main logic, Register Queue with TaskGroup, launch all threads and do the functor's work | |||
| Status CelebAOp::operator()() { | |||
| RETURN_IF_NOT_OK(LaunchThreadsAndInitOp()); | |||
| @@ -310,9 +279,8 @@ Status CelebAOp::AddIOBlock(std::unique_ptr<DataBuffer> *data_buffer) { | |||
| RETURN_IF_NOT_OK((*data_buffer)->PopRow(&sample_row)); | |||
| std::shared_ptr<Tensor> sample_ids = sample_row[0]; | |||
| for (auto itr = sample_ids->begin<int64_t>(); itr != sample_ids->end<int64_t>(); ++itr) { | |||
| if ((*itr) >= num_rows_exact_) { | |||
| MS_LOG(WARNING) << "Sample Id (" << *itr << ") is out of bounds, skipping. Max id is " << num_rows_exact_ | |||
| << "."; | |||
| if ((*itr) >= num_rows_) { | |||
| MS_LOG(WARNING) << "Sample Id (" << *itr << ") is out of bounds, skipping. Max id is " << num_rows_ << "."; | |||
| continue; | |||
| } | |||
| keys.push_back(*itr); | |||
| @@ -446,7 +414,7 @@ void CelebAOp::Print(std::ostream &out, bool show_all) const { | |||
| // Call the super class for displaying any common detailed info | |||
| ParallelOp::Print(out, show_all); | |||
| // Then show any custom derived-internal stuff | |||
| out << "\nNumber of rows:" << num_rows_exact_ << "\nceleba dir: " << folder_path_ << "\n\n"; | |||
| out << "\nNumber of rows:" << num_rows_ << "\nceleba dir: " << folder_path_ << "\n\n"; | |||
| } | |||
| } | |||
| @@ -108,14 +108,6 @@ class CelebAOp : public ParallelOp, RandomAccessOp { | |||
| return *this; | |||
| } | |||
| // Setter method | |||
| // @param int64_t num_samples | |||
| // @return Builder setter method returns reference to the builder. | |||
| Builder &SetNumSamples(int64_t num_samples) { | |||
| builder_num_samples_ = num_samples; | |||
| return *this; | |||
| } | |||
| // Setter method | |||
| // @param const std::string dataset_type: type to be read | |||
| // @return Builder setter method returns reference to the builder. | |||
| @@ -141,7 +133,6 @@ class CelebAOp : public ParallelOp, RandomAccessOp { | |||
| std::set<std::string> builder_extensions_; | |||
| std::shared_ptr<Sampler> builder_sampler_; | |||
| std::unique_ptr<DataSchema> builder_schema_; | |||
| int64_t builder_num_samples_; | |||
| std::string builder_dataset_type_; | |||
| }; | |||
| @@ -153,7 +144,7 @@ class CelebAOp : public ParallelOp, RandomAccessOp { | |||
| // @param std::unique_ptr<Sampler> sampler - sampler tells CelebAOp what to read | |||
| CelebAOp(int32_t num_workers, int32_t rows_per_buffer, const std::string &dir, int32_t queue_size, bool decode, | |||
| const std::string &dataset_type, const std::set<std::string> &exts, std::unique_ptr<DataSchema> schema, | |||
| std::shared_ptr<Sampler> sampler, int64_t num_samples); | |||
| std::shared_ptr<Sampler> sampler); | |||
| ~CelebAOp() override = default; | |||
| @@ -163,16 +154,6 @@ class CelebAOp : public ParallelOp, RandomAccessOp { | |||
| // @return Status - The error code return | |||
| Status operator()() override; | |||
| // Method derived from RandomAccess Op, enable Sampler to get numRows | |||
| // @param int64_t num - to return numRows | |||
| // @return Status - The error code return | |||
| Status GetNumSamples(int64_t *num) const override; | |||
| // Method derived from RandomAccess Op, enable Sampler to get numRows | |||
| // @param int64_t num - to return numRows | |||
| // @return Status - The error code return | |||
| Status GetNumRowsInDataset(int64_t *num) const override; | |||
| // Worker thread pulls a number of IOBlock from IOBlock Queue, make a buffer and push it to Connector | |||
| // @param int32_t worker_id - id of each worker | |||
| // @return Status - The error code return | |||
| @@ -233,11 +214,9 @@ class CelebAOp : public ParallelOp, RandomAccessOp { | |||
| std::shared_ptr<Sampler> sampler_; | |||
| std::unique_ptr<Queue<std::vector<std::string>>> attr_info_queue_; | |||
| int64_t num_rows_in_attr_file_; // rows number specified in attr file | |||
| int64_t num_rows_exact_; // exact rows number,maybe is less than rows_num_in_attr_file_ | |||
| QueueList<std::unique_ptr<IOBlock>> io_block_queues_; | |||
| WaitPost wp_; | |||
| std::vector<std::pair<std::string, std::vector<int32_t>>> image_labels_vec_; | |||
| int64_t num_samples_; | |||
| std::string dataset_type_; | |||
| std::ifstream partition_file_; | |||
| }; | |||
| @@ -35,7 +35,7 @@ constexpr uint32_t kCifarImageChannel = 3; | |||
| constexpr uint32_t kCifarBlockImageNum = 5; | |||
| constexpr uint32_t kCifarImageSize = kCifarImageHeight * kCifarImageWidth * kCifarImageChannel; | |||
| CifarOp::Builder::Builder() : num_samples_(0), sampler_(nullptr) { | |||
| CifarOp::Builder::Builder() : sampler_(nullptr) { | |||
| std::shared_ptr<ConfigManager> cfg = GlobalContext::config_manager(); | |||
| num_workers_ = cfg->num_parallel_workers(); | |||
| rows_per_buffer_ = cfg->rows_per_buffer(); | |||
| @@ -46,7 +46,9 @@ CifarOp::Builder::Builder() : num_samples_(0), sampler_(nullptr) { | |||
| Status CifarOp::Builder::Build(std::shared_ptr<CifarOp> *ptr) { | |||
| RETURN_IF_NOT_OK(SanityCheck()); | |||
| if (sampler_ == nullptr) { | |||
| sampler_ = std::make_shared<SequentialSampler>(); | |||
| int64_t num_samples = 0; | |||
| int64_t start_index = 0; | |||
| sampler_ = std::make_shared<SequentialSampler>(start_index, num_samples); | |||
| } | |||
| schema_ = std::make_unique<DataSchema>(); | |||
| TensorShape scalar = TensorShape::CreateScalar(); | |||
| @@ -62,7 +64,7 @@ Status CifarOp::Builder::Build(std::shared_ptr<CifarOp> *ptr) { | |||
| ColDescriptor("fine_label", DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 0, &another_scalar))); | |||
| } | |||
| *ptr = std::make_shared<CifarOp>(cifar_type_, num_workers_, rows_per_buffer_, dir_, op_connect_size_, num_samples_, | |||
| *ptr = std::make_shared<CifarOp>(cifar_type_, num_workers_, rows_per_buffer_, dir_, op_connect_size_, | |||
| std::move(schema_), std::move(sampler_)); | |||
| return Status::OK(); | |||
| } | |||
| @@ -76,16 +78,13 @@ Status CifarOp::Builder::SanityCheck() { | |||
| } | |||
| CifarOp::CifarOp(CifarType type, int32_t num_works, int32_t rows_per_buf, const std::string &file_dir, | |||
| int32_t queue_size, int64_t num_samples, std::unique_ptr<DataSchema> data_schema, | |||
| std::shared_ptr<Sampler> sampler) | |||
| int32_t queue_size, std::unique_ptr<DataSchema> data_schema, std::shared_ptr<Sampler> sampler) | |||
| : ParallelOp(num_works, queue_size), | |||
| cifar_type_(type), | |||
| rows_per_buffer_(rows_per_buf), | |||
| folder_path_(file_dir), | |||
| num_samples_(num_samples), | |||
| data_schema_(std::move(data_schema)), | |||
| sampler_(std::move(sampler)), | |||
| num_rows_(0), | |||
| row_cnt_(0), | |||
| buf_cnt_(0) { | |||
| // set the column name map (base class field) | |||
| @@ -112,8 +111,7 @@ Status CifarOp::operator()() { | |||
| for (auto itr = sample_ids->begin<int64_t>(); itr != sample_ids->end<int64_t>(); itr++) { | |||
| keys.push_back(*itr); | |||
| row_cnt_++; | |||
| if ((*itr) >= num_rows_) continue; // index out of bound, skipping | |||
| if (row_cnt_ >= num_samples_) break; // enough row read, break for loop | |||
| if ((*itr) >= num_rows_) continue; // index out of bound, skipping | |||
| if (row_cnt_ % rows_per_buffer_ == 0) { | |||
| RETURN_IF_NOT_OK(io_block_queues_[buf_cnt_++ % num_workers_]->Add( | |||
| std::make_unique<IOBlock>(IOBlock(keys, IOBlock::kDeIoBlockNone)))); | |||
| @@ -255,30 +253,6 @@ Status CifarOp::InitSampler() { | |||
| return Status::OK(); | |||
| } | |||
| // Derived from RandomAccessOp | |||
| Status CifarOp::GetNumSamples(int64_t *num) const { | |||
| if (num == nullptr || num_rows_ == 0) { | |||
| std::string api = cifar_type_ == kCifar10 ? "Cifar10Dataset" : "Cifar100Dataset"; | |||
| std::string err_msg = "There is no valid data matching the dataset API " + api + | |||
| ".Please check file path or dataset API validation first."; | |||
| RETURN_STATUS_UNEXPECTED(err_msg); | |||
| } | |||
| (*num) = num_samples_; | |||
| return Status::OK(); | |||
| } | |||
| // Derived from RandomAccessOp | |||
| Status CifarOp::GetNumRowsInDataset(int64_t *num) const { | |||
| if (num == nullptr || num_rows_ == 0) { | |||
| std::string api = cifar_type_ == kCifar10 ? "Cifar10Dataset" : "Cifar100Dataset"; | |||
| std::string err_msg = "There is no valid data matching the dataset API " + api + | |||
| ".Please check file path or dataset API validation first."; | |||
| RETURN_STATUS_UNEXPECTED(err_msg); | |||
| } | |||
| (*num) = num_rows_; | |||
| return Status::OK(); | |||
| } | |||
| Status CifarOp::ReadCifarBlockDataAsync() { | |||
| TaskManager::FindMe()->Post(); | |||
| RETURN_IF_NOT_OK(GetCifarFiles()); | |||
| @@ -404,7 +378,6 @@ Status CifarOp::ParseCifarData() { | |||
| } | |||
| cifar_image_label_pairs_.shrink_to_fit(); | |||
| num_rows_ = cifar_image_label_pairs_.size(); | |||
| num_samples_ = (num_samples_ == 0 || num_samples_ > num_rows_) ? num_rows_ : num_samples_; | |||
| if (num_rows_ == 0) { | |||
| std::string api = cifar_type_ == kCifar10 ? "Cifar10Dataset" : "Cifar100Dataset"; | |||
| std::string err_msg = "There is no valid data matching the dataset API " + api + | |||
| @@ -432,11 +405,11 @@ Status CifarOp::GetClassIds(std::map<int32_t, std::vector<int64_t>> *cls_ids) co | |||
| return Status::OK(); | |||
| } | |||
| Status CifarOp::CountTotalRows(const std::string &dir, int64_t numSamples, bool isCIFAR10, int64_t *count) { | |||
| Status CifarOp::CountTotalRows(const std::string &dir, bool isCIFAR10, int64_t *count) { | |||
| // the logic of counting the number of samples is copied from ReadCifar100Block() and ReadCifar10Block() | |||
| std::shared_ptr<CifarOp> op; | |||
| *count = 0; | |||
| RETURN_IF_NOT_OK(Builder().SetCifarDir(dir).SetNumSamples(numSamples).SetCifarType(isCIFAR10).Build(&op)); | |||
| RETURN_IF_NOT_OK(Builder().SetCifarDir(dir).SetCifarType(isCIFAR10).Build(&op)); | |||
| RETURN_IF_NOT_OK(op->GetCifarFiles()); | |||
| if (op->cifar_type_ == kCifar10) { | |||
| constexpr int64_t num_cifar10_records = 10000; | |||
| @@ -448,7 +421,6 @@ Status CifarOp::CountTotalRows(const std::string &dir, int64_t numSamples, bool | |||
| } | |||
| *count = *count + num_cifar10_records; | |||
| } | |||
| *count = *count < numSamples || numSamples == 0 ? *count : numSamples; | |||
| return Status::OK(); | |||
| } else { | |||
| int64_t num_cifar100_records = 0; | |||
| @@ -470,7 +442,7 @@ Status CifarOp::CountTotalRows(const std::string &dir, int64_t numSamples, bool | |||
| RETURN_STATUS_UNEXPECTED(err_msg); | |||
| } | |||
| } | |||
| *count = num_cifar100_records < numSamples || numSamples == 0 ? num_cifar100_records : numSamples; | |||
| *count = num_cifar100_records; | |||
| return Status::OK(); | |||
| } | |||
| } | |||
| @@ -73,14 +73,6 @@ class CifarOp : public ParallelOp, public RandomAccessOp { | |||
| return *this; | |||
| } | |||
| // Setter method | |||
| // @param uint64_t num_samples | |||
| // @return Builder setter method returns reference to the builder. | |||
| Builder &SetNumSamples(uint64_t num_samples) { | |||
| num_samples_ = num_samples; | |||
| return *this; | |||
| } | |||
| // Setter method | |||
| // @param std::shared_ptr<Sampler> sampler | |||
| // @return Builder setter method returns reference to the builder. | |||
| @@ -121,7 +113,6 @@ class CifarOp : public ParallelOp, public RandomAccessOp { | |||
| private: | |||
| std::string dir_; | |||
| int32_t num_workers_; | |||
| uint64_t num_samples_; | |||
| int32_t rows_per_buffer_; | |||
| int32_t op_connect_size_; | |||
| std::shared_ptr<Sampler> sampler_; | |||
| @@ -137,7 +128,7 @@ class CifarOp : public ParallelOp, public RandomAccessOp { | |||
| // @param uint32_t - queueSize - connector queue size | |||
| // @param std::unique_ptr<Sampler> sampler - sampler tells ImageFolderOp what to read | |||
| CifarOp(CifarType type, int32_t num_works, int32_t rows_per_buf, const std::string &file_dir, int32_t queue_size, | |||
| int64_t num_samples, std::unique_ptr<DataSchema> data_schema, std::shared_ptr<Sampler> sampler); | |||
| std::unique_ptr<DataSchema> data_schema, std::shared_ptr<Sampler> sampler); | |||
| // Destructor. | |||
| ~CifarOp() = default; | |||
| @@ -152,16 +143,6 @@ class CifarOp : public ParallelOp, public RandomAccessOp { | |||
| // @return Status - The error code return | |||
| Status operator()() override; | |||
| // Method derived from RandomAccess Op, enable Sampler to get numRows | |||
| // @param uint64_t num - to return numRows | |||
| // @return Status - The error code return | |||
| Status GetNumSamples(int64_t *num) const override; | |||
| // Method derived from RandomAccess Op, enable Sampler to get total numRows in dataset | |||
| // @param uint64_t num - to return numRows | |||
| // @return Status - The error code return | |||
| Status GetNumRowsInDataset(int64_t *num) const override; | |||
| // A print method typically used for debugging | |||
| // @param out | |||
| // @param show_all | |||
| @@ -169,11 +150,10 @@ class CifarOp : public ParallelOp, public RandomAccessOp { | |||
| // Function to count the number of samples in the CIFAR dataset | |||
| // @param dir path to the CIFAR directory | |||
| // @param numSamples maximum number of samples requested | |||
| // @param isCIFAR10 true if CIFAR10 and false if CIFAR100 | |||
| // @param count output arg that will hold the minimum of the actual dataset size and numSamples | |||
| // @param count output arg that will hold the actual dataset size | |||
| // @return | |||
| static Status CountTotalRows(const std::string &dir, int64_t numSamples, bool isCIFAR10, int64_t *count); | |||
| static Status CountTotalRows(const std::string &dir, bool isCIFAR10, int64_t *count); | |||
| private: | |||
| // Initialize Sampler, calls sampler->Init() within | |||
| @@ -227,10 +207,8 @@ class CifarOp : public ParallelOp, public RandomAccessOp { | |||
| CifarType cifar_type_; | |||
| int32_t rows_per_buffer_; | |||
| std::string folder_path_; | |||
| int64_t num_samples_; | |||
| std::unique_ptr<DataSchema> data_schema_; | |||
| std::shared_ptr<Sampler> sampler_; | |||
| int64_t num_rows_; | |||
| int64_t row_cnt_; | |||
| int64_t buf_cnt_; | |||
| @@ -26,8 +26,7 @@ | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| ImageFolderOp::Builder::Builder() | |||
| : builder_decode_(false), builder_recursive_(false), builder_num_samples_(0), builder_sampler_(nullptr) { | |||
| ImageFolderOp::Builder::Builder() : builder_decode_(false), builder_recursive_(false), builder_sampler_(nullptr) { | |||
| std::shared_ptr<ConfigManager> cfg = GlobalContext::config_manager(); | |||
| builder_num_workers_ = cfg->num_parallel_workers(); | |||
| builder_rows_per_buffer_ = cfg->rows_per_buffer(); | |||
| @@ -37,7 +36,9 @@ ImageFolderOp::Builder::Builder() | |||
| Status ImageFolderOp::Builder::Build(std::shared_ptr<ImageFolderOp> *ptr) { | |||
| RETURN_IF_NOT_OK(SanityCheck()); | |||
| if (builder_sampler_ == nullptr) { | |||
| builder_sampler_ = std::make_shared<SequentialSampler>(); | |||
| int64_t num_samples = 0; // default num samples of 0 means to sample entire set of data | |||
| int64_t start_index = 0; | |||
| builder_sampler_ = std::make_shared<SequentialSampler>(start_index, num_samples); | |||
| } | |||
| builder_schema_ = std::make_unique<DataSchema>(); | |||
| TensorShape scalar = TensorShape::CreateScalar(); | |||
| @@ -46,9 +47,9 @@ Status ImageFolderOp::Builder::Build(std::shared_ptr<ImageFolderOp> *ptr) { | |||
| RETURN_IF_NOT_OK(builder_schema_->AddColumn( | |||
| ColDescriptor("label", DataType(DataType::DE_INT32), TensorImpl::kFlexible, 0, &scalar))); | |||
| *ptr = std::make_shared<ImageFolderOp>(builder_num_workers_, builder_rows_per_buffer_, builder_dir_, | |||
| builder_op_connector_size_, builder_num_samples_, builder_recursive_, | |||
| builder_decode_, builder_extensions_, builder_labels_to_read_, | |||
| std::move(builder_schema_), std::move(builder_sampler_)); | |||
| builder_op_connector_size_, builder_recursive_, builder_decode_, | |||
| builder_extensions_, builder_labels_to_read_, std::move(builder_schema_), | |||
| std::move(builder_sampler_)); | |||
| return Status::OK(); | |||
| } | |||
| @@ -61,20 +62,18 @@ Status ImageFolderOp::Builder::SanityCheck() { | |||
| } | |||
| ImageFolderOp::ImageFolderOp(int32_t num_wkrs, int32_t rows_per_buffer, std::string file_dir, int32_t queue_size, | |||
| int64_t num_samples, bool recursive, bool do_decode, const std::set<std::string> &exts, | |||
| bool recursive, bool do_decode, const std::set<std::string> &exts, | |||
| const std::map<std::string, int32_t> &map, std::unique_ptr<DataSchema> data_schema, | |||
| std::shared_ptr<Sampler> sampler) | |||
| : ParallelOp(num_wkrs, queue_size), | |||
| rows_per_buffer_(rows_per_buffer), | |||
| folder_path_(file_dir), | |||
| num_samples_(num_samples), | |||
| recursive_(recursive), | |||
| decode_(do_decode), | |||
| extensions_(exts), | |||
| class_index_(map), | |||
| data_schema_(std::move(data_schema)), | |||
| sampler_(std::move(sampler)), | |||
| num_rows_(0), | |||
| row_cnt_(0), | |||
| buf_cnt_(0), | |||
| sampler_ind_(0), | |||
| @@ -117,7 +116,6 @@ Status ImageFolderOp::PrescanMasterEntry(const std::string &filedir) { | |||
| } | |||
| image_label_pairs_.shrink_to_fit(); | |||
| num_rows_ = image_label_pairs_.size(); | |||
| num_samples_ = (num_samples_ == 0 || num_samples_ > num_rows_) ? num_rows_ : num_samples_; | |||
| // free memory of two queues used for pre-scan | |||
| folder_name_queue_->Reset(); | |||
| image_name_queue_->Reset(); | |||
| @@ -138,8 +136,7 @@ Status ImageFolderOp::operator()() { | |||
| std::shared_ptr<Tensor> sample_ids = sample_row[0]; | |||
| if (sample_ids->type() != DataType(DataType::DE_INT64)) RETURN_STATUS_UNEXPECTED("Sampler Tensor isn't int64"); | |||
| for (auto itr = sample_ids->begin<int64_t>(); itr != sample_ids->end<int64_t>(); ++itr) { | |||
| if ((*itr) >= num_rows_) continue; // index out of bound, skipping | |||
| if (row_cnt_ >= num_samples_) break; // enough row read, break for loop | |||
| if ((*itr) >= num_rows_) continue; // index out of bound, skipping | |||
| keys.push_back(*itr); | |||
| row_cnt_++; | |||
| if (row_cnt_ % rows_per_buffer_ == 0) { | |||
| @@ -272,28 +269,6 @@ Status ImageFolderOp::InitSampler() { | |||
| return Status::OK(); | |||
| } | |||
| // Derived from RandomAccessOp | |||
| Status ImageFolderOp::GetNumSamples(int64_t *num) const { | |||
| if (num == nullptr || num_samples_ == 0) { | |||
| RETURN_STATUS_UNEXPECTED( | |||
| "There is no valid data matching the dataset API ImageFolderDatasetV2.Please check file path or dataset API " | |||
| "validation first."); | |||
| } | |||
| (*num) = num_samples_; | |||
| return Status::OK(); | |||
| } | |||
| // Derived from RandomAccessOp | |||
| Status ImageFolderOp::GetNumRowsInDataset(int64_t *num) const { | |||
| if (num == nullptr || num_rows_ == 0) { | |||
| RETURN_STATUS_UNEXPECTED( | |||
| "There is no valid data matching the dataset API ImageFolderDatasetV2.Please check file path or dataset API " | |||
| "validation first."); | |||
| } | |||
| (*num) = num_rows_; | |||
| return Status::OK(); | |||
| } | |||
| // Derived from RandomAccessOp | |||
| Status ImageFolderOp::GetClassIds(std::map<int32_t, std::vector<int64_t>> *cls_ids) const { | |||
| if (cls_ids == nullptr || !cls_ids->empty() || image_label_pairs_.empty()) { | |||
| @@ -413,16 +388,14 @@ Status ImageFolderOp::LaunchThreadsAndInitOp() { | |||
| return Status::OK(); | |||
| } | |||
| Status ImageFolderOp::CountRowsAndClasses(const std::string &path, const int64_t &num_samples, | |||
| const std::set<std::string> &exts, int64_t *num_rows, int64_t *num_classes, | |||
| int64_t dev_id, int64_t num_dev) { | |||
| Status ImageFolderOp::CountRowsAndClasses(const std::string &path, const std::set<std::string> &exts, int64_t *num_rows, | |||
| int64_t *num_classes, int64_t dev_id, int64_t num_dev) { | |||
| Path dir(path); | |||
| std::string err_msg = ""; | |||
| int64_t row_cnt = 0; | |||
| err_msg += (dir.Exists() == false || dir.IsDirectory() == false) ? "unable to open dir " + path : ""; | |||
| err_msg += (num_classes == nullptr || num_rows == nullptr) ? "num_class/num_rows is null\n" : ""; | |||
| err_msg += (dev_id >= num_dev || num_dev <= 0) ? "invalid sharding config\n" : ""; | |||
| err_msg += num_samples < 0 ? "num_samples can't be negative! set it to 0 to use all samples\n" : ""; | |||
| if (err_msg.empty() == false) { | |||
| RETURN_STATUS_UNEXPECTED(err_msg); | |||
| } | |||
| @@ -441,10 +414,6 @@ Status ImageFolderOp::CountRowsAndClasses(const std::string &path, const int64_t | |||
| while (dir_itr->hasNext()) { | |||
| if (exts.empty() || exts.find(subdir.Extension()) != exts.end()) { | |||
| ++row_cnt; | |||
| if (row_cnt == num_samples * num_dev) { | |||
| (*num_rows) = (row_cnt / num_dev) + (row_cnt % num_dev == 0 ? 0 : 1); | |||
| return Status::OK(); | |||
| } | |||
| } | |||
| } | |||
| foldernames.pop(); | |||
| @@ -107,14 +107,6 @@ class ImageFolderOp : public ParallelOp, public RandomAccessOp { | |||
| return *this; | |||
| } | |||
| // Setter method | |||
| // @param int64_t num_samples | |||
| // @return Builder setter method returns reference to the builder. | |||
| Builder &SetNumSamples(int64_t num_samples) { | |||
| builder_num_samples_ = num_samples; | |||
| return *this; | |||
| } | |||
| // Setter method | |||
| // @param std::shared_ptr<Sampler> sampler | |||
| // @return Builder setter method returns reference to the builder. | |||
| @@ -153,7 +145,6 @@ class ImageFolderOp : public ParallelOp, public RandomAccessOp { | |||
| bool builder_recursive_; | |||
| std::string builder_dir_; | |||
| int32_t builder_num_workers_; | |||
| int64_t builder_num_samples_; | |||
| int32_t builder_rows_per_buffer_; | |||
| int32_t builder_op_connector_size_; | |||
| std::set<std::string> builder_extensions_; | |||
| @@ -169,10 +160,9 @@ class ImageFolderOp : public ParallelOp, public RandomAccessOp { | |||
| // @param int32_t queue_size - connector queue size | |||
| // @param std::set<std::string> exts - set of file extensions to read, if empty, read everything under the dir | |||
| // @param td::unique_ptr<Sampler> sampler - sampler tells ImageFolderOp what to read | |||
| ImageFolderOp(int32_t num_wkrs, int32_t rows_per_buffer, std::string file_dir, int32_t queue_size, | |||
| int64_t num_samples, bool recursive, bool do_decode, const std::set<std::string> &exts, | |||
| const std::map<std::string, int32_t> &map, std::unique_ptr<DataSchema>, | |||
| std::shared_ptr<Sampler> sampler); | |||
| ImageFolderOp(int32_t num_wkrs, int32_t rows_per_buffer, std::string file_dir, int32_t queue_size, bool recursive, | |||
| bool do_decode, const std::set<std::string> &exts, const std::map<std::string, int32_t> &map, | |||
| std::unique_ptr<DataSchema>, std::shared_ptr<Sampler> sampler); | |||
| // Destructor. | |||
| ~ImageFolderOp() = default; | |||
| @@ -198,16 +188,6 @@ class ImageFolderOp : public ParallelOp, public RandomAccessOp { | |||
| // @return Status - The error code return | |||
| Status operator()() override; | |||
| // Method derived from RandomAccess Op, enable Sampler to get numRows | |||
| // @param int64_t num - to return numRows | |||
| // @return Status - The error code return | |||
| Status GetNumSamples(int64_t *num) const override; | |||
| // Method derived from RandomAccess Op, enable Sampler to get total numRows in dataset | |||
| // @param int64_t num - to return numRows | |||
| // @return Status - The error code return | |||
| Status GetNumRowsInDataset(int64_t *num) const override; | |||
| // Method derived from RandomAccess Op, enable Sampler to get all ids for each class | |||
| // @param (std::map<int64_t, std::vector<int64_t >> * map - key label, val all ids for this class | |||
| // @return Status - The error code return | |||
| @@ -221,9 +201,8 @@ class ImageFolderOp : public ParallelOp, public RandomAccessOp { | |||
| // This function is a hack! It is to return the num_class and num_rows the old storageOp does. The result | |||
| // returned by this function may not be consistent with what image_folder_op is going to return | |||
| // user this at your own risk! | |||
| static Status CountRowsAndClasses(const std::string &path, const int64_t &num_samples, | |||
| const std::set<std::string> &exts, int64_t *num_rows, int64_t *num_classes, | |||
| int64_t dev_id = 0, int64_t num_dev = 1); | |||
| static Status CountRowsAndClasses(const std::string &path, const std::set<std::string> &exts, int64_t *num_rows, | |||
| int64_t *num_classes, int64_t dev_id = 0, int64_t num_dev = 1); | |||
| // Base-class override for NodePass visitor acceptor. | |||
| // @param p - Pointer to the NodePass to be accepted. | |||
| @@ -266,14 +245,12 @@ class ImageFolderOp : public ParallelOp, public RandomAccessOp { | |||
| int32_t rows_per_buffer_; | |||
| std::string folder_path_; // directory of image folder | |||
| int64_t num_samples_; | |||
| bool recursive_; | |||
| bool decode_; | |||
| std::set<std::string> extensions_; // extensions allowed | |||
| std::map<std::string, int32_t> class_index_; | |||
| std::unique_ptr<DataSchema> data_schema_; | |||
| std::shared_ptr<Sampler> sampler_; | |||
| int64_t num_rows_; // total number of images in ImageFolder | |||
| int64_t row_cnt_; | |||
| int64_t buf_cnt_; | |||
| int64_t sampler_ind_; | |||
| @@ -29,7 +29,7 @@ | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| ManifestOp::Builder::Builder() : builder_sampler_(nullptr), builder_num_samples_(0), builder_decode_(false) { | |||
| ManifestOp::Builder::Builder() : builder_sampler_(nullptr), builder_decode_(false) { | |||
| std::shared_ptr<ConfigManager> cfg = GlobalContext::config_manager(); | |||
| builder_num_workers_ = cfg->num_parallel_workers(); | |||
| builder_rows_per_buffer_ = cfg->rows_per_buffer(); | |||
| @@ -39,16 +39,18 @@ ManifestOp::Builder::Builder() : builder_sampler_(nullptr), builder_num_samples_ | |||
| Status ManifestOp::Builder::Build(std::shared_ptr<ManifestOp> *ptr) { | |||
| RETURN_IF_NOT_OK(SanityCheck()); | |||
| if (builder_sampler_ == nullptr) { | |||
| builder_sampler_ = std::make_shared<SequentialSampler>(); | |||
| int64_t num_samples = 0; | |||
| int64_t start_index = 0; | |||
| builder_sampler_ = std::make_shared<SequentialSampler>(start_index, num_samples); | |||
| } | |||
| builder_schema_ = std::make_unique<DataSchema>(); | |||
| RETURN_IF_NOT_OK( | |||
| builder_schema_->AddColumn(ColDescriptor("image", DataType(DataType::DE_UINT8), TensorImpl::kFlexible, 1))); | |||
| RETURN_IF_NOT_OK( | |||
| builder_schema_->AddColumn(ColDescriptor("label", DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 1))); | |||
| *ptr = std::make_shared<ManifestOp>( | |||
| builder_num_workers_, builder_rows_per_buffer_, builder_file_, builder_op_connector_size_, builder_num_samples_, | |||
| builder_decode_, builder_labels_to_read_, std::move(builder_schema_), std::move(builder_sampler_), builder_usage_); | |||
| *ptr = std::make_shared<ManifestOp>(builder_num_workers_, builder_rows_per_buffer_, builder_file_, | |||
| builder_op_connector_size_, builder_decode_, builder_labels_to_read_, | |||
| std::move(builder_schema_), std::move(builder_sampler_), builder_usage_); | |||
| return Status::OK(); | |||
| } | |||
| @@ -59,9 +61,9 @@ Status ManifestOp::Builder::SanityCheck() { | |||
| return err_msg.empty() ? Status::OK() : Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, err_msg); | |||
| } | |||
| ManifestOp::ManifestOp(int32_t num_works, int32_t rows_per_buffer, std::string file, int32_t queue_size, | |||
| int64_t num_samples, bool decode, const std::map<std::string, int32_t> &class_index, | |||
| std::unique_ptr<DataSchema> data_schema, std::shared_ptr<Sampler> sampler, std::string usage) | |||
| ManifestOp::ManifestOp(int32_t num_works, int32_t rows_per_buffer, std::string file, int32_t queue_size, bool decode, | |||
| const std::map<std::string, int32_t> &class_index, std::unique_ptr<DataSchema> data_schema, | |||
| std::shared_ptr<Sampler> sampler, std::string usage) | |||
| : ParallelOp(num_works, queue_size), | |||
| rows_per_buffer_(rows_per_buffer), | |||
| io_block_pushed_(0), | |||
| @@ -71,8 +73,6 @@ ManifestOp::ManifestOp(int32_t num_works, int32_t rows_per_buffer, std::string f | |||
| file_(file), | |||
| class_index_(class_index), | |||
| sampler_(std::move(sampler)), | |||
| num_samples_(num_samples), | |||
| num_rows_(0), | |||
| decode_(decode), | |||
| usage_(usage), | |||
| buf_cnt_(0) { | |||
| @@ -101,8 +101,7 @@ Status ManifestOp::AddIoBlock(std::unique_ptr<DataBuffer> *sampler_buffer) { | |||
| RETURN_IF_NOT_OK((*sampler_buffer)->PopRow(&sample_row)); | |||
| std::shared_ptr<Tensor> sample_ids = sample_row[0]; | |||
| for (auto itr = sample_ids->begin<int64_t>(); itr != sample_ids->end<int64_t>(); ++itr) { | |||
| if ((*itr) >= num_rows_) continue; // index out of bound, skipping | |||
| if (row_cnt_ >= num_samples_) break; // enough row read, break for loop | |||
| if ((*itr) >= num_rows_) continue; // index out of bound, skipping | |||
| keys.push_back(*itr); | |||
| row_cnt_++; | |||
| if (row_cnt_ % rows_per_buffer_ == 0) { | |||
| @@ -269,28 +268,6 @@ Status ManifestOp::InitSampler() { | |||
| return Status::OK(); | |||
| } | |||
| // Derived from RandomAccessOp | |||
| Status ManifestOp::GetNumSamples(int64_t *num) const { | |||
| if (num == nullptr || num_rows_ == 0) { | |||
| RETURN_STATUS_UNEXPECTED( | |||
| "There is no valid data matching the dataset API ManifestDataset.Please check file path or dataset API " | |||
| "validation first."); | |||
| } | |||
| (*num) = num_samples_; | |||
| return Status::OK(); | |||
| } | |||
| // Derived from RandomAccessOp | |||
| Status ManifestOp::GetNumRowsInDataset(int64_t *num) const { | |||
| if (num == nullptr || num_rows_ == 0) { | |||
| RETURN_STATUS_UNEXPECTED( | |||
| "There is no valid data matching the dataset API ManifestDataset.Please check file path or dataset API " | |||
| "validation first."); | |||
| } | |||
| (*num) = num_rows_; | |||
| return Status::OK(); | |||
| } | |||
| // Derived from RandomAccessOp | |||
| Status ManifestOp::GetClassIds(std::map<int32_t, std::vector<int64_t>> *cls_ids) const { | |||
| if (cls_ids == nullptr || !cls_ids->empty() || image_labelname_.empty()) { | |||
| @@ -408,7 +385,6 @@ Status ManifestOp::CountDatasetInfo() { | |||
| } | |||
| num_rows_ = static_cast<int64_t>(image_labelname_.size()); | |||
| num_samples_ = (num_samples_ == 0 || num_samples_ > num_rows_) ? num_rows_ : num_samples_; | |||
| if (num_rows_ == 0) { | |||
| RETURN_STATUS_UNEXPECTED( | |||
| "There is no valid data matching the dataset API ManifestDataset.Please check file path or dataset API " | |||
| @@ -417,8 +393,8 @@ Status ManifestOp::CountDatasetInfo() { | |||
| return Status::OK(); | |||
| } | |||
| Status ManifestOp::CountTotalRows(const std::string &file, int64_t numSamples, const py::dict &dict, | |||
| const std::string &usage, int64_t *count, int64_t *numClasses) { | |||
| Status ManifestOp::CountTotalRows(const std::string &file, const py::dict &dict, const std::string &usage, | |||
| int64_t *count, int64_t *numClasses) { | |||
| // the logic of counting the number of samples is copied from ParseManifestFile() | |||
| std::map<std::string, int32_t> map; | |||
| for (auto p : dict) { | |||
| @@ -428,17 +404,15 @@ Status ManifestOp::CountTotalRows(const std::string &file, int64_t numSamples, c | |||
| std::shared_ptr<ManifestOp> op; | |||
| *count = 0; | |||
| RETURN_IF_NOT_OK( | |||
| Builder().SetManifestFile(file).SetNumSamples(numSamples).SetClassIndex(map).SetUsage(usage).Build(&op)); | |||
| RETURN_IF_NOT_OK(Builder().SetManifestFile(file).SetClassIndex(map).SetUsage(usage).Build(&op)); | |||
| RETURN_IF_NOT_OK(op->ParseManifestFile()); | |||
| *numClasses = static_cast<int64_t>(op->label_index_.size()); | |||
| *count = static_cast<int64_t>(op->image_labelname_.size()); | |||
| *count = (*count < numSamples || numSamples == 0) ? *count : numSamples; | |||
| return Status::OK(); | |||
| } | |||
| Status ManifestOp::GetClassIndexing(const std::string &file, int64_t numSamples, const py::dict &dict, | |||
| const std::string &usage, std::map<std::string, int32_t> *output_class_indexing) { | |||
| Status ManifestOp::GetClassIndexing(const std::string &file, const py::dict &dict, const std::string &usage, | |||
| std::map<std::string, int32_t> *output_class_indexing) { | |||
| std::map<std::string, int32_t> input_class_indexing; | |||
| for (auto p : dict) { | |||
| (void)input_class_indexing.insert(std::pair<std::string, int32_t>(py::reinterpret_borrow<py::str>(p.first), | |||
| @@ -449,12 +423,7 @@ Status ManifestOp::GetClassIndexing(const std::string &file, int64_t numSamples, | |||
| *output_class_indexing = input_class_indexing; | |||
| } else { | |||
| std::shared_ptr<ManifestOp> op; | |||
| RETURN_IF_NOT_OK(Builder() | |||
| .SetManifestFile(file) | |||
| .SetNumSamples(numSamples) | |||
| .SetClassIndex(input_class_indexing) | |||
| .SetUsage(usage) | |||
| .Build(&op)); | |||
| RETURN_IF_NOT_OK(Builder().SetManifestFile(file).SetClassIndex(input_class_indexing).SetUsage(usage).Build(&op)); | |||
| RETURN_IF_NOT_OK(op->ParseManifestFile()); | |||
| RETURN_IF_NOT_OK(op->CountDatasetInfo()); | |||
| uint32_t count = 0; | |||
| @@ -86,14 +86,6 @@ class ManifestOp : public ParallelOp, public RandomAccessOp { | |||
| return *this; | |||
| } | |||
| // Setter method | |||
| // @param int64_t num_samples | |||
| // @return Builder setter method returns reference to the builder. | |||
| Builder &SetNumSamples(int64_t num_samples) { | |||
| builder_num_samples_ = num_samples; | |||
| return *this; | |||
| } | |||
| // Setter method | |||
| // @param std::shared_ptr<Sampler> sampler | |||
| // @return Builder setter method returns reference to the builder. | |||
| @@ -129,7 +121,6 @@ class ManifestOp : public ParallelOp, public RandomAccessOp { | |||
| private: | |||
| std::shared_ptr<Sampler> builder_sampler_; | |||
| int64_t builder_num_samples_; | |||
| bool builder_decode_; | |||
| std::string builder_file_; | |||
| @@ -147,8 +138,8 @@ class ManifestOp : public ParallelOp, public RandomAccessOp { | |||
| // @param std::string - file list of Manifest | |||
| // @param int32_t queue_size - connector queue size | |||
| // @param td::unique_ptr<Sampler> sampler - sampler tells ImageFolderOp what to read | |||
| ManifestOp(int32_t num_works, int32_t rows_per_buffer, std::string file, int32_t queue_size, int64_t num_samples, | |||
| bool decode, const std::map<std::string, int32_t> &class_index, std::unique_ptr<DataSchema> data_schema, | |||
| ManifestOp(int32_t num_works, int32_t rows_per_buffer, std::string file, int32_t queue_size, bool decode, | |||
| const std::map<std::string, int32_t> &class_index, std::unique_ptr<DataSchema> data_schema, | |||
| std::shared_ptr<Sampler> sampler, std::string usage); | |||
| // Destructor. | |||
| ~ManifestOp() = default; | |||
| @@ -164,16 +155,6 @@ class ManifestOp : public ParallelOp, public RandomAccessOp { | |||
| // @return Status - The error code return | |||
| Status operator()() override; | |||
| // Method derived from RandomAccess Op, enable Sampler to get numRows | |||
| // @param int64_t num - to return numRows | |||
| // @return Status - The error code return | |||
| Status GetNumSamples(int64_t *num) const override; | |||
| // Method derived from RandomAccess Op, enable Sampler to get total number of Rows in dataset | |||
| // @param int64_t num - to return numRows | |||
| // @return Status - The error code return | |||
| Status GetNumRowsInDataset(int64_t *num) const override; | |||
| // Method derived from RandomAccess Op, enable Sampler to get all ids for each class | |||
| // @param (std::map<int64_t, std::vector<int64_t >> * map - key label, val all ids for this class | |||
| // @return Status - The error code return | |||
| @@ -184,12 +165,12 @@ class ManifestOp : public ParallelOp, public RandomAccessOp { | |||
| // @param show_all | |||
| void Print(std::ostream &out, bool show_all) const override; | |||
| static Status CountTotalRows(const std::string &file, int64_t numSamples, const py::dict &dict, | |||
| const std::string &usage, int64_t *count, int64_t *numClasses); | |||
| static Status CountTotalRows(const std::string &file, const py::dict &dict, const std::string &usage, int64_t *count, | |||
| int64_t *numClasses); | |||
| // Get str-to-int mapping from label name to index | |||
| static Status GetClassIndexing(const std::string &file, int64_t numSamples, const py::dict &dict, | |||
| const std::string &usage, std::map<std::string, int32_t> *output_class_indexing); | |||
| static Status GetClassIndexing(const std::string &file, const py::dict &dict, const std::string &usage, | |||
| std::map<std::string, int32_t> *output_class_indexing); | |||
| private: | |||
| // Initialize Sampler, calls sampler->Init() within | |||
| @@ -240,8 +221,6 @@ class ManifestOp : public ParallelOp, public RandomAccessOp { | |||
| std::string file_; // file that store the information of images | |||
| std::map<std::string, int32_t> class_index_; | |||
| std::shared_ptr<Sampler> sampler_; | |||
| int64_t num_samples_; | |||
| int64_t num_rows_; | |||
| bool decode_; | |||
| std::string usage_; | |||
| int64_t buf_cnt_; | |||
| @@ -91,7 +91,6 @@ MindRecordOp::MindRecordOp(int32_t num_mind_record_workers, int32_t rows_per_buf | |||
| block_reader_(block_reader), | |||
| buffers_needed_(0), | |||
| buf_cnt_(0), | |||
| num_rows_(0), | |||
| ended_worker_(0), | |||
| buffer_water_mark_(0) { | |||
| io_blk_queues_.Init(num_workers_, op_connector_queue_size); | |||
| @@ -31,7 +31,7 @@ const int32_t kMnistLabelFileMagicNumber = 2049; | |||
| const int32_t kMnistImageRows = 28; | |||
| const int32_t kMnistImageCols = 28; | |||
| MnistOp::Builder::Builder() : builder_num_samples_(0), builder_sampler_(nullptr) { | |||
| MnistOp::Builder::Builder() : builder_sampler_(nullptr) { | |||
| std::shared_ptr<ConfigManager> cfg = GlobalContext::config_manager(); | |||
| builder_num_workers_ = cfg->num_parallel_workers(); | |||
| builder_rows_per_buffer_ = cfg->rows_per_buffer(); | |||
| @@ -41,7 +41,9 @@ MnistOp::Builder::Builder() : builder_num_samples_(0), builder_sampler_(nullptr) | |||
| Status MnistOp::Builder::Build(std::shared_ptr<MnistOp> *ptr) { | |||
| RETURN_IF_NOT_OK(SanityCheck()); | |||
| if (builder_sampler_ == nullptr) { | |||
| builder_sampler_ = std::make_shared<SequentialSampler>(); | |||
| int64_t num_samples = 0; | |||
| int64_t start_index = 0; | |||
| builder_sampler_ = std::make_shared<SequentialSampler>(start_index, num_samples); | |||
| } | |||
| builder_schema_ = std::make_unique<DataSchema>(); | |||
| RETURN_IF_NOT_OK( | |||
| @@ -49,9 +51,8 @@ Status MnistOp::Builder::Build(std::shared_ptr<MnistOp> *ptr) { | |||
| TensorShape scalar = TensorShape::CreateScalar(); | |||
| RETURN_IF_NOT_OK(builder_schema_->AddColumn( | |||
| ColDescriptor("label", DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 0, &scalar))); | |||
| *ptr = | |||
| std::make_shared<MnistOp>(builder_num_workers_, builder_rows_per_buffer_, builder_dir_, builder_op_connector_size_, | |||
| builder_num_samples_, std::move(builder_schema_), std::move(builder_sampler_)); | |||
| *ptr = std::make_shared<MnistOp>(builder_num_workers_, builder_rows_per_buffer_, builder_dir_, | |||
| builder_op_connector_size_, std::move(builder_schema_), std::move(builder_sampler_)); | |||
| return Status::OK(); | |||
| } | |||
| @@ -60,17 +61,14 @@ Status MnistOp::Builder::SanityCheck() { | |||
| std::string err_msg; | |||
| err_msg += dir.IsDirectory() == false ? "MNIST path is invalid or not set\n" : ""; | |||
| err_msg += builder_num_workers_ <= 0 ? "Number of parallel workers is set to 0 or negative\n" : ""; | |||
| err_msg += builder_num_samples_ < 0 ? "Number of samples is set to negative\n" : ""; | |||
| return err_msg.empty() ? Status::OK() : Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, err_msg); | |||
| } | |||
| MnistOp::MnistOp(int32_t num_workers, int32_t rows_per_buffer, std::string folder_path, int32_t queue_size, | |||
| int64_t num_samples, std::unique_ptr<DataSchema> data_schema, std::shared_ptr<Sampler> sampler) | |||
| std::unique_ptr<DataSchema> data_schema, std::shared_ptr<Sampler> sampler) | |||
| : ParallelOp(num_workers, queue_size), | |||
| buf_cnt_(0), | |||
| row_cnt_(0), | |||
| num_rows_(0), | |||
| num_samples_(num_samples), | |||
| folder_path_(folder_path), | |||
| rows_per_buffer_(rows_per_buffer), | |||
| sampler_(std::move(sampler)), | |||
| @@ -84,8 +82,7 @@ MnistOp::MnistOp(int32_t num_workers, int32_t rows_per_buffer, std::string folde | |||
| Status MnistOp::TraversalSampleIds(const std::shared_ptr<Tensor> &sample_ids, std::vector<int64_t> *keys) { | |||
| for (auto itr = sample_ids->begin<int64_t>(); itr != sample_ids->end<int64_t>(); ++itr) { | |||
| if ((*itr) >= num_rows_) continue; // index out of bound, skipping | |||
| if (row_cnt_ >= num_samples_) break; // enough row read, break for loop | |||
| if ((*itr) >= num_rows_) continue; // index out of bound, skipping | |||
| keys->push_back(*itr); | |||
| row_cnt_++; | |||
| if (row_cnt_ % rows_per_buffer_ == 0) { | |||
| @@ -219,17 +216,6 @@ Status MnistOp::InitSampler() { | |||
| return Status::OK(); | |||
| } | |||
| // Derived from RandomAccessOp | |||
| Status MnistOp::GetNumSamples(int64_t *num) const { | |||
| if (num == nullptr || num_rows_ == 0) { | |||
| RETURN_STATUS_UNEXPECTED( | |||
| "There is no valid data matching the dataset API MnistDataset.Please check file path or dataset API " | |||
| "validation first."); | |||
| } | |||
| (*num) = num_samples_; | |||
| return Status::OK(); | |||
| } | |||
| // Derived from RandomAccessOp | |||
| Status MnistOp::GetClassIds(std::map<int32_t, std::vector<int64_t>> *cls_ids) const { | |||
| if (cls_ids == nullptr || !cls_ids->empty() || image_label_pairs_.empty()) { | |||
| @@ -364,7 +350,6 @@ Status MnistOp::ParseMnistData() { | |||
| } | |||
| image_label_pairs_.shrink_to_fit(); | |||
| num_rows_ = image_label_pairs_.size(); | |||
| num_samples_ = (num_samples_ == 0 || num_samples_ > num_rows_) ? num_rows_ : num_samples_; | |||
| return Status::OK(); | |||
| } | |||
| @@ -414,11 +399,11 @@ Status MnistOp::LaunchThreadsAndInitOp() { | |||
| return Status::OK(); | |||
| } | |||
| Status MnistOp::CountTotalRows(const std::string &dir, int64_t numSamples, int64_t *count) { | |||
| Status MnistOp::CountTotalRows(const std::string &dir, int64_t *count) { | |||
| // the logic of counting the number of samples is copied from ParseMnistData() and uses CheckReader() | |||
| std::shared_ptr<MnistOp> op; | |||
| *count = 0; | |||
| RETURN_IF_NOT_OK(Builder().SetDir(dir).SetNumSamples(numSamples).Build(&op)); | |||
| RETURN_IF_NOT_OK(Builder().SetDir(dir).Build(&op)); | |||
| RETURN_IF_NOT_OK(op->WalkAllFiles()); | |||
| @@ -440,19 +425,6 @@ Status MnistOp::CountTotalRows(const std::string &dir, int64_t numSamples, int64 | |||
| label_reader.close(); | |||
| } | |||
| *count = (numSamples == 0 || *count < numSamples) ? *count : numSamples; | |||
| return Status::OK(); | |||
| } | |||
| // Derived from RandomAccessOp | |||
| Status MnistOp::GetNumRowsInDataset(int64_t *num) const { | |||
| if (num == nullptr || num_rows_ == 0) { | |||
| RETURN_STATUS_UNEXPECTED( | |||
| "There is no valid data matching the dataset API MnistDataset.Please check file path or dataset API " | |||
| "validation first."); | |||
| } | |||
| (*num) = num_rows_; | |||
| return Status::OK(); | |||
| } | |||
| } // namespace dataset | |||
| @@ -78,14 +78,6 @@ class MnistOp : public ParallelOp, public RandomAccessOp { | |||
| return *this; | |||
| } | |||
| // Setter method | |||
| // @param int64_t num_samples | |||
| // @return Builder setter method returns reference to the builder. | |||
| Builder &SetNumSamples(int64_t num_samples) { | |||
| builder_num_samples_ = num_samples; | |||
| return *this; | |||
| } | |||
| // Setter method | |||
| // @param std::shared_ptr<Sampler> sampler | |||
| // @return Builder setter method returns reference to the builder. | |||
| @@ -114,7 +106,6 @@ class MnistOp : public ParallelOp, public RandomAccessOp { | |||
| private: | |||
| std::string builder_dir_; | |||
| int32_t builder_num_workers_; | |||
| int64_t builder_num_samples_; | |||
| int32_t builder_rows_per_buffer_; | |||
| int32_t builder_op_connector_size_; | |||
| std::shared_ptr<Sampler> builder_sampler_; | |||
| @@ -126,11 +117,10 @@ class MnistOp : public ParallelOp, public RandomAccessOp { | |||
| // @param int32_t rows_per_buffer - number of images (rows) in each buffer | |||
| // @param std::string folder_path - dir directory of mnist | |||
| // @param int32_t queue_size - connector queue size | |||
| // @param int64_t num_samples - number of samples to read | |||
| // @param std::unique_ptr<DataSchema> data_schema - the schema of the mnist dataset | |||
| // @param td::unique_ptr<Sampler> sampler - sampler tells MnistOp what to read | |||
| MnistOp(int32_t num_workers, int32_t rows_per_buffer, std::string folder_path, int32_t queue_size, | |||
| int64_t num_samples, std::unique_ptr<DataSchema> data_schema, std::shared_ptr<Sampler> sampler); | |||
| std::unique_ptr<DataSchema> data_schema, std::shared_ptr<Sampler> sampler); | |||
| // Destructor. | |||
| ~MnistOp() = default; | |||
| @@ -146,16 +136,6 @@ class MnistOp : public ParallelOp, public RandomAccessOp { | |||
| // @return Status - The error code return | |||
| Status operator()() override; | |||
| // Method derived from RandomAccess Op, enable Sampler to get numRows | |||
| // @param int64_t num - to return numRows | |||
| // @return Status - The error code return | |||
| Status GetNumSamples(int64_t *num) const override; | |||
| // Method derived from RandomAccess Op, enable Sampler to get total numRows in dataset | |||
| // @param int64_t num - to return numRows | |||
| // @return Status - The error code return | |||
| Status GetNumRowsInDataset(int64_t *num) const override; | |||
| // Method derived from RandomAccess Op, enable Sampler to get all ids for each class | |||
| // @param (std::map<uint64_t, std::vector<uint64_t >> * map - key label, val all ids for this class | |||
| // @return Status - The error code return | |||
| @@ -167,11 +147,10 @@ class MnistOp : public ParallelOp, public RandomAccessOp { | |||
| void Print(std::ostream &out, bool show_all) const override; | |||
| // Function to count the number of samples in the MNIST dataset | |||
| // @param dir path to the MNSIT directory | |||
| // @param numSamples maximum number of samples requested | |||
| // @param dir path to the MNIST directory | |||
| // @param count output arg that will hold the minimum of the actual dataset size and numSamples | |||
| // @return | |||
| static Status CountTotalRows(const std::string &dir, int64_t numSamples, int64_t *count); | |||
| static Status CountTotalRows(const std::string &dir, int64_t *count); | |||
| private: | |||
| // Initialize Sampler, calls sampler->Init() within | |||
| @@ -244,9 +223,7 @@ class MnistOp : public ParallelOp, public RandomAccessOp { | |||
| int64_t buf_cnt_; | |||
| int64_t row_cnt_; | |||
| int64_t num_rows_; // total number of images in Mnist | |||
| WaitPost wp_; | |||
| int64_t num_samples_; | |||
| std::string folder_path_; // directory of image folder | |||
| int32_t rows_per_buffer_; | |||
| std::shared_ptr<Sampler> sampler_; | |||
| @@ -8,6 +8,5 @@ add_library(engine-datasetops-source-sampler OBJECT | |||
| sampler.cc | |||
| sequential_sampler.cc | |||
| subset_random_sampler.cc | |||
| subset_sampler.cc | |||
| weighted_random_sampler.cc | |||
| ) | |||
| @@ -23,8 +23,9 @@ | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| DistributedSampler::DistributedSampler(int64_t num_dev, int64_t dev_id, bool shuffle, uint32_t seed) | |||
| : Sampler(), | |||
| DistributedSampler::DistributedSampler(int64_t num_samples, int64_t num_dev, int64_t dev_id, bool shuffle, | |||
| uint32_t seed) | |||
| : Sampler(num_samples, std::numeric_limits<int64_t>::max()), | |||
| cnt_(0), | |||
| seed_(seed == std::numeric_limits<uint32_t>::max() ? GetSeed() : seed), | |||
| device_id_(dev_id), | |||
| @@ -32,6 +33,11 @@ DistributedSampler::DistributedSampler(int64_t num_dev, int64_t dev_id, bool shu | |||
| shuffle_(shuffle) {} | |||
| Status DistributedSampler::InitSampler() { | |||
| // Special value of 0 for num_samples means that the user wants to sample the entire set of data. | |||
| // If the user asked to sample more rows than exists in the dataset, adjust the num_samples accordingly. | |||
| if (num_samples_ == 0 || num_samples_ > num_rows_) { | |||
| num_samples_ = num_rows_; | |||
| } | |||
| CHECK_FAIL_RETURN_UNEXPECTED(num_samples_ > 0, "num_samples <= 0\n"); | |||
| CHECK_FAIL_RETURN_UNEXPECTED(num_rows_ > 0, "num_rows <= 0\n"); | |||
| CHECK_FAIL_RETURN_UNEXPECTED(device_id_ < num_devices_ && device_id_ >= 0 && num_rows_ > 0 && num_samples_ > 0, | |||
| @@ -27,10 +27,11 @@ namespace mindspore { | |||
| namespace dataset { | |||
| class DistributedSampler : public Sampler { | |||
| public: | |||
| // @param int64_t numDev | |||
| // @param int64_t devId | |||
| // @param num_samples | |||
| // @param int64_t num_dev | |||
| // @param int64_t dev_id | |||
| // @param bool shuffle | |||
| DistributedSampler(int64_t num_dev, int64_t dev_id, bool shuffle = true, | |||
| DistributedSampler(int64_t num_samples, int64_t num_dev, int64_t dev_id, bool shuffle, | |||
| uint32_t seed = std::numeric_limits<uint32_t>::max()); | |||
| // default destructor | |||
| @@ -20,12 +20,11 @@ | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| PKSampler::PKSampler(int64_t val, bool shuffle, int64_t samples_per_buffer) | |||
| : Sampler(samples_per_buffer), | |||
| PKSampler::PKSampler(int64_t num_samples, int64_t val, bool shuffle, int64_t samples_per_buffer) | |||
| : Sampler(num_samples, samples_per_buffer), | |||
| shuffle_(shuffle), | |||
| seed_(GetSeed()), | |||
| next_id_(0), | |||
| num_pk_samples_(0), | |||
| samples_per_class_(val) {} | |||
| Status PKSampler::InitSampler() { | |||
| @@ -36,22 +35,34 @@ Status PKSampler::InitSampler() { | |||
| } | |||
| } | |||
| rnd_.seed(seed_++); | |||
| num_pk_samples_ = samples_per_class_ * static_cast<int64_t>(labels_.size()); | |||
| samples_per_buffer_ = (samples_per_buffer_ > num_pk_samples_) ? num_pk_samples_ : samples_per_buffer_; | |||
| num_samples_ = num_pk_samples_; | |||
| // The special handshake gives the list of classes and id's, but it did not set the num_rows_ to | |||
| // capture the total number of possible sample ids. | |||
| // Compute that here for this case to find the total number of samples that are available to return. | |||
| // (in this case, samples per class * total classes). | |||
| num_rows_ = samples_per_class_ * static_cast<int64_t>(labels_.size()); | |||
| // The user may have chosen to sample less than the total amount. | |||
| // Special value of 0 for num_samples means that the user wants to sample the entire set of data. | |||
| // If the user asked to sample more rows than exists in the dataset, adjust the num_samples accordingly. | |||
| if (num_samples_ == 0 || num_samples_ > num_rows_) { | |||
| num_samples_ = num_rows_; | |||
| } | |||
| samples_per_buffer_ = (samples_per_buffer_ > num_samples_) ? num_samples_ : samples_per_buffer_; | |||
| if (shuffle_ == true) { | |||
| std::shuffle(labels_.begin(), labels_.end(), rnd_); | |||
| } else { | |||
| std::sort(labels_.begin(), labels_.end()); | |||
| } | |||
| CHECK_FAIL_RETURN_UNEXPECTED(num_pk_samples_ > 0, "num_class or K (num samples per class) is not positive"); | |||
| CHECK_FAIL_RETURN_UNEXPECTED(num_samples_ > 0, "num_class or K (num samples per class) is not positive"); | |||
| return Status::OK(); | |||
| } | |||
| Status PKSampler::GetNextBuffer(std::unique_ptr<DataBuffer> *out_buffer) { | |||
| if (next_id_ > num_pk_samples_ || num_pk_samples_ == 0) { | |||
| if (next_id_ > num_samples_ || num_samples_ == 0) { | |||
| RETURN_STATUS_UNEXPECTED("Index out of bound in PKSampler"); | |||
| } else if (next_id_ == num_pk_samples_) { | |||
| } else if (next_id_ == num_samples_) { | |||
| (*out_buffer) = std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOE); | |||
| } else { | |||
| if (HasChildSampler()) { | |||
| @@ -60,8 +71,7 @@ Status PKSampler::GetNextBuffer(std::unique_ptr<DataBuffer> *out_buffer) { | |||
| (*out_buffer) = std::make_unique<DataBuffer>(next_id_, DataBuffer::kDeBFlagNone); | |||
| std::shared_ptr<Tensor> sample_ids; | |||
| int64_t last_id = | |||
| (samples_per_buffer_ + next_id_ > num_pk_samples_) ? num_pk_samples_ : samples_per_buffer_ + next_id_; | |||
| int64_t last_id = (samples_per_buffer_ + next_id_ > num_samples_) ? num_samples_ : samples_per_buffer_ + next_id_; | |||
| RETURN_IF_NOT_OK(CreateSamplerTensor(&sample_ids, last_id - next_id_)); | |||
| int64_t *id_ptr = reinterpret_cast<int64_t *>(sample_ids->GetMutableBuffer()); | |||
| while (next_id_ < last_id) { | |||
| @@ -85,7 +95,7 @@ Status PKSampler::GetNextBuffer(std::unique_ptr<DataBuffer> *out_buffer) { | |||
| } | |||
| Status PKSampler::Reset() { | |||
| CHECK_FAIL_RETURN_UNEXPECTED(next_id_ == num_pk_samples_, "ERROR Reset() called early/late"); | |||
| CHECK_FAIL_RETURN_UNEXPECTED(next_id_ == num_samples_, "ERROR Reset() called early/late"); | |||
| next_id_ = 0; | |||
| rnd_.seed(seed_++); | |||
| @@ -28,10 +28,11 @@ namespace mindspore { | |||
| namespace dataset { | |||
| class PKSampler : public Sampler { // NOT YET FINISHED | |||
| public: | |||
| // @param int64_t kVal | |||
| // @param num_samples - the number of samples to draw. value of 0 means to take the full amount | |||
| // @param int64_t val | |||
| // @param bool shuffle - shuffle all classIds or not, if true, classes may be 5,1,4,3,2 | |||
| // @param int64_t samplesPerBuffer - Num of Sampler Ids to fetch via 1 GetNextBuffer call | |||
| explicit PKSampler(int64_t val, bool shuffle = false, | |||
| explicit PKSampler(int64_t num_samples, int64_t val, bool shuffle, | |||
| int64_t samples_per_buffer = std::numeric_limits<int64_t>::max()); | |||
| // default destructor | |||
| @@ -42,8 +43,9 @@ class PKSampler : public Sampler { // NOT YET FINISHED | |||
| // @return - The error code return | |||
| Status GetNextBuffer(std::unique_ptr<DataBuffer> *out_buffer) override; | |||
| // first handshake between StorageOp and Sampler | |||
| // @param op - StorageOp pointer, pass in so Sampler can call GetNumSamples() and get ClassIds() | |||
| // first handshake between leaf source op and Sampler. This func will determine the amount of data | |||
| // in the dataset that we can sample from. | |||
| // @param op - leaf op pointer, pass in so Sampler can ask it about how much data there is | |||
| // @return | |||
| Status HandshakeRandomAccessOp(const RandomAccessOp *op) override; | |||
| @@ -58,7 +60,6 @@ class PKSampler : public Sampler { // NOT YET FINISHED | |||
| bool shuffle_; | |||
| uint32_t seed_; | |||
| int64_t next_id_; | |||
| int64_t num_pk_samples_; | |||
| int64_t samples_per_class_; | |||
| std::mt19937 rnd_; | |||
| std::vector<int64_t> labels_; | |||
| @@ -20,8 +20,8 @@ | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| PythonSampler::PythonSampler(py::object py_sampler_instance, int64_t samples_per_buffer) | |||
| : Sampler(samples_per_buffer), py_sampler_instance(py_sampler_instance), need_to_reset_(false) {} | |||
| PythonSampler::PythonSampler(int64_t num_samples, py::object py_sampler_instance, int64_t samples_per_buffer) | |||
| : Sampler(num_samples, samples_per_buffer), py_sampler_instance(py_sampler_instance), need_to_reset_(false) {} | |||
| Status PythonSampler::GetNextBuffer(std::unique_ptr<DataBuffer> *out_buffer) { | |||
| if (need_to_reset_) { | |||
| @@ -65,6 +65,11 @@ Status PythonSampler::GetNextBuffer(std::unique_ptr<DataBuffer> *out_buffer) { | |||
| Status PythonSampler::InitSampler() { | |||
| CHECK_FAIL_RETURN_UNEXPECTED(num_rows_ > 0, "ERROR num_rows_ should be greater than 0"); | |||
| // Special value of 0 for num_samples means that the user wants to sample the entire set of data. | |||
| // If the user asked to sample more rows than exists in the dataset, adjust the num_samples accordingly. | |||
| if (num_samples_ == 0 || num_samples_ > num_rows_) { | |||
| num_samples_ = num_rows_; | |||
| } | |||
| { | |||
| py::gil_scoped_acquire gil_acquire; | |||
| if (Py_IsInitialized() == 0) { | |||
| @@ -26,8 +26,11 @@ namespace dataset { | |||
| class PythonSampler : public Sampler { | |||
| public: | |||
| // Constructor | |||
| // @param int64_t samplesPerBuffer - Num of Sampler Ids to fetch via 1 GetNextBuffer call | |||
| explicit PythonSampler(py::object py_sampler_instance, | |||
| // @param num_samples - the number of samples to draw. Value of 0 means to sample all of the | |||
| // data from the dataset. | |||
| // @param py_sampler_instance - the python instance of the sampler | |||
| // @param int64_t samples_per_buffer - Num of Sampler Ids to fetch via 1 GetNextBuffer call | |||
| explicit PythonSampler(int64_t num_samples, py::object py_sampler_instance, | |||
| int64_t samples_per_buffer = std::numeric_limits<int64_t>::max()); | |||
| // Destructor. | |||
| @@ -22,12 +22,11 @@ | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| RandomSampler::RandomSampler(bool replacement, bool reshuffle_each_epoch, int64_t num_samples, | |||
| RandomSampler::RandomSampler(int64_t num_samples, bool replacement, bool reshuffle_each_epoch, | |||
| int64_t samples_per_buffer) | |||
| : Sampler(samples_per_buffer), | |||
| : Sampler(num_samples, samples_per_buffer), | |||
| seed_(GetSeed()), | |||
| replacement_(replacement), | |||
| user_num_samples_(num_samples), | |||
| next_id_(0), | |||
| reshuffle_each_epoch_(reshuffle_each_epoch), | |||
| dist(nullptr) {} | |||
| @@ -70,27 +69,25 @@ Status RandomSampler::GetNextBuffer(std::unique_ptr<DataBuffer> *out_buffer) { | |||
| } | |||
| Status RandomSampler::InitSampler() { | |||
| CHECK_FAIL_RETURN_UNEXPECTED(num_rows_ > 0, "num_rows needs to be positive."); | |||
| // Special value of 0 for num_samples means that the user wants to sample the entire set of data. | |||
| // If the user asked to sample more rows than exists in the dataset, adjust the num_samples accordingly. | |||
| if (num_samples_ == 0 || num_samples_ > num_rows_) { | |||
| num_samples_ = num_rows_; | |||
| } | |||
| CHECK_FAIL_RETURN_UNEXPECTED(num_samples_ > 0 && num_rows_ > 0, "both num_samples & num_rows need to be positive"); | |||
| samples_per_buffer_ = samples_per_buffer_ > num_samples_ ? num_samples_ : samples_per_buffer_; | |||
| rnd_.seed(seed_); | |||
| if (replacement_ == false) { | |||
| num_samples_ = std::min(num_samples_, num_rows_); | |||
| num_samples_ = std::min(num_samples_, user_num_samples_); | |||
| shuffled_ids_.reserve(num_rows_); | |||
| for (int64_t i = 0; i < num_rows_; i++) { | |||
| shuffled_ids_.push_back(i); | |||
| } | |||
| std::shuffle(shuffled_ids_.begin(), shuffled_ids_.end(), rnd_); | |||
| } else { | |||
| num_samples_ = std::min(num_samples_, user_num_samples_); | |||
| dist = std::make_unique<std::uniform_int_distribution<int64_t>>(0, num_rows_ - 1); | |||
| } | |||
| CHECK_FAIL_RETURN_UNEXPECTED(num_samples_ > 0, "num_samples needs to be positive."); | |||
| samples_per_buffer_ = samples_per_buffer_ > num_samples_ ? num_samples_ : samples_per_buffer_; | |||
| return Status::OK(); | |||
| } | |||
| @@ -119,7 +116,6 @@ void RandomSampler::Print(std::ostream &out, bool show_all) const { | |||
| out << "(sampler): RandomSampler\n"; | |||
| if (show_all) { | |||
| out << "user_num_samples_: " << user_num_samples_ << '\n'; | |||
| out << "num_samples_: " << num_samples_ << '\n'; | |||
| out << "next_id_: " << next_id_ << '\n'; | |||
| } | |||
| @@ -27,11 +27,11 @@ namespace dataset { | |||
| class RandomSampler : public Sampler { | |||
| public: | |||
| // Constructor | |||
| // @param int64_t num_samples - number samples to draw | |||
| // @param bool replacement - put he id back / or not after a sample | |||
| // @param int64_t numSamples - number samples to draw | |||
| // @param int64_t samplesPerBuffer - Num of Sampler Ids to fetch via 1 GetNextBuffer call | |||
| explicit RandomSampler(bool replacement = false, bool reshuffle_each_epoch = true, | |||
| int64_t num_samples = std::numeric_limits<int64_t>::max(), | |||
| // @param reshuffle_each_epoch - T/F to reshuffle after epoch | |||
| // @param int64_t samples_per_buffer - Num of Sampler Ids to fetch via 1 GetNextBuffer call | |||
| explicit RandomSampler(int64_t num_samples, bool replacement, bool reshuffle_each_epoch, | |||
| int64_t samples_per_buffer = std::numeric_limits<int64_t>::max()); | |||
| // Destructor. | |||
| @@ -55,7 +55,6 @@ class RandomSampler : public Sampler { | |||
| private: | |||
| uint32_t seed_; | |||
| bool replacement_; | |||
| int64_t user_num_samples_; | |||
| std::vector<int64_t> shuffled_ids_; // only used for NO REPLACEMENT | |||
| int64_t next_id_; | |||
| std::mt19937 rnd_; | |||
| @@ -19,8 +19,25 @@ | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| Sampler::Sampler(int64_t samples_per_buffer) | |||
| : DatasetOp(0), num_rows_(0), num_samples_(0), samples_per_buffer_(samples_per_buffer), col_desc_(nullptr) {} | |||
| Status RandomAccessOp::GetNumRowsInDataset(int64_t *num) const { | |||
| // The sampler base class itself does not compute it's own num_rows_ value. | |||
| // Instead, this value is computed by the derived leaf op during it's own initialization | |||
| // after it has interacted with it's storage layers. | |||
| // Here, it is just a getter method to return the value. However, it is invalid if there is | |||
| // not a value set for this count, so generate a failure if that is the case. | |||
| if (num == nullptr || num_rows_ == 0) { | |||
| RETURN_STATUS_UNEXPECTED("RandomAccessOp has not computed it's num rows yet."); | |||
| } | |||
| (*num) = num_rows_; | |||
| return Status::OK(); | |||
| } | |||
| Sampler::Sampler(int64_t num_samples, int64_t samples_per_buffer) | |||
| : DatasetOp(0), | |||
| num_rows_(0), | |||
| num_samples_(num_samples), | |||
| samples_per_buffer_(samples_per_buffer), | |||
| col_desc_(nullptr) {} | |||
| Status Sampler::HandshakeRandomAccessOp(const RandomAccessOp *op) { | |||
| std::shared_ptr<Sampler> child_sampler; | |||
| @@ -36,10 +53,10 @@ Status Sampler::HandshakeRandomAccessOp(const RandomAccessOp *op) { | |||
| } | |||
| CHECK_FAIL_RETURN_UNEXPECTED(op != nullptr, "RandomAccessOp is nullptr\n"); | |||
| RETURN_IF_NOT_OK(op->GetNumSamples(&num_samples_)); | |||
| // If there's a child sampler, set the row count to be it's sample count | |||
| if (HasChildSampler()) { | |||
| int64_t child_num_samples = child_sampler->num_samples(); | |||
| num_rows_ = child_num_samples; | |||
| num_rows_ = child_sampler->num_samples_; | |||
| } else { | |||
| RETURN_IF_NOT_OK(op->GetNumRowsInDataset(&num_rows_)); | |||
| } | |||
| @@ -105,7 +122,7 @@ Status Sampler::GetAllIdsThenReset(py::array *data) { | |||
| } | |||
| Status Sampler::SetNumSamples(int64_t num_samples) { | |||
| CHECK_FAIL_RETURN_UNEXPECTED(num_samples > 0, "num_samples is negative or 0"); | |||
| CHECK_FAIL_RETURN_UNEXPECTED(num_samples >= 0, "num_samples is negative"); | |||
| num_samples_ = num_samples; | |||
| return Status::OK(); | |||
| } | |||
| @@ -116,6 +133,16 @@ Status Sampler::SetNumRowsInDataset(int64_t num_rows) { | |||
| return Status::OK(); | |||
| } | |||
| // inline op doesn't have it's own consumer, it's assigned from parent | |||
| int32_t Sampler::num_consumers() const { | |||
| if (parent_.empty() || parent_[0] == nullptr) { | |||
| MS_LOG(WARNING) << "Sampler with no parent. num_consumers is 0."; | |||
| return 0; | |||
| } else { | |||
| return parent_[0]->num_consumers(); | |||
| } | |||
| } | |||
| Status Sampler::AddChild(std::shared_ptr<DatasetOp> child) { | |||
| if (child == nullptr) { | |||
| return Status::OK(); | |||
| @@ -155,5 +182,14 @@ Status Sampler::GetAssociatedChildId(int64_t *out_associated_id, int64_t id) { | |||
| return Status::OK(); | |||
| } | |||
| // inline op doesn't have it's own producers, it's assigned from child | |||
| int32_t Sampler::num_producers() const { | |||
| if (child_.empty() || child_[0] == nullptr) { | |||
| MS_LOG(WARNING) << "Sampler with no child, num_producers is 0."; | |||
| return 0; | |||
| } else { | |||
| return child_[0]->num_producers(); | |||
| } | |||
| } | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| @@ -33,23 +33,10 @@ namespace dataset { | |||
| // must inherit from if those leaf operator wish to support sampling. | |||
| class RandomAccessOp { | |||
| public: | |||
| // Sampler get numRows from StorageOp | |||
| // @param int64_t num - return number of rows, normally num of samples | |||
| // @return - The error code return | |||
| virtual Status GetNumSamples(int64_t *num_samples) const { | |||
| // CI complains num_samples not used if the following line is not added | |||
| CHECK_FAIL_RETURN_UNEXPECTED(num_samples != nullptr, "num_samples == nullptr"); | |||
| RETURN_STATUS_UNEXPECTED("function GetNumSamples needs to overridden to support this sampler"); | |||
| } | |||
| // Sampler get number of rows in the dataset! | |||
| // Sampler get number of rows in the dataset | |||
| // @param int64_t num - return number of rows for this dataset | |||
| // @return - The error code return | |||
| virtual Status GetNumRowsInDataset(int64_t *num_rows) const { | |||
| // CI complains num_rows not used if the following line is not added | |||
| CHECK_FAIL_RETURN_UNEXPECTED(num_rows != nullptr, "num_rows == nullptr"); | |||
| RETURN_STATUS_UNEXPECTED("function GetNumRowsInDataset needs to overridden to support this sampler"); | |||
| } | |||
| Status GetNumRowsInDataset(int64_t *num_rows) const; | |||
| // sampler gets label , imageIds from storageOp, this function is unique to PK | |||
| // @param std::map<int64_t, std::vector<int64_t>> * map | |||
| @@ -60,12 +47,20 @@ class RandomAccessOp { | |||
| // default destructor | |||
| virtual ~RandomAccessOp() = default; | |||
| protected: | |||
| // The amount of rows in the dataset itself. This is the before-sampling value, the | |||
| // total count of rows. A sampler may choose to sample less than this amount. | |||
| int64_t num_rows_; | |||
| }; | |||
| class Sampler : public DatasetOp { | |||
| public: | |||
| // Constructor | |||
| // @param int64_t num_samples: the user-requested number of samples ids to generate. A value of 0 | |||
| // indicates that the sampler should produce the complete set of ids. | |||
| // @param int64_t samplesPerBuffer: Num of Sampler Ids to fetch via 1 GetNextBuffer call | |||
| explicit Sampler(int64_t samples_per_buffer = std::numeric_limits<int64_t>::max()); | |||
| explicit Sampler(int64_t num_samples, int64_t samples_per_buffer); | |||
| // default destructor | |||
| ~Sampler() = default; | |||
| @@ -84,33 +79,36 @@ class Sampler : public DatasetOp { | |||
| // @return - The error code return | |||
| Status Reset() override = 0; | |||
| // setter function for num_rows_ | |||
| Status SetNumRowsInDataset(int64_t num_rows); | |||
| // setter function for num_samples_ | |||
| Status SetNumSamples(int64_t num_samples); | |||
| int64_t num_samples() { return num_samples_; } | |||
| // first handshake between StorageOp and Sampler. This func will call getNumRows and getNumSamples | |||
| // @param op - StorageOp pointer, pass in so Sampler can call getNumSamples() and get ClassIds() | |||
| // first handshake between leaf source op and Sampler. This func will determine the amount of data | |||
| // in the dataset that we can sample from. | |||
| // @param op - leaf op pointer, pass in so Sampler can ask it about how much data there is | |||
| // @return | |||
| virtual Status HandshakeRandomAccessOp(const RandomAccessOp *op); | |||
| // initialize sampler and perform checks on certain vars | |||
| virtual Status InitSampler() { return Status::OK(); } | |||
| // Not meant to be called | |||
| // setter for num samples | |||
| // @param num_samples - the number of samples to assign. | |||
| // @return status error code | |||
| Status SetNumSamples(int64_t num_samples); | |||
| // setter for num or records in the dataset | |||
| // @param num_rows - the number of records | |||
| // @return status error code | |||
| Status SetNumRowsInDataset(int64_t num_rows); | |||
| // Sampler is an inlined op and has no workers. Producers and consumers are computed. | |||
| // @return | |||
| int32_t num_workers() const final { return 0; } | |||
| // Not meant to be called | |||
| // Identify num consumers (inlined op) | |||
| // @return | |||
| int32_t num_consumers() const final { return 0; } | |||
| int32_t num_consumers() const final; | |||
| // Not meant to be called | |||
| // Identify num producers (inlined op) | |||
| // @return | |||
| int32_t num_producers() const final { return 0; } | |||
| int32_t num_producers() const final; | |||
| // Not meant to be called! | |||
| // @return - The error code return | |||
| @@ -151,10 +149,11 @@ class Sampler : public DatasetOp { | |||
| // output. Otherwise, num_rows_ is the number of rows in the dataset. | |||
| int64_t num_rows_; | |||
| // Number of ids this sampler will return. | |||
| // The user may want to sample less than the full amount of data. num_samples_ reduces the number | |||
| // of id's returned as request by the user. Derived classes will choose how to sample the smaller | |||
| // amount. | |||
| int64_t num_samples_; | |||
| // The max number of ids a DataBuffer returned by this sampler will contain. | |||
| int64_t samples_per_buffer_; | |||
| std::unique_ptr<ColDescriptor> col_desc_; | |||
| std::unique_ptr<DataBuffer> child_ids_; | |||
| @@ -20,34 +20,42 @@ | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| SequentialSampler::SequentialSampler(int64_t samples_per_buffer) : Sampler(samples_per_buffer), next_id_(0) {} | |||
| SequentialSampler::SequentialSampler(int64_t num_samples, int64_t start_index, int64_t samples_per_buffer) | |||
| : Sampler(num_samples, samples_per_buffer), start_index_(start_index), current_id_(start_index), id_count_(0) {} | |||
| Status SequentialSampler::GetNextBuffer(std::unique_ptr<DataBuffer> *out_buffer) { | |||
| if (next_id_ > num_samples_) { | |||
| RETURN_STATUS_UNEXPECTED("Sequential Sampler Internal Error"); | |||
| } else if (next_id_ == num_samples_) { | |||
| if (id_count_ > num_samples_) { | |||
| RETURN_STATUS_UNEXPECTED("SequentialSampler Internal Error"); | |||
| } else if (id_count_ == num_samples_) { | |||
| (*out_buffer) = std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOE); | |||
| } else { | |||
| if (HasChildSampler()) { | |||
| RETURN_IF_NOT_OK(child_[0]->GetNextBuffer(&child_ids_)); | |||
| } | |||
| (*out_buffer) = std::make_unique<DataBuffer>(next_id_, DataBuffer::kDeBFlagNone); | |||
| (*out_buffer) = std::make_unique<DataBuffer>(current_id_, DataBuffer::kDeBFlagNone); | |||
| std::shared_ptr<Tensor> sampleIds; | |||
| int64_t lastId = (samples_per_buffer_ + next_id_ > num_samples_) ? num_samples_ : samples_per_buffer_ + next_id_; | |||
| RETURN_IF_NOT_OK(CreateSamplerTensor(&sampleIds, lastId - next_id_)); | |||
| // Compute how many ids are left to pack, and pack this amount into a new buffer. Respect the setting for | |||
| // samples per buffer though. | |||
| int64_t remaining_ids = num_samples_ - id_count_; | |||
| int64_t num_elements = std::min(remaining_ids, samples_per_buffer_); | |||
| RETURN_IF_NOT_OK(CreateSamplerTensor(&sampleIds, num_elements)); | |||
| int64_t *idPtr = reinterpret_cast<int64_t *>(sampleIds->GetMutableBuffer()); | |||
| while (next_id_ < lastId) { | |||
| int64_t sampled_id = next_id_; | |||
| for (int64_t i = 0; i < num_elements; i++) { | |||
| int64_t sampled_id = current_id_; | |||
| if (HasChildSampler()) { | |||
| RETURN_IF_NOT_OK(GetAssociatedChildId(&sampled_id, sampled_id)); | |||
| } | |||
| *idPtr = sampled_id; | |||
| next_id_++; | |||
| current_id_++; // Move the current id to the next one in the sequence | |||
| idPtr++; | |||
| } | |||
| id_count_ += num_elements; // Count the packed ids towards our overall sample count | |||
| TensorRow row(1, sampleIds); | |||
| (*out_buffer)->set_tensor_table(std::make_unique<TensorQTable>(1, row)); | |||
| } | |||
| @@ -55,19 +63,24 @@ Status SequentialSampler::GetNextBuffer(std::unique_ptr<DataBuffer> *out_buffer) | |||
| } | |||
| Status SequentialSampler::InitSampler() { | |||
| num_samples_ = (num_samples_ <= 0) ? num_rows_ : num_samples_; // if num_samples < 0, try if num_rows is set | |||
| if (HasChildSampler()) { | |||
| num_samples_ = std::min(num_samples_, num_rows_); | |||
| CHECK_FAIL_RETURN_UNEXPECTED(start_index_ >= 0, "start_index < 0\n"); | |||
| CHECK_FAIL_RETURN_UNEXPECTED(start_index_ < num_rows_, "start_index >= num_rows\n"); | |||
| CHECK_FAIL_RETURN_UNEXPECTED(num_samples_ >= 0, "num_samples < 0\n"); | |||
| // Adjust the num_samples count based on the range of ids we are sequencing. If num_samples is 0, we sample | |||
| // the entire set. If it's non-zero, we will implicitly cap the amount sampled based on available data. | |||
| int64_t available_row_count = num_rows_ - start_index_; | |||
| if (num_samples_ == 0 || num_samples_ > available_row_count) { | |||
| num_samples_ = available_row_count; | |||
| } | |||
| CHECK_FAIL_RETURN_UNEXPECTED(num_samples_ > 0 && samples_per_buffer_ > 0, "Fail to init Sequential Sampler"); | |||
| samples_per_buffer_ = samples_per_buffer_ > num_samples_ ? num_samples_ : samples_per_buffer_; | |||
| return Status::OK(); | |||
| } | |||
| Status SequentialSampler::Reset() { | |||
| CHECK_FAIL_RETURN_UNEXPECTED(next_id_ == num_samples_, "ERROR Reset() called early/late"); | |||
| next_id_ = 0; | |||
| CHECK_FAIL_RETURN_UNEXPECTED(id_count_ == num_samples_, "ERROR Reset() called early/late"); | |||
| current_id_ = start_index_; | |||
| id_count_ = 0; | |||
| if (HasChildSampler()) { | |||
| RETURN_IF_NOT_OK(child_[0]->Reset()); | |||
| @@ -26,8 +26,12 @@ namespace dataset { | |||
| class SequentialSampler : public Sampler { | |||
| public: | |||
| // Constructor | |||
| // @param num_samples - The number of samples to draw. A value of 0 indicates the sampler should produce the | |||
| // full amount of ids from the dataset | |||
| // @param start_index - The starting index value | |||
| // @param int64_t samplesPerBuffer - Num of Sampler Ids to fetch via 1 GetNextBuffer call | |||
| explicit SequentialSampler(int64_t samples_per_buffer = std::numeric_limits<int64_t>::max()); | |||
| explicit SequentialSampler(int64_t num_samples, int64_t start_index, | |||
| int64_t samples_per_buffer = std::numeric_limits<int64_t>::max()); | |||
| // Destructor. | |||
| ~SequentialSampler() = default; | |||
| @@ -48,7 +52,9 @@ class SequentialSampler : public Sampler { | |||
| void Print(std::ostream &out, bool show_all) const override; | |||
| private: | |||
| int64_t next_id_; | |||
| int64_t current_id_; // The id sequencer. Each new id increments from this | |||
| int64_t start_index_; // The starting id. current_id_ begins from here. | |||
| int64_t id_count_; // An internal counter that tracks how many ids have been produced | |||
| }; | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| @@ -27,22 +27,28 @@ | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| // Constructor. | |||
| SubsetRandomSampler::SubsetRandomSampler(const std::vector<int64_t> &indices, int64_t samples_per_buffer) | |||
| : Sampler(samples_per_buffer), indices_(indices), sample_id_(0), buffer_id_(0) {} | |||
| SubsetRandomSampler::SubsetRandomSampler(int64_t num_samples, const std::vector<int64_t> &indices, | |||
| int64_t samples_per_buffer) | |||
| : Sampler(num_samples, samples_per_buffer), indices_(indices), sample_id_(0), buffer_id_(0) {} | |||
| // Initialized this Sampler. | |||
| Status SubsetRandomSampler::InitSampler() { | |||
| CHECK_FAIL_RETURN_UNEXPECTED(num_rows_ > 0, "num_rows <= 0\n"); | |||
| num_samples_ = indices_.size(); | |||
| // Special value of 0 for num_samples means that the user wants to sample the entire set of data. | |||
| // In this case, the id's are provided by the user. Cap the num_samples on the number of id's given. | |||
| if (num_samples_ == 0 || num_samples_ > static_cast<int64_t>(indices_.size())) { | |||
| num_samples_ = static_cast<int64_t>(indices_.size()); | |||
| } | |||
| // Initialize random generator with seed from config manager | |||
| rand_gen_.seed(GetSeed()); | |||
| if (static_cast<size_t>(samples_per_buffer_) > indices_.size()) { | |||
| samples_per_buffer_ = static_cast<int64_t>(indices_.size()); | |||
| if (samples_per_buffer_ > num_samples_) { | |||
| samples_per_buffer_ = num_samples_; | |||
| } | |||
| // num_samples_ could be smaller than the total number of input id's. | |||
| // We will shuffle the full set of id's, but only select the first num_samples_ of them later. | |||
| std::shuffle(indices_.begin(), indices_.end(), rand_gen_); | |||
| return Status::OK(); | |||
| @@ -68,7 +74,7 @@ Status SubsetRandomSampler::Reset() { | |||
| // Get the sample ids. | |||
| Status SubsetRandomSampler::GetNextBuffer(std::unique_ptr<DataBuffer> *out_buffer) { | |||
| // All samples have been drawn | |||
| if (sample_id_ == indices_.size()) { | |||
| if (sample_id_ == num_samples_) { | |||
| (*out_buffer) = std::make_unique<DataBuffer>(buffer_id_++, DataBuffer::kDeBFlagEOE); | |||
| } else { | |||
| if (HasChildSampler()) { | |||
| @@ -80,8 +86,8 @@ Status SubsetRandomSampler::GetNextBuffer(std::unique_ptr<DataBuffer> *out_buffe | |||
| int64_t last_id = sample_id_ + samples_per_buffer_; | |||
| // Handling the return all samples at once, and when last draw is not a full batch. | |||
| if (static_cast<size_t>(last_id) > indices_.size()) { | |||
| last_id = indices_.size(); | |||
| if (last_id > num_samples_) { | |||
| last_id = num_samples_; | |||
| } | |||
| // Allocate tensor | |||
| @@ -28,10 +28,11 @@ namespace dataset { | |||
| class SubsetRandomSampler : public Sampler { | |||
| public: | |||
| // Constructor. | |||
| // @param num_samples The number of samples to draw. 0 for the full amount. | |||
| // @param indices List of indices from where we will randomly draw samples. | |||
| // @param samples_per_buffer The number of ids we draw on each call to GetNextBuffer(). | |||
| // When samplesPerBuffer=0, GetNextBuffer() will draw all the sample ids and return them at once. | |||
| explicit SubsetRandomSampler(const std::vector<int64_t> &indices, | |||
| explicit SubsetRandomSampler(int64_t num_samples, const std::vector<int64_t> &indices, | |||
| std::int64_t samples_per_buffer = std::numeric_limits<int64_t>::max()); | |||
| // Destructor. | |||
| @@ -1,85 +0,0 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "dataset/engine/datasetops/source/sampler/subset_sampler.h" | |||
| #include <memory> | |||
| #include <string> | |||
| #include "dataset/core/config_manager.h" | |||
| #include "dataset/core/global_context.h" | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| // Constructor. | |||
| SubsetSampler::SubsetSampler(int64_t start_index, int64_t subset_size) | |||
| : Sampler(subset_size), start_index_(start_index), subset_size_(subset_size), current_id_(0) {} | |||
| Status SubsetSampler::InitSampler() { | |||
| CHECK_FAIL_RETURN_UNEXPECTED(subset_size_ > 0, "subset_size <= 0\n"); | |||
| CHECK_FAIL_RETURN_UNEXPECTED(start_index_ >= 0, "start_index < 0\n"); | |||
| CHECK_FAIL_RETURN_UNEXPECTED(start_index_ < num_rows_, "start_index >= num_rows\n"); | |||
| CHECK_FAIL_RETURN_UNEXPECTED(start_index_ + subset_size_ - 1 < num_rows_, "Final index out of bounds.\n"); | |||
| num_samples_ = subset_size_; | |||
| return Status::OK(); | |||
| } | |||
| Status SubsetSampler::Reset() { | |||
| current_id_ = 0; | |||
| if (HasChildSampler()) { | |||
| RETURN_IF_NOT_OK(child_[0]->Reset()); | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| Status SubsetSampler::GetNextBuffer(std::unique_ptr<DataBuffer> *out_buffer) { | |||
| if (current_id_ > subset_size_) { | |||
| RETURN_STATUS_UNEXPECTED("SubsetSampler Internal Error"); | |||
| } else if (current_id_ == subset_size_) { | |||
| (*out_buffer) = std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOE); | |||
| } else { | |||
| if (HasChildSampler()) { | |||
| RETURN_IF_NOT_OK(child_[0]->GetNextBuffer(&child_ids_)); | |||
| } | |||
| (*out_buffer) = std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagNone); | |||
| std::shared_ptr<Tensor> sampled_ids; | |||
| RETURN_IF_NOT_OK(CreateSamplerTensor(&sampled_ids, subset_size_)); | |||
| int64_t *sampled_ids_start_addr = reinterpret_cast<int64_t *>(sampled_ids->GetMutableBuffer()); | |||
| while (current_id_ < subset_size_) { | |||
| int64_t sampled_id = start_index_ + current_id_; | |||
| if (HasChildSampler()) { | |||
| RETURN_IF_NOT_OK(GetAssociatedChildId(&sampled_id, sampled_id)); | |||
| } | |||
| *(sampled_ids_start_addr + current_id_) = sampled_id; | |||
| current_id_++; | |||
| } | |||
| TensorRow sampled_ids_row(1, sampled_ids); | |||
| (*out_buffer)->set_tensor_table(std::make_unique<TensorQTable>(1, sampled_ids_row)); | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| @@ -1,58 +0,0 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef DATASET_ENGINE_DATASETOPS_SOURCE_SAMPLER_SUBSET_SAMPLER_H_ | |||
| #define DATASET_ENGINE_DATASETOPS_SOURCE_SAMPLER_SUBSET_SAMPLER_H_ | |||
| #include <memory> | |||
| #include <vector> | |||
| #include "dataset/engine/datasetops/source/sampler/sampler.h" | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| class SubsetSampler : public Sampler { | |||
| public: | |||
| // Constructor. | |||
| // @param start_index The index we start sampling from. | |||
| explicit SubsetSampler(int64_t start_index, int64_t subset_size); | |||
| // Destructor. | |||
| ~SubsetSampler() = default; | |||
| // Initialize the sampler. | |||
| // @return Status | |||
| Status InitSampler() override; | |||
| // Reset the internal variable to the initial state and reshuffle the indices. | |||
| // @return Status | |||
| Status Reset() override; | |||
| // Get the sample ids. | |||
| // @param[out] out_buffer The address of a unique_ptr to DataBuffer where the sample ids will be placed. | |||
| // @note the sample ids (int64_t) will be placed in one Tensor. | |||
| Status GetNextBuffer(std::unique_ptr<DataBuffer> *out_buffer) override; | |||
| private: | |||
| int64_t start_index_; | |||
| int64_t subset_size_; | |||
| int64_t current_id_; | |||
| }; | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| #endif // DATASET_ENGINE_DATASETOPS_SOURCE_SAMPLER_SUBSET_SAMPLER_H_ | |||
| @@ -27,25 +27,28 @@ | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| // Constructor. | |||
| WeightedRandomSampler::WeightedRandomSampler(const std::vector<double> &weights, int64_t num_samples, bool replacement, | |||
| WeightedRandomSampler::WeightedRandomSampler(int64_t num_samples, const std::vector<double> &weights, bool replacement, | |||
| int64_t samples_per_buffer) | |||
| : Sampler(samples_per_buffer), | |||
| : Sampler(num_samples, samples_per_buffer), | |||
| weights_(weights), | |||
| replacement_(replacement), | |||
| sample_id_(0), | |||
| buffer_id_(0), | |||
| user_num_samples_(num_samples) {} | |||
| buffer_id_(0) {} | |||
| // Initialized this Sampler. | |||
| Status WeightedRandomSampler::InitSampler() { | |||
| CHECK_FAIL_RETURN_UNEXPECTED(num_rows_ > 0 && user_num_samples_, "num_samples & num_rows need to be positive"); | |||
| // Special value of 0 for num_samples means that the user wants to sample the entire set of data. | |||
| // If the user asked to sample more rows than exists in the dataset, adjust the num_samples accordingly. | |||
| if (num_samples_ == 0 || num_samples_ > num_rows_) { | |||
| num_samples_ = num_rows_; | |||
| } | |||
| CHECK_FAIL_RETURN_UNEXPECTED(num_rows_ > 0 && num_samples_, "num_samples & num_rows need to be positive"); | |||
| CHECK_FAIL_RETURN_UNEXPECTED(samples_per_buffer_ > 0, "samples_per_buffer<=0\n"); | |||
| num_samples_ = user_num_samples_; | |||
| // Initialize random generator with seed from config manager | |||
| rand_gen_.seed(GetSeed()); | |||
| samples_per_buffer_ = (samples_per_buffer_ > user_num_samples_) ? user_num_samples_ : samples_per_buffer_; | |||
| samples_per_buffer_ = (samples_per_buffer_ > num_samples_) ? num_samples_ : samples_per_buffer_; | |||
| if (!replacement_) { | |||
| exp_dist_ = std::make_unique<std::exponential_distribution<>>(1); | |||
| @@ -67,8 +70,8 @@ void WeightedRandomSampler::InitOnePassSampling() { | |||
| } | |||
| // Partial sort the first `numSamples` elements. | |||
| std::partial_sort(val_idx.begin(), val_idx.begin() + user_num_samples_, val_idx.end()); | |||
| for (int64_t i = 0; i < user_num_samples_; i++) { | |||
| std::partial_sort(val_idx.begin(), val_idx.begin() + num_samples_, val_idx.end()); | |||
| for (int64_t i = 0; i < num_samples_; i++) { | |||
| onepass_ids_.push_back(val_idx[i].second); | |||
| } | |||
| } | |||
| @@ -98,11 +101,11 @@ Status WeightedRandomSampler::GetNextBuffer(std::unique_ptr<DataBuffer> *out_buf | |||
| "number of samples weights is more than num of rows. Might generate id out of bound OR other errors"); | |||
| } | |||
| if (!replacement_ && (weights_.size() < static_cast<size_t>(user_num_samples_))) { | |||
| if (!replacement_ && (weights_.size() < static_cast<size_t>(num_samples_))) { | |||
| RETURN_STATUS_UNEXPECTED("Without replacement, sample weights less than numSamples"); | |||
| } | |||
| if (sample_id_ == user_num_samples_) { | |||
| if (sample_id_ == num_samples_) { | |||
| (*out_buffer) = std::make_unique<DataBuffer>(buffer_id_++, DataBuffer::kDeBFlagEOE); | |||
| } else { | |||
| if (HasChildSampler()) { | |||
| @@ -114,8 +117,8 @@ Status WeightedRandomSampler::GetNextBuffer(std::unique_ptr<DataBuffer> *out_buf | |||
| int64_t last_id = sample_id_ + samples_per_buffer_; | |||
| // Handling the return all samples at once, and when last draw is not a full batch. | |||
| if (last_id > user_num_samples_) { | |||
| last_id = user_num_samples_; | |||
| if (last_id > num_samples_) { | |||
| last_id = num_samples_; | |||
| } | |||
| // Allocate tensor. | |||
| @@ -29,12 +29,12 @@ namespace dataset { | |||
| class WeightedRandomSampler : public Sampler { | |||
| public: | |||
| // Constructor. | |||
| // @param weights A lift of sample weights. | |||
| // @param num_samples Number of samples to be drawn. | |||
| // @param weights A lift of sample weights. | |||
| // @param replacement Determine if samples are drawn with/without replacement. | |||
| // @param samples_per_buffer The number of ids we draw on each call to GetNextBuffer(). | |||
| // When samplesPerBuffer=0, GetNextBuffer() will draw all the sample ids and return them at once. | |||
| WeightedRandomSampler(const std::vector<double> &weights, int64_t num_samples, bool replacement = true, | |||
| WeightedRandomSampler(int64_t num_samples, const std::vector<double> &weights, bool replacement, | |||
| int64_t samples_per_buffer = std::numeric_limits<int64_t>::max()); | |||
| // Destructor. | |||
| @@ -69,9 +69,6 @@ class WeightedRandomSampler : public Sampler { | |||
| // Random engine and device | |||
| std::mt19937 rand_gen_; | |||
| // num_samples from user | |||
| int64_t user_num_samples_; | |||
| // Discrete distribution for generating weighted random numbers with replacement. | |||
| std::unique_ptr<std::discrete_distribution<int64_t>> discrete_dist_; | |||
| @@ -33,7 +33,7 @@ | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| TextFileOp::Builder::Builder() | |||
| : builder_device_id_(0), builder_num_devices_(1), builder_num_samples_(0), builder_shuffle_files_(false) { | |||
| : builder_device_id_(0), builder_num_devices_(1), builder_total_rows_(0), builder_shuffle_files_(false) { | |||
| std::shared_ptr<ConfigManager> config_manager = GlobalContext::config_manager(); | |||
| builder_num_workers_ = config_manager->num_parallel_workers(); | |||
| builder_op_connector_size_ = config_manager->op_connector_size(); | |||
| @@ -62,7 +62,7 @@ Status TextFileOp::Builder::Build(std::shared_ptr<TextFileOp> *op) { | |||
| builder_schema_->AddColumn(ColDescriptor("text", DataType(DataType::DE_UINT8), TensorImpl::kFlexible, 1))); | |||
| std::shared_ptr<TextFileOp> text_file_op = std::make_shared<TextFileOp>( | |||
| builder_num_workers_, builder_rows_per_buffer_, builder_num_samples_, builder_worker_connector_size_, | |||
| builder_num_workers_, builder_rows_per_buffer_, builder_total_rows_, builder_worker_connector_size_, | |||
| std::move(builder_schema_), builder_text_files_list_, builder_op_connector_size_, builder_shuffle_files_, | |||
| builder_num_devices_, builder_device_id_); | |||
| RETURN_IF_NOT_OK(text_file_op->Init()); | |||
| @@ -71,14 +71,14 @@ Status TextFileOp::Builder::Build(std::shared_ptr<TextFileOp> *op) { | |||
| return Status::OK(); | |||
| } | |||
| TextFileOp::TextFileOp(int32_t num_workers, int64_t rows_per_buffer, int64_t num_samples, int32_t worker_connector_size, | |||
| TextFileOp::TextFileOp(int32_t num_workers, int64_t rows_per_buffer, int64_t total_rows, int32_t worker_connector_size, | |||
| std::unique_ptr<DataSchema> schema, std::vector<std::string> text_files_list, | |||
| int32_t op_connector_size, bool shuffle_files, int32_t num_device, int32_t device_id) | |||
| : ParallelOp(num_workers, op_connector_size), | |||
| device_id_(device_id), | |||
| num_devices_(num_device), | |||
| rows_per_buffer_(rows_per_buffer), | |||
| num_samples_(num_samples), | |||
| total_rows_(total_rows), | |||
| text_files_list_(std::move(text_files_list)), | |||
| shuffle_files_(shuffle_files), | |||
| data_schema_(std::move(schema)), | |||
| @@ -104,9 +104,9 @@ void TextFileOp::Print(std::ostream &out, bool show_all) const { | |||
| // Call the super class for displaying any common detailed info | |||
| ParallelOp::Print(out, show_all); | |||
| // Then show any custom derived-internal stuff | |||
| out << "\nRows per buffer: " << rows_per_buffer_ << "\nSample count: " << num_samples_ | |||
| << "\nDevice id: " << device_id_ << "\nNumber of devices: " << num_devices_ | |||
| << "\nShuffle files: " << ((shuffle_files_) ? "yes" : "no") << "\nText files list:\n"; | |||
| out << "\nRows per buffer: " << rows_per_buffer_ << "\nRow count: " << total_rows_ << "\nDevice id: " << device_id_ | |||
| << "\nNumber of devices: " << num_devices_ << "\nShuffle files: " << ((shuffle_files_) ? "yes" : "no") | |||
| << "\nText files list:\n"; | |||
| for (int i = 0; i < text_files_list_.size(); ++i) { | |||
| out << " " << text_files_list_[i]; | |||
| } | |||
| @@ -404,9 +404,9 @@ Status TextFileOp::operator()() { | |||
| RETURN_IF_NOT_OK(jagged_buffer_connector_->Pop(0, &buffer)); | |||
| if (buffer->eoe()) { | |||
| workers_done++; | |||
| } else if (num_samples_ == 0 || rows_read < num_samples_) { | |||
| if ((num_samples_ > 0) && (rows_read + buffer->NumRows() > num_samples_)) { | |||
| int64_t rowsToRemove = buffer->NumRows() - (num_samples_ - rows_read); | |||
| } else if (total_rows_ == 0 || rows_read < total_rows_) { | |||
| if ((total_rows_ > 0) && (rows_read + buffer->NumRows() > total_rows_)) { | |||
| int64_t rowsToRemove = buffer->NumRows() - (total_rows_ - rows_read); | |||
| RETURN_IF_NOT_OK(buffer->SliceOff(rowsToRemove)); | |||
| } | |||
| rows_read += buffer->NumRows(); | |||
| @@ -107,8 +107,8 @@ class TextFileOp : public ParallelOp { | |||
| // Setter method. | |||
| // @return Builder - setter method returns reference to the builder. | |||
| Builder &SetNumSamples(int64_t num_samples) { | |||
| builder_num_samples_ = num_samples; | |||
| Builder &SetTotalRows(int64_t total_rows) { | |||
| builder_total_rows_ = total_rows; | |||
| return *this; | |||
| } | |||
| @@ -118,7 +118,7 @@ class TextFileOp : public ParallelOp { | |||
| int32_t builder_num_workers_; | |||
| int32_t builder_op_connector_size_; | |||
| int64_t builder_rows_per_buffer_; | |||
| int64_t builder_num_samples_; | |||
| int64_t builder_total_rows_; | |||
| int32_t builder_worker_connector_size_; | |||
| std::vector<std::string> builder_text_files_list_; | |||
| bool builder_shuffle_files_; | |||
| @@ -136,7 +136,7 @@ class TextFileOp : public ParallelOp { | |||
| // @param columns_to_load - the names of the columns to load data from. | |||
| // @param shuffle_files - whether or not to shuffle the files before reading data. | |||
| // @param equal_rows_per_shard - whether or not to get equal rows for each process. | |||
| TextFileOp(int32_t num_workers, int64_t rows_per_buffer, int64_t num_samples, int32_t worker_connector_size, | |||
| TextFileOp(int32_t num_workers, int64_t rows_per_buffer, int64_t total_rows, int32_t worker_connector_size, | |||
| std::unique_ptr<DataSchema>, std::vector<std::string> text_files_list, int32_t op_connector_size, | |||
| bool shuffle_files, int32_t num_devices, int32_t device_id); | |||
| @@ -246,7 +246,7 @@ class TextFileOp : public ParallelOp { | |||
| int32_t device_id_; | |||
| int32_t num_devices_; | |||
| int64_t rows_per_buffer_; | |||
| int64_t num_samples_; | |||
| int64_t total_rows_; | |||
| std::vector<std::string> text_files_list_; | |||
| bool shuffle_files_; | |||
| std::unique_ptr<DataSchema> data_schema_; | |||
| @@ -44,7 +44,7 @@ const char kSegmentationExtension[] = ".png"; | |||
| const char kAnnotationExtension[] = ".xml"; | |||
| const char kImageSetsExtension[] = ".txt"; | |||
| VOCOp::Builder::Builder() : builder_decode_(false), builder_num_samples_(0), builder_sampler_(nullptr) { | |||
| VOCOp::Builder::Builder() : builder_decode_(false), builder_sampler_(nullptr) { | |||
| std::shared_ptr<ConfigManager> cfg = GlobalContext::config_manager(); | |||
| builder_num_workers_ = cfg->num_parallel_workers(); | |||
| builder_rows_per_buffer_ = cfg->rows_per_buffer(); | |||
| @@ -55,7 +55,9 @@ VOCOp::Builder::Builder() : builder_decode_(false), builder_num_samples_(0), bui | |||
| Status VOCOp::Builder::Build(std::shared_ptr<VOCOp> *ptr) { | |||
| RETURN_IF_NOT_OK(SanityCheck()); | |||
| if (builder_sampler_ == nullptr) { | |||
| builder_sampler_ = std::make_shared<SequentialSampler>(); | |||
| int64_t num_samples = 0; | |||
| int64_t start_index = 0; | |||
| builder_sampler_ = std::make_shared<SequentialSampler>(start_index, num_samples); | |||
| } | |||
| builder_schema_ = std::make_unique<DataSchema>(); | |||
| if (builder_task_type_ == TaskType::Segmentation) { | |||
| @@ -71,8 +73,7 @@ Status VOCOp::Builder::Build(std::shared_ptr<VOCOp> *ptr) { | |||
| } | |||
| *ptr = std::make_shared<VOCOp>(builder_task_type_, builder_task_mode_, builder_dir_, builder_labels_to_read_, | |||
| builder_num_workers_, builder_rows_per_buffer_, builder_op_connector_size_, | |||
| builder_num_samples_, builder_decode_, std::move(builder_schema_), | |||
| std::move(builder_sampler_)); | |||
| builder_decode_, std::move(builder_schema_), std::move(builder_sampler_)); | |||
| return Status::OK(); | |||
| } | |||
| @@ -81,20 +82,16 @@ Status VOCOp::Builder::SanityCheck() { | |||
| std::string err_msg; | |||
| err_msg += dir.IsDirectory() == false ? "VOC path is invalid or not set\n" : ""; | |||
| err_msg += builder_num_workers_ <= 0 ? "Num of parallel workers is set to 0 or negative\n" : ""; | |||
| err_msg += builder_num_samples_ < 0 ? "num_samples is negative\n" : ""; | |||
| return err_msg.empty() ? Status::OK() : Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, err_msg); | |||
| } | |||
| VOCOp::VOCOp(const TaskType &task_type, const std::string &task_mode, const std::string &folder_path, | |||
| const std::map<std::string, int32_t> &class_index, int32_t num_workers, int32_t rows_per_buffer, | |||
| int32_t queue_size, int64_t num_samples, bool decode, std::unique_ptr<DataSchema> data_schema, | |||
| std::shared_ptr<Sampler> sampler) | |||
| int32_t queue_size, bool decode, std::unique_ptr<DataSchema> data_schema, std::shared_ptr<Sampler> sampler) | |||
| : ParallelOp(num_workers, queue_size), | |||
| decode_(decode), | |||
| row_cnt_(0), | |||
| buf_cnt_(0), | |||
| num_rows_(0), | |||
| num_samples_(num_samples), | |||
| task_type_(task_type), | |||
| task_mode_(task_mode), | |||
| folder_path_(folder_path), | |||
| @@ -112,7 +109,6 @@ VOCOp::VOCOp(const TaskType &task_type, const std::string &task_mode, const std: | |||
| Status VOCOp::TraverseSampleIds(const std::shared_ptr<Tensor> &sample_ids, std::vector<int64_t> *keys) { | |||
| for (auto itr = sample_ids->begin<int64_t>(); itr != sample_ids->end<int64_t>(); ++itr) { | |||
| if ((*itr) > num_rows_) continue; | |||
| if (row_cnt_ == num_samples_) break; | |||
| keys->push_back(*itr); | |||
| row_cnt_++; | |||
| if (row_cnt_ % rows_per_buffer_ == 0) { | |||
| @@ -187,16 +183,6 @@ Status VOCOp::Reset() { | |||
| return Status::OK(); | |||
| } | |||
| Status VOCOp::GetNumSamples(int64_t *num) const { | |||
| if (num == nullptr || num_rows_ == 0) { | |||
| RETURN_STATUS_UNEXPECTED( | |||
| "There is no valid data matching the dataset API VOCDataset.Please check file path or dataset API " | |||
| "validation first."); | |||
| } | |||
| (*num) = num_samples_; | |||
| return Status::OK(); | |||
| } | |||
| Status VOCOp::LoadTensorRow(const std::string &image_id, TensorRow *trow) { | |||
| if (task_type_ == TaskType::Segmentation) { | |||
| std::shared_ptr<Tensor> image, target; | |||
| @@ -280,7 +266,6 @@ Status VOCOp::ParseImageIds() { | |||
| in_file.close(); | |||
| image_ids_.shrink_to_fit(); | |||
| num_rows_ = image_ids_.size(); | |||
| num_samples_ = (num_samples_ == 0 || num_samples_ > num_rows_) ? num_rows_ : num_samples_; | |||
| return Status::OK(); | |||
| } | |||
| @@ -305,7 +290,6 @@ Status VOCOp::ParseAnnotationIds() { | |||
| } | |||
| num_rows_ = image_ids_.size(); | |||
| num_samples_ = (num_samples_ == 0 || num_samples_ > num_rows_) ? num_rows_ : num_samples_; | |||
| return Status::OK(); | |||
| } | |||
| @@ -432,19 +416,8 @@ Status VOCOp::ReadAnnotationToTensor(const std::string &path, const ColDescripto | |||
| return Status::OK(); | |||
| } | |||
| // Derived from RandomAccessOp | |||
| Status VOCOp::GetNumRowsInDataset(int64_t *num) const { | |||
| if (num == nullptr || num_rows_ == 0) { | |||
| RETURN_STATUS_UNEXPECTED( | |||
| "There is no valid data matching the dataset API VOCDataset.Please check file path or dataset API " | |||
| "validation first."); | |||
| } | |||
| (*num) = num_rows_; | |||
| return Status::OK(); | |||
| } | |||
| Status VOCOp::CountTotalRows(const std::string &dir, const std::string &task_type, const std::string &task_mode, | |||
| const py::dict &dict, int64_t numSamples, int64_t *count) { | |||
| const py::dict &dict, int64_t *count) { | |||
| if (task_type == "Detection") { | |||
| std::map<std::string, int32_t> input_class_indexing; | |||
| for (auto p : dict) { | |||
| @@ -464,14 +437,12 @@ Status VOCOp::CountTotalRows(const std::string &dir, const std::string &task_typ | |||
| RETURN_IF_NOT_OK(op->ParseImageIds()); | |||
| *count = static_cast<int64_t>(op->image_ids_.size()); | |||
| } | |||
| *count = (numSamples == 0 || *count < numSamples) ? *count : numSamples; | |||
| return Status::OK(); | |||
| } | |||
| Status VOCOp::GetClassIndexing(const std::string &dir, const std::string &task_type, const std::string &task_mode, | |||
| const py::dict &dict, int64_t numSamples, | |||
| std::map<std::string, int32_t> *output_class_indexing) { | |||
| const py::dict &dict, std::map<std::string, int32_t> *output_class_indexing) { | |||
| std::map<std::string, int32_t> input_class_indexing; | |||
| for (auto p : dict) { | |||
| (void)input_class_indexing.insert(std::pair<std::string, int32_t>(py::reinterpret_borrow<py::str>(p.first), | |||
| @@ -116,14 +116,6 @@ class VOCOp : public ParallelOp, public RandomAccessOp { | |||
| return *this; | |||
| } | |||
| // Setter method. | |||
| // @param int64_t num_samples | |||
| // @return Builder setter method returns reference to the builder. | |||
| Builder &SetNumSamples(int64_t num_samples) { | |||
| builder_num_samples_ = num_samples; | |||
| return *this; | |||
| } | |||
| // Setter method. | |||
| // @param std::shared_ptr<Sampler> sampler | |||
| // @return Builder setter method returns reference to the builder. | |||
| @@ -157,7 +149,6 @@ class VOCOp : public ParallelOp, public RandomAccessOp { | |||
| int32_t builder_num_workers_; | |||
| int32_t builder_op_connector_size_; | |||
| int32_t builder_rows_per_buffer_; | |||
| int64_t builder_num_samples_; | |||
| std::shared_ptr<Sampler> builder_sampler_; | |||
| std::unique_ptr<DataSchema> builder_schema_; | |||
| std::map<std::string, int32_t> builder_labels_to_read_; | |||
| @@ -171,14 +162,12 @@ class VOCOp : public ParallelOp, public RandomAccessOp { | |||
| // @param int32_t num_workers - number of workers reading images in parallel | |||
| // @param int32_t rows_per_buffer - number of images (rows) in each buffer | |||
| // @param int32_t queue_size - connector queue size | |||
| // @param int64_t num_samples - number of samples to read | |||
| // @param bool decode - whether to decode images | |||
| // @param std::unique_ptr<DataSchema> data_schema - the schema of the VOC dataset | |||
| // @param std::shared_ptr<Sampler> sampler - sampler tells VOCOp what to read | |||
| VOCOp(const TaskType &task_type, const std::string &task_mode, const std::string &folder_path, | |||
| const std::map<std::string, int32_t> &class_index, int32_t num_workers, int32_t rows_per_buffer, | |||
| int32_t queue_size, int64_t num_samples, bool decode, std::unique_ptr<DataSchema> data_schema, | |||
| std::shared_ptr<Sampler> sampler); | |||
| int32_t queue_size, bool decode, std::unique_ptr<DataSchema> data_schema, std::shared_ptr<Sampler> sampler); | |||
| // Destructor | |||
| ~VOCOp() = default; | |||
| @@ -194,15 +183,6 @@ class VOCOp : public ParallelOp, public RandomAccessOp { | |||
| // @return Status - The error code return | |||
| Status operator()() override; | |||
| // Method derived from RandomAccessOp, enable Sampler to get numRows | |||
| // @param uint64_t num - to return numRows | |||
| // return Status - The error code return | |||
| Status GetNumSamples(int64_t *num) const override; | |||
| // Method derived from RandomAccessOp, enable Sampler to get total number of rows in dataset | |||
| // @param uint64_t num - to return numRows | |||
| Status GetNumRowsInDataset(int64_t *num) const override; | |||
| // A print method typically used for debugging | |||
| // @param out | |||
| // @param show_all | |||
| @@ -212,10 +192,9 @@ class VOCOp : public ParallelOp, public RandomAccessOp { | |||
| // @param const std::string &task_type - task type of reading voc job | |||
| // @param const std::string &task_mode - task mode of reading voc job | |||
| // @param const py::dict &dict - input dict of class index | |||
| // @param int64_t numSamples - samples number of VOCDataset | |||
| // @param int64_t *count - output rows number of VOCDataset | |||
| static Status CountTotalRows(const std::string &dir, const std::string &task_type, const std::string &task_mode, | |||
| const py::dict &dict, int64_t numSamples, int64_t *count); | |||
| const py::dict &dict, int64_t *count); | |||
| // @param const std::string &dir - VOC dir path | |||
| // @param const std::string &task_type - task type of reading voc job | |||
| @@ -224,8 +203,7 @@ class VOCOp : public ParallelOp, public RandomAccessOp { | |||
| // @param int64_t numSamples - samples number of VOCDataset | |||
| // @param std::map<std::string, int32_t> *output_class_indexing - output class index of VOCDataset | |||
| static Status GetClassIndexing(const std::string &dir, const std::string &task_type, const std::string &task_mode, | |||
| const py::dict &dict, int64_t numSamples, | |||
| std::map<std::string, int32_t> *output_class_indexing); | |||
| const py::dict &dict, std::map<std::string, int32_t> *output_class_indexing); | |||
| private: | |||
| // Initialize Sampler, calls sampler->Init() within | |||
| @@ -283,8 +261,6 @@ class VOCOp : public ParallelOp, public RandomAccessOp { | |||
| bool decode_; | |||
| int64_t row_cnt_; | |||
| int64_t buf_cnt_; | |||
| int64_t num_rows_; | |||
| int64_t num_samples_; | |||
| std::string folder_path_; | |||
| TaskType task_type_; | |||
| std::string task_mode_; | |||
| @@ -23,7 +23,7 @@ from .engine.datasets import TFRecordDataset, ImageFolderDatasetV2, MnistDataset | |||
| GeneratorDataset, ManifestDataset, Cifar10Dataset, Cifar100Dataset, VOCDataset, CelebADataset, TextFileDataset, \ | |||
| Schema, Shuffle, zip, RandomDataset | |||
| from .engine.samplers import DistributedSampler, PKSampler, RandomSampler, SequentialSampler, SubsetRandomSampler, \ | |||
| WeightedRandomSampler, SubsetSampler, Sampler | |||
| WeightedRandomSampler, Sampler | |||
| from .engine.serializer_deserializer import serialize, deserialize, show | |||
| from .engine.graphdata import GraphData | |||
| @@ -1261,8 +1261,8 @@ class MappableDataset(SourceDataset): | |||
| def _get_sampler_dataset_size(self): | |||
| if self.sampler is not None: | |||
| if hasattr(self.sampler, 'get_dataset_size'): | |||
| return self.sampler.get_dataset_size() | |||
| if hasattr(self.sampler, 'get_num_samples'): | |||
| return self.sampler.get_num_samples() | |||
| if hasattr(self.sampler, '__len__'): | |||
| return len(self.sampler) | |||
| @@ -1355,7 +1355,7 @@ class MappableDataset(SourceDataset): | |||
| random_sampler.reshuffle_each_epoch = False | |||
| ds.add_sampler(random_sampler) | |||
| subset_sampler = samplers.SubsetSampler(current_split_start_index, size) | |||
| subset_sampler = samplers.SequentialSampler(current_split_start_index, size) | |||
| ds.add_sampler(subset_sampler) | |||
| # add sequential sampler, so that if user calls use_sampler, we will | |||
| @@ -2226,31 +2226,45 @@ def _select_sampler(num_samples, input_sampler, shuffle, num_shards, shard_id): | |||
| num_shards (int): Number of shard for sharding. | |||
| shard_id (int): Shard ID. | |||
| """ | |||
| if input_sampler is not None: | |||
| # If the user provided a sampler, then it doesn't matter what the other args are because | |||
| # we are being asked specifically to use the given sampler. | |||
| # That means the following arguments: num_shards, shard_id, shuffle, num_samples should all | |||
| # be None. Consider this example: | |||
| # sampler = ds.DistributedSampler(num_shards=8, shard_id=3, shuffle=shuffle) | |||
| # data1 = ds.VOCDataset(voc_dir, decode=True, sampler=sampler, num_shards=4, shard_id=1) | |||
| # In this case, the user has given different sample-related arguments that contradict each other. | |||
| # To prevent this, only allow the user to manually specify the sampler if those arguments are all None | |||
| if (isinstance(input_sampler, (samplers.SequentialSampler, samplers.DistributedSampler, | |||
| samplers.RandomSampler, samplers.SubsetRandomSampler, | |||
| samplers.WeightedRandomSampler, samplers.Sampler)) and | |||
| (num_shards is not None or shard_id is not None or shuffle is not None or num_samples is not None)): | |||
| raise ValueError( | |||
| 'Conflicting arguments during sampler assignments. num_samples: {}, num_shards: {},' | |||
| ' shard_id: {}, shuffle: {})'.format(num_samples, num_shards, shard_id, shuffle)) | |||
| return input_sampler | |||
| if shuffle is None: | |||
| if input_sampler is not None: | |||
| # If shuffle is not specified, user provided sampler, use user's sampler | |||
| return input_sampler | |||
| if num_shards is not None: | |||
| # If shuffle is not specified, sharding enabled, use distributed random sampler | |||
| shuffle = True | |||
| return samplers.DistributedSampler(num_shards, shard_id, shuffle=shuffle) | |||
| return samplers.DistributedSampler(num_shards, shard_id, shuffle=shuffle, num_samples=num_samples) | |||
| # If shuffle is not specified, sharding disabled, use random sampler | |||
| if num_samples is not None: | |||
| return samplers.RandomSampler(replacement=True, num_samples=num_samples) | |||
| return samplers.RandomSampler() | |||
| return samplers.RandomSampler(num_samples=num_samples) | |||
| if shuffle is True: | |||
| if num_shards is not None: | |||
| # If shuffle enabled, sharding enabled, use distributed random sampler | |||
| return samplers.DistributedSampler(num_shards, shard_id, shuffle=shuffle) | |||
| return samplers.DistributedSampler(num_shards, shard_id, shuffle=shuffle, num_samples=num_samples) | |||
| # If shuffle enabled, sharding disabled, use random sampler | |||
| if num_samples is not None: | |||
| return samplers.RandomSampler(replacement=True, num_samples=num_samples) | |||
| return samplers.RandomSampler() | |||
| return samplers.RandomSampler(num_samples=num_samples) | |||
| if num_shards is not None: | |||
| # If shuffle disabled, sharding enabled, use distributed sequential sampler | |||
| return samplers.DistributedSampler(num_shards, shard_id, shuffle=shuffle) | |||
| return samplers.DistributedSampler(num_shards, shard_id, shuffle=shuffle, num_samples=num_samples) | |||
| # If shuffle disabled, sharding disabled, use sequential sampler | |||
| return samplers.SequentialSampler() | |||
| return samplers.SequentialSampler(num_samples=num_samples) | |||
| class ImageFolderDatasetV2(MappableDataset): | |||
| @@ -2370,11 +2384,7 @@ class ImageFolderDatasetV2(MappableDataset): | |||
| Return: | |||
| Number, number of batches. | |||
| """ | |||
| if self.num_samples is None: | |||
| num_samples = 0 | |||
| else: | |||
| num_samples = self.num_samples | |||
| num_rows = ImageFolderOp.get_num_rows_and_classes(self.dataset_dir, num_samples)[0] | |||
| num_rows = ImageFolderOp.get_num_rows_and_classes(self.dataset_dir)[0] | |||
| rows_per_shard = get_num_rows(num_rows, self.num_shards) | |||
| rows_from_sampler = self._get_sampler_dataset_size() | |||
| @@ -2390,11 +2400,7 @@ class ImageFolderDatasetV2(MappableDataset): | |||
| Return: | |||
| Number, number of classes. | |||
| """ | |||
| if self.num_samples is None: | |||
| num_samples = 0 | |||
| else: | |||
| num_samples = self.num_samples | |||
| return ImageFolderOp.get_num_rows_and_classes(self.dataset_dir, num_samples)[1] | |||
| return ImageFolderOp.get_num_rows_and_classes(self.dataset_dir)[1] | |||
| def is_shuffled(self): | |||
| if self.shuffle_level is None: | |||
| @@ -2503,12 +2509,7 @@ class MnistDataset(MappableDataset): | |||
| Return: | |||
| Number, number of batches. | |||
| """ | |||
| if self.num_samples is None: | |||
| num_samples = 0 | |||
| else: | |||
| num_samples = self.num_samples | |||
| num_rows = MnistOp.get_num_rows(self.dataset_dir, num_samples) | |||
| num_rows = MnistOp.get_num_rows(self.dataset_dir) | |||
| rows_per_shard = get_num_rows(num_rows, self.num_shards) | |||
| rows_from_sampler = self._get_sampler_dataset_size() | |||
| @@ -2956,11 +2957,8 @@ class GeneratorDataset(MappableDataset): | |||
| if isinstance(self.sampler, (samplers.SequentialSampler, samplers.DistributedSampler, | |||
| samplers.RandomSampler, samplers.SubsetRandomSampler, | |||
| samplers.WeightedRandomSampler, samplers.Sampler)): | |||
| if num_samples is None: | |||
| num_samples = len(source) | |||
| sampler_instance = self.sampler.create() | |||
| sampler_instance.set_num_rows(len(source)) | |||
| sampler_instance.set_num_samples(num_samples) | |||
| sampler_instance.initialize() | |||
| if num_parallel_workers > 1: | |||
| self.source = (lambda: _cpp_sampler_fn_mp(sampler_instance, source, num_parallel_workers)) | |||
| @@ -3304,17 +3302,12 @@ class ManifestDataset(MappableDataset): | |||
| Return: | |||
| Number, number of batches. | |||
| """ | |||
| if self.num_samples is None: | |||
| num_samples = 0 | |||
| else: | |||
| num_samples = self.num_samples | |||
| if self.class_indexing is None: | |||
| class_indexing = dict() | |||
| else: | |||
| class_indexing = self.class_indexing | |||
| num_rows = ManifestOp.get_num_rows_and_classes(self.dataset_file, num_samples, class_indexing, self.usage)[0] | |||
| num_rows = ManifestOp.get_num_rows_and_classes(self.dataset_file, class_indexing, self.usage)[0] | |||
| rows_per_shard = get_num_rows(num_rows, self.num_shards) | |||
| rows_from_sampler = self._get_sampler_dataset_size() | |||
| @@ -3330,17 +3323,12 @@ class ManifestDataset(MappableDataset): | |||
| Return: | |||
| Number, number of classes. | |||
| """ | |||
| if self.num_samples is None: | |||
| num_samples = 0 | |||
| else: | |||
| num_samples = self.num_samples | |||
| if self.class_indexing is None: | |||
| class_indexing = dict() | |||
| else: | |||
| class_indexing = self.class_indexing | |||
| return ManifestOp.get_num_rows_and_classes(self.dataset_file, num_samples, class_indexing, self.usage)[1] | |||
| return ManifestOp.get_num_rows_and_classes(self.dataset_file, class_indexing, self.usage)[1] | |||
| def get_class_indexing(self): | |||
| """ | |||
| @@ -3349,17 +3337,12 @@ class ManifestDataset(MappableDataset): | |||
| Return: | |||
| Dict, A str-to-int mapping from label name to index. | |||
| """ | |||
| if self.num_samples is None: | |||
| num_samples = 0 | |||
| else: | |||
| num_samples = self.num_samples | |||
| if self.class_indexing is None: | |||
| class_indexing = dict() | |||
| else: | |||
| class_indexing = self.class_indexing | |||
| return ManifestOp.get_class_indexing(self.dataset_file, num_samples, class_indexing, self.usage) | |||
| return ManifestOp.get_class_indexing(self.dataset_file, class_indexing, self.usage) | |||
| def is_shuffled(self): | |||
| if self.shuffle_level is None: | |||
| @@ -3473,12 +3456,8 @@ class Cifar10Dataset(MappableDataset): | |||
| Return: | |||
| Number, number of batches. | |||
| """ | |||
| if self.num_samples is None: | |||
| num_samples = 0 | |||
| else: | |||
| num_samples = self.num_samples | |||
| num_rows = CifarOp.get_num_rows(self.dataset_dir, num_samples, True) | |||
| num_rows = CifarOp.get_num_rows(self.dataset_dir, True) | |||
| rows_per_shard = get_num_rows(num_rows, self.num_shards) | |||
| rows_from_sampler = self._get_sampler_dataset_size() | |||
| @@ -3597,12 +3576,8 @@ class Cifar100Dataset(MappableDataset): | |||
| Return: | |||
| Number, number of batches. | |||
| """ | |||
| if self.num_samples is None: | |||
| num_samples = 0 | |||
| else: | |||
| num_samples = self.num_samples | |||
| num_rows = CifarOp.get_num_rows(self.dataset_dir, num_samples, False) | |||
| num_rows = CifarOp.get_num_rows(self.dataset_dir, False) | |||
| rows_per_shard = get_num_rows(num_rows, self.num_shards) | |||
| rows_from_sampler = self._get_sampler_dataset_size() | |||
| @@ -3631,7 +3606,7 @@ class RandomDataset(SourceDataset): | |||
| Args: | |||
| num_samples (int): number of samples to generate. | |||
| schema (str or Schema, optional): Path to the json schema file or schema object (default=None). | |||
| If the schema is not provided, the meta data from the TFRecord file is considered the schema. | |||
| If the schema is not provided, the random dataset generates a random schema. | |||
| columns_list (list[str], optional): List of columns to be read (default=None, read all columns) | |||
| num_parallel_workers (int, optional): number of workers to read the data | |||
| (default=None, number set in the config). | |||
| @@ -3644,9 +3619,12 @@ class RandomDataset(SourceDataset): | |||
| schema_obj = Schema(schema) # read the schema file and convert to schema object to validate it | |||
| self.schema = schema | |||
| self.columns_list = columns_list | |||
| self.num_samples = num_samples | |||
| if schema_obj is not None and num_samples is None: | |||
| self.num_samples = schema_obj.num_rows | |||
| elif num_samples is None: | |||
| self.num_samples = 0 | |||
| else: | |||
| self.num_samples = num_samples | |||
| def get_args(self): | |||
| args = super().get_args() | |||
| @@ -4015,17 +3993,12 @@ class VOCDataset(MappableDataset): | |||
| if self.task != "Detection": | |||
| raise NotImplementedError() | |||
| if self.num_samples is None: | |||
| num_samples = 0 | |||
| else: | |||
| num_samples = self.num_samples | |||
| if self.class_indexing is None: | |||
| class_indexing = dict() | |||
| else: | |||
| class_indexing = self.class_indexing | |||
| return VOCOp.get_class_indexing(self.dataset_dir, self.task, self.mode, class_indexing, num_samples) | |||
| return VOCOp.get_class_indexing(self.dataset_dir, self.task, self.mode, class_indexing) | |||
| def is_shuffled(self): | |||
| if self.shuffle_level is None: | |||
| @@ -4205,9 +4178,11 @@ class TextFileDataset(SourceDataset): | |||
| if self._dataset_size is None: | |||
| num_rows = TextFileOp.get_num_rows(self.dataset_files) | |||
| num_rows = get_num_rows(num_rows, self.num_shards) | |||
| if self.num_samples is None: | |||
| return num_rows | |||
| return min(self.num_samples, num_rows) | |||
| # If the user gave a num samples in the dataset, then the sampler will limit the rows returned | |||
| # to that amount. Account for that here in the row count | |||
| if self.num_samples is not None and self.num_samples > 0 and num_rows > self.num_samples: | |||
| num_rows = self.num_samples | |||
| return num_rows | |||
| return self._dataset_size | |||
| def is_shuffled(self): | |||
| @@ -22,7 +22,6 @@ User can also define custom sampler by extending from Sampler class. | |||
| import numpy as np | |||
| import mindspore._c_dataengine as cde | |||
| class Sampler: | |||
| """ | |||
| Base class for user defined sampler. | |||
| @@ -44,10 +43,10 @@ class Sampler: | |||
| >>> ds = ds.ImageFolderDatasetV2(path, sampler=ReverseSampler()) | |||
| """ | |||
| def __init__(self): | |||
| def __init__(self, num_samples=None): | |||
| self.dataset_size = 0 | |||
| self.num_samples = 0 | |||
| self.child_sampler = None | |||
| self.num_samples = num_samples | |||
| def __iter__(self): | |||
| """ | |||
| @@ -84,7 +83,8 @@ class Sampler: | |||
| # Instance fetcher | |||
| # Do not override this method! | |||
| def create(self): | |||
| c_sampler = cde.PythonSampler(self) | |||
| num_samples = self.num_samples if self.num_samples is not None else 0 | |||
| c_sampler = cde.PythonSampler(num_samples, self) | |||
| c_child_sampler = self.create_child() | |||
| c_sampler.add_child(c_child_sampler) | |||
| return c_sampler | |||
| @@ -114,7 +114,7 @@ class Sampler: | |||
| return self.child_sampler.is_sharded() | |||
| def get_dataset_size(self): | |||
| def get_num_samples(self): | |||
| return self._get_indices().size | |||
| @@ -124,8 +124,9 @@ class BuiltinSampler: | |||
| User should not extend this class. | |||
| """ | |||
| def __init__(self): | |||
| def __init__(self, num_samples=None): | |||
| self.child_sampler = None | |||
| self.num_samples = num_samples | |||
| def create(self): | |||
| pass | |||
| @@ -149,11 +150,37 @@ class BuiltinSampler: | |||
| def is_sharded(self): | |||
| raise NotImplementedError("Sampler must implement is_sharded.") | |||
| def get_dataset_size(self): | |||
| def get_num_samples(self): | |||
| """ | |||
| All samplers can contain a numeric num_samples value (or it could be set to None). | |||
| Child sampler can exist or be None. | |||
| if child sampler exists, then the child sampler count can be a numeric value or None. | |||
| Given these conditions, we need to output what the sampler count is for this sampler. | |||
| The following table shows the possible results from calling this function. | |||
| child sampler num_samples child_samples result | |||
| ------------- ----------- ------------- -------- | |||
| T x y min(x, y) | |||
| T x None x | |||
| T None y y | |||
| T None None None | |||
| None x n/a x | |||
| None None n/a None | |||
| Returns: | |||
| int, The number of samples, or None | |||
| """ | |||
| if self.child_sampler is not None: | |||
| return self.child_sampler.get_dataset_size() | |||
| child_samples = self.child_sampler.get_num_samples() | |||
| if self.num_samples is not None: | |||
| if child_samples is not None: | |||
| return min(self.num_samples, child_samples) | |||
| return self.num_samples | |||
| return None | |||
| return child_samples | |||
| return self.num_samples | |||
| class DistributedSampler(BuiltinSampler): | |||
| @@ -164,6 +191,7 @@ class DistributedSampler(BuiltinSampler): | |||
| num_shards (int): Number of shards to divide the dataset into. | |||
| shard_id (int): Shard ID of the current shard within num_shards. | |||
| shuffle (bool, optional): If true, the indices are shuffled (default=True). | |||
| num_samples (int, optional): The number of samples to draw (default=None, all elements). | |||
| Examples: | |||
| >>> import mindspore.dataset as ds | |||
| @@ -180,7 +208,7 @@ class DistributedSampler(BuiltinSampler): | |||
| ValueError: If shuffle is not a boolean value. | |||
| """ | |||
| def __init__(self, num_shards, shard_id, shuffle=True): | |||
| def __init__(self, num_shards, shard_id, shuffle=True, num_samples=None): | |||
| if num_shards <= 0: | |||
| raise ValueError("num_shards should be a positive integer value, but got num_shards={}".format(num_shards)) | |||
| @@ -194,12 +222,13 @@ class DistributedSampler(BuiltinSampler): | |||
| self.shard_id = shard_id | |||
| self.shuffle = shuffle | |||
| self.seed = 0 | |||
| super().__init__() | |||
| super().__init__(num_samples) | |||
| def create(self): | |||
| num_samples = self.num_samples if self.num_samples is not None else 0 | |||
| # each time user calls create_dict_iterator() (to do repeat) sampler would get a different seed to shuffle | |||
| self.seed += 1 | |||
| c_sampler = cde.DistributedSampler(self.num_shards, self.shard_id, self.shuffle, self.seed) | |||
| c_sampler = cde.DistributedSampler(num_samples, self.num_shards, self.shard_id, self.shuffle, self.seed) | |||
| c_child_sampler = self.create_child() | |||
| c_sampler.add_child(c_child_sampler) | |||
| return c_sampler | |||
| @@ -226,6 +255,7 @@ class PKSampler(BuiltinSampler): | |||
| num_class (int, optional): Number of classes to sample (default=None, all classes). | |||
| shuffle (bool, optional): If true, the class IDs are shuffled (default=False). | |||
| class_column (str, optional): Name of column to classify dataset(default='label'), for MindDataset. | |||
| num_samples (int, optional): The number of samples to draw (default=None, all elements). | |||
| Examples: | |||
| >>> import mindspore.dataset as ds | |||
| @@ -242,7 +272,7 @@ class PKSampler(BuiltinSampler): | |||
| ValueError: If shuffle is not boolean. | |||
| """ | |||
| def __init__(self, num_val, num_class=None, shuffle=False, class_column='label'): | |||
| def __init__(self, num_val, num_class=None, shuffle=False, class_column='label', num_samples=None): | |||
| if num_val <= 0: | |||
| raise ValueError("num_val should be a positive integer value, but got num_val={}".format(num_val)) | |||
| @@ -255,10 +285,11 @@ class PKSampler(BuiltinSampler): | |||
| self.num_val = num_val | |||
| self.shuffle = shuffle | |||
| self.class_column = class_column # work for minddataset | |||
| super().__init__() | |||
| super().__init__(num_samples) | |||
| def create(self): | |||
| c_sampler = cde.PKSampler(self.num_val, self.shuffle) | |||
| num_samples = self.num_samples if self.num_samples is not None else 0 | |||
| c_sampler = cde.PKSampler(num_samples, self.num_val, self.shuffle) | |||
| c_child_sampler = self.create_child() | |||
| c_sampler.add_child(c_child_sampler) | |||
| return c_sampler | |||
| @@ -309,23 +340,18 @@ class RandomSampler(BuiltinSampler): | |||
| raise ValueError("replacement should be a boolean value, but got replacement={}".format(replacement)) | |||
| if num_samples is not None: | |||
| if num_samples <= 0: | |||
| if num_samples < 0: | |||
| raise ValueError("num_samples should be a positive integer " | |||
| "value, but got num_samples={}".format(num_samples)) | |||
| self.deterministic = False | |||
| self.replacement = replacement | |||
| self.num_samples = num_samples | |||
| self.reshuffle_each_epoch = True | |||
| super().__init__() | |||
| super().__init__(num_samples) | |||
| def create(self): | |||
| c_sampler = None | |||
| if self.num_samples is None: | |||
| c_sampler = cde.RandomSampler(self.replacement, self.reshuffle_each_epoch) | |||
| else: | |||
| c_sampler = cde.RandomSampler(self.replacement, self.reshuffle_each_epoch, self.num_samples) | |||
| num_samples = self.num_samples if self.num_samples is not None else 0 | |||
| c_sampler = cde.RandomSampler(num_samples, self.replacement, self.reshuffle_each_epoch) | |||
| c_child_sampler = self.create_child() | |||
| c_sampler.add_child(c_child_sampler) | |||
| return c_sampler | |||
| @@ -339,84 +365,33 @@ class RandomSampler(BuiltinSampler): | |||
| return self.child_sampler.is_sharded() | |||
| def get_dataset_size(self): | |||
| return self.num_samples | |||
| class SequentialSampler(BuiltinSampler): | |||
| """ | |||
| Samples the dataset elements sequentially, same as not having a sampler. | |||
| Examples: | |||
| >>> import mindspore.dataset as ds | |||
| >>> | |||
| >>> dataset_dir = "path/to/imagefolder_directory" | |||
| >>> | |||
| >>> # creates a SequentialSampler | |||
| >>> sampler = ds.SequentialSampler() | |||
| >>> data = ds.ImageFolderDatasetV2(dataset_dir, num_parallel_workers=8, sampler=sampler) | |||
| """ | |||
| def create(self): | |||
| c_sampler = cde.SequentialSampler() | |||
| c_child_sampler = self.create_child() | |||
| c_sampler.add_child(c_child_sampler) | |||
| return c_sampler | |||
| def is_shuffled(self): | |||
| if self.child_sampler is None: | |||
| return False | |||
| return self.child_sampler.is_shuffled() | |||
| def is_sharded(self): | |||
| if self.child_sampler is None: | |||
| return False | |||
| return self.child_sampler.is_sharded() | |||
| class SubsetSampler(BuiltinSampler): | |||
| """ | |||
| Samples a subset of elements consecutively from a given index. | |||
| Args: | |||
| start_index (int): Index to start sampling at. | |||
| subset_size (int): How many samples to include in this subset. | |||
| start_index (int, optional): Index to start sampling at. (dafault=None starts at first id) | |||
| num_samples (int, optional): Number of elements to sample (default=None, all elements). | |||
| Examples: | |||
| >>> import mindspore.dataset as ds | |||
| >>> | |||
| >>> dataset_dir = "path/to/imagefolder_directory" | |||
| >>> | |||
| >>> # creates a SubsetSampler, will sample the next 5 images from the 100th image. | |||
| >>> sampler = ds.SubsetSampler(100, 5) | |||
| >>> # creates a SequentialSampler | |||
| >>> sampler = ds.SequentialSampler() | |||
| >>> data = ds.ImageFolderDatasetV2(dataset_dir, num_parallel_workers=8, sampler=sampler) | |||
| Raises: | |||
| ValueError: If start_index is not a positive int. | |||
| ValueError: If subset_size is not a positive int. | |||
| """ | |||
| def __init__(self, start_index, subset_size): | |||
| if not isinstance(start_index, int): | |||
| raise ValueError("start_index should be an int.") | |||
| if start_index < 0: | |||
| raise ValueError("start_index should not be negative.") | |||
| if not isinstance(subset_size, int): | |||
| raise ValueError("start_index should be an int") | |||
| if subset_size < 0: | |||
| raise ValueError("subset_size should not be negative.") | |||
| def __init__(self, start_index=None, num_samples=None): | |||
| self.start_index = start_index | |||
| self.subset_size = subset_size | |||
| super().__init__() | |||
| super().__init__(num_samples) | |||
| def create(self): | |||
| c_sampler = cde.SubsetSampler(self.start_index, self.subset_size) | |||
| start_index = self.start_index if self.start_index is not None else 0 | |||
| num_samples = self.num_samples if self.num_samples is not None else 0 | |||
| c_sampler = cde.SequentialSampler(num_samples, start_index) | |||
| c_child_sampler = self.create_child() | |||
| c_sampler.add_child(c_child_sampler) | |||
| return c_sampler | |||
| @@ -433,9 +408,6 @@ class SubsetSampler(BuiltinSampler): | |||
| return self.child_sampler.is_sharded() | |||
| def get_dataset_size(self): | |||
| return self.subset_size | |||
| class SubsetRandomSampler(BuiltinSampler): | |||
| """ | |||
| @@ -443,6 +415,7 @@ class SubsetRandomSampler(BuiltinSampler): | |||
| Args: | |||
| indices (list[int]): A sequence of indices. | |||
| num_samples (int, optional): Number of elements to sample (default=None, all elements). | |||
| Examples: | |||
| >>> import mindspore.dataset as ds | |||
| @@ -456,15 +429,16 @@ class SubsetRandomSampler(BuiltinSampler): | |||
| >>> data = ds.ImageFolderDatasetV2(dataset_dir, num_parallel_workers=8, sampler=sampler) | |||
| """ | |||
| def __init__(self, indices): | |||
| def __init__(self, indices, num_samples=None): | |||
| if not isinstance(indices, list): | |||
| indices = [indices] | |||
| self.indices = indices | |||
| super().__init__() | |||
| super().__init__(num_samples) | |||
| def create(self): | |||
| c_sampler = cde.SubsetRandomSampler(self.indices) | |||
| num_samples = self.num_samples if self.num_samples is not None else 0 | |||
| c_sampler = cde.SubsetRandomSampler(num_samples, self.indices) | |||
| c_child_sampler = self.create_child() | |||
| c_sampler.add_child(c_child_sampler) | |||
| return c_sampler | |||
| @@ -481,9 +455,9 @@ class SubsetRandomSampler(BuiltinSampler): | |||
| def _create_for_minddataset(self): | |||
| return cde.MindrecordSubsetRandomSampler(self.indices) | |||
| def get_dataset_size(self): | |||
| return len(self.indices) | |||
| def get_num_samples(self): | |||
| num_samples = super().get_num_samples() | |||
| return min(len(self.indices), num_samples) | |||
| class WeightedRandomSampler(BuiltinSampler): | |||
| @@ -492,7 +466,7 @@ class WeightedRandomSampler(BuiltinSampler): | |||
| Args: | |||
| weights (list[float]): A sequence of weights, not necessarily summing up to 1. | |||
| num_samples (int): Number of elements to sample. | |||
| num_samples (int): Number of elements to sample (default=None, all elements). | |||
| replacement (bool, optional): If True, put the sample ID back for the next draw (default=True). | |||
| Examples: | |||
| @@ -511,24 +485,25 @@ class WeightedRandomSampler(BuiltinSampler): | |||
| ValueError: If replacement is not boolean. | |||
| """ | |||
| def __init__(self, weights, num_samples, replacement=True): | |||
| def __init__(self, weights, num_samples=None, replacement=True): | |||
| if not isinstance(weights, list): | |||
| weights = [weights] | |||
| if num_samples <= 0: | |||
| raise ValueError("num_samples should be a positive integer " | |||
| "value, but got num_samples={}".format(num_samples)) | |||
| if num_samples is not None: | |||
| if num_samples < 0: | |||
| raise ValueError("num_samples should be a positive integer " | |||
| "value, but got num_samples={}".format(num_samples)) | |||
| if not isinstance(replacement, bool): | |||
| raise ValueError("replacement should be a boolean value, but got replacement={}".format(replacement)) | |||
| self.weights = weights | |||
| self.num_samples = num_samples | |||
| self.replacement = replacement | |||
| super().__init__() | |||
| super().__init__(num_samples) | |||
| def create(self): | |||
| c_sampler = cde.WeightedRandomSampler(self.weights, self.num_samples, self.replacement) | |||
| num_samples = self.num_samples if self.num_samples is not None else 0 | |||
| c_sampler = cde.WeightedRandomSampler(num_samples, self.weights, self.replacement) | |||
| c_child_sampler = self.create_child() | |||
| c_sampler.add_child(c_child_sampler) | |||
| return c_sampler | |||
| @@ -541,6 +516,3 @@ class WeightedRandomSampler(BuiltinSampler): | |||
| return False | |||
| return self.child_sampler.is_sharded() | |||
| def get_dataset_size(self): | |||
| return self.num_samples | |||
| @@ -161,6 +161,20 @@ def traverse(node): | |||
| else: | |||
| node_repr[k] = v | |||
| # If a sampler exists in this node, then the following 4 arguments must be set to None: | |||
| # num_samples, shard_id, num_shards, shuffle | |||
| # These arguments get moved into the sampler itself, so they are no longer needed to | |||
| # be set at the dataset level. | |||
| if 'sampler' in node_args.keys(): | |||
| if 'num_samples' in node_repr.keys(): | |||
| node_repr['num_samples'] = None | |||
| if 'shuffle' in node_repr.keys(): | |||
| node_repr['shuffle'] = None | |||
| if 'num_shards' in node_repr.keys(): | |||
| node_repr['num_shards'] = None | |||
| if 'shard_id' in node_repr.keys(): | |||
| node_repr['shard_id'] = None | |||
| # Leaf node doesn't have input attribute. | |||
| if not node.input: | |||
| return node_repr | |||
| @@ -283,8 +283,8 @@ def check_num_parallel_workers(value): | |||
| def check_num_samples(value): | |||
| check_type(value, 'num_samples', int) | |||
| if value <= 0: | |||
| raise ValueError("num_samples must be greater than 0!") | |||
| if value < 0: | |||
| raise ValueError("num_samples cannot be less than 0!") | |||
| def check_dataset_dir(dataset_dir): | |||
| @@ -39,14 +39,13 @@ std::shared_ptr<RepeatOp> Repeat(int repeat_cnt); | |||
| std::shared_ptr<ExecutionTree> Build(std::vector<std::shared_ptr<DatasetOp>> ops); | |||
| std::shared_ptr<CelebAOp> Celeba(int32_t num_workers, int32_t rows_per_buffer, int32_t queue_size, | |||
| const std::string &dir, int64_t num_samples = 0, | |||
| std::unique_ptr<Sampler> sampler = nullptr, bool decode = false, | |||
| const std::string &dataset_type="all") { | |||
| const std::string &dir, std::shared_ptr<Sampler> sampler = nullptr, | |||
| bool decode = false, const std::string &dataset_type="all") { | |||
| std::shared_ptr<CelebAOp> so; | |||
| CelebAOp::Builder builder; | |||
| Status rc = builder.SetNumWorkers(num_workers).SetCelebADir(dir).SetRowsPerBuffer(rows_per_buffer) | |||
| .SetOpConnectorSize(queue_size).SetSampler(std::move(sampler)).SetDecode(decode) | |||
| .SetNumSamples(num_samples).SetDatasetType(dataset_type).Build(&so); | |||
| .SetDatasetType(dataset_type).Build(&so); | |||
| return so; | |||
| } | |||
| @@ -116,11 +115,12 @@ TEST_F(MindDataTestCelebaDataset, TestCelebaRepeat) { | |||
| TEST_F(MindDataTestCelebaDataset, TestSubsetRandomSamplerCeleba) { | |||
| std::vector<int64_t> indices({1}); | |||
| std::unique_ptr<Sampler> sampler = std::make_unique<SubsetRandomSampler>(indices); | |||
| int64_t num_samples = 0; | |||
| std::shared_ptr<Sampler> sampler = std::make_shared<SubsetRandomSampler>(num_samples, indices); | |||
| uint32_t expect_labels[1][40] = {{0,0,0,1,0,0,0,1,0,0,0,1,0,0,0,0,0,0,0,1,0,1,0,0,1,0,0,0,0,0,0,1,0,0,0,0,0,0,0,1}}; | |||
| std::string dir = datasets_root_path_ + "/testCelebAData/"; | |||
| uint32_t count = 0; | |||
| auto tree = Build({Celeba(16, 2, 32, dir, 0, std::move(sampler))}); | |||
| auto tree = Build({Celeba(16, 2, 32, dir, std::move(sampler))}); | |||
| tree->Prepare(); | |||
| Status rc = tree->Launch(); | |||
| if (rc.IsError()) { | |||
| @@ -143,25 +143,3 @@ TEST_F(MindDataTestCelebaDataset, TestSubsetRandomSamplerCeleba) { | |||
| EXPECT_TRUE(count == 1); | |||
| } | |||
| } | |||
| TEST_F(MindDataTestCelebaDataset, TestCelebaNumSamples) { | |||
| std::string dir = datasets_root_path_ + "/testCelebAData/"; | |||
| uint32_t count = 0; | |||
| auto tree = Build({Celeba(16, 2, 32, dir, 1)}); | |||
| tree->Prepare(); | |||
| Status rc = tree->Launch(); | |||
| if (rc.IsError()) { | |||
| MS_LOG(ERROR) << "Return code error detected during tree launch: " << rc.ToString() << "."; | |||
| EXPECT_TRUE(false); | |||
| } else { | |||
| DatasetIterator di(tree); | |||
| TensorMap tersor_map; | |||
| di.GetNextAsMap(&tersor_map); | |||
| EXPECT_TRUE(rc.IsOk()); | |||
| while (tersor_map.size() != 0) { | |||
| count++; | |||
| di.GetNextAsMap(&tersor_map); | |||
| } | |||
| EXPECT_TRUE(count == 1); | |||
| } | |||
| } | |||
| @@ -45,13 +45,12 @@ std::shared_ptr<RepeatOp> Repeat(int repeatCnt); | |||
| std::shared_ptr<ExecutionTree> Build(std::vector<std::shared_ptr<DatasetOp>> ops); | |||
| std::shared_ptr<CifarOp> Cifarop(uint64_t num_works, uint64_t rows, uint64_t conns, std::string path, | |||
| std::unique_ptr<Sampler> sampler = nullptr, | |||
| uint64_t num_samples = 0, bool cifar10 = true) { | |||
| std::shared_ptr<Sampler> sampler = nullptr, bool cifar10 = true) { | |||
| std::shared_ptr<CifarOp> so; | |||
| CifarOp::Builder builder; | |||
| Status rc = builder.SetNumWorkers(num_works).SetCifarDir(path).SetRowsPerBuffer(rows) | |||
| .SetOpConnectorSize(conns).SetSampler(std::move(sampler)).SetCifarType(cifar10) | |||
| .SetNumSamples(num_samples).Build(&so); | |||
| .Build(&so); | |||
| return so; | |||
| } | |||
| @@ -66,7 +65,7 @@ TEST_F(MindDataTestCifarOp, TestSequentialSamplerCifar10) { | |||
| //appear in this dataset | |||
| //Example: python tests/dataset/data/prep_data.py | |||
| std::string folder_path = datasets_root_path_ + "/testCifar10Data/"; | |||
| auto tree = Build({Cifarop(16, 2, 32, folder_path, nullptr, 100)}); | |||
| auto tree = Build({Cifarop(16, 2, 32, folder_path, nullptr)}); | |||
| tree->Prepare(); | |||
| Status rc = tree->Launch(); | |||
| if (rc.IsError()) { | |||
| @@ -79,7 +78,8 @@ TEST_F(MindDataTestCifarOp, TestSequentialSamplerCifar10) { | |||
| EXPECT_TRUE(rc.IsOk()); | |||
| uint64_t i = 0; | |||
| uint32_t label = 0; | |||
| while (tensor_map.size() != 0) { | |||
| // Note: only iterating first 100 rows then break out. | |||
| while (tensor_map.size() != 0 && i < 100) { | |||
| tensor_map["label"]->GetItemAt<uint32_t>(&label, {}); | |||
| MS_LOG(DEBUG) << "row: " << i << "\t" << tensor_map["image"]->shape() << "label:" << label << "\n"; | |||
| i++; | |||
| @@ -92,9 +92,9 @@ TEST_F(MindDataTestCifarOp, TestSequentialSamplerCifar10) { | |||
| TEST_F(MindDataTestCifarOp, TestRandomSamplerCifar10) { | |||
| uint32_t original_seed = GlobalContext::config_manager()->seed(); | |||
| GlobalContext::config_manager()->set_seed(0); | |||
| std::unique_ptr<Sampler> sampler = std::make_unique<RandomSampler>(true, true, 12); | |||
| std::shared_ptr<Sampler> sampler = std::make_unique<RandomSampler>(12, true, true); | |||
| std::string folder_path = datasets_root_path_ + "/testCifar10Data/"; | |||
| auto tree = Build({Cifarop(16, 2, 32, folder_path, std::move(sampler), 100)}); | |||
| auto tree = Build({Cifarop(16, 2, 32, folder_path, std::move(sampler))}); | |||
| tree->Prepare(); | |||
| Status rc = tree->Launch(); | |||
| if (rc.IsError()) { | |||
| @@ -118,34 +118,9 @@ TEST_F(MindDataTestCifarOp, TestRandomSamplerCifar10) { | |||
| GlobalContext::config_manager()->set_seed(original_seed); | |||
| } | |||
| TEST_F(MindDataTestCifarOp, TestCifar10NumSample) { | |||
| std::string folder_path = datasets_root_path_ + "/testCifar10Data/"; | |||
| auto tree = Build({Cifarop(16, 2, 32, folder_path, nullptr, 100)}); | |||
| tree->Prepare(); | |||
| Status rc = tree->Launch(); | |||
| if (rc.IsError()) { | |||
| MS_LOG(ERROR) << "Return code error detected during tree launch: " << common::SafeCStr(rc.ToString()) << "."; | |||
| EXPECT_TRUE(false); | |||
| } else { | |||
| DatasetIterator di(tree); | |||
| TensorMap tensor_map; | |||
| di.GetNextAsMap(&tensor_map); | |||
| EXPECT_TRUE(rc.IsOk()); | |||
| uint64_t i = 0; | |||
| uint32_t label = 0; | |||
| while (tensor_map.size() != 0) { | |||
| tensor_map["label"]->GetItemAt<uint32_t>(&label, {}); | |||
| MS_LOG(DEBUG) << "row: " << i << "\t" << tensor_map["image"]->shape() << "label:" << label << "\n"; | |||
| i++; | |||
| di.GetNextAsMap(&tensor_map); | |||
| } | |||
| EXPECT_TRUE(i == 100); | |||
| } | |||
| } | |||
| TEST_F(MindDataTestCifarOp, TestSequentialSamplerCifar100) { | |||
| std::string folder_path = datasets_root_path_ + "/testCifar100Data/"; | |||
| auto tree = Build({Cifarop(16, 2, 32, folder_path, nullptr, 100, false)}); | |||
| auto tree = Build({Cifarop(16, 2, 32, folder_path, nullptr, false)}); | |||
| tree->Prepare(); | |||
| Status rc = tree->Launch(); | |||
| if (rc.IsError()) { | |||
| @@ -159,7 +134,8 @@ TEST_F(MindDataTestCifarOp, TestSequentialSamplerCifar100) { | |||
| uint64_t i = 0; | |||
| uint32_t coarse = 0; | |||
| uint32_t fine = 0; | |||
| while (tensor_map.size() != 0) { | |||
| // only iterate to 100 then break out of loop | |||
| while (tensor_map.size() != 0 && i < 100) { | |||
| tensor_map["coarse_label"]->GetItemAt<uint32_t>(&coarse, {}); | |||
| tensor_map["fine_label"]->GetItemAt<uint32_t>(&fine, {}); | |||
| MS_LOG(DEBUG) << "row: " << i << "\t" << tensor_map["image"]->shape() << " coarse:" | |||
| @@ -50,9 +50,8 @@ std::shared_ptr<RepeatOp> Repeat(int repeat_cnt); | |||
| std::shared_ptr<ExecutionTree> Build(std::vector<std::shared_ptr<DatasetOp>> ops); | |||
| std::shared_ptr<ImageFolderOp> ImageFolder(int64_t num_works, int64_t rows, int64_t conns, std::string path, | |||
| bool shuf = false, std::unique_ptr<Sampler> sampler = nullptr, | |||
| std::map<std::string, int32_t> map = {}, int64_t num_samples = 0, | |||
| bool decode = false) { | |||
| bool shuf = false, std::shared_ptr<Sampler> sampler = nullptr, | |||
| std::map<std::string, int32_t> map = {}, bool decode = false) { | |||
| std::shared_ptr<ImageFolderOp> so; | |||
| ImageFolderOp::Builder builder; | |||
| Status rc = builder.SetNumWorkers(num_works) | |||
| @@ -63,7 +62,6 @@ std::shared_ptr<ImageFolderOp> ImageFolder(int64_t num_works, int64_t rows, int6 | |||
| .SetSampler(std::move(sampler)) | |||
| .SetClassIndex(map) | |||
| .SetDecode(decode) | |||
| .SetNumSamples(num_samples) | |||
| .Build(&so); | |||
| return so; | |||
| } | |||
| @@ -138,7 +136,8 @@ TEST_F(MindDataTestImageFolderSampler, TestRandomImageFolder) { | |||
| TEST_F(MindDataTestImageFolderSampler, TestRandomSamplerImageFolder) { | |||
| int32_t original_seed = GlobalContext::config_manager()->seed(); | |||
| GlobalContext::config_manager()->set_seed(0); | |||
| std::unique_ptr<Sampler> sampler = std::make_unique<RandomSampler>(true, true, 12); | |||
| int64_t num_samples = 12; | |||
| std::shared_ptr<Sampler> sampler = std::make_unique<RandomSampler>(num_samples, true, true); | |||
| int32_t res[] = {2, 2, 2, 3, 2, 3, 2, 3, 1, 2, 2, 1}; // ground truth label | |||
| std::string folder_path = datasets_root_path_ + "/testPK/data"; | |||
| auto tree = Build({ImageFolder(16, 2, 32, folder_path, false, std::move(sampler))}); | |||
| @@ -200,7 +199,8 @@ TEST_F(MindDataTestImageFolderSampler, TestSequentialImageFolderWithRepeatBatch) | |||
| TEST_F(MindDataTestImageFolderSampler, TestSubsetRandomSamplerImageFolder) { | |||
| // id range 0 - 10 is label 0, and id range 11 - 21 is label 1 | |||
| std::vector<int64_t> indices({0, 1, 2, 3, 4, 5, 12, 13, 14, 15, 16, 11}); | |||
| std::unique_ptr<Sampler> sampler = std::make_unique<SubsetRandomSampler>(indices); | |||
| int64_t num_samples = 0; | |||
| std::shared_ptr<Sampler> sampler = std::make_shared<SubsetRandomSampler>(num_samples, indices); | |||
| std::string folder_path = datasets_root_path_ + "/testPK/data"; | |||
| // Expect 6 samples for label 0 and 1 | |||
| int res[2] = {6, 6}; | |||
| @@ -237,8 +237,8 @@ TEST_F(MindDataTestImageFolderSampler, TestWeightedRandomSamplerImageFolder) { | |||
| std::vector<double> weights(total_samples, std::rand() % 100); | |||
| // create sampler with replacement = replacement | |||
| std::unique_ptr<Sampler> sampler = | |||
| std::make_unique<WeightedRandomSampler>(weights, num_samples, true, samples_per_buffer); | |||
| std::shared_ptr<Sampler> sampler = | |||
| std::make_shared<WeightedRandomSampler>(num_samples, weights, true, samples_per_buffer); | |||
| std::string folder_path = datasets_root_path_ + "/testPK/data"; | |||
| auto tree = Build({ImageFolder(16, 2, 32, folder_path, false, std::move(sampler))}); | |||
| @@ -295,7 +295,8 @@ TEST_F(MindDataTestImageFolderSampler, TestImageFolderClassIndex) { | |||
| } | |||
| TEST_F(MindDataTestImageFolderSampler, TestDistributedSampler) { | |||
| std::unique_ptr<Sampler> sampler = std::make_unique<DistributedSampler>(11, 10, false); | |||
| int64_t num_samples = 0; | |||
| std::shared_ptr<Sampler> sampler = std::make_shared<DistributedSampler>(num_samples, 11, 10, false); | |||
| std::string folder_path = datasets_root_path_ + "/testPK/data"; | |||
| auto tree = Build({ImageFolder(16, 2, 32, folder_path, false, std::move(sampler)), Repeat(4)}); | |||
| tree->Prepare(); | |||
| @@ -322,7 +323,8 @@ TEST_F(MindDataTestImageFolderSampler, TestDistributedSampler) { | |||
| } | |||
| TEST_F(MindDataTestImageFolderSampler, TestPKSamplerImageFolder) { | |||
| std::unique_ptr<Sampler> sampler = std::make_unique<PKSampler>(3, false, 4); | |||
| int64_t num_samples = 0; | |||
| std::shared_ptr<Sampler> sampler = std::make_shared<PKSampler>(num_samples, 3, false, 4); | |||
| int32_t res[] = {0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3}; // ground truth label | |||
| std::string folder_path = datasets_root_path_ + "/testPK/data"; | |||
| auto tree = Build({ImageFolder(16, 2, 32, folder_path, false, std::move(sampler))}); | |||
| @@ -349,39 +351,16 @@ TEST_F(MindDataTestImageFolderSampler, TestPKSamplerImageFolder) { | |||
| } | |||
| } | |||
| TEST_F(MindDataTestImageFolderSampler, TestImageFolderNumSamples) { | |||
| std::string folder_path = datasets_root_path_ + "/testPK/data"; | |||
| auto tree = Build({ImageFolder(16, 2, 32, folder_path, false, nullptr, {}, 11), Repeat(2)}); | |||
| tree->Prepare(); | |||
| Status rc = tree->Launch(); | |||
| if (rc.IsError()) { | |||
| MS_LOG(ERROR) << "Return code error detected during tree launch: " << common::SafeCStr(rc.ToString()) << "."; | |||
| EXPECT_TRUE(false); | |||
| } else { | |||
| DatasetIterator di(tree); | |||
| TensorMap tensor_map; | |||
| di.GetNextAsMap(&tensor_map); | |||
| EXPECT_TRUE(rc.IsOk()); | |||
| uint64_t i = 0; | |||
| int32_t label = 0; | |||
| while (tensor_map.size() != 0) { | |||
| tensor_map["label"]->GetItemAt<int32_t>(&label, {}); | |||
| EXPECT_TRUE(0 == label); | |||
| MS_LOG(DEBUG) << "row: " << i << "\t" << tensor_map["image"]->shape() << "label:" << label << "\n"; | |||
| i++; | |||
| di.GetNextAsMap(&tensor_map); | |||
| } | |||
| EXPECT_TRUE(i == 22); | |||
| } | |||
| } | |||
| TEST_F(MindDataTestImageFolderSampler, TestImageFolderDecode) { | |||
| std::string folder_path = datasets_root_path_ + "/testPK/data"; | |||
| std::map<std::string, int32_t> map; | |||
| map["class3"] = 333; | |||
| map["class1"] = 111; | |||
| map["wrong folder name"] = 1234; // this is skipped | |||
| auto tree = Build({ImageFolder(16, 2, 32, folder_path, false, nullptr, map, 20, true)}); | |||
| int64_t num_samples = 20; | |||
| int64_t start_index = 0; | |||
| auto seq_sampler = std::make_shared<SequentialSampler>(num_samples, start_index); | |||
| auto tree = Build({ImageFolder(16, 2, 32, folder_path, false, std::move(seq_sampler), map, true)}); | |||
| int64_t res[2] = {111, 333}; | |||
| tree->Prepare(); | |||
| Status rc = tree->Launch(); | |||
| @@ -408,33 +387,12 @@ TEST_F(MindDataTestImageFolderSampler, TestImageFolderDecode) { | |||
| } | |||
| } | |||
| TEST_F(MindDataTestImageFolderSampler, TestImageFolderDatasetSize) { | |||
| std::string folder_path = datasets_root_path_ + "/testPK/data"; | |||
| int64_t num_rows = 0; | |||
| int64_t num_classes = 0; | |||
| ImageFolderOp::CountRowsAndClasses(folder_path, 15, {}, &num_rows, &num_classes); | |||
| EXPECT_TRUE(num_rows == 15 && num_classes == 4); | |||
| ImageFolderOp::CountRowsAndClasses(folder_path, 44, {}, &num_rows, &num_classes); | |||
| EXPECT_TRUE(num_rows == 44 && num_classes == 4); | |||
| ImageFolderOp::CountRowsAndClasses(folder_path, 0, {}, &num_rows, &num_classes); | |||
| EXPECT_TRUE(num_rows == 44 && num_classes == 4); | |||
| ImageFolderOp::CountRowsAndClasses(folder_path, 55, {}, &num_rows, &num_classes); | |||
| EXPECT_TRUE(num_rows == 44 && num_classes == 4); | |||
| ImageFolderOp::CountRowsAndClasses(folder_path, 44, {}, &num_rows, &num_classes, 2, 3); | |||
| EXPECT_TRUE(num_rows == 15 && num_classes == 4); | |||
| ImageFolderOp::CountRowsAndClasses(folder_path, 33, {}, &num_rows, &num_classes, 0, 3); | |||
| EXPECT_TRUE(num_rows == 15 && num_classes == 4); | |||
| ImageFolderOp::CountRowsAndClasses(folder_path, 13, {}, &num_rows, &num_classes, 0, 11); | |||
| EXPECT_TRUE(num_rows == 4 && num_classes == 4); | |||
| ImageFolderOp::CountRowsAndClasses(folder_path, 3, {}, &num_rows, &num_classes, 0, 11); | |||
| EXPECT_TRUE(num_rows == 3 && num_classes == 4); | |||
| } | |||
| TEST_F(MindDataTestImageFolderSampler, TestImageFolderSharding1) { | |||
| std::unique_ptr<Sampler> sampler = std::make_unique<DistributedSampler>(4, 0, false); | |||
| int64_t num_samples = 5; | |||
| std::shared_ptr<Sampler> sampler = std::make_shared<DistributedSampler>(num_samples, 4, 0, false); | |||
| std::string folder_path = datasets_root_path_ + "/testPK/data"; | |||
| // numWrks, rows, conns, path, shuffle, sampler, map, numSamples, decode | |||
| auto tree = Build({ImageFolder(16, 2, 32, folder_path, false, std::move(sampler), {}, 5)}); | |||
| auto tree = Build({ImageFolder(16, 2, 32, folder_path, false, std::move(sampler), {})}); | |||
| tree->Prepare(); | |||
| Status rc = tree->Launch(); | |||
| int32_t labels[5] = {0, 0, 0, 1, 1}; | |||
| @@ -460,10 +418,11 @@ TEST_F(MindDataTestImageFolderSampler, TestImageFolderSharding1) { | |||
| } | |||
| TEST_F(MindDataTestImageFolderSampler, TestImageFolderSharding2) { | |||
| std::unique_ptr<Sampler> sampler = std::make_unique<DistributedSampler>(4, 3, false); | |||
| int64_t num_samples = 12; | |||
| std::shared_ptr<Sampler> sampler = std::make_shared<DistributedSampler>(num_samples, 4, 3, false); | |||
| std::string folder_path = datasets_root_path_ + "/testPK/data"; | |||
| // numWrks, rows, conns, path, shuffle, sampler, map, numSamples, decode | |||
| auto tree = Build({ImageFolder(16, 16, 32, folder_path, false, std::move(sampler), {}, 12)}); | |||
| auto tree = Build({ImageFolder(16, 16, 32, folder_path, false, std::move(sampler), {})}); | |||
| tree->Prepare(); | |||
| Status rc = tree->Launch(); | |||
| uint32_t labels[11] = {0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3}; | |||
| @@ -23,6 +23,7 @@ | |||
| #include "dataset/core/client.h" | |||
| #include "dataset/core/global_context.h" | |||
| #include "dataset/engine/datasetops/source/manifest_op.h" | |||
| #include "dataset/engine/datasetops/source/sampler/sequential_sampler.h" | |||
| #include "dataset/engine/datasetops/source/sampler/subset_random_sampler.h" | |||
| #include "dataset/util/de_error.h" | |||
| #include "dataset/util/status.h" | |||
| @@ -42,14 +43,13 @@ std::shared_ptr<RepeatOp> Repeat(int repeatCnt); | |||
| std::shared_ptr<ExecutionTree> Build(std::vector<std::shared_ptr<DatasetOp>> ops); | |||
| std::shared_ptr<ManifestOp> Manifest(int32_t num_works, int32_t rows, int32_t conns, const std::string &file, | |||
| std::string usage = "train", std::unique_ptr<Sampler> sampler = nullptr, | |||
| std::map<std::string, int32_t> map = {}, uint64_t num_samples = 0, | |||
| bool decode = false) { | |||
| std::string usage = "train", std::shared_ptr<Sampler> sampler = nullptr, | |||
| std::map<std::string, int32_t> map = {}, bool decode = false) { | |||
| std::shared_ptr<ManifestOp> so; | |||
| ManifestOp::Builder builder; | |||
| Status rc = builder.SetNumWorkers(num_works).SetManifestFile(file).SetRowsPerBuffer( | |||
| rows).SetOpConnectorSize(conns).SetSampler(std::move(sampler)).SetClassIndex(map).SetDecode(decode) | |||
| .SetNumSamples(num_samples).SetUsage(usage).Build(&so); | |||
| .SetUsage(usage).Build(&so); | |||
| return so; | |||
| } | |||
| @@ -86,7 +86,8 @@ TEST_F(MindDataTestManifest, TestSequentialManifestWithRepeat) { | |||
| TEST_F(MindDataTestManifest, TestSubsetRandomSamplerManifest) { | |||
| std::vector<int64_t> indices({1}); | |||
| std::unique_ptr<Sampler> sampler = std::make_unique<SubsetRandomSampler>(indices); | |||
| int64_t num_samples = 0; | |||
| std::shared_ptr<Sampler> sampler = std::make_shared<SubsetRandomSampler>(num_samples, indices); | |||
| std::string file = datasets_root_path_ + "/testManifestData/cpp.json"; | |||
| // Expect 6 samples for label 0 and 1 | |||
| auto tree = Build({Manifest(16, 2, 32, file, "train", std::move(sampler))}); | |||
| @@ -145,7 +146,10 @@ TEST_F(MindDataTestManifest, MindDataTestManifestClassIndex) { | |||
| TEST_F(MindDataTestManifest, MindDataTestManifestNumSamples) { | |||
| std::string file = datasets_root_path_ + "/testManifestData/cpp.json"; | |||
| auto tree = Build({Manifest(16, 2, 32, file, "train", nullptr, {}, 1), Repeat(4)}); | |||
| int64_t num_samples = 1; | |||
| int64_t start_index = 0; | |||
| auto seq_sampler = std::make_shared<SequentialSampler>(num_samples, start_index); | |||
| auto tree = Build({Manifest(16, 2, 32, file, "train", std::move(seq_sampler), {}), Repeat(4)}); | |||
| tree->Prepare(); | |||
| Status rc = tree->Launch(); | |||
| if (rc.IsError()) { | |||
| @@ -171,7 +175,10 @@ TEST_F(MindDataTestManifest, MindDataTestManifestNumSamples) { | |||
| TEST_F(MindDataTestManifest, MindDataTestManifestEval) { | |||
| std::string file = datasets_root_path_ + "/testManifestData/cpp.json"; | |||
| auto tree = Build({Manifest(16, 2, 32, file, "eval", nullptr, {}, 1)}); | |||
| int64_t num_samples = 1; | |||
| int64_t start_index = 0; | |||
| auto seq_sampler = std::make_shared<SequentialSampler>(num_samples, start_index); | |||
| auto tree = Build({Manifest(16, 2, 32, file, "eval", std::move(seq_sampler), {})}); | |||
| tree->Prepare(); | |||
| Status rc = tree->Launch(); | |||
| if (rc.IsError()) { | |||
| @@ -120,9 +120,8 @@ class MindDataTestMapOp : public UT::DatasetOpTesting { | |||
| }; | |||
| std::shared_ptr<ImageFolderOp> ImageFolder(int64_t num_works, int64_t rows, int64_t conns, std::string path, | |||
| bool shuf = false, std::unique_ptr<Sampler> sampler = nullptr, | |||
| std::map<std::string, int32_t> map = {}, int64_t num_samples = 0, | |||
| bool decode = false); | |||
| bool shuf = false, std::shared_ptr<Sampler> sampler = nullptr, | |||
| std::map<std::string, int32_t> map = {}, bool decode = false); | |||
| std::shared_ptr<ExecutionTree> Build(std::vector<std::shared_ptr<DatasetOp>> ops); | |||
| @@ -53,13 +53,11 @@ Status Create1DTensor(std::shared_ptr<Tensor> *sample_ids, int64_t num_elements, | |||
| DataType::Type data_type = DataType::DE_UINT32); | |||
| std::shared_ptr<MnistOp> CreateMnist(int64_t num_wrks, int64_t rows, int64_t conns, std::string path, | |||
| bool shuf = false, std::unique_ptr<Sampler> sampler = nullptr, | |||
| int64_t num_samples = 0) { | |||
| bool shuf = false, std::shared_ptr<Sampler> sampler = nullptr) { | |||
| std::shared_ptr<MnistOp> so; | |||
| MnistOp::Builder builder; | |||
| Status rc = builder.SetNumWorkers(num_wrks).SetDir(path).SetRowsPerBuffer(rows) | |||
| .SetOpConnectorSize(conns).SetSampler(std::move(sampler)) | |||
| .SetNumSamples(num_samples).Build(&so); | |||
| .SetOpConnectorSize(conns).SetSampler(std::move(sampler)).Build(&so); | |||
| return so; | |||
| } | |||
| @@ -74,7 +72,10 @@ TEST_F(MindDataTestMnistSampler, TestSequentialMnistWithRepeat) { | |||
| // appear in this dataset | |||
| // Example: python tests/dataset/data/prep_data.py | |||
| std::string folder_path = datasets_root_path_ + "/testMnistData/"; | |||
| auto tree = Build({CreateMnist(16, 2, 32, folder_path, false, nullptr, 10), Repeat(2)}); | |||
| int64_t num_samples = 10; | |||
| int64_t start_index = 0; | |||
| auto seq_sampler = std::make_shared<SequentialSampler>(num_samples, start_index); | |||
| auto tree = Build({CreateMnist(16, 2, 32, folder_path, false, std::move(seq_sampler)), Repeat(2)}); | |||
| tree->Prepare(); | |||
| uint32_t res[] = {0, 0, 0, 0, 0, 0, 0, 0, 0, 0}; | |||
| Status rc = tree->Launch(); | |||
| @@ -101,7 +102,10 @@ TEST_F(MindDataTestMnistSampler, TestSequentialMnistWithRepeat) { | |||
| TEST_F(MindDataTestMnistSampler, TestSequentialImageFolderWithRepeatBatch) { | |||
| std::string folder_path = datasets_root_path_ + "/testMnistData/"; | |||
| auto tree = Build({CreateMnist(16, 2, 32, folder_path, false, nullptr, 10), Repeat(2), Batch(5)}); | |||
| int64_t num_samples = 10; | |||
| int64_t start_index = 0; | |||
| auto seq_sampler = std::make_shared<SequentialSampler>(num_samples, start_index); | |||
| auto tree = Build({CreateMnist(16, 2, 32, folder_path, false, std::move(seq_sampler)), Repeat(2), Batch(5)}); | |||
| tree->Prepare(); | |||
| uint32_t res[4][5] = { {0, 0, 0, 0, 0 }, | |||
| {0, 0, 0, 0, 0 }, | |||
| @@ -43,20 +43,11 @@ class MindDataTestStandAloneSampler : public UT::DatasetOpTesting { | |||
| protected: | |||
| class MockStorageOp : public RandomAccessOp { | |||
| public: | |||
| MockStorageOp(int64_t val) : m_val_(val) {} | |||
| Status GetNumSamples(int64_t *ptr) const override { | |||
| (*ptr) = m_val_; | |||
| return Status::OK(); | |||
| } | |||
| Status GetNumRowsInDataset(int64_t *ptr) const override { | |||
| (*ptr) = m_val_; | |||
| return Status::OK(); | |||
| MockStorageOp(int64_t val){ | |||
| // row count is in base class as protected member | |||
| // GetNumRowsInDataset does not need an override, the default from base class is fine. | |||
| num_rows_ = val; | |||
| } | |||
| private: | |||
| int64_t m_val_; | |||
| }; | |||
| }; | |||
| @@ -73,8 +64,9 @@ TEST_F(MindDataTestStandAloneSampler, TestDistributedSampler) { | |||
| MockStorageOp mock(20); | |||
| std::unique_ptr<DataBuffer> db; | |||
| std::shared_ptr<Tensor> tensor; | |||
| int64_t num_samples = 0; | |||
| for (int i = 0; i < 6; i++) { | |||
| std::unique_ptr<Sampler> sampler = std::make_unique<DistributedSampler>(3, i % 3, (i < 3 ? false : true)); | |||
| std::shared_ptr<Sampler> sampler = std::make_shared<DistributedSampler>(num_samples, 3, i % 3, (i < 3 ? false : true)); | |||
| sampler->HandshakeRandomAccessOp(&mock); | |||
| sampler->GetNextBuffer(&db); | |||
| db->GetTensor(&tensor, 0, 0); | |||
| @@ -92,7 +84,9 @@ TEST_F(MindDataTestStandAloneSampler, TestStandAoneSequentialSampler) { | |||
| std::shared_ptr<Tensor> label1, label2; | |||
| CreateINT64Tensor(&label1, 3, reinterpret_cast<unsigned char *>(res)); | |||
| CreateINT64Tensor(&label2, 2, reinterpret_cast<unsigned char *>(res + 3)); | |||
| std::shared_ptr<Sampler> sampler = std::make_shared<SequentialSampler>(3); | |||
| int64_t num_samples = 0; | |||
| int64_t start_index = 0; | |||
| std::shared_ptr<Sampler> sampler = std::make_shared<SequentialSampler>(num_samples, start_index, 3); | |||
| std::unique_ptr<DataBuffer> db; | |||
| std::shared_ptr<Tensor> tensor; | |||
| sampler->HandshakeRandomAccessOp(&mock); | |||
| @@ -31,26 +31,17 @@ class MindDataTestSubsetRandomSampler : public UT::Common { | |||
| public: | |||
| class DummyRandomAccessOp : public RandomAccessOp { | |||
| public: | |||
| DummyRandomAccessOp(int64_t num_rows) : num_rows_(num_rows) {}; | |||
| Status GetNumSamples(int64_t *num) const { | |||
| *num = num_rows_; | |||
| return Status::OK(); | |||
| } | |||
| Status GetNumRowsInDataset(int64_t *num) const { | |||
| *num = num_rows_; | |||
| return Status::OK(); | |||
| } | |||
| private: | |||
| int64_t num_rows_; | |||
| DummyRandomAccessOp(int64_t num_rows) { | |||
| num_rows_ = num_rows; // base class | |||
| }; | |||
| }; | |||
| }; | |||
| TEST_F(MindDataTestSubsetRandomSampler, TestAllAtOnce) { | |||
| std::vector<int64_t> in({0, 1, 2, 3, 4}); | |||
| std::unordered_set<int64_t> in_set(in.begin(), in.end()); | |||
| SubsetRandomSampler sampler(in); | |||
| int64_t num_samples = 0; | |||
| SubsetRandomSampler sampler(num_samples, in); | |||
| DummyRandomAccessOp dummyRandomAccessOp(5); | |||
| sampler.HandshakeRandomAccessOp(&dummyRandomAccessOp); | |||
| @@ -77,8 +68,9 @@ TEST_F(MindDataTestSubsetRandomSampler, TestAllAtOnce) { | |||
| TEST_F(MindDataTestSubsetRandomSampler, TestGetNextBuffer) { | |||
| int64_t total_samples = 100000 - 5; | |||
| int64_t samples_per_buffer = 10; | |||
| int64_t num_samples = 0; | |||
| std::vector<int64_t> input(total_samples, 1); | |||
| SubsetRandomSampler sampler(input, samples_per_buffer); | |||
| SubsetRandomSampler sampler(num_samples, input, samples_per_buffer); | |||
| DummyRandomAccessOp dummyRandomAccessOp(total_samples); | |||
| sampler.HandshakeRandomAccessOp(&dummyRandomAccessOp); | |||
| @@ -109,7 +101,8 @@ TEST_F(MindDataTestSubsetRandomSampler, TestGetNextBuffer) { | |||
| TEST_F(MindDataTestSubsetRandomSampler, TestReset) { | |||
| std::vector<int64_t> in({0, 1, 2, 3, 4}); | |||
| std::unordered_set<int64_t> in_set(in.begin(), in.end()); | |||
| SubsetRandomSampler sampler(in); | |||
| int64_t num_samples = 0; | |||
| SubsetRandomSampler sampler(num_samples, in); | |||
| DummyRandomAccessOp dummyRandomAccessOp(5); | |||
| sampler.HandshakeRandomAccessOp(&dummyRandomAccessOp); | |||
| @@ -35,19 +35,11 @@ class MindDataTestWeightedRandomSampler : public UT::Common { | |||
| public: | |||
| class DummyRandomAccessOp : public RandomAccessOp { | |||
| public: | |||
| DummyRandomAccessOp(uint64_t num_rows) : num_rows_(num_rows) {}; | |||
| Status GetNumSamples(int64_t *num) const { | |||
| *num = num_rows_; | |||
| return Status::OK(); | |||
| DummyRandomAccessOp(uint64_t num_rows) { | |||
| // row count is in base class as protected member | |||
| // GetNumRowsInDataset does not need an override, the default from base class is fine. | |||
| num_rows_ = num_rows; | |||
| } | |||
| Status GetNumRowsInDataset(int64_t *num) const { | |||
| *num = num_rows_; | |||
| return Status::OK(); | |||
| } | |||
| private: | |||
| uint64_t num_rows_; | |||
| }; | |||
| }; | |||
| @@ -59,7 +51,7 @@ TEST_F(MindDataTestWeightedRandomSampler, TestOneshotReplacement) { | |||
| std::vector<uint64_t> freq(total_samples, 0); | |||
| // create sampler with replacement = true | |||
| WeightedRandomSampler m_sampler(weights, num_samples, true); | |||
| WeightedRandomSampler m_sampler(num_samples, weights, true); | |||
| DummyRandomAccessOp dummyRandomAccessOp(total_samples); | |||
| m_sampler.HandshakeRandomAccessOp(&dummyRandomAccessOp); | |||
| @@ -89,7 +81,7 @@ TEST_F(MindDataTestWeightedRandomSampler, TestOneshotNoReplacement) { | |||
| std::vector<uint64_t> freq(total_samples, 0); | |||
| // create sampler with replacement = replacement | |||
| WeightedRandomSampler m_sampler(weights, num_samples, false); | |||
| WeightedRandomSampler m_sampler(num_samples, weights, false); | |||
| DummyRandomAccessOp dummyRandomAccessOp(total_samples); | |||
| m_sampler.HandshakeRandomAccessOp(&dummyRandomAccessOp); | |||
| @@ -125,7 +117,7 @@ TEST_F(MindDataTestWeightedRandomSampler, TestGetNextBufferReplacement) { | |||
| std::vector<double> weights(total_samples, std::rand() % 100); | |||
| // create sampler with replacement = replacement | |||
| WeightedRandomSampler m_sampler(weights, num_samples, true, samples_per_buffer); | |||
| WeightedRandomSampler m_sampler(num_samples, weights, true, samples_per_buffer); | |||
| DummyRandomAccessOp dummyRandomAccessOp(total_samples); | |||
| m_sampler.HandshakeRandomAccessOp(&dummyRandomAccessOp); | |||
| @@ -161,7 +153,7 @@ TEST_F(MindDataTestWeightedRandomSampler, TestGetNextBufferNoReplacement) { | |||
| std::vector<uint64_t> freq(total_samples, 0); | |||
| // create sampler with replacement = replacement | |||
| WeightedRandomSampler m_sampler(weights, num_samples, false, samples_per_buffer); | |||
| WeightedRandomSampler m_sampler(num_samples, weights, false, samples_per_buffer); | |||
| DummyRandomAccessOp dummyRandomAccessOp(total_samples); | |||
| m_sampler.HandshakeRandomAccessOp(&dummyRandomAccessOp); | |||
| @@ -202,7 +194,7 @@ TEST_F(MindDataTestWeightedRandomSampler, TestResetReplacement) { | |||
| std::vector<uint64_t> freq(total_samples, 0); | |||
| // create sampler with replacement = true | |||
| WeightedRandomSampler m_sampler(weights, num_samples, true); | |||
| WeightedRandomSampler m_sampler(num_samples, weights, true); | |||
| DummyRandomAccessOp dummyRandomAccessOp(total_samples); | |||
| m_sampler.HandshakeRandomAccessOp(&dummyRandomAccessOp); | |||
| @@ -247,7 +239,7 @@ TEST_F(MindDataTestWeightedRandomSampler, TestResetNoReplacement) { | |||
| std::vector<uint64_t> freq(total_samples, 0); | |||
| // create sampler with replacement = true | |||
| WeightedRandomSampler m_sampler(weights, num_samples, false); | |||
| WeightedRandomSampler m_sampler(num_samples, weights, false); | |||
| DummyRandomAccessOp dummyRandomAccessOp(total_samples); | |||
| m_sampler.HandshakeRandomAccessOp(&dummyRandomAccessOp); | |||
| @@ -58,7 +58,7 @@ def test_imagefolder_numsamples(): | |||
| assert num_iter == 10 | |||
| random_sampler = ds.RandomSampler(num_samples=3, replacement=True) | |||
| data1 = ds.ImageFolderDatasetV2(DATA_DIR, num_samples=10, num_parallel_workers=2, sampler=random_sampler) | |||
| data1 = ds.ImageFolderDatasetV2(DATA_DIR, num_parallel_workers=2, sampler=random_sampler) | |||
| num_iter = 0 | |||
| for item in data1.create_dict_iterator(): | |||
| @@ -67,7 +67,7 @@ def test_imagefolder_numsamples(): | |||
| assert num_iter == 3 | |||
| random_sampler = ds.RandomSampler(num_samples=3, replacement=False) | |||
| data1 = ds.ImageFolderDatasetV2(DATA_DIR, num_samples=10, num_parallel_workers=2, sampler=random_sampler) | |||
| data1 = ds.ImageFolderDatasetV2(DATA_DIR, num_parallel_workers=2, sampler=random_sampler) | |||
| num_iter = 0 | |||
| for item in data1.create_dict_iterator(): | |||
| @@ -162,8 +162,8 @@ def test_voc_shardings(print_res=False): | |||
| voc_dir = "../data/dataset/testVOC2012" | |||
| def sharding_config(num_shards, shard_id, num_samples, shuffle, repeat_cnt=1): | |||
| sampler = ds.DistributedSampler(num_shards, shard_id, shuffle=shuffle) | |||
| data1 = ds.VOCDataset(voc_dir, decode=True, sampler=sampler, num_samples=num_samples) | |||
| sampler = ds.DistributedSampler(num_shards, shard_id, shuffle=shuffle, num_samples=num_samples) | |||
| data1 = ds.VOCDataset(voc_dir, decode=True, sampler=sampler) | |||
| data1 = data1.repeat(repeat_cnt) | |||
| res = [] | |||
| for item in data1.create_dict_iterator(): # each data is a dictionary | |||
| @@ -35,18 +35,13 @@ def test_exception_01(): | |||
| def test_exception_02(): | |||
| """ | |||
| Test multiple exceptions with invalid input | |||
| Test exceptions with invalid input, and test valid input | |||
| """ | |||
| logger.info("test_exception_02") | |||
| num_samples = 0 | |||
| with pytest.raises(ValueError) as info: | |||
| data = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], num_samples=num_samples) | |||
| assert "num_samples must be greater than 0" in str(info.value) | |||
| num_samples = -1 | |||
| with pytest.raises(ValueError) as info: | |||
| data = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], num_samples=num_samples) | |||
| assert "num_samples must be greater than 0" in str(info.value) | |||
| assert "num_samples cannot be less than 0" in str(info.value) | |||
| num_samples = 1 | |||
| data = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], num_samples=num_samples) | |||
| @@ -544,7 +544,7 @@ def test_distributed_sampler(): | |||
| def test_num_samples(): | |||
| source = [(np.array([x]),) for x in range(64)] | |||
| num_samples = 32 | |||
| ds1 = ds.GeneratorDataset(source, ["data"], sampler=ds.SequentialSampler(), num_samples=num_samples) | |||
| ds1 = ds.GeneratorDataset(source, ["data"], sampler=ds.SequentialSampler(num_samples=num_samples)) | |||
| ds2 = ds.GeneratorDataset(source, ["data"], sampler=[i for i in range(32)], num_samples=num_samples) | |||
| ds3 = ds.GeneratorDataset(generator_1d, ["data"], num_samples=num_samples) | |||
| @@ -660,4 +660,6 @@ if __name__ == "__main__": | |||
| test_sequential_sampler() | |||
| test_distributed_sampler() | |||
| test_random_sampler() | |||
| test_num_samples() | |||
| test_num_samples_underflow() | |||
| test_schema() | |||
| @@ -28,8 +28,8 @@ def test_sequential_sampler(print_res=False): | |||
| map_ = {(172876, 0): 0, (54214, 0): 1, (54214, 1): 2, (173673, 0): 3, (64631, 1): 4} | |||
| def test_config(num_samples, num_repeats=None): | |||
| sampler = ds.SequentialSampler() | |||
| data1 = ds.ManifestDataset(manifest_file, num_samples=num_samples, sampler=sampler) | |||
| sampler = ds.SequentialSampler(num_samples=num_samples) | |||
| data1 = ds.ManifestDataset(manifest_file, sampler=sampler) | |||
| if num_repeats is not None: | |||
| data1 = data1.repeat(num_repeats) | |||
| res = [] | |||
| @@ -43,6 +43,7 @@ def test_sequential_sampler(print_res=False): | |||
| assert test_config(num_samples=3, num_repeats=None) == [0, 1, 2] | |||
| assert test_config(num_samples=None, num_repeats=2) == [0, 1, 2, 3, 4] * 2 | |||
| assert test_config(num_samples=0, num_repeats=2) == [0, 1, 2, 3, 4] * 2 | |||
| assert test_config(num_samples=4, num_repeats=2) == [0, 1, 2, 3] * 2 | |||
| @@ -119,8 +120,8 @@ def test_python_sampler(): | |||
| return iter([i for i in range(self.dataset_size)]) | |||
| class Sp2(ds.Sampler): | |||
| def __init__(self): | |||
| super(Sp2, self).__init__() | |||
| def __init__(self, num_samples=None): | |||
| super(Sp2, self).__init__(num_samples) | |||
| # at this stage, self.dataset_size and self.num_samples are not yet known | |||
| self.cnt = 0 | |||
| @@ -130,8 +131,8 @@ def test_python_sampler(): | |||
| def reset(self): | |||
| self.cnt = (self.cnt + 1) % self.dataset_size | |||
| def test_config(num_samples, num_repeats, sampler): | |||
| data1 = ds.ManifestDataset(manifest_file, num_samples=num_samples, sampler=sampler) | |||
| def test_config(num_repeats, sampler): | |||
| data1 = ds.ManifestDataset(manifest_file, sampler=sampler) | |||
| if num_repeats is not None: | |||
| data1 = data1.repeat(num_repeats) | |||
| res = [] | |||
| @@ -154,8 +155,8 @@ def test_python_sampler(): | |||
| assert data[0] == (np.array(i),) | |||
| i = i - 1 | |||
| assert test_config(5, 2, Sp1()) == [0, 1, 2, 3, 4, 0, 1, 2, 3, 4] | |||
| assert test_config(2, 6, Sp2()) == [0, 0, 1, 1, 2, 2, 3, 3, 4, 4, 0, 0] | |||
| assert test_config(2, Sp1(5)) == [0, 1, 2, 3, 4, 0, 1, 2, 3, 4] | |||
| assert test_config(6, Sp2(2)) == [0, 0, 1, 1, 2, 2, 3, 3, 4, 4, 0, 0] | |||
| test_generator() | |||
| sp1 = Sp1().create() | |||
| @@ -169,9 +170,8 @@ def test_subset_sampler(): | |||
| manifest_file = "../data/dataset/testManifestData/test5trainimgs.json" | |||
| map_ = {(172876, 0): 0, (54214, 0): 1, (54214, 1): 2, (173673, 0): 3, (64631, 1): 4} | |||
| def test_config(num_samples, start_index, subset_size): | |||
| _ = num_samples | |||
| sampler = ds.SubsetSampler(start_index, subset_size) | |||
| def test_config(start_index, num_samples): | |||
| sampler = ds.SequentialSampler(start_index, num_samples) | |||
| d = ds.ManifestDataset(manifest_file, sampler=sampler) | |||
| res = [] | |||
| @@ -180,19 +180,15 @@ def test_subset_sampler(): | |||
| return res | |||
| with pytest.raises(RuntimeError) as info: | |||
| test_config(5, 0, 0) | |||
| assert "subset_size <= 0" in str(info.value) | |||
| assert test_config(5, 0, 1) == [0] | |||
| assert test_config(5, 0, 2) == [0, 1] | |||
| assert test_config(5, 0, 3) == [0, 1, 2] | |||
| assert test_config(5, 0, 4) == [0, 1, 2, 3] | |||
| assert test_config(5, 0, 5) == [0, 1, 2, 3, 4] | |||
| assert test_config(5, 1, 1) == [1] | |||
| assert test_config(5, 2, 3) == [2, 3, 4] | |||
| assert test_config(5, 3, 2) == [3, 4] | |||
| assert test_config(5, 4, 1) == [4] | |||
| assert test_config(0, 1) == [0] | |||
| assert test_config(0, 2) == [0, 1] | |||
| assert test_config(0, 3) == [0, 1, 2] | |||
| assert test_config(0, 4) == [0, 1, 2, 3] | |||
| assert test_config(0, 5) == [0, 1, 2, 3, 4] | |||
| assert test_config(1, 1) == [1] | |||
| assert test_config(2, 3) == [2, 3, 4] | |||
| assert test_config(3, 2) == [3, 4] | |||
| assert test_config(4, 1) == [4] | |||
| def test_sampler_chain(): | |||
| @@ -200,11 +196,11 @@ def test_sampler_chain(): | |||
| map_ = {(172876, 0): 0, (54214, 0): 1, (54214, 1): 2, (173673, 0): 3, (64631, 1): 4} | |||
| def test_config(num_shards, shard_id): | |||
| sampler = ds.DistributedSampler(num_shards, shard_id, False) | |||
| sampler = ds.DistributedSampler(num_shards, shard_id, shuffle=False, num_samples=5) | |||
| child_sampler = ds.SequentialSampler() | |||
| sampler.add_child(child_sampler) | |||
| data1 = ds.ManifestDataset(manifest_file, num_samples=5, sampler=sampler) | |||
| data1 = ds.ManifestDataset(manifest_file, sampler=sampler) | |||
| res = [] | |||
| for item in data1.create_dict_iterator(): | |||
| @@ -234,6 +230,11 @@ def test_add_sampler_invalid_input(): | |||
| data1.use_sampler("sampler") | |||
| assert "not an instance of a sampler" in str(info.value) | |||
| sampler = ds.SequentialSampler() | |||
| with pytest.raises(ValueError) as info: | |||
| data2 = ds.ManifestDataset(manifest_file, sampler=sampler, num_samples=20) | |||
| assert "Conflicting arguments during sampler assignments" in str(info.value) | |||
| if __name__ == '__main__': | |||
| test_sequential_sampler(True) | |||