ci fix 1 fix ci 2 fix ci 3 fix ci 4 fix ci 5 fix ci 6 fix ci 7 change var name in CelebA from dataset_type to usage address review cmts fix cpp ut fail change mode to usage in VOCDataset fix cmts fix ci test case failtags/v1.0.0
| @@ -15,7 +15,7 @@ | |||||
| */ | */ | ||||
| #include <fstream> | #include <fstream> | ||||
| #include <unordered_set> | |||||
| #include "minddata/dataset/include/datasets.h" | #include "minddata/dataset/include/datasets.h" | ||||
| #include "minddata/dataset/include/samplers.h" | #include "minddata/dataset/include/samplers.h" | ||||
| #include "minddata/dataset/include/transforms.h" | #include "minddata/dataset/include/transforms.h" | ||||
| @@ -132,26 +132,28 @@ std::shared_ptr<AlbumDataset> Album(const std::string &dataset_dir, const std::s | |||||
| } | } | ||||
| // Function to create a CelebADataset. | // Function to create a CelebADataset. | ||||
| std::shared_ptr<CelebADataset> CelebA(const std::string &dataset_dir, const std::string &dataset_type, | |||||
| std::shared_ptr<CelebADataset> CelebA(const std::string &dataset_dir, const std::string &usage, | |||||
| const std::shared_ptr<SamplerObj> &sampler, bool decode, | const std::shared_ptr<SamplerObj> &sampler, bool decode, | ||||
| const std::set<std::string> &extensions) { | const std::set<std::string> &extensions) { | ||||
| auto ds = std::make_shared<CelebADataset>(dataset_dir, dataset_type, sampler, decode, extensions); | |||||
| auto ds = std::make_shared<CelebADataset>(dataset_dir, usage, sampler, decode, extensions); | |||||
| // Call derived class validation method. | // Call derived class validation method. | ||||
| return ds->ValidateParams() ? ds : nullptr; | return ds->ValidateParams() ? ds : nullptr; | ||||
| } | } | ||||
| // Function to create a Cifar10Dataset. | // Function to create a Cifar10Dataset. | ||||
| std::shared_ptr<Cifar10Dataset> Cifar10(const std::string &dataset_dir, const std::shared_ptr<SamplerObj> &sampler) { | |||||
| auto ds = std::make_shared<Cifar10Dataset>(dataset_dir, sampler); | |||||
| std::shared_ptr<Cifar10Dataset> Cifar10(const std::string &dataset_dir, const std::string &usage, | |||||
| const std::shared_ptr<SamplerObj> &sampler) { | |||||
| auto ds = std::make_shared<Cifar10Dataset>(dataset_dir, usage, sampler); | |||||
| // Call derived class validation method. | // Call derived class validation method. | ||||
| return ds->ValidateParams() ? ds : nullptr; | return ds->ValidateParams() ? ds : nullptr; | ||||
| } | } | ||||
| // Function to create a Cifar100Dataset. | // Function to create a Cifar100Dataset. | ||||
| std::shared_ptr<Cifar100Dataset> Cifar100(const std::string &dataset_dir, const std::shared_ptr<SamplerObj> &sampler) { | |||||
| auto ds = std::make_shared<Cifar100Dataset>(dataset_dir, sampler); | |||||
| std::shared_ptr<Cifar100Dataset> Cifar100(const std::string &dataset_dir, const std::string &usage, | |||||
| const std::shared_ptr<SamplerObj> &sampler) { | |||||
| auto ds = std::make_shared<Cifar100Dataset>(dataset_dir, usage, sampler); | |||||
| // Call derived class validation method. | // Call derived class validation method. | ||||
| return ds->ValidateParams() ? ds : nullptr; | return ds->ValidateParams() ? ds : nullptr; | ||||
| @@ -217,8 +219,9 @@ std::shared_ptr<ManifestDataset> Manifest(const std::string &dataset_file, const | |||||
| #endif | #endif | ||||
| // Function to create a MnistDataset. | // Function to create a MnistDataset. | ||||
| std::shared_ptr<MnistDataset> Mnist(const std::string &dataset_dir, const std::shared_ptr<SamplerObj> &sampler) { | |||||
| auto ds = std::make_shared<MnistDataset>(dataset_dir, sampler); | |||||
| std::shared_ptr<MnistDataset> Mnist(const std::string &dataset_dir, const std::string &usage, | |||||
| const std::shared_ptr<SamplerObj> &sampler) { | |||||
| auto ds = std::make_shared<MnistDataset>(dataset_dir, usage, sampler); | |||||
| // Call derived class validation method. | // Call derived class validation method. | ||||
| return ds->ValidateParams() ? ds : nullptr; | return ds->ValidateParams() ? ds : nullptr; | ||||
| @@ -244,10 +247,10 @@ std::shared_ptr<TextFileDataset> TextFile(const std::vector<std::string> &datase | |||||
| #ifndef ENABLE_ANDROID | #ifndef ENABLE_ANDROID | ||||
| // Function to create a VOCDataset. | // Function to create a VOCDataset. | ||||
| std::shared_ptr<VOCDataset> VOC(const std::string &dataset_dir, const std::string &task, const std::string &mode, | |||||
| std::shared_ptr<VOCDataset> VOC(const std::string &dataset_dir, const std::string &task, const std::string &usage, | |||||
| const std::map<std::string, int32_t> &class_indexing, bool decode, | const std::map<std::string, int32_t> &class_indexing, bool decode, | ||||
| const std::shared_ptr<SamplerObj> &sampler) { | const std::shared_ptr<SamplerObj> &sampler) { | ||||
| auto ds = std::make_shared<VOCDataset>(dataset_dir, task, mode, class_indexing, decode, sampler); | |||||
| auto ds = std::make_shared<VOCDataset>(dataset_dir, task, usage, class_indexing, decode, sampler); | |||||
| // Call derived class validation method. | // Call derived class validation method. | ||||
| return ds->ValidateParams() ? ds : nullptr; | return ds->ValidateParams() ? ds : nullptr; | ||||
| @@ -727,6 +730,10 @@ bool ValidateDatasetSampler(const std::string &dataset_name, const std::shared_p | |||||
| return true; | return true; | ||||
| } | } | ||||
| bool ValidateStringValue(const std::string &str, const std::unordered_set<std::string> &valid_strings) { | |||||
| return valid_strings.find(str) != valid_strings.end(); | |||||
| } | |||||
| // Helper function to validate dataset input/output column parameter | // Helper function to validate dataset input/output column parameter | ||||
| bool ValidateDatasetColumnParam(const std::string &dataset_name, const std::string &column_param, | bool ValidateDatasetColumnParam(const std::string &dataset_name, const std::string &column_param, | ||||
| const std::vector<std::string> &columns) { | const std::vector<std::string> &columns) { | ||||
| @@ -802,29 +809,14 @@ std::vector<std::shared_ptr<DatasetOp>> AlbumDataset::Build() { | |||||
| } | } | ||||
| // Constructor for CelebADataset | // Constructor for CelebADataset | ||||
| CelebADataset::CelebADataset(const std::string &dataset_dir, const std::string &dataset_type, | |||||
| CelebADataset::CelebADataset(const std::string &dataset_dir, const std::string &usage, | |||||
| const std::shared_ptr<SamplerObj> &sampler, const bool &decode, | const std::shared_ptr<SamplerObj> &sampler, const bool &decode, | ||||
| const std::set<std::string> &extensions) | const std::set<std::string> &extensions) | ||||
| : dataset_dir_(dataset_dir), | |||||
| dataset_type_(dataset_type), | |||||
| sampler_(sampler), | |||||
| decode_(decode), | |||||
| extensions_(extensions) {} | |||||
| : dataset_dir_(dataset_dir), usage_(usage), sampler_(sampler), decode_(decode), extensions_(extensions) {} | |||||
| bool CelebADataset::ValidateParams() { | bool CelebADataset::ValidateParams() { | ||||
| if (!ValidateDatasetDirParam("CelebADataset", dataset_dir_)) { | |||||
| return false; | |||||
| } | |||||
| if (!ValidateDatasetSampler("CelebADataset", sampler_)) { | |||||
| return false; | |||||
| } | |||||
| std::set<std::string> dataset_type_list = {"all", "train", "valid", "test"}; | |||||
| auto iter = dataset_type_list.find(dataset_type_); | |||||
| if (iter == dataset_type_list.end()) { | |||||
| MS_LOG(ERROR) << "dataset_type should be one of 'all', 'train', 'valid' or 'test'."; | |||||
| return false; | |||||
| } | |||||
| return true; | |||||
| return ValidateDatasetDirParam("CelebADataset", dataset_dir_) && ValidateDatasetSampler("CelebADataset", sampler_) && | |||||
| ValidateStringValue(usage_, {"all", "train", "valid", "test"}); | |||||
| } | } | ||||
| // Function to build CelebADataset | // Function to build CelebADataset | ||||
| @@ -839,17 +831,20 @@ std::vector<std::shared_ptr<DatasetOp>> CelebADataset::Build() { | |||||
| RETURN_EMPTY_IF_ERROR( | RETURN_EMPTY_IF_ERROR( | ||||
| schema->AddColumn(ColDescriptor("attr", DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 1))); | schema->AddColumn(ColDescriptor("attr", DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 1))); | ||||
| node_ops.push_back(std::make_shared<CelebAOp>(num_workers_, rows_per_buffer_, dataset_dir_, connector_que_size_, | node_ops.push_back(std::make_shared<CelebAOp>(num_workers_, rows_per_buffer_, dataset_dir_, connector_que_size_, | ||||
| decode_, dataset_type_, extensions_, std::move(schema), | |||||
| decode_, usage_, extensions_, std::move(schema), | |||||
| std::move(sampler_->Build()))); | std::move(sampler_->Build()))); | ||||
| return node_ops; | return node_ops; | ||||
| } | } | ||||
| // Constructor for Cifar10Dataset | // Constructor for Cifar10Dataset | ||||
| Cifar10Dataset::Cifar10Dataset(const std::string &dataset_dir, std::shared_ptr<SamplerObj> sampler) | |||||
| : dataset_dir_(dataset_dir), sampler_(sampler) {} | |||||
| Cifar10Dataset::Cifar10Dataset(const std::string &dataset_dir, const std::string &usage, | |||||
| std::shared_ptr<SamplerObj> sampler) | |||||
| : dataset_dir_(dataset_dir), usage_(usage), sampler_(sampler) {} | |||||
| bool Cifar10Dataset::ValidateParams() { | bool Cifar10Dataset::ValidateParams() { | ||||
| return ValidateDatasetDirParam("Cifar10Dataset", dataset_dir_) && ValidateDatasetSampler("Cifar10Dataset", sampler_); | |||||
| return ValidateDatasetDirParam("Cifar10Dataset", dataset_dir_) && | |||||
| ValidateDatasetSampler("Cifar10Dataset", sampler_) && | |||||
| ValidateStringValue(usage_, {"train", "test", "all", ""}); | |||||
| } | } | ||||
| // Function to build CifarOp for Cifar10 | // Function to build CifarOp for Cifar10 | ||||
| @@ -864,19 +859,21 @@ std::vector<std::shared_ptr<DatasetOp>> Cifar10Dataset::Build() { | |||||
| RETURN_EMPTY_IF_ERROR( | RETURN_EMPTY_IF_ERROR( | ||||
| schema->AddColumn(ColDescriptor("label", DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 0, &scalar))); | schema->AddColumn(ColDescriptor("label", DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 0, &scalar))); | ||||
| node_ops.push_back(std::make_shared<CifarOp>(CifarOp::CifarType::kCifar10, num_workers_, rows_per_buffer_, | |||||
| node_ops.push_back(std::make_shared<CifarOp>(CifarOp::CifarType::kCifar10, usage_, num_workers_, rows_per_buffer_, | |||||
| dataset_dir_, connector_que_size_, std::move(schema), | dataset_dir_, connector_que_size_, std::move(schema), | ||||
| std::move(sampler_->Build()))); | std::move(sampler_->Build()))); | ||||
| return node_ops; | return node_ops; | ||||
| } | } | ||||
| // Constructor for Cifar100Dataset | // Constructor for Cifar100Dataset | ||||
| Cifar100Dataset::Cifar100Dataset(const std::string &dataset_dir, std::shared_ptr<SamplerObj> sampler) | |||||
| : dataset_dir_(dataset_dir), sampler_(sampler) {} | |||||
| Cifar100Dataset::Cifar100Dataset(const std::string &dataset_dir, const std::string &usage, | |||||
| std::shared_ptr<SamplerObj> sampler) | |||||
| : dataset_dir_(dataset_dir), usage_(usage), sampler_(sampler) {} | |||||
| bool Cifar100Dataset::ValidateParams() { | bool Cifar100Dataset::ValidateParams() { | ||||
| return ValidateDatasetDirParam("Cifar100Dataset", dataset_dir_) && | return ValidateDatasetDirParam("Cifar100Dataset", dataset_dir_) && | ||||
| ValidateDatasetSampler("Cifar100Dataset", sampler_); | |||||
| ValidateDatasetSampler("Cifar100Dataset", sampler_) && | |||||
| ValidateStringValue(usage_, {"train", "test", "all", ""}); | |||||
| } | } | ||||
| // Function to build CifarOp for Cifar100 | // Function to build CifarOp for Cifar100 | ||||
| @@ -893,7 +890,7 @@ std::vector<std::shared_ptr<DatasetOp>> Cifar100Dataset::Build() { | |||||
| RETURN_EMPTY_IF_ERROR( | RETURN_EMPTY_IF_ERROR( | ||||
| schema->AddColumn(ColDescriptor("fine_label", DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 0, &scalar))); | schema->AddColumn(ColDescriptor("fine_label", DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 0, &scalar))); | ||||
| node_ops.push_back(std::make_shared<CifarOp>(CifarOp::CifarType::kCifar100, num_workers_, rows_per_buffer_, | |||||
| node_ops.push_back(std::make_shared<CifarOp>(CifarOp::CifarType::kCifar100, usage_, num_workers_, rows_per_buffer_, | |||||
| dataset_dir_, connector_que_size_, std::move(schema), | dataset_dir_, connector_que_size_, std::move(schema), | ||||
| std::move(sampler_->Build()))); | std::move(sampler_->Build()))); | ||||
| return node_ops; | return node_ops; | ||||
| @@ -1360,11 +1357,12 @@ std::vector<std::shared_ptr<DatasetOp>> ManifestDataset::Build() { | |||||
| } | } | ||||
| #endif | #endif | ||||
| MnistDataset::MnistDataset(std::string dataset_dir, std::shared_ptr<SamplerObj> sampler) | |||||
| : dataset_dir_(dataset_dir), sampler_(sampler) {} | |||||
| MnistDataset::MnistDataset(std::string dataset_dir, std::string usage, std::shared_ptr<SamplerObj> sampler) | |||||
| : dataset_dir_(dataset_dir), usage_(usage), sampler_(sampler) {} | |||||
| bool MnistDataset::ValidateParams() { | bool MnistDataset::ValidateParams() { | ||||
| return ValidateDatasetDirParam("MnistDataset", dataset_dir_) && ValidateDatasetSampler("MnistDataset", sampler_); | |||||
| return ValidateStringValue(usage_, {"train", "test", "all", ""}) && | |||||
| ValidateDatasetDirParam("MnistDataset", dataset_dir_) && ValidateDatasetSampler("MnistDataset", sampler_); | |||||
| } | } | ||||
| std::vector<std::shared_ptr<DatasetOp>> MnistDataset::Build() { | std::vector<std::shared_ptr<DatasetOp>> MnistDataset::Build() { | ||||
| @@ -1378,8 +1376,8 @@ std::vector<std::shared_ptr<DatasetOp>> MnistDataset::Build() { | |||||
| RETURN_EMPTY_IF_ERROR( | RETURN_EMPTY_IF_ERROR( | ||||
| schema->AddColumn(ColDescriptor("label", DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 0, &scalar))); | schema->AddColumn(ColDescriptor("label", DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 0, &scalar))); | ||||
| node_ops.push_back(std::make_shared<MnistOp>(num_workers_, rows_per_buffer_, dataset_dir_, connector_que_size_, | |||||
| std::move(schema), std::move(sampler_->Build()))); | |||||
| node_ops.push_back(std::make_shared<MnistOp>(usage_, num_workers_, rows_per_buffer_, dataset_dir_, | |||||
| connector_que_size_, std::move(schema), std::move(sampler_->Build()))); | |||||
| return node_ops; | return node_ops; | ||||
| } | } | ||||
| @@ -1570,12 +1568,12 @@ std::vector<std::shared_ptr<DatasetOp>> TFRecordDataset::Build() { | |||||
| #ifndef ENABLE_ANDROID | #ifndef ENABLE_ANDROID | ||||
| // Constructor for VOCDataset | // Constructor for VOCDataset | ||||
| VOCDataset::VOCDataset(const std::string &dataset_dir, const std::string &task, const std::string &mode, | |||||
| VOCDataset::VOCDataset(const std::string &dataset_dir, const std::string &task, const std::string &usage, | |||||
| const std::map<std::string, int32_t> &class_indexing, bool decode, | const std::map<std::string, int32_t> &class_indexing, bool decode, | ||||
| std::shared_ptr<SamplerObj> sampler) | std::shared_ptr<SamplerObj> sampler) | ||||
| : dataset_dir_(dataset_dir), | : dataset_dir_(dataset_dir), | ||||
| task_(task), | task_(task), | ||||
| mode_(mode), | |||||
| usage_(usage), | |||||
| class_index_(class_indexing), | class_index_(class_indexing), | ||||
| decode_(decode), | decode_(decode), | ||||
| sampler_(sampler) {} | sampler_(sampler) {} | ||||
| @@ -1594,15 +1592,15 @@ bool VOCDataset::ValidateParams() { | |||||
| MS_LOG(ERROR) << "class_indexing is invalid in Segmentation task."; | MS_LOG(ERROR) << "class_indexing is invalid in Segmentation task."; | ||||
| return false; | return false; | ||||
| } | } | ||||
| Path imagesets_file = dir / "ImageSets" / "Segmentation" / mode_ + ".txt"; | |||||
| Path imagesets_file = dir / "ImageSets" / "Segmentation" / usage_ + ".txt"; | |||||
| if (!imagesets_file.Exists()) { | if (!imagesets_file.Exists()) { | ||||
| MS_LOG(ERROR) << "Invalid mode: " << mode_ << ", file \"" << imagesets_file << "\" is not exists!"; | |||||
| MS_LOG(ERROR) << "Invalid mode: " << usage_ << ", file \"" << imagesets_file << "\" does not exist!"; | |||||
| return false; | return false; | ||||
| } | } | ||||
| } else if (task_ == "Detection") { | } else if (task_ == "Detection") { | ||||
| Path imagesets_file = dir / "ImageSets" / "Main" / mode_ + ".txt"; | |||||
| Path imagesets_file = dir / "ImageSets" / "Main" / usage_ + ".txt"; | |||||
| if (!imagesets_file.Exists()) { | if (!imagesets_file.Exists()) { | ||||
| MS_LOG(ERROR) << "Invalid mode: " << mode_ << ", file \"" << imagesets_file << "\" is not exists!"; | |||||
| MS_LOG(ERROR) << "Invalid mode: " << usage_ << ", file \"" << imagesets_file << "\" does not exist!"; | |||||
| return false; | return false; | ||||
| } | } | ||||
| } else { | } else { | ||||
| @@ -1641,7 +1639,7 @@ std::vector<std::shared_ptr<DatasetOp>> VOCDataset::Build() { | |||||
| } | } | ||||
| std::shared_ptr<VOCOp> voc_op; | std::shared_ptr<VOCOp> voc_op; | ||||
| voc_op = std::make_shared<VOCOp>(task_type_, mode_, dataset_dir_, class_index_, num_workers_, rows_per_buffer_, | |||||
| voc_op = std::make_shared<VOCOp>(task_type_, usage_, dataset_dir_, class_index_, num_workers_, rows_per_buffer_, | |||||
| connector_que_size_, decode_, std::move(schema), std::move(sampler_->Build())); | connector_que_size_, decode_, std::move(schema), std::move(sampler_->Build())); | ||||
| node_ops.push_back(voc_op); | node_ops.push_back(voc_op); | ||||
| return node_ops; | return node_ops; | ||||
| @@ -41,9 +41,9 @@ namespace dataset { | |||||
| PYBIND_REGISTER(CifarOp, 1, ([](const py::module *m) { | PYBIND_REGISTER(CifarOp, 1, ([](const py::module *m) { | ||||
| (void)py::class_<CifarOp, DatasetOp, std::shared_ptr<CifarOp>>(*m, "CifarOp") | (void)py::class_<CifarOp, DatasetOp, std::shared_ptr<CifarOp>>(*m, "CifarOp") | ||||
| .def_static("get_num_rows", [](const std::string &dir, bool isCifar10) { | |||||
| .def_static("get_num_rows", [](const std::string &dir, const std::string &usage, bool isCifar10) { | |||||
| int64_t count = 0; | int64_t count = 0; | ||||
| THROW_IF_ERROR(CifarOp::CountTotalRows(dir, isCifar10, &count)); | |||||
| THROW_IF_ERROR(CifarOp::CountTotalRows(dir, usage, isCifar10, &count)); | |||||
| return count; | return count; | ||||
| }); | }); | ||||
| })); | })); | ||||
| @@ -131,9 +131,9 @@ PYBIND_REGISTER(MindRecordOp, 1, ([](const py::module *m) { | |||||
| PYBIND_REGISTER(MnistOp, 1, ([](const py::module *m) { | PYBIND_REGISTER(MnistOp, 1, ([](const py::module *m) { | ||||
| (void)py::class_<MnistOp, DatasetOp, std::shared_ptr<MnistOp>>(*m, "MnistOp") | (void)py::class_<MnistOp, DatasetOp, std::shared_ptr<MnistOp>>(*m, "MnistOp") | ||||
| .def_static("get_num_rows", [](const std::string &dir) { | |||||
| .def_static("get_num_rows", [](const std::string &dir, const std::string &usage) { | |||||
| int64_t count = 0; | int64_t count = 0; | ||||
| THROW_IF_ERROR(MnistOp::CountTotalRows(dir, &count)); | |||||
| THROW_IF_ERROR(MnistOp::CountTotalRows(dir, usage, &count)); | |||||
| return count; | return count; | ||||
| }); | }); | ||||
| })); | })); | ||||
| @@ -1354,25 +1354,14 @@ Status DEPipeline::ParseManifestOp(const py::dict &args, std::shared_ptr<Dataset | |||||
| Status DEPipeline::ParseVOCOp(const py::dict &args, std::shared_ptr<DatasetOp> *top, | Status DEPipeline::ParseVOCOp(const py::dict &args, std::shared_ptr<DatasetOp> *top, | ||||
| std::shared_ptr<DatasetOp> *bottom) { | std::shared_ptr<DatasetOp> *bottom) { | ||||
| if (args["dataset_dir"].is_none()) { | |||||
| std::string err_msg = "Error: No dataset path specified"; | |||||
| RETURN_STATUS_UNEXPECTED(err_msg); | |||||
| } | |||||
| if (args["task"].is_none()) { | |||||
| std::string err_msg = "Error: No task specified"; | |||||
| RETURN_STATUS_UNEXPECTED(err_msg); | |||||
| } | |||||
| if (args["mode"].is_none()) { | |||||
| std::string err_msg = "Error: No mode specified"; | |||||
| RETURN_STATUS_UNEXPECTED(err_msg); | |||||
| } | |||||
| CHECK_FAIL_RETURN_UNEXPECTED(!args["dataset_dir"].is_none(), "Error: No dataset path specified."); | |||||
| CHECK_FAIL_RETURN_UNEXPECTED(!args["task"].is_none(), "Error: No task specified."); | |||||
| CHECK_FAIL_RETURN_UNEXPECTED(!args["usage"].is_none(), "Error: No usage specified."); | |||||
| std::shared_ptr<VOCOp::Builder> builder = std::make_shared<VOCOp::Builder>(); | std::shared_ptr<VOCOp::Builder> builder = std::make_shared<VOCOp::Builder>(); | ||||
| (void)builder->SetDir(ToString(args["dataset_dir"])); | (void)builder->SetDir(ToString(args["dataset_dir"])); | ||||
| (void)builder->SetTask(ToString(args["task"])); | (void)builder->SetTask(ToString(args["task"])); | ||||
| (void)builder->SetMode(ToString(args["mode"])); | |||||
| (void)builder->SetUsage(ToString(args["usage"])); | |||||
| for (auto arg : args) { | for (auto arg : args) { | ||||
| std::string key = py::str(arg.first); | std::string key = py::str(arg.first); | ||||
| py::handle value = arg.second; | py::handle value = arg.second; | ||||
| @@ -1461,6 +1450,8 @@ Status DEPipeline::ParseCifar10Op(const py::dict &args, std::shared_ptr<DatasetO | |||||
| auto create = py::reinterpret_borrow<py::object>(value).attr("create"); | auto create = py::reinterpret_borrow<py::object>(value).attr("create"); | ||||
| std::shared_ptr<Sampler> sampler = create().cast<std::shared_ptr<Sampler>>(); | std::shared_ptr<Sampler> sampler = create().cast<std::shared_ptr<Sampler>>(); | ||||
| (void)builder->SetSampler(std::move(sampler)); | (void)builder->SetSampler(std::move(sampler)); | ||||
| } else if (key == "usage") { | |||||
| (void)builder->SetUsage(ToString(value)); | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| @@ -1495,6 +1486,8 @@ Status DEPipeline::ParseCifar100Op(const py::dict &args, std::shared_ptr<Dataset | |||||
| auto create = py::reinterpret_borrow<py::object>(value).attr("create"); | auto create = py::reinterpret_borrow<py::object>(value).attr("create"); | ||||
| std::shared_ptr<Sampler> sampler = create().cast<std::shared_ptr<Sampler>>(); | std::shared_ptr<Sampler> sampler = create().cast<std::shared_ptr<Sampler>>(); | ||||
| (void)builder->SetSampler(std::move(sampler)); | (void)builder->SetSampler(std::move(sampler)); | ||||
| } else if (key == "usage") { | |||||
| (void)builder->SetUsage(ToString(value)); | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| @@ -1608,6 +1601,8 @@ Status DEPipeline::ParseMnistOp(const py::dict &args, std::shared_ptr<DatasetOp> | |||||
| auto create = py::reinterpret_borrow<py::object>(value).attr("create"); | auto create = py::reinterpret_borrow<py::object>(value).attr("create"); | ||||
| std::shared_ptr<Sampler> sampler = create().cast<std::shared_ptr<Sampler>>(); | std::shared_ptr<Sampler> sampler = create().cast<std::shared_ptr<Sampler>>(); | ||||
| (void)builder->SetSampler(std::move(sampler)); | (void)builder->SetSampler(std::move(sampler)); | ||||
| } else if (key == "usage") { | |||||
| (void)builder->SetUsage(ToString(value)); | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| @@ -1645,8 +1640,8 @@ Status DEPipeline::ParseCelebAOp(const py::dict &args, std::shared_ptr<DatasetOp | |||||
| (void)builder->SetDecode(ToBool(value)); | (void)builder->SetDecode(ToBool(value)); | ||||
| } else if (key == "extensions") { | } else if (key == "extensions") { | ||||
| (void)builder->SetExtensions(ToStringSet(value)); | (void)builder->SetExtensions(ToStringSet(value)); | ||||
| } else if (key == "dataset_type") { | |||||
| (void)builder->SetDatasetType(ToString(value)); | |||||
| } else if (key == "usage") { | |||||
| (void)builder->SetUsage(ToString(value)); | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| @@ -36,7 +36,7 @@ CelebAOp::Builder::Builder() : builder_decode_(false), builder_sampler_(nullptr) | |||||
| Status CelebAOp::Builder::Build(std::shared_ptr<CelebAOp> *op) { | Status CelebAOp::Builder::Build(std::shared_ptr<CelebAOp> *op) { | ||||
| MS_LOG(DEBUG) << "Celeba dataset directory is " << builder_dir_.c_str() << "."; | MS_LOG(DEBUG) << "Celeba dataset directory is " << builder_dir_.c_str() << "."; | ||||
| MS_LOG(DEBUG) << "Celeba dataset type is " << builder_dataset_type_.c_str() << "."; | |||||
| MS_LOG(DEBUG) << "Celeba dataset type is " << builder_usage_.c_str() << "."; | |||||
| RETURN_IF_NOT_OK(SanityCheck()); | RETURN_IF_NOT_OK(SanityCheck()); | ||||
| if (builder_sampler_ == nullptr) { | if (builder_sampler_ == nullptr) { | ||||
| const int64_t num_samples = 0; | const int64_t num_samples = 0; | ||||
| @@ -51,8 +51,8 @@ Status CelebAOp::Builder::Build(std::shared_ptr<CelebAOp> *op) { | |||||
| RETURN_IF_NOT_OK( | RETURN_IF_NOT_OK( | ||||
| builder_schema_->AddColumn(ColDescriptor("attr", DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 1))); | builder_schema_->AddColumn(ColDescriptor("attr", DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 1))); | ||||
| *op = std::make_shared<CelebAOp>(builder_num_workers_, builder_rows_per_buffer_, builder_dir_, | *op = std::make_shared<CelebAOp>(builder_num_workers_, builder_rows_per_buffer_, builder_dir_, | ||||
| builder_op_connector_size_, builder_decode_, builder_dataset_type_, | |||||
| builder_extensions_, std::move(builder_schema_), std::move(builder_sampler_)); | |||||
| builder_op_connector_size_, builder_decode_, builder_usage_, builder_extensions_, | |||||
| std::move(builder_schema_), std::move(builder_sampler_)); | |||||
| if (*op == nullptr) { | if (*op == nullptr) { | ||||
| return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, "CelebAOp is null"); | return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, "CelebAOp is null"); | ||||
| } | } | ||||
| @@ -69,7 +69,7 @@ Status CelebAOp::Builder::SanityCheck() { | |||||
| } | } | ||||
| CelebAOp::CelebAOp(int32_t num_workers, int32_t rows_per_buffer, const std::string &dir, int32_t queue_size, | CelebAOp::CelebAOp(int32_t num_workers, int32_t rows_per_buffer, const std::string &dir, int32_t queue_size, | ||||
| bool decode, const std::string &dataset_type, const std::set<std::string> &exts, | |||||
| bool decode, const std::string &usage, const std::set<std::string> &exts, | |||||
| std::unique_ptr<DataSchema> schema, std::shared_ptr<Sampler> sampler) | std::unique_ptr<DataSchema> schema, std::shared_ptr<Sampler> sampler) | ||||
| : ParallelOp(num_workers, queue_size, std::move(sampler)), | : ParallelOp(num_workers, queue_size, std::move(sampler)), | ||||
| rows_per_buffer_(rows_per_buffer), | rows_per_buffer_(rows_per_buffer), | ||||
| @@ -78,7 +78,7 @@ CelebAOp::CelebAOp(int32_t num_workers, int32_t rows_per_buffer, const std::stri | |||||
| extensions_(exts), | extensions_(exts), | ||||
| data_schema_(std::move(schema)), | data_schema_(std::move(schema)), | ||||
| num_rows_in_attr_file_(0), | num_rows_in_attr_file_(0), | ||||
| dataset_type_(dataset_type) { | |||||
| usage_(usage) { | |||||
| attr_info_queue_ = std::make_unique<Queue<std::vector<std::string>>>(queue_size); | attr_info_queue_ = std::make_unique<Queue<std::vector<std::string>>>(queue_size); | ||||
| io_block_queues_.Init(num_workers_, queue_size); | io_block_queues_.Init(num_workers_, queue_size); | ||||
| } | } | ||||
| @@ -135,7 +135,7 @@ Status CelebAOp::ParseAttrFile() { | |||||
| std::vector<std::string> image_infos; | std::vector<std::string> image_infos; | ||||
| image_infos.reserve(oc_queue_size_); | image_infos.reserve(oc_queue_size_); | ||||
| while (getline(attr_file, image_info)) { | while (getline(attr_file, image_info)) { | ||||
| if ((image_info.empty()) || (dataset_type_ != "all" && !CheckDatasetTypeValid())) { | |||||
| if ((image_info.empty()) || (usage_ != "all" && !CheckDatasetTypeValid())) { | |||||
| continue; | continue; | ||||
| } | } | ||||
| image_infos.push_back(image_info); | image_infos.push_back(image_info); | ||||
| @@ -179,11 +179,11 @@ bool CelebAOp::CheckDatasetTypeValid() { | |||||
| return false; | return false; | ||||
| } | } | ||||
| // train:0, valid=1, test=2 | // train:0, valid=1, test=2 | ||||
| if (dataset_type_ == "train" && (type == 0)) { | |||||
| if (usage_ == "train" && (type == 0)) { | |||||
| return true; | return true; | ||||
| } else if (dataset_type_ == "valid" && (type == 1)) { | |||||
| } else if (usage_ == "valid" && (type == 1)) { | |||||
| return true; | return true; | ||||
| } else if (dataset_type_ == "test" && (type == 2)) { | |||||
| } else if (usage_ == "test" && (type == 2)) { | |||||
| return true; | return true; | ||||
| } | } | ||||
| @@ -109,10 +109,10 @@ class CelebAOp : public ParallelOp, RandomAccessOp { | |||||
| } | } | ||||
| // Setter method | // Setter method | ||||
| // @param const std::string dataset_type: type to be read | |||||
| // @param const std::string usage: type to be read | |||||
| // @return Builder setter method returns reference to the builder. | // @return Builder setter method returns reference to the builder. | ||||
| Builder &SetDatasetType(const std::string &dataset_type) { | |||||
| builder_dataset_type_ = dataset_type; | |||||
| Builder &SetUsage(const std::string &usage) { | |||||
| builder_usage_ = usage; | |||||
| return *this; | return *this; | ||||
| } | } | ||||
| // Check validity of input args | // Check validity of input args | ||||
| @@ -133,7 +133,7 @@ class CelebAOp : public ParallelOp, RandomAccessOp { | |||||
| std::set<std::string> builder_extensions_; | std::set<std::string> builder_extensions_; | ||||
| std::shared_ptr<Sampler> builder_sampler_; | std::shared_ptr<Sampler> builder_sampler_; | ||||
| std::unique_ptr<DataSchema> builder_schema_; | std::unique_ptr<DataSchema> builder_schema_; | ||||
| std::string builder_dataset_type_; | |||||
| std::string builder_usage_; | |||||
| }; | }; | ||||
| // Constructor | // Constructor | ||||
| @@ -143,12 +143,12 @@ class CelebAOp : public ParallelOp, RandomAccessOp { | |||||
| // @param int32_t queueSize - connector queue size | // @param int32_t queueSize - connector queue size | ||||
| // @param std::unique_ptr<Sampler> sampler - sampler tells CelebAOp what to read | // @param std::unique_ptr<Sampler> sampler - sampler tells CelebAOp what to read | ||||
| CelebAOp(int32_t num_workers, int32_t rows_per_buffer, const std::string &dir, int32_t queue_size, bool decode, | CelebAOp(int32_t num_workers, int32_t rows_per_buffer, const std::string &dir, int32_t queue_size, bool decode, | ||||
| const std::string &dataset_type, const std::set<std::string> &exts, std::unique_ptr<DataSchema> schema, | |||||
| const std::string &usage, const std::set<std::string> &exts, std::unique_ptr<DataSchema> schema, | |||||
| std::shared_ptr<Sampler> sampler); | std::shared_ptr<Sampler> sampler); | ||||
| ~CelebAOp() override = default; | ~CelebAOp() override = default; | ||||
| // Main Loop of CelebaOp | |||||
| // Main Loop of CelebAOp | |||||
| // Master thread: Fill IOBlockQueue, then goes to sleep | // Master thread: Fill IOBlockQueue, then goes to sleep | ||||
| // Worker thread: pulls IOBlock from IOBlockQueue, work on it then put buffer to mOutConnector | // Worker thread: pulls IOBlock from IOBlockQueue, work on it then put buffer to mOutConnector | ||||
| // @return Status - The error code return | // @return Status - The error code return | ||||
| @@ -177,7 +177,7 @@ class CelebAOp : public ParallelOp, RandomAccessOp { | |||||
| // Op name getter | // Op name getter | ||||
| // @return Name of the current Op | // @return Name of the current Op | ||||
| std::string Name() const { return "CelebAOp"; } | |||||
| std::string Name() const override { return "CelebAOp"; } | |||||
| private: | private: | ||||
| // Called first when function is called | // Called first when function is called | ||||
| @@ -232,7 +232,7 @@ class CelebAOp : public ParallelOp, RandomAccessOp { | |||||
| QueueList<std::unique_ptr<IOBlock>> io_block_queues_; | QueueList<std::unique_ptr<IOBlock>> io_block_queues_; | ||||
| WaitPost wp_; | WaitPost wp_; | ||||
| std::vector<std::pair<std::string, std::vector<int32_t>>> image_labels_vec_; | std::vector<std::pair<std::string, std::vector<int32_t>>> image_labels_vec_; | ||||
| std::string dataset_type_; | |||||
| std::string usage_; | |||||
| std::ifstream partition_file_; | std::ifstream partition_file_; | ||||
| }; | }; | ||||
| } // namespace dataset | } // namespace dataset | ||||
| @@ -18,15 +18,16 @@ | |||||
| #include <algorithm> | #include <algorithm> | ||||
| #include <fstream> | #include <fstream> | ||||
| #include <iomanip> | #include <iomanip> | ||||
| #include <set> | |||||
| #include <utility> | #include <utility> | ||||
| #include "utils/ms_utils.h" | |||||
| #include "minddata/dataset/core/config_manager.h" | #include "minddata/dataset/core/config_manager.h" | ||||
| #include "minddata/dataset/core/tensor_shape.h" | #include "minddata/dataset/core/tensor_shape.h" | ||||
| #include "minddata/dataset/engine/datasetops/source/sampler/sequential_sampler.h" | #include "minddata/dataset/engine/datasetops/source/sampler/sequential_sampler.h" | ||||
| #include "minddata/dataset/engine/db_connector.h" | #include "minddata/dataset/engine/db_connector.h" | ||||
| #include "minddata/dataset/engine/execution_tree.h" | #include "minddata/dataset/engine/execution_tree.h" | ||||
| #include "minddata/dataset/engine/opt/pass.h" | #include "minddata/dataset/engine/opt/pass.h" | ||||
| #include "utils/ms_utils.h" | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace dataset { | namespace dataset { | ||||
| @@ -36,7 +37,7 @@ constexpr uint32_t kCifarImageChannel = 3; | |||||
| constexpr uint32_t kCifarBlockImageNum = 5; | constexpr uint32_t kCifarBlockImageNum = 5; | ||||
| constexpr uint32_t kCifarImageSize = kCifarImageHeight * kCifarImageWidth * kCifarImageChannel; | constexpr uint32_t kCifarImageSize = kCifarImageHeight * kCifarImageWidth * kCifarImageChannel; | ||||
| CifarOp::Builder::Builder() : sampler_(nullptr) { | |||||
| CifarOp::Builder::Builder() : sampler_(nullptr), usage_("") { | |||||
| std::shared_ptr<ConfigManager> cfg = GlobalContext::config_manager(); | std::shared_ptr<ConfigManager> cfg = GlobalContext::config_manager(); | ||||
| num_workers_ = cfg->num_parallel_workers(); | num_workers_ = cfg->num_parallel_workers(); | ||||
| rows_per_buffer_ = cfg->rows_per_buffer(); | rows_per_buffer_ = cfg->rows_per_buffer(); | ||||
| @@ -65,23 +66,27 @@ Status CifarOp::Builder::Build(std::shared_ptr<CifarOp> *ptr) { | |||||
| ColDescriptor("fine_label", DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 0, &another_scalar))); | ColDescriptor("fine_label", DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 0, &another_scalar))); | ||||
| } | } | ||||
| *ptr = std::make_shared<CifarOp>(cifar_type_, num_workers_, rows_per_buffer_, dir_, op_connect_size_, | |||||
| *ptr = std::make_shared<CifarOp>(cifar_type_, usage_, num_workers_, rows_per_buffer_, dir_, op_connect_size_, | |||||
| std::move(schema_), std::move(sampler_)); | std::move(schema_), std::move(sampler_)); | ||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| Status CifarOp::Builder::SanityCheck() { | Status CifarOp::Builder::SanityCheck() { | ||||
| const std::set<std::string> valid = {"test", "train", "all", ""}; | |||||
| Path dir(dir_); | Path dir(dir_); | ||||
| std::string err_msg; | std::string err_msg; | ||||
| err_msg += dir.IsDirectory() == false ? "Cifar path is invalid or not set\n" : ""; | err_msg += dir.IsDirectory() == false ? "Cifar path is invalid or not set\n" : ""; | ||||
| err_msg += num_workers_ <= 0 ? "Num of parallel workers is negative or 0\n" : ""; | err_msg += num_workers_ <= 0 ? "Num of parallel workers is negative or 0\n" : ""; | ||||
| err_msg += valid.find(usage_) == valid.end() ? "usage needs to be 'train','test' or 'all'\n" : ""; | |||||
| return err_msg.empty() ? Status::OK() : Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, err_msg); | return err_msg.empty() ? Status::OK() : Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, err_msg); | ||||
| } | } | ||||
| CifarOp::CifarOp(CifarType type, int32_t num_works, int32_t rows_per_buf, const std::string &file_dir, | |||||
| int32_t queue_size, std::unique_ptr<DataSchema> data_schema, std::shared_ptr<Sampler> sampler) | |||||
| CifarOp::CifarOp(CifarType type, const std::string &usage, int32_t num_works, int32_t rows_per_buf, | |||||
| const std::string &file_dir, int32_t queue_size, std::unique_ptr<DataSchema> data_schema, | |||||
| std::shared_ptr<Sampler> sampler) | |||||
| : ParallelOp(num_works, queue_size, std::move(sampler)), | : ParallelOp(num_works, queue_size, std::move(sampler)), | ||||
| cifar_type_(type), | cifar_type_(type), | ||||
| usage_(usage), | |||||
| rows_per_buffer_(rows_per_buf), | rows_per_buffer_(rows_per_buf), | ||||
| folder_path_(file_dir), | folder_path_(file_dir), | ||||
| data_schema_(std::move(data_schema)), | data_schema_(std::move(data_schema)), | ||||
| @@ -258,21 +263,32 @@ Status CifarOp::ReadCifarBlockDataAsync() { | |||||
| } | } | ||||
| Status CifarOp::ReadCifar10BlockData() { | Status CifarOp::ReadCifar10BlockData() { | ||||
| // CIFAR 10 has 6 bin files. data_batch_1.bin ... data_batch_5.bin and 1 test_batch.bin file | |||||
| // each of the file has exactly 10K images and labels and size is 30,730 KB | |||||
| // each image has the dimension of 32 x 32 x 3 = 3072 plus 1 label (label has 10 classes) so each row has 3073 bytes | |||||
| constexpr uint32_t num_cifar10_records = 10000; | constexpr uint32_t num_cifar10_records = 10000; | ||||
| uint32_t block_size = (kCifarImageSize + 1) * kCifarBlockImageNum; // about 2M | uint32_t block_size = (kCifarImageSize + 1) * kCifarBlockImageNum; // about 2M | ||||
| std::vector<unsigned char> image_data(block_size * sizeof(unsigned char), 0); | std::vector<unsigned char> image_data(block_size * sizeof(unsigned char), 0); | ||||
| for (auto &file : cifar_files_) { | for (auto &file : cifar_files_) { | ||||
| std::ifstream in(file, std::ios::binary); | |||||
| if (!in.is_open()) { | |||||
| std::string err_msg = file + " can not be opened."; | |||||
| RETURN_STATUS_UNEXPECTED(err_msg); | |||||
| // check the validity of the file path | |||||
| Path file_path(file); | |||||
| CHECK_FAIL_RETURN_UNEXPECTED(file_path.Exists() && !file_path.IsDirectory(), "invalid file:" + file); | |||||
| std::string file_name = file_path.Basename(); | |||||
| if (usage_ == "train") { | |||||
| if (file_name.find("data_batch") == std::string::npos) continue; | |||||
| } else if (usage_ == "test") { | |||||
| if (file_name.find("test_batch") == std::string::npos) continue; | |||||
| } else { // get all the files that contain the word batch, aka any cifar 100 files | |||||
| if (file_name.find("batch") == std::string::npos) continue; | |||||
| } | } | ||||
| std::ifstream in(file, std::ios::binary); | |||||
| CHECK_FAIL_RETURN_UNEXPECTED(in.is_open(), file + " can not be opened."); | |||||
| for (uint32_t index = 0; index < num_cifar10_records / kCifarBlockImageNum; ++index) { | for (uint32_t index = 0; index < num_cifar10_records / kCifarBlockImageNum; ++index) { | ||||
| (void)in.read(reinterpret_cast<char *>(&(image_data[0])), block_size * sizeof(unsigned char)); | (void)in.read(reinterpret_cast<char *>(&(image_data[0])), block_size * sizeof(unsigned char)); | ||||
| if (in.fail()) { | |||||
| RETURN_STATUS_UNEXPECTED("Fail to read cifar file" + file); | |||||
| } | |||||
| CHECK_FAIL_RETURN_UNEXPECTED(!in.fail(), "Fail to read cifar file" + file); | |||||
| (void)cifar_raw_data_block_->EmplaceBack(image_data); | (void)cifar_raw_data_block_->EmplaceBack(image_data); | ||||
| } | } | ||||
| in.close(); | in.close(); | ||||
| @@ -283,15 +299,21 @@ Status CifarOp::ReadCifar10BlockData() { | |||||
| } | } | ||||
| Status CifarOp::ReadCifar100BlockData() { | Status CifarOp::ReadCifar100BlockData() { | ||||
| // CIFAR 100 has 2 bin files. train.bin (60K imgs) 153,700KB and test.bin (30,740KB) (10K imgs) | |||||
| // each img has two labels. Each row then is 32 * 32 *5 + 2 = 3,074 Bytes | |||||
| uint32_t num_cifar100_records = 0; // test:10000, train:50000 | uint32_t num_cifar100_records = 0; // test:10000, train:50000 | ||||
| uint32_t block_size = (kCifarImageSize + 2) * kCifarBlockImageNum; // about 2M | uint32_t block_size = (kCifarImageSize + 2) * kCifarBlockImageNum; // about 2M | ||||
| std::vector<unsigned char> image_data(block_size * sizeof(unsigned char), 0); | std::vector<unsigned char> image_data(block_size * sizeof(unsigned char), 0); | ||||
| for (auto &file : cifar_files_) { | for (auto &file : cifar_files_) { | ||||
| int pos = file.find_last_of('/'); | |||||
| if (pos == std::string::npos) { | |||||
| RETURN_STATUS_UNEXPECTED("Invalid cifar100 file path"); | |||||
| } | |||||
| std::string file_name(file.substr(pos + 1)); | |||||
| // check the validity of the file path | |||||
| Path file_path(file); | |||||
| CHECK_FAIL_RETURN_UNEXPECTED(file_path.Exists() && !file_path.IsDirectory(), "invalid file:" + file); | |||||
| std::string file_name = file_path.Basename(); | |||||
| // if usage is train/test, get only these 2 files | |||||
| if (usage_ == "train" && file_name.find("train") == std::string::npos) continue; | |||||
| if (usage_ == "test" && file_name.find("test") == std::string::npos) continue; | |||||
| if (file_name.find("test") != std::string::npos) { | if (file_name.find("test") != std::string::npos) { | ||||
| num_cifar100_records = 10000; | num_cifar100_records = 10000; | ||||
| } else if (file_name.find("train") != std::string::npos) { | } else if (file_name.find("train") != std::string::npos) { | ||||
| @@ -301,15 +323,11 @@ Status CifarOp::ReadCifar100BlockData() { | |||||
| } | } | ||||
| std::ifstream in(file, std::ios::binary); | std::ifstream in(file, std::ios::binary); | ||||
| if (!in.is_open()) { | |||||
| RETURN_STATUS_UNEXPECTED(file + " can not be opened."); | |||||
| } | |||||
| CHECK_FAIL_RETURN_UNEXPECTED(in.is_open(), file + " can not be opened."); | |||||
| for (uint32_t index = 0; index < num_cifar100_records / kCifarBlockImageNum; index++) { | for (uint32_t index = 0; index < num_cifar100_records / kCifarBlockImageNum; index++) { | ||||
| (void)in.read(reinterpret_cast<char *>(&(image_data[0])), block_size * sizeof(unsigned char)); | (void)in.read(reinterpret_cast<char *>(&(image_data[0])), block_size * sizeof(unsigned char)); | ||||
| if (in.fail()) { | |||||
| RETURN_STATUS_UNEXPECTED("Fail to read cifar file" + file); | |||||
| } | |||||
| CHECK_FAIL_RETURN_UNEXPECTED(!in.fail(), "Fail to read cifar file" + file); | |||||
| (void)cifar_raw_data_block_->EmplaceBack(image_data); | (void)cifar_raw_data_block_->EmplaceBack(image_data); | ||||
| } | } | ||||
| in.close(); | in.close(); | ||||
| @@ -319,26 +337,20 @@ Status CifarOp::ReadCifar100BlockData() { | |||||
| } | } | ||||
| Status CifarOp::GetCifarFiles() { | Status CifarOp::GetCifarFiles() { | ||||
| // Initialize queue to hold the file names | |||||
| const std::string kExtension = ".bin"; | const std::string kExtension = ".bin"; | ||||
| Path dataset_directory(folder_path_); | |||||
| auto dirIt = Path::DirIterator::OpenDirectory(&dataset_directory); | |||||
| Path dir_path(folder_path_); | |||||
| auto dirIt = Path::DirIterator::OpenDirectory(&dir_path); | |||||
| if (dirIt) { | if (dirIt) { | ||||
| while (dirIt->hasNext()) { | while (dirIt->hasNext()) { | ||||
| Path file = dirIt->next(); | Path file = dirIt->next(); | ||||
| std::string filename = file.toString(); | |||||
| if (filename.find(kExtension) != std::string::npos) { | |||||
| cifar_files_.push_back(filename); | |||||
| MS_LOG(INFO) << "Cifar operator found file at " << filename << "."; | |||||
| if (file.Extension() == kExtension) { | |||||
| cifar_files_.push_back(file.toString()); | |||||
| } | } | ||||
| } | } | ||||
| } else { | } else { | ||||
| std::string err_msg = "Unable to open directory " + dataset_directory.toString(); | |||||
| RETURN_STATUS_UNEXPECTED(err_msg); | |||||
| } | |||||
| if (cifar_files_.size() == 0) { | |||||
| RETURN_STATUS_UNEXPECTED("No .bin files found under " + folder_path_); | |||||
| RETURN_STATUS_UNEXPECTED("Unable to open directory " + dir_path.toString()); | |||||
| } | } | ||||
| CHECK_FAIL_RETURN_UNEXPECTED(!cifar_files_.empty(), "No .bin files found under " + folder_path_); | |||||
| std::sort(cifar_files_.begin(), cifar_files_.end()); | std::sort(cifar_files_.begin(), cifar_files_.end()); | ||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| @@ -378,9 +390,8 @@ Status CifarOp::ParseCifarData() { | |||||
| num_rows_ = cifar_image_label_pairs_.size(); | num_rows_ = cifar_image_label_pairs_.size(); | ||||
| if (num_rows_ == 0) { | if (num_rows_ == 0) { | ||||
| std::string api = cifar_type_ == kCifar10 ? "Cifar10Dataset" : "Cifar100Dataset"; | std::string api = cifar_type_ == kCifar10 ? "Cifar10Dataset" : "Cifar100Dataset"; | ||||
| std::string err_msg = "There is no valid data matching the dataset API " + api + | |||||
| ".Please check file path or dataset API validation first."; | |||||
| RETURN_STATUS_UNEXPECTED(err_msg); | |||||
| RETURN_STATUS_UNEXPECTED("There is no valid data matching the dataset API " + api + | |||||
| ".Please check file path or dataset API validation first."); | |||||
| } | } | ||||
| cifar_raw_data_block_->Reset(); | cifar_raw_data_block_->Reset(); | ||||
| return Status::OK(); | return Status::OK(); | ||||
| @@ -403,46 +414,51 @@ Status CifarOp::GetClassIds(std::map<int32_t, std::vector<int64_t>> *cls_ids) co | |||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| Status CifarOp::CountTotalRows(const std::string &dir, bool isCIFAR10, int64_t *count) { | |||||
| Status CifarOp::CountTotalRows(const std::string &dir, const std::string &usage, bool isCIFAR10, int64_t *count) { | |||||
| // the logic of counting the number of samples is copied from ReadCifar100Block() and ReadCifar10Block() | // the logic of counting the number of samples is copied from ReadCifar100Block() and ReadCifar10Block() | ||||
| std::shared_ptr<CifarOp> op; | std::shared_ptr<CifarOp> op; | ||||
| *count = 0; | *count = 0; | ||||
| RETURN_IF_NOT_OK(Builder().SetCifarDir(dir).SetCifarType(isCIFAR10).Build(&op)); | |||||
| RETURN_IF_NOT_OK(Builder().SetCifarDir(dir).SetCifarType(isCIFAR10).SetUsage(usage).Build(&op)); | |||||
| RETURN_IF_NOT_OK(op->GetCifarFiles()); | RETURN_IF_NOT_OK(op->GetCifarFiles()); | ||||
| if (op->cifar_type_ == kCifar10) { | if (op->cifar_type_ == kCifar10) { | ||||
| constexpr int64_t num_cifar10_records = 10000; | constexpr int64_t num_cifar10_records = 10000; | ||||
| for (auto &file : op->cifar_files_) { | for (auto &file : op->cifar_files_) { | ||||
| std::ifstream in(file, std::ios::binary); | |||||
| if (!in.is_open()) { | |||||
| std::string err_msg = file + " can not be opened."; | |||||
| RETURN_STATUS_UNEXPECTED(err_msg); | |||||
| Path file_path(file); | |||||
| CHECK_FAIL_RETURN_UNEXPECTED(file_path.Exists() && !file_path.IsDirectory(), "invalid file:" + file); | |||||
| std::string file_name = file_path.Basename(); | |||||
| if (op->usage_ == "train") { | |||||
| if (file_name.find("data_batch") == std::string::npos) continue; | |||||
| } else if (op->usage_ == "test") { | |||||
| if (file_name.find("test_batch") == std::string::npos) continue; | |||||
| } else { // get all the files that contain the word batch, aka any cifar 100 files | |||||
| if (file_name.find("batch") == std::string::npos) continue; | |||||
| } | } | ||||
| std::ifstream in(file, std::ios::binary); | |||||
| CHECK_FAIL_RETURN_UNEXPECTED(in.is_open(), file + " can not be opened."); | |||||
| *count = *count + num_cifar10_records; | *count = *count + num_cifar10_records; | ||||
| } | } | ||||
| return Status::OK(); | return Status::OK(); | ||||
| } else { | } else { | ||||
| int64_t num_cifar100_records = 0; | int64_t num_cifar100_records = 0; | ||||
| for (auto &file : op->cifar_files_) { | for (auto &file : op->cifar_files_) { | ||||
| size_t pos = file.find_last_of('/'); | |||||
| if (pos == std::string::npos) { | |||||
| std::string err_msg = "Invalid cifar100 file path"; | |||||
| RETURN_STATUS_UNEXPECTED(err_msg); | |||||
| } | |||||
| std::string file_name; | |||||
| if (file.size() > 0) | |||||
| file_name = file.substr(pos + 1); | |||||
| else | |||||
| RETURN_STATUS_UNEXPECTED("Invalid string length!"); | |||||
| Path file_path(file); | |||||
| std::string file_name = file_path.Basename(); | |||||
| CHECK_FAIL_RETURN_UNEXPECTED(file_path.Exists() && !file_path.IsDirectory(), "invalid file:" + file); | |||||
| if (op->usage_ == "train" && file_path.Basename().find("train") == std::string::npos) continue; | |||||
| if (op->usage_ == "test" && file_path.Basename().find("test") == std::string::npos) continue; | |||||
| if (file_name.find("test") != std::string::npos) { | if (file_name.find("test") != std::string::npos) { | ||||
| num_cifar100_records = 10000; | |||||
| num_cifar100_records += 10000; | |||||
| } else if (file_name.find("train") != std::string::npos) { | } else if (file_name.find("train") != std::string::npos) { | ||||
| num_cifar100_records = 50000; | |||||
| num_cifar100_records += 50000; | |||||
| } | } | ||||
| std::ifstream in(file, std::ios::binary); | std::ifstream in(file, std::ios::binary); | ||||
| if (!in.is_open()) { | |||||
| std::string err_msg = file + " can not be opened."; | |||||
| RETURN_STATUS_UNEXPECTED(err_msg); | |||||
| } | |||||
| CHECK_FAIL_RETURN_UNEXPECTED(in.is_open(), file + " can not be opened."); | |||||
| } | } | ||||
| *count = num_cifar100_records; | *count = num_cifar100_records; | ||||
| return Status::OK(); | return Status::OK(); | ||||
| @@ -83,15 +83,23 @@ class CifarOp : public ParallelOp, public RandomAccessOp { | |||||
| // Setter method | // Setter method | ||||
| // @param const std::string & dir | // @param const std::string & dir | ||||
| // @return | |||||
| // @return Builder setter method returns reference to the builder. | |||||
| Builder &SetCifarDir(const std::string &dir) { | Builder &SetCifarDir(const std::string &dir) { | ||||
| dir_ = dir; | dir_ = dir; | ||||
| return *this; | return *this; | ||||
| } | } | ||||
| // Setter method | |||||
| // @param const std::string &usage | |||||
| // @return Builder setter method returns reference to the builder. | |||||
| Builder &SetUsage(const std::string &usage) { | |||||
| usage_ = usage; | |||||
| return *this; | |||||
| } | |||||
| // Setter method | // Setter method | ||||
| // @param const std::string & dir | // @param const std::string & dir | ||||
| // @return | |||||
| // @return Builder setter method returns reference to the builder. | |||||
| Builder &SetCifarType(const bool cifar10) { | Builder &SetCifarType(const bool cifar10) { | ||||
| if (cifar10) { | if (cifar10) { | ||||
| cifar_type_ = kCifar10; | cifar_type_ = kCifar10; | ||||
| @@ -112,6 +120,7 @@ class CifarOp : public ParallelOp, public RandomAccessOp { | |||||
| private: | private: | ||||
| std::string dir_; | std::string dir_; | ||||
| std::string usage_; | |||||
| int32_t num_workers_; | int32_t num_workers_; | ||||
| int32_t rows_per_buffer_; | int32_t rows_per_buffer_; | ||||
| int32_t op_connect_size_; | int32_t op_connect_size_; | ||||
| @@ -122,13 +131,15 @@ class CifarOp : public ParallelOp, public RandomAccessOp { | |||||
| // Constructor | // Constructor | ||||
| // @param CifarType type - Cifar10 or Cifar100 | // @param CifarType type - Cifar10 or Cifar100 | ||||
| // @param const std::string &usage - Usage of this dataset, can be 'train', 'test' or 'all' | |||||
| // @param uint32_t numWorks - Num of workers reading images in parallel | // @param uint32_t numWorks - Num of workers reading images in parallel | ||||
| // @param uint32_t - rowsPerBuffer Number of images (rows) in each buffer | // @param uint32_t - rowsPerBuffer Number of images (rows) in each buffer | ||||
| // @param std::string - dir directory of cifar dataset | // @param std::string - dir directory of cifar dataset | ||||
| // @param uint32_t - queueSize - connector queue size | // @param uint32_t - queueSize - connector queue size | ||||
| // @param std::unique_ptr<Sampler> sampler - sampler tells ImageFolderOp what to read | // @param std::unique_ptr<Sampler> sampler - sampler tells ImageFolderOp what to read | ||||
| CifarOp(CifarType type, int32_t num_works, int32_t rows_per_buf, const std::string &file_dir, int32_t queue_size, | |||||
| std::unique_ptr<DataSchema> data_schema, std::shared_ptr<Sampler> sampler); | |||||
| CifarOp(CifarType type, const std::string &usage, int32_t num_works, int32_t rows_per_buf, | |||||
| const std::string &file_dir, int32_t queue_size, std::unique_ptr<DataSchema> data_schema, | |||||
| std::shared_ptr<Sampler> sampler); | |||||
| // Destructor. | // Destructor. | ||||
| ~CifarOp() = default; | ~CifarOp() = default; | ||||
| @@ -153,7 +164,7 @@ class CifarOp : public ParallelOp, public RandomAccessOp { | |||||
| // @param isCIFAR10 true if CIFAR10 and false if CIFAR100 | // @param isCIFAR10 true if CIFAR10 and false if CIFAR100 | ||||
| // @param count output arg that will hold the actual dataset size | // @param count output arg that will hold the actual dataset size | ||||
| // @return | // @return | ||||
| static Status CountTotalRows(const std::string &dir, bool isCIFAR10, int64_t *count); | |||||
| static Status CountTotalRows(const std::string &dir, const std::string &usage, bool isCIFAR10, int64_t *count); | |||||
| /// \brief Base-class override for NodePass visitor acceptor | /// \brief Base-class override for NodePass visitor acceptor | ||||
| /// \param[in] p Pointer to the NodePass to be accepted | /// \param[in] p Pointer to the NodePass to be accepted | ||||
| @@ -224,7 +235,7 @@ class CifarOp : public ParallelOp, public RandomAccessOp { | |||||
| std::unique_ptr<DataSchema> data_schema_; | std::unique_ptr<DataSchema> data_schema_; | ||||
| int64_t row_cnt_; | int64_t row_cnt_; | ||||
| int64_t buf_cnt_; | int64_t buf_cnt_; | ||||
| const std::string usage_; // can only be either "train" or "test" | |||||
| WaitPost wp_; | WaitPost wp_; | ||||
| QueueList<std::unique_ptr<IOBlock>> io_block_queues_; | QueueList<std::unique_ptr<IOBlock>> io_block_queues_; | ||||
| std::unique_ptr<Queue<std::vector<unsigned char>>> cifar_raw_data_block_; | std::unique_ptr<Queue<std::vector<unsigned char>>> cifar_raw_data_block_; | ||||
| @@ -17,6 +17,7 @@ | |||||
| #include <fstream> | #include <fstream> | ||||
| #include <iomanip> | #include <iomanip> | ||||
| #include <set> | |||||
| #include "utils/ms_utils.h" | #include "utils/ms_utils.h" | ||||
| #include "minddata/dataset/core/config_manager.h" | #include "minddata/dataset/core/config_manager.h" | ||||
| #include "minddata/dataset/core/tensor_shape.h" | #include "minddata/dataset/core/tensor_shape.h" | ||||
| @@ -32,7 +33,7 @@ const int32_t kMnistLabelFileMagicNumber = 2049; | |||||
| const int32_t kMnistImageRows = 28; | const int32_t kMnistImageRows = 28; | ||||
| const int32_t kMnistImageCols = 28; | const int32_t kMnistImageCols = 28; | ||||
| MnistOp::Builder::Builder() : builder_sampler_(nullptr) { | |||||
| MnistOp::Builder::Builder() : builder_sampler_(nullptr), builder_usage_("") { | |||||
| std::shared_ptr<ConfigManager> cfg = GlobalContext::config_manager(); | std::shared_ptr<ConfigManager> cfg = GlobalContext::config_manager(); | ||||
| builder_num_workers_ = cfg->num_parallel_workers(); | builder_num_workers_ = cfg->num_parallel_workers(); | ||||
| builder_rows_per_buffer_ = cfg->rows_per_buffer(); | builder_rows_per_buffer_ = cfg->rows_per_buffer(); | ||||
| @@ -52,22 +53,25 @@ Status MnistOp::Builder::Build(std::shared_ptr<MnistOp> *ptr) { | |||||
| TensorShape scalar = TensorShape::CreateScalar(); | TensorShape scalar = TensorShape::CreateScalar(); | ||||
| RETURN_IF_NOT_OK(builder_schema_->AddColumn( | RETURN_IF_NOT_OK(builder_schema_->AddColumn( | ||||
| ColDescriptor("label", DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 0, &scalar))); | ColDescriptor("label", DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 0, &scalar))); | ||||
| *ptr = std::make_shared<MnistOp>(builder_num_workers_, builder_rows_per_buffer_, builder_dir_, | |||||
| *ptr = std::make_shared<MnistOp>(builder_usage_, builder_num_workers_, builder_rows_per_buffer_, builder_dir_, | |||||
| builder_op_connector_size_, std::move(builder_schema_), std::move(builder_sampler_)); | builder_op_connector_size_, std::move(builder_schema_), std::move(builder_sampler_)); | ||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| Status MnistOp::Builder::SanityCheck() { | Status MnistOp::Builder::SanityCheck() { | ||||
| const std::set<std::string> valid = {"test", "train", "all", ""}; | |||||
| Path dir(builder_dir_); | Path dir(builder_dir_); | ||||
| std::string err_msg; | std::string err_msg; | ||||
| err_msg += dir.IsDirectory() == false ? "MNIST path is invalid or not set\n" : ""; | err_msg += dir.IsDirectory() == false ? "MNIST path is invalid or not set\n" : ""; | ||||
| err_msg += builder_num_workers_ <= 0 ? "Number of parallel workers is set to 0 or negative\n" : ""; | err_msg += builder_num_workers_ <= 0 ? "Number of parallel workers is set to 0 or negative\n" : ""; | ||||
| err_msg += valid.find(builder_usage_) == valid.end() ? "usage needs to be 'train','test' or 'all'\n" : ""; | |||||
| return err_msg.empty() ? Status::OK() : Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, err_msg); | return err_msg.empty() ? Status::OK() : Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, err_msg); | ||||
| } | } | ||||
| MnistOp::MnistOp(int32_t num_workers, int32_t rows_per_buffer, std::string folder_path, int32_t queue_size, | |||||
| std::unique_ptr<DataSchema> data_schema, std::shared_ptr<Sampler> sampler) | |||||
| MnistOp::MnistOp(const std::string &usage, int32_t num_workers, int32_t rows_per_buffer, std::string folder_path, | |||||
| int32_t queue_size, std::unique_ptr<DataSchema> data_schema, std::shared_ptr<Sampler> sampler) | |||||
| : ParallelOp(num_workers, queue_size, std::move(sampler)), | : ParallelOp(num_workers, queue_size, std::move(sampler)), | ||||
| usage_(usage), | |||||
| buf_cnt_(0), | buf_cnt_(0), | ||||
| row_cnt_(0), | row_cnt_(0), | ||||
| folder_path_(folder_path), | folder_path_(folder_path), | ||||
| @@ -226,9 +230,7 @@ Status MnistOp::GetClassIds(std::map<int32_t, std::vector<int64_t>> *cls_ids) co | |||||
| Status MnistOp::ReadFromReader(std::ifstream *reader, uint32_t *result) { | Status MnistOp::ReadFromReader(std::ifstream *reader, uint32_t *result) { | ||||
| uint32_t res = 0; | uint32_t res = 0; | ||||
| reader->read(reinterpret_cast<char *>(&res), 4); | reader->read(reinterpret_cast<char *>(&res), 4); | ||||
| if (reader->fail()) { | |||||
| RETURN_STATUS_UNEXPECTED("Failed to read 4 bytes from file"); | |||||
| } | |||||
| CHECK_FAIL_RETURN_UNEXPECTED(!reader->fail(), "Failed to read 4 bytes from file"); | |||||
| *result = SwapEndian(res); | *result = SwapEndian(res); | ||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| @@ -239,15 +241,12 @@ uint32_t MnistOp::SwapEndian(uint32_t val) const { | |||||
| } | } | ||||
| Status MnistOp::CheckImage(const std::string &file_name, std::ifstream *image_reader, uint32_t *num_images) { | Status MnistOp::CheckImage(const std::string &file_name, std::ifstream *image_reader, uint32_t *num_images) { | ||||
| if (image_reader->is_open() == false) { | |||||
| RETURN_STATUS_UNEXPECTED("Cannot open mnist image file: " + file_name); | |||||
| } | |||||
| CHECK_FAIL_RETURN_UNEXPECTED(image_reader->is_open(), "Cannot open mnist image file: " + file_name); | |||||
| int64_t image_len = image_reader->seekg(0, std::ios::end).tellg(); | int64_t image_len = image_reader->seekg(0, std::ios::end).tellg(); | ||||
| (void)image_reader->seekg(0, std::ios::beg); | (void)image_reader->seekg(0, std::ios::beg); | ||||
| // The first 16 bytes of the image file are type, number, row and column | // The first 16 bytes of the image file are type, number, row and column | ||||
| if (image_len < 16) { | |||||
| RETURN_STATUS_UNEXPECTED("Mnist file is corrupted."); | |||||
| } | |||||
| CHECK_FAIL_RETURN_UNEXPECTED(image_len >= 16, "Mnist file is corrupted."); | |||||
| uint32_t magic_number; | uint32_t magic_number; | ||||
| RETURN_IF_NOT_OK(ReadFromReader(image_reader, &magic_number)); | RETURN_IF_NOT_OK(ReadFromReader(image_reader, &magic_number)); | ||||
| CHECK_FAIL_RETURN_UNEXPECTED(magic_number == kMnistImageFileMagicNumber, | CHECK_FAIL_RETURN_UNEXPECTED(magic_number == kMnistImageFileMagicNumber, | ||||
| @@ -260,35 +259,25 @@ Status MnistOp::CheckImage(const std::string &file_name, std::ifstream *image_re | |||||
| uint32_t cols; | uint32_t cols; | ||||
| RETURN_IF_NOT_OK(ReadFromReader(image_reader, &cols)); | RETURN_IF_NOT_OK(ReadFromReader(image_reader, &cols)); | ||||
| // The image size of the Mnist dataset is fixed at [28,28] | // The image size of the Mnist dataset is fixed at [28,28] | ||||
| if ((rows != kMnistImageRows) || (cols != kMnistImageCols)) { | |||||
| RETURN_STATUS_UNEXPECTED("Wrong shape of image."); | |||||
| } | |||||
| if ((image_len - 16) != num_items * rows * cols) { | |||||
| RETURN_STATUS_UNEXPECTED("Wrong number of image."); | |||||
| } | |||||
| CHECK_FAIL_RETURN_UNEXPECTED((rows == kMnistImageRows) && (cols == kMnistImageCols), "Wrong shape of image."); | |||||
| CHECK_FAIL_RETURN_UNEXPECTED((image_len - 16) == num_items * rows * cols, "Wrong number of image."); | |||||
| *num_images = num_items; | *num_images = num_items; | ||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| Status MnistOp::CheckLabel(const std::string &file_name, std::ifstream *label_reader, uint32_t *num_labels) { | Status MnistOp::CheckLabel(const std::string &file_name, std::ifstream *label_reader, uint32_t *num_labels) { | ||||
| if (label_reader->is_open() == false) { | |||||
| RETURN_STATUS_UNEXPECTED("Cannot open mnist label file: " + file_name); | |||||
| } | |||||
| CHECK_FAIL_RETURN_UNEXPECTED(label_reader->is_open(), "Cannot open mnist label file: " + file_name); | |||||
| int64_t label_len = label_reader->seekg(0, std::ios::end).tellg(); | int64_t label_len = label_reader->seekg(0, std::ios::end).tellg(); | ||||
| (void)label_reader->seekg(0, std::ios::beg); | (void)label_reader->seekg(0, std::ios::beg); | ||||
| // The first 8 bytes of the image file are type and number | // The first 8 bytes of the image file are type and number | ||||
| if (label_len < 8) { | |||||
| RETURN_STATUS_UNEXPECTED("Mnist file is corrupted."); | |||||
| } | |||||
| CHECK_FAIL_RETURN_UNEXPECTED(label_len >= 8, "Mnist file is corrupted."); | |||||
| uint32_t magic_number; | uint32_t magic_number; | ||||
| RETURN_IF_NOT_OK(ReadFromReader(label_reader, &magic_number)); | RETURN_IF_NOT_OK(ReadFromReader(label_reader, &magic_number)); | ||||
| CHECK_FAIL_RETURN_UNEXPECTED(magic_number == kMnistLabelFileMagicNumber, | CHECK_FAIL_RETURN_UNEXPECTED(magic_number == kMnistLabelFileMagicNumber, | ||||
| "This is not the mnist label file: " + file_name); | "This is not the mnist label file: " + file_name); | ||||
| uint32_t num_items; | uint32_t num_items; | ||||
| RETURN_IF_NOT_OK(ReadFromReader(label_reader, &num_items)); | RETURN_IF_NOT_OK(ReadFromReader(label_reader, &num_items)); | ||||
| if ((label_len - 8) != num_items) { | |||||
| RETURN_STATUS_UNEXPECTED("Wrong number of labels!"); | |||||
| } | |||||
| CHECK_FAIL_RETURN_UNEXPECTED((label_len - 8) == num_items, "Wrong number of labels!"); | |||||
| *num_labels = num_items; | *num_labels = num_items; | ||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| @@ -330,6 +319,9 @@ Status MnistOp::ReadImageAndLabel(std::ifstream *image_reader, std::ifstream *la | |||||
| } | } | ||||
| Status MnistOp::ParseMnistData() { | Status MnistOp::ParseMnistData() { | ||||
| // MNIST contains 4 files, idx3 are image files, idx 1 are labels | |||||
| // training files contain 60K examples and testing files contain 10K examples | |||||
| // t10k-images-idx3-ubyte t10k-labels-idx1-ubyte train-images-idx3-ubyte train-labels-idx1-ubyte | |||||
| for (size_t i = 0; i < image_names_.size(); ++i) { | for (size_t i = 0; i < image_names_.size(); ++i) { | ||||
| std::ifstream image_reader, label_reader; | std::ifstream image_reader, label_reader; | ||||
| image_reader.open(image_names_[i], std::ios::binary); | image_reader.open(image_names_[i], std::ios::binary); | ||||
| @@ -354,18 +346,22 @@ Status MnistOp::ParseMnistData() { | |||||
| Status MnistOp::WalkAllFiles() { | Status MnistOp::WalkAllFiles() { | ||||
| const std::string kImageExtension = "idx3-ubyte"; | const std::string kImageExtension = "idx3-ubyte"; | ||||
| const std::string kLabelExtension = "idx1-ubyte"; | const std::string kLabelExtension = "idx1-ubyte"; | ||||
| const std::string train_prefix = "train"; | |||||
| const std::string test_prefix = "t10k"; | |||||
| Path dir(folder_path_); | Path dir(folder_path_); | ||||
| auto dir_it = Path::DirIterator::OpenDirectory(&dir); | auto dir_it = Path::DirIterator::OpenDirectory(&dir); | ||||
| std::string prefix; // empty string, used to match usage = "" (default) or usage == "all" | |||||
| if (usage_ == "train" || usage_ == "test") prefix = (usage_ == "test" ? test_prefix : train_prefix); | |||||
| if (dir_it != nullptr) { | if (dir_it != nullptr) { | ||||
| while (dir_it->hasNext()) { | while (dir_it->hasNext()) { | ||||
| Path file = dir_it->next(); | Path file = dir_it->next(); | ||||
| std::string filename = file.toString(); | |||||
| if (filename.find(kImageExtension) != std::string::npos) { | |||||
| image_names_.push_back(filename); | |||||
| std::string filename = file.Basename(); | |||||
| if (filename.find(prefix + "-images-" + kImageExtension) != std::string::npos) { | |||||
| image_names_.push_back(file.toString()); | |||||
| MS_LOG(INFO) << "Mnist operator found image file at " << filename << "."; | MS_LOG(INFO) << "Mnist operator found image file at " << filename << "."; | ||||
| } else if (filename.find(kLabelExtension) != std::string::npos) { | |||||
| label_names_.push_back(filename); | |||||
| } else if (filename.find(prefix + "-labels-" + kLabelExtension) != std::string::npos) { | |||||
| label_names_.push_back(file.toString()); | |||||
| MS_LOG(INFO) << "Mnist Operator found label file at " << filename << "."; | MS_LOG(INFO) << "Mnist Operator found label file at " << filename << "."; | ||||
| } | } | ||||
| } | } | ||||
| @@ -376,9 +372,7 @@ Status MnistOp::WalkAllFiles() { | |||||
| std::sort(image_names_.begin(), image_names_.end()); | std::sort(image_names_.begin(), image_names_.end()); | ||||
| std::sort(label_names_.begin(), label_names_.end()); | std::sort(label_names_.begin(), label_names_.end()); | ||||
| if (image_names_.size() != label_names_.size()) { | |||||
| RETURN_STATUS_UNEXPECTED("num of images does not equal to num of labels"); | |||||
| } | |||||
| CHECK_FAIL_RETURN_UNEXPECTED(image_names_.size() == label_names_.size(), "num of idx3 files != num of idx1 files"); | |||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| @@ -397,11 +391,11 @@ Status MnistOp::LaunchThreadsAndInitOp() { | |||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| Status MnistOp::CountTotalRows(const std::string &dir, int64_t *count) { | |||||
| Status MnistOp::CountTotalRows(const std::string &dir, const std::string &usage, int64_t *count) { | |||||
| // the logic of counting the number of samples is copied from ParseMnistData() and uses CheckReader() | // the logic of counting the number of samples is copied from ParseMnistData() and uses CheckReader() | ||||
| std::shared_ptr<MnistOp> op; | std::shared_ptr<MnistOp> op; | ||||
| *count = 0; | *count = 0; | ||||
| RETURN_IF_NOT_OK(Builder().SetDir(dir).Build(&op)); | |||||
| RETURN_IF_NOT_OK(Builder().SetDir(dir).SetUsage(usage).Build(&op)); | |||||
| RETURN_IF_NOT_OK(op->WalkAllFiles()); | RETURN_IF_NOT_OK(op->WalkAllFiles()); | ||||
| @@ -47,8 +47,6 @@ class MnistOp : public ParallelOp, public RandomAccessOp { | |||||
| class Builder { | class Builder { | ||||
| public: | public: | ||||
| // Constructor for Builder class of MnistOp | // Constructor for Builder class of MnistOp | ||||
| // @param uint32_t numWrks - number of parallel workers | |||||
| // @param dir - directory folder got ImageNetFolder | |||||
| Builder(); | Builder(); | ||||
| // Destructor. | // Destructor. | ||||
| @@ -87,13 +85,20 @@ class MnistOp : public ParallelOp, public RandomAccessOp { | |||||
| } | } | ||||
| // Setter method | // Setter method | ||||
| // @param const std::string & dir | |||||
| // @param const std::string &dir | |||||
| // @return | // @return | ||||
| Builder &SetDir(const std::string &dir) { | Builder &SetDir(const std::string &dir) { | ||||
| builder_dir_ = dir; | builder_dir_ = dir; | ||||
| return *this; | return *this; | ||||
| } | } | ||||
| // Setter method | |||||
| // @param const std::string &usage | |||||
| // @return | |||||
| Builder &SetUsage(const std::string &usage) { | |||||
| builder_usage_ = usage; | |||||
| return *this; | |||||
| } | |||||
| // Check validity of input args | // Check validity of input args | ||||
| // @return - The error code return | // @return - The error code return | ||||
| Status SanityCheck(); | Status SanityCheck(); | ||||
| @@ -105,6 +110,7 @@ class MnistOp : public ParallelOp, public RandomAccessOp { | |||||
| private: | private: | ||||
| std::string builder_dir_; | std::string builder_dir_; | ||||
| std::string builder_usage_; | |||||
| int32_t builder_num_workers_; | int32_t builder_num_workers_; | ||||
| int32_t builder_rows_per_buffer_; | int32_t builder_rows_per_buffer_; | ||||
| int32_t builder_op_connector_size_; | int32_t builder_op_connector_size_; | ||||
| @@ -113,14 +119,15 @@ class MnistOp : public ParallelOp, public RandomAccessOp { | |||||
| }; | }; | ||||
| // Constructor | // Constructor | ||||
| // @param const std::string &usage - Usage of this dataset, can be 'train', 'test' or 'all' | |||||
| // @param int32_t num_workers - number of workers reading images in parallel | // @param int32_t num_workers - number of workers reading images in parallel | ||||
| // @param int32_t rows_per_buffer - number of images (rows) in each buffer | // @param int32_t rows_per_buffer - number of images (rows) in each buffer | ||||
| // @param std::string folder_path - dir directory of mnist | // @param std::string folder_path - dir directory of mnist | ||||
| // @param int32_t queue_size - connector queue size | // @param int32_t queue_size - connector queue size | ||||
| // @param std::unique_ptr<DataSchema> data_schema - the schema of the mnist dataset | // @param std::unique_ptr<DataSchema> data_schema - the schema of the mnist dataset | ||||
| // @param td::unique_ptr<Sampler> sampler - sampler tells MnistOp what to read | // @param td::unique_ptr<Sampler> sampler - sampler tells MnistOp what to read | ||||
| MnistOp(int32_t num_workers, int32_t rows_per_buffer, std::string folder_path, int32_t queue_size, | |||||
| std::unique_ptr<DataSchema> data_schema, std::shared_ptr<Sampler> sampler); | |||||
| MnistOp(const std::string &usage, int32_t num_workers, int32_t rows_per_buffer, std::string folder_path, | |||||
| int32_t queue_size, std::unique_ptr<DataSchema> data_schema, std::shared_ptr<Sampler> sampler); | |||||
| // Destructor. | // Destructor. | ||||
| ~MnistOp() = default; | ~MnistOp() = default; | ||||
| @@ -150,7 +157,7 @@ class MnistOp : public ParallelOp, public RandomAccessOp { | |||||
| // @param dir path to the MNIST directory | // @param dir path to the MNIST directory | ||||
| // @param count output arg that will hold the minimum of the actual dataset size and numSamples | // @param count output arg that will hold the minimum of the actual dataset size and numSamples | ||||
| // @return | // @return | ||||
| static Status CountTotalRows(const std::string &dir, int64_t *count); | |||||
| static Status CountTotalRows(const std::string &dir, const std::string &usage, int64_t *count); | |||||
| /// \brief Base-class override for NodePass visitor acceptor | /// \brief Base-class override for NodePass visitor acceptor | ||||
| /// \param[in] p Pointer to the NodePass to be accepted | /// \param[in] p Pointer to the NodePass to be accepted | ||||
| @@ -241,6 +248,7 @@ class MnistOp : public ParallelOp, public RandomAccessOp { | |||||
| WaitPost wp_; | WaitPost wp_; | ||||
| std::string folder_path_; // directory of image folder | std::string folder_path_; // directory of image folder | ||||
| int32_t rows_per_buffer_; | int32_t rows_per_buffer_; | ||||
| const std::string usage_; // can only be either "train" or "test" | |||||
| std::unique_ptr<DataSchema> data_schema_; | std::unique_ptr<DataSchema> data_schema_; | ||||
| std::vector<MnistLabelPair> image_label_pairs_; | std::vector<MnistLabelPair> image_label_pairs_; | ||||
| std::vector<std::string> image_names_; | std::vector<std::string> image_names_; | ||||
| @@ -18,14 +18,15 @@ | |||||
| #include <algorithm> | #include <algorithm> | ||||
| #include <fstream> | #include <fstream> | ||||
| #include <iomanip> | #include <iomanip> | ||||
| #include "./tinyxml2.h" | #include "./tinyxml2.h" | ||||
| #include "utils/ms_utils.h" | |||||
| #include "minddata/dataset/core/config_manager.h" | #include "minddata/dataset/core/config_manager.h" | ||||
| #include "minddata/dataset/core/tensor_shape.h" | #include "minddata/dataset/core/tensor_shape.h" | ||||
| #include "minddata/dataset/engine/datasetops/source/sampler/sequential_sampler.h" | #include "minddata/dataset/engine/datasetops/source/sampler/sequential_sampler.h" | ||||
| #include "minddata/dataset/engine/db_connector.h" | #include "minddata/dataset/engine/db_connector.h" | ||||
| #include "minddata/dataset/engine/execution_tree.h" | #include "minddata/dataset/engine/execution_tree.h" | ||||
| #include "minddata/dataset/engine/opt/pass.h" | #include "minddata/dataset/engine/opt/pass.h" | ||||
| #include "utils/ms_utils.h" | |||||
| using tinyxml2::XMLDocument; | using tinyxml2::XMLDocument; | ||||
| using tinyxml2::XMLElement; | using tinyxml2::XMLElement; | ||||
| @@ -81,7 +82,7 @@ Status VOCOp::Builder::Build(std::shared_ptr<VOCOp> *ptr) { | |||||
| RETURN_IF_NOT_OK(builder_schema_->AddColumn( | RETURN_IF_NOT_OK(builder_schema_->AddColumn( | ||||
| ColDescriptor(std::string(kColumnTruncate), DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 1))); | ColDescriptor(std::string(kColumnTruncate), DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 1))); | ||||
| } | } | ||||
| *ptr = std::make_shared<VOCOp>(builder_task_type_, builder_task_mode_, builder_dir_, builder_labels_to_read_, | |||||
| *ptr = std::make_shared<VOCOp>(builder_task_type_, builder_usage_, builder_dir_, builder_labels_to_read_, | |||||
| builder_num_workers_, builder_rows_per_buffer_, builder_op_connector_size_, | builder_num_workers_, builder_rows_per_buffer_, builder_op_connector_size_, | ||||
| builder_decode_, std::move(builder_schema_), std::move(builder_sampler_)); | builder_decode_, std::move(builder_schema_), std::move(builder_sampler_)); | ||||
| return Status::OK(); | return Status::OK(); | ||||
| @@ -103,7 +104,7 @@ VOCOp::VOCOp(const TaskType &task_type, const std::string &task_mode, const std: | |||||
| row_cnt_(0), | row_cnt_(0), | ||||
| buf_cnt_(0), | buf_cnt_(0), | ||||
| task_type_(task_type), | task_type_(task_type), | ||||
| task_mode_(task_mode), | |||||
| usage_(task_mode), | |||||
| folder_path_(folder_path), | folder_path_(folder_path), | ||||
| class_index_(class_index), | class_index_(class_index), | ||||
| rows_per_buffer_(rows_per_buffer), | rows_per_buffer_(rows_per_buffer), | ||||
| @@ -251,10 +252,9 @@ Status VOCOp::WorkerEntry(int32_t worker_id) { | |||||
| Status VOCOp::ParseImageIds() { | Status VOCOp::ParseImageIds() { | ||||
| std::string image_sets_file; | std::string image_sets_file; | ||||
| if (task_type_ == TaskType::Segmentation) { | if (task_type_ == TaskType::Segmentation) { | ||||
| image_sets_file = | |||||
| folder_path_ + std::string(kImageSetsSegmentation) + task_mode_ + std::string(kImageSetsExtension); | |||||
| image_sets_file = folder_path_ + std::string(kImageSetsSegmentation) + usage_ + std::string(kImageSetsExtension); | |||||
| } else if (task_type_ == TaskType::Detection) { | } else if (task_type_ == TaskType::Detection) { | ||||
| image_sets_file = folder_path_ + std::string(kImageSetsMain) + task_mode_ + std::string(kImageSetsExtension); | |||||
| image_sets_file = folder_path_ + std::string(kImageSetsMain) + usage_ + std::string(kImageSetsExtension); | |||||
| } | } | ||||
| std::ifstream in_file; | std::ifstream in_file; | ||||
| in_file.open(image_sets_file); | in_file.open(image_sets_file); | ||||
| @@ -431,13 +431,13 @@ Status VOCOp::CountTotalRows(const std::string &dir, const std::string &task_typ | |||||
| std::shared_ptr<VOCOp> op; | std::shared_ptr<VOCOp> op; | ||||
| RETURN_IF_NOT_OK( | RETURN_IF_NOT_OK( | ||||
| Builder().SetDir(dir).SetTask(task_type).SetMode(task_mode).SetClassIndex(input_class_indexing).Build(&op)); | |||||
| Builder().SetDir(dir).SetTask(task_type).SetUsage(task_mode).SetClassIndex(input_class_indexing).Build(&op)); | |||||
| RETURN_IF_NOT_OK(op->ParseImageIds()); | RETURN_IF_NOT_OK(op->ParseImageIds()); | ||||
| RETURN_IF_NOT_OK(op->ParseAnnotationIds()); | RETURN_IF_NOT_OK(op->ParseAnnotationIds()); | ||||
| *count = static_cast<int64_t>(op->image_ids_.size()); | *count = static_cast<int64_t>(op->image_ids_.size()); | ||||
| } else if (task_type == "Segmentation") { | } else if (task_type == "Segmentation") { | ||||
| std::shared_ptr<VOCOp> op; | std::shared_ptr<VOCOp> op; | ||||
| RETURN_IF_NOT_OK(Builder().SetDir(dir).SetTask(task_type).SetMode(task_mode).Build(&op)); | |||||
| RETURN_IF_NOT_OK(Builder().SetDir(dir).SetTask(task_type).SetUsage(task_mode).Build(&op)); | |||||
| RETURN_IF_NOT_OK(op->ParseImageIds()); | RETURN_IF_NOT_OK(op->ParseImageIds()); | ||||
| *count = static_cast<int64_t>(op->image_ids_.size()); | *count = static_cast<int64_t>(op->image_ids_.size()); | ||||
| } | } | ||||
| @@ -458,7 +458,7 @@ Status VOCOp::GetClassIndexing(const std::string &dir, const std::string &task_t | |||||
| } else { | } else { | ||||
| std::shared_ptr<VOCOp> op; | std::shared_ptr<VOCOp> op; | ||||
| RETURN_IF_NOT_OK( | RETURN_IF_NOT_OK( | ||||
| Builder().SetDir(dir).SetTask(task_type).SetMode(task_mode).SetClassIndex(input_class_indexing).Build(&op)); | |||||
| Builder().SetDir(dir).SetTask(task_type).SetUsage(task_mode).SetClassIndex(input_class_indexing).Build(&op)); | |||||
| RETURN_IF_NOT_OK(op->ParseImageIds()); | RETURN_IF_NOT_OK(op->ParseImageIds()); | ||||
| RETURN_IF_NOT_OK(op->ParseAnnotationIds()); | RETURN_IF_NOT_OK(op->ParseAnnotationIds()); | ||||
| for (const auto label : op->label_index_) { | for (const auto label : op->label_index_) { | ||||
| @@ -73,7 +73,7 @@ class VOCOp : public ParallelOp, public RandomAccessOp { | |||||
| } | } | ||||
| // Setter method. | // Setter method. | ||||
| // @param const std::string & task_type | |||||
| // @param const std::string &task_type | |||||
| // @return Builder setter method returns reference to the builder. | // @return Builder setter method returns reference to the builder. | ||||
| Builder &SetTask(const std::string &task_type) { | Builder &SetTask(const std::string &task_type) { | ||||
| if (task_type == "Segmentation") { | if (task_type == "Segmentation") { | ||||
| @@ -85,10 +85,10 @@ class VOCOp : public ParallelOp, public RandomAccessOp { | |||||
| } | } | ||||
| // Setter method. | // Setter method. | ||||
| // @param const std::string & task_mode | |||||
| // @param const std::string &usage | |||||
| // @return Builder setter method returns reference to the builder. | // @return Builder setter method returns reference to the builder. | ||||
| Builder &SetMode(const std::string &task_mode) { | |||||
| builder_task_mode_ = task_mode; | |||||
| Builder &SetUsage(const std::string &usage) { | |||||
| builder_usage_ = usage; | |||||
| return *this; | return *this; | ||||
| } | } | ||||
| @@ -145,7 +145,7 @@ class VOCOp : public ParallelOp, public RandomAccessOp { | |||||
| bool builder_decode_; | bool builder_decode_; | ||||
| std::string builder_dir_; | std::string builder_dir_; | ||||
| TaskType builder_task_type_; | TaskType builder_task_type_; | ||||
| std::string builder_task_mode_; | |||||
| std::string builder_usage_; | |||||
| int32_t builder_num_workers_; | int32_t builder_num_workers_; | ||||
| int32_t builder_op_connector_size_; | int32_t builder_op_connector_size_; | ||||
| int32_t builder_rows_per_buffer_; | int32_t builder_rows_per_buffer_; | ||||
| @@ -279,7 +279,7 @@ class VOCOp : public ParallelOp, public RandomAccessOp { | |||||
| int64_t buf_cnt_; | int64_t buf_cnt_; | ||||
| std::string folder_path_; | std::string folder_path_; | ||||
| TaskType task_type_; | TaskType task_type_; | ||||
| std::string task_mode_; | |||||
| std::string usage_; | |||||
| int32_t rows_per_buffer_; | int32_t rows_per_buffer_; | ||||
| std::unique_ptr<DataSchema> data_schema_; | std::unique_ptr<DataSchema> data_schema_; | ||||
| @@ -111,34 +111,36 @@ std::shared_ptr<AlbumDataset> Album(const std::string &dataset_dir, const std::s | |||||
| /// \brief Function to create a CelebADataset | /// \brief Function to create a CelebADataset | ||||
| /// \notes The generated dataset has two columns ['image', 'attr']. | /// \notes The generated dataset has two columns ['image', 'attr']. | ||||
| // The type of the image tensor is uint8. The attr tensor is uint32 and one hot type. | |||||
| /// The type of the image tensor is uint8. The attr tensor is uint32 and one hot type. | |||||
| /// \param[in] dataset_dir Path to the root directory that contains the dataset. | /// \param[in] dataset_dir Path to the root directory that contains the dataset. | ||||
| /// \param[in] dataset_type One of 'all', 'train', 'valid' or 'test'. | |||||
| /// \param[in] usage One of "all", "train", "valid" or "test". | |||||
| /// \param[in] sampler Object used to choose samples from the dataset. If sampler is not given, | /// \param[in] sampler Object used to choose samples from the dataset. If sampler is not given, | ||||
| /// a `RandomSampler` will be used to randomly iterate the entire dataset (default = RandomSampler()) | /// a `RandomSampler` will be used to randomly iterate the entire dataset (default = RandomSampler()) | ||||
| /// \param[in] decode Decode the images after reading (default=false). | /// \param[in] decode Decode the images after reading (default=false). | ||||
| /// \param[in] extensions Set of file extensions to be included in the dataset (default={}). | /// \param[in] extensions Set of file extensions to be included in the dataset (default={}). | ||||
| /// \return Shared pointer to the current Dataset | /// \return Shared pointer to the current Dataset | ||||
| std::shared_ptr<CelebADataset> CelebA(const std::string &dataset_dir, const std::string &dataset_type = "all", | |||||
| std::shared_ptr<CelebADataset> CelebA(const std::string &dataset_dir, const std::string &usage = "all", | |||||
| const std::shared_ptr<SamplerObj> &sampler = RandomSampler(), bool decode = false, | const std::shared_ptr<SamplerObj> &sampler = RandomSampler(), bool decode = false, | ||||
| const std::set<std::string> &extensions = {}); | const std::set<std::string> &extensions = {}); | ||||
| /// \brief Function to create a Cifar10 Dataset | /// \brief Function to create a Cifar10 Dataset | ||||
| /// \notes The generated dataset has two columns ['image', 'label'] | |||||
| /// \notes The generated dataset has two columns ["image", "label"] | |||||
| /// \param[in] dataset_dir Path to the root directory that contains the dataset | /// \param[in] dataset_dir Path to the root directory that contains the dataset | ||||
| /// \param[in] usage of CIFAR10, can be "train", "test" or "all" | |||||
| /// \param[in] sampler Object used to choose samples from the dataset. If sampler is not given, | /// \param[in] sampler Object used to choose samples from the dataset. If sampler is not given, | ||||
| /// a `RandomSampler` will be used to randomly iterate the entire dataset (default = RandomSampler()) | /// a `RandomSampler` will be used to randomly iterate the entire dataset (default = RandomSampler()) | ||||
| /// \return Shared pointer to the current Dataset | /// \return Shared pointer to the current Dataset | ||||
| std::shared_ptr<Cifar10Dataset> Cifar10(const std::string &dataset_dir, | |||||
| std::shared_ptr<Cifar10Dataset> Cifar10(const std::string &dataset_dir, const std::string &usage = std::string(), | |||||
| const std::shared_ptr<SamplerObj> &sampler = RandomSampler()); | const std::shared_ptr<SamplerObj> &sampler = RandomSampler()); | ||||
| /// \brief Function to create a Cifar100 Dataset | /// \brief Function to create a Cifar100 Dataset | ||||
| /// \notes The generated dataset has three columns ['image', 'coarse_label', 'fine_label'] | |||||
| /// \notes The generated dataset has three columns ["image", "coarse_label", "fine_label"] | |||||
| /// \param[in] dataset_dir Path to the root directory that contains the dataset | /// \param[in] dataset_dir Path to the root directory that contains the dataset | ||||
| /// \param[in] usage of CIFAR100, can be "train", "test" or "all" | |||||
| /// \param[in] sampler Object used to choose samples from the dataset. If sampler is not given, | /// \param[in] sampler Object used to choose samples from the dataset. If sampler is not given, | ||||
| /// a `RandomSampler` will be used to randomly iterate the entire dataset (default = RandomSampler()) | /// a `RandomSampler` will be used to randomly iterate the entire dataset (default = RandomSampler()) | ||||
| /// \return Shared pointer to the current Dataset | /// \return Shared pointer to the current Dataset | ||||
| std::shared_ptr<Cifar100Dataset> Cifar100(const std::string &dataset_dir, | |||||
| std::shared_ptr<Cifar100Dataset> Cifar100(const std::string &dataset_dir, const std::string &usage = std::string(), | |||||
| const std::shared_ptr<SamplerObj> &sampler = RandomSampler()); | const std::shared_ptr<SamplerObj> &sampler = RandomSampler()); | ||||
| /// \brief Function to create a CLUEDataset | /// \brief Function to create a CLUEDataset | ||||
| @@ -212,7 +214,7 @@ std::shared_ptr<CSVDataset> CSV(const std::vector<std::string> &dataset_files, c | |||||
| /// \brief Function to create an ImageFolderDataset | /// \brief Function to create an ImageFolderDataset | ||||
| /// \notes A source dataset that reads images from a tree of directories | /// \notes A source dataset that reads images from a tree of directories | ||||
| /// All images within one folder have the same label | /// All images within one folder have the same label | ||||
| /// The generated dataset has two columns ['image', 'label'] | |||||
| /// The generated dataset has two columns ["image", "label"] | |||||
| /// \param[in] dataset_dir Path to the root directory that contains the dataset | /// \param[in] dataset_dir Path to the root directory that contains the dataset | ||||
| /// \param[in] decode A flag to decode in ImageFolder | /// \param[in] decode A flag to decode in ImageFolder | ||||
| /// \param[in] sampler Object used to choose samples from the dataset. If sampler is not given, | /// \param[in] sampler Object used to choose samples from the dataset. If sampler is not given, | ||||
| @@ -227,7 +229,7 @@ std::shared_ptr<ImageFolderDataset> ImageFolder(const std::string &dataset_dir, | |||||
| #ifndef ENABLE_ANDROID | #ifndef ENABLE_ANDROID | ||||
| /// \brief Function to create a ManifestDataset | /// \brief Function to create a ManifestDataset | ||||
| /// \notes The generated dataset has two columns ['image', 'label'] | |||||
| /// \notes The generated dataset has two columns ["image", "label"] | |||||
| /// \param[in] dataset_file The dataset file to be read | /// \param[in] dataset_file The dataset file to be read | ||||
| /// \param[in] usage Need "train", "eval" or "inference" data (default="train") | /// \param[in] usage Need "train", "eval" or "inference" data (default="train") | ||||
| /// \param[in] sampler Object used to choose samples from the dataset. If sampler is not given, | /// \param[in] sampler Object used to choose samples from the dataset. If sampler is not given, | ||||
| @@ -243,12 +245,13 @@ std::shared_ptr<ManifestDataset> Manifest(const std::string &dataset_file, const | |||||
| #endif | #endif | ||||
| /// \brief Function to create a MnistDataset | /// \brief Function to create a MnistDataset | ||||
| /// \notes The generated dataset has two columns ['image', 'label'] | |||||
| /// \notes The generated dataset has two columns ["image", "label"] | |||||
| /// \param[in] dataset_dir Path to the root directory that contains the dataset | /// \param[in] dataset_dir Path to the root directory that contains the dataset | ||||
| /// \param[in] usage of MNIST, can be "train", "test" or "all" | |||||
| /// \param[in] sampler Object used to choose samples from the dataset. If sampler is not given, | /// \param[in] sampler Object used to choose samples from the dataset. If sampler is not given, | ||||
| /// a `RandomSampler` will be used to randomly iterate the entire dataset (default = RandomSampler()) | /// a `RandomSampler` will be used to randomly iterate the entire dataset (default = RandomSampler()) | ||||
| /// \return Shared pointer to the current MnistDataset | /// \return Shared pointer to the current MnistDataset | ||||
| std::shared_ptr<MnistDataset> Mnist(const std::string &dataset_dir, | |||||
| std::shared_ptr<MnistDataset> Mnist(const std::string &dataset_dir, const std::string &usage = std::string(), | |||||
| const std::shared_ptr<SamplerObj> &sampler = RandomSampler()); | const std::shared_ptr<SamplerObj> &sampler = RandomSampler()); | ||||
| /// \brief Function to create a ConcatDataset | /// \brief Function to create a ConcatDataset | ||||
| @@ -404,14 +407,14 @@ std::shared_ptr<TFRecordDataset> TFRecord(const std::vector<std::string> &datase | |||||
| /// - task='Segmentation', column: [['image', dtype=uint8], ['target',dtype=uint8]]. | /// - task='Segmentation', column: [['image', dtype=uint8], ['target',dtype=uint8]]. | ||||
| /// \param[in] dataset_dir Path to the root directory that contains the dataset | /// \param[in] dataset_dir Path to the root directory that contains the dataset | ||||
| /// \param[in] task Set the task type of reading voc data, now only support "Segmentation" or "Detection" | /// \param[in] task Set the task type of reading voc data, now only support "Segmentation" or "Detection" | ||||
| /// \param[in] mode Set the data list txt file to be readed | |||||
| /// \param[in] usage The type of data list text file to be read | |||||
| /// \param[in] class_indexing A str-to-int mapping from label name to index, only valid in "Detection" task | /// \param[in] class_indexing A str-to-int mapping from label name to index, only valid in "Detection" task | ||||
| /// \param[in] decode Decode the images after reading | /// \param[in] decode Decode the images after reading | ||||
| /// \param[in] sampler Object used to choose samples from the dataset. If sampler is not given, | /// \param[in] sampler Object used to choose samples from the dataset. If sampler is not given, | ||||
| /// a `RandomSampler` will be used to randomly iterate the entire dataset (default = RandomSampler()) | /// a `RandomSampler` will be used to randomly iterate the entire dataset (default = RandomSampler()) | ||||
| /// \return Shared pointer to the current Dataset | /// \return Shared pointer to the current Dataset | ||||
| std::shared_ptr<VOCDataset> VOC(const std::string &dataset_dir, const std::string &task = "Segmentation", | std::shared_ptr<VOCDataset> VOC(const std::string &dataset_dir, const std::string &task = "Segmentation", | ||||
| const std::string &mode = "train", | |||||
| const std::string &usage = "train", | |||||
| const std::map<std::string, int32_t> &class_indexing = {}, bool decode = false, | const std::map<std::string, int32_t> &class_indexing = {}, bool decode = false, | ||||
| const std::shared_ptr<SamplerObj> &sampler = RandomSampler()); | const std::shared_ptr<SamplerObj> &sampler = RandomSampler()); | ||||
| #endif | #endif | ||||
| @@ -702,9 +705,8 @@ class AlbumDataset : public Dataset { | |||||
| class CelebADataset : public Dataset { | class CelebADataset : public Dataset { | ||||
| public: | public: | ||||
| /// \brief Constructor | /// \brief Constructor | ||||
| CelebADataset(const std::string &dataset_dir, const std::string &dataset_type, | |||||
| const std::shared_ptr<SamplerObj> &sampler, const bool &decode, | |||||
| const std::set<std::string> &extensions); | |||||
| CelebADataset(const std::string &dataset_dir, const std::string &usage, const std::shared_ptr<SamplerObj> &sampler, | |||||
| const bool &decode, const std::set<std::string> &extensions); | |||||
| /// \brief Destructor | /// \brief Destructor | ||||
| ~CelebADataset() = default; | ~CelebADataset() = default; | ||||
| @@ -719,7 +721,7 @@ class CelebADataset : public Dataset { | |||||
| private: | private: | ||||
| std::string dataset_dir_; | std::string dataset_dir_; | ||||
| std::string dataset_type_; | |||||
| std::string usage_; | |||||
| bool decode_; | bool decode_; | ||||
| std::set<std::string> extensions_; | std::set<std::string> extensions_; | ||||
| std::shared_ptr<SamplerObj> sampler_; | std::shared_ptr<SamplerObj> sampler_; | ||||
| @@ -730,7 +732,7 @@ class CelebADataset : public Dataset { | |||||
| class Cifar10Dataset : public Dataset { | class Cifar10Dataset : public Dataset { | ||||
| public: | public: | ||||
| /// \brief Constructor | /// \brief Constructor | ||||
| Cifar10Dataset(const std::string &dataset_dir, std::shared_ptr<SamplerObj> sampler); | |||||
| Cifar10Dataset(const std::string &dataset_dir, const std::string &usage, std::shared_ptr<SamplerObj> sampler); | |||||
| /// \brief Destructor | /// \brief Destructor | ||||
| ~Cifar10Dataset() = default; | ~Cifar10Dataset() = default; | ||||
| @@ -745,13 +747,14 @@ class Cifar10Dataset : public Dataset { | |||||
| private: | private: | ||||
| std::string dataset_dir_; | std::string dataset_dir_; | ||||
| std::string usage_; | |||||
| std::shared_ptr<SamplerObj> sampler_; | std::shared_ptr<SamplerObj> sampler_; | ||||
| }; | }; | ||||
| class Cifar100Dataset : public Dataset { | class Cifar100Dataset : public Dataset { | ||||
| public: | public: | ||||
| /// \brief Constructor | /// \brief Constructor | ||||
| Cifar100Dataset(const std::string &dataset_dir, std::shared_ptr<SamplerObj> sampler); | |||||
| Cifar100Dataset(const std::string &dataset_dir, const std::string &usage, std::shared_ptr<SamplerObj> sampler); | |||||
| /// \brief Destructor | /// \brief Destructor | ||||
| ~Cifar100Dataset() = default; | ~Cifar100Dataset() = default; | ||||
| @@ -766,6 +769,7 @@ class Cifar100Dataset : public Dataset { | |||||
| private: | private: | ||||
| std::string dataset_dir_; | std::string dataset_dir_; | ||||
| std::string usage_; | |||||
| std::shared_ptr<SamplerObj> sampler_; | std::shared_ptr<SamplerObj> sampler_; | ||||
| }; | }; | ||||
| @@ -831,7 +835,7 @@ class CocoDataset : public Dataset { | |||||
| enum CsvType : uint8_t { INT = 0, FLOAT, STRING }; | enum CsvType : uint8_t { INT = 0, FLOAT, STRING }; | ||||
| /// \brief Base class of CSV Record | /// \brief Base class of CSV Record | ||||
| struct CsvBase { | |||||
| class CsvBase { | |||||
| public: | public: | ||||
| CsvBase() = default; | CsvBase() = default; | ||||
| explicit CsvBase(CsvType t) : type(t) {} | explicit CsvBase(CsvType t) : type(t) {} | ||||
| @@ -936,7 +940,7 @@ class ManifestDataset : public Dataset { | |||||
| class MnistDataset : public Dataset { | class MnistDataset : public Dataset { | ||||
| public: | public: | ||||
| /// \brief Constructor | /// \brief Constructor | ||||
| MnistDataset(std::string dataset_dir, std::shared_ptr<SamplerObj> sampler); | |||||
| MnistDataset(std::string dataset_dir, std::string usage, std::shared_ptr<SamplerObj> sampler); | |||||
| /// \brief Destructor | /// \brief Destructor | ||||
| ~MnistDataset() = default; | ~MnistDataset() = default; | ||||
| @@ -951,6 +955,7 @@ class MnistDataset : public Dataset { | |||||
| private: | private: | ||||
| std::string dataset_dir_; | std::string dataset_dir_; | ||||
| std::string usage_; | |||||
| std::shared_ptr<SamplerObj> sampler_; | std::shared_ptr<SamplerObj> sampler_; | ||||
| }; | }; | ||||
| @@ -1087,7 +1092,7 @@ class TFRecordDataset : public Dataset { | |||||
| class VOCDataset : public Dataset { | class VOCDataset : public Dataset { | ||||
| public: | public: | ||||
| /// \brief Constructor | /// \brief Constructor | ||||
| VOCDataset(const std::string &dataset_dir, const std::string &task, const std::string &mode, | |||||
| VOCDataset(const std::string &dataset_dir, const std::string &task, const std::string &usage, | |||||
| const std::map<std::string, int32_t> &class_indexing, bool decode, std::shared_ptr<SamplerObj> sampler); | const std::map<std::string, int32_t> &class_indexing, bool decode, std::shared_ptr<SamplerObj> sampler); | ||||
| /// \brief Destructor | /// \brief Destructor | ||||
| @@ -1110,7 +1115,7 @@ class VOCDataset : public Dataset { | |||||
| const std::string kColumnTruncate = "truncate"; | const std::string kColumnTruncate = "truncate"; | ||||
| std::string dataset_dir_; | std::string dataset_dir_; | ||||
| std::string task_; | std::string task_; | ||||
| std::string mode_; | |||||
| std::string usage_; | |||||
| std::map<std::string, int32_t> class_index_; | std::map<std::string, int32_t> class_index_; | ||||
| bool decode_; | bool decode_; | ||||
| std::shared_ptr<SamplerObj> sampler_; | std::shared_ptr<SamplerObj> sampler_; | ||||
| @@ -132,6 +132,12 @@ def check_valid_detype(type_): | |||||
| return True | return True | ||||
| def check_valid_str(value, valid_strings, arg_name=""): | |||||
| type_check(value, (str,), arg_name) | |||||
| if value not in valid_strings: | |||||
| raise ValueError("Input {0} is not within the valid set of {1}.".format(arg_name, str(valid_strings))) | |||||
| def check_columns(columns, name): | def check_columns(columns, name): | ||||
| """ | """ | ||||
| Validate strings in column_names. | Validate strings in column_names. | ||||
| @@ -2877,6 +2877,9 @@ class MnistDataset(MappableDataset): | |||||
| Args: | Args: | ||||
| dataset_dir (str): Path to the root directory that contains the dataset. | dataset_dir (str): Path to the root directory that contains the dataset. | ||||
| usage (str, optional): Usage of this dataset, can be "train", "test" or "all" . "train" will read from 60,000 | |||||
| train samples, "test" will read from 10,000 test samples, "all" will read from all 70,000 samples. | |||||
| (default=None, all samples) | |||||
| num_samples (int, optional): The number of images to be included in the dataset | num_samples (int, optional): The number of images to be included in the dataset | ||||
| (default=None, all images). | (default=None, all images). | ||||
| num_parallel_workers (int, optional): Number of workers to read the data | num_parallel_workers (int, optional): Number of workers to read the data | ||||
| @@ -2906,11 +2909,12 @@ class MnistDataset(MappableDataset): | |||||
| """ | """ | ||||
| @check_mnist_cifar_dataset | @check_mnist_cifar_dataset | ||||
| def __init__(self, dataset_dir, num_samples=None, num_parallel_workers=None, | |||||
| def __init__(self, dataset_dir, usage=None, num_samples=None, num_parallel_workers=None, | |||||
| shuffle=None, sampler=None, num_shards=None, shard_id=None): | shuffle=None, sampler=None, num_shards=None, shard_id=None): | ||||
| super().__init__(num_parallel_workers) | super().__init__(num_parallel_workers) | ||||
| self.dataset_dir = dataset_dir | self.dataset_dir = dataset_dir | ||||
| self.usage = usage | |||||
| self.sampler = _select_sampler(num_samples, sampler, shuffle, num_shards, shard_id) | self.sampler = _select_sampler(num_samples, sampler, shuffle, num_shards, shard_id) | ||||
| self.num_samples = num_samples | self.num_samples = num_samples | ||||
| self.shuffle_level = shuffle | self.shuffle_level = shuffle | ||||
| @@ -2920,6 +2924,7 @@ class MnistDataset(MappableDataset): | |||||
| def get_args(self): | def get_args(self): | ||||
| args = super().get_args() | args = super().get_args() | ||||
| args["dataset_dir"] = self.dataset_dir | args["dataset_dir"] = self.dataset_dir | ||||
| args["usage"] = self.usage | |||||
| args["num_samples"] = self.num_samples | args["num_samples"] = self.num_samples | ||||
| args["shuffle"] = self.shuffle_level | args["shuffle"] = self.shuffle_level | ||||
| args["sampler"] = self.sampler | args["sampler"] = self.sampler | ||||
| @@ -2935,7 +2940,7 @@ class MnistDataset(MappableDataset): | |||||
| Number, number of batches. | Number, number of batches. | ||||
| """ | """ | ||||
| if self.dataset_size is None: | if self.dataset_size is None: | ||||
| num_rows = MnistOp.get_num_rows(self.dataset_dir) | |||||
| num_rows = MnistOp.get_num_rows(self.dataset_dir, "all" if self.usage is None else self.usage) | |||||
| self.dataset_size = get_num_rows(num_rows, self.num_shards) | self.dataset_size = get_num_rows(num_rows, self.num_shards) | ||||
| rows_from_sampler = self._get_sampler_dataset_size() | rows_from_sampler = self._get_sampler_dataset_size() | ||||
| if rows_from_sampler is not None and rows_from_sampler < self.dataset_size: | if rows_from_sampler is not None and rows_from_sampler < self.dataset_size: | ||||
| @@ -3913,6 +3918,9 @@ class Cifar10Dataset(MappableDataset): | |||||
| Args: | Args: | ||||
| dataset_dir (str): Path to the root directory that contains the dataset. | dataset_dir (str): Path to the root directory that contains the dataset. | ||||
| usage (str, optional): Usage of this dataset, can be "train", "test" or "all" . "train" will read from 50,000 | |||||
| train samples, "test" will read from 10,000 test samples, "all" will read from all 60,000 samples. | |||||
| (default=None, all samples) | |||||
| num_samples (int, optional): The number of images to be included in the dataset. | num_samples (int, optional): The number of images to be included in the dataset. | ||||
| (default=None, all images). | (default=None, all images). | ||||
| num_parallel_workers (int, optional): Number of workers to read the data | num_parallel_workers (int, optional): Number of workers to read the data | ||||
| @@ -3946,11 +3954,12 @@ class Cifar10Dataset(MappableDataset): | |||||
| """ | """ | ||||
| @check_mnist_cifar_dataset | @check_mnist_cifar_dataset | ||||
| def __init__(self, dataset_dir, num_samples=None, num_parallel_workers=None, | |||||
| def __init__(self, dataset_dir, usage=None, num_samples=None, num_parallel_workers=None, | |||||
| shuffle=None, sampler=None, num_shards=None, shard_id=None): | shuffle=None, sampler=None, num_shards=None, shard_id=None): | ||||
| super().__init__(num_parallel_workers) | super().__init__(num_parallel_workers) | ||||
| self.dataset_dir = dataset_dir | self.dataset_dir = dataset_dir | ||||
| self.usage = usage | |||||
| self.sampler = _select_sampler(num_samples, sampler, shuffle, num_shards, shard_id) | self.sampler = _select_sampler(num_samples, sampler, shuffle, num_shards, shard_id) | ||||
| self.num_samples = num_samples | self.num_samples = num_samples | ||||
| self.num_shards = num_shards | self.num_shards = num_shards | ||||
| @@ -3960,6 +3969,7 @@ class Cifar10Dataset(MappableDataset): | |||||
| def get_args(self): | def get_args(self): | ||||
| args = super().get_args() | args = super().get_args() | ||||
| args["dataset_dir"] = self.dataset_dir | args["dataset_dir"] = self.dataset_dir | ||||
| args["usage"] = self.usage | |||||
| args["num_samples"] = self.num_samples | args["num_samples"] = self.num_samples | ||||
| args["sampler"] = self.sampler | args["sampler"] = self.sampler | ||||
| args["num_shards"] = self.num_shards | args["num_shards"] = self.num_shards | ||||
| @@ -3975,7 +3985,7 @@ class Cifar10Dataset(MappableDataset): | |||||
| Number, number of batches. | Number, number of batches. | ||||
| """ | """ | ||||
| if self.dataset_size is None: | if self.dataset_size is None: | ||||
| num_rows = CifarOp.get_num_rows(self.dataset_dir, True) | |||||
| num_rows = CifarOp.get_num_rows(self.dataset_dir, "all" if self.usage is None else self.usage, True) | |||||
| self.dataset_size = get_num_rows(num_rows, self.num_shards) | self.dataset_size = get_num_rows(num_rows, self.num_shards) | ||||
| rows_from_sampler = self._get_sampler_dataset_size() | rows_from_sampler = self._get_sampler_dataset_size() | ||||
| @@ -4051,6 +4061,9 @@ class Cifar100Dataset(MappableDataset): | |||||
| Args: | Args: | ||||
| dataset_dir (str): Path to the root directory that contains the dataset. | dataset_dir (str): Path to the root directory that contains the dataset. | ||||
| usage (str, optional): Usage of this dataset, can be "train", "test" or "all" . "train" will read from 50,000 | |||||
| train samples, "test" will read from 10,000 test samples, "all" will read from all 60,000 samples. | |||||
| (default=None, all samples) | |||||
| num_samples (int, optional): The number of images to be included in the dataset. | num_samples (int, optional): The number of images to be included in the dataset. | ||||
| (default=None, all images). | (default=None, all images). | ||||
| num_parallel_workers (int, optional): Number of workers to read the data | num_parallel_workers (int, optional): Number of workers to read the data | ||||
| @@ -4082,11 +4095,12 @@ class Cifar100Dataset(MappableDataset): | |||||
| """ | """ | ||||
| @check_mnist_cifar_dataset | @check_mnist_cifar_dataset | ||||
| def __init__(self, dataset_dir, num_samples=None, num_parallel_workers=None, | |||||
| def __init__(self, dataset_dir, usage=None, num_samples=None, num_parallel_workers=None, | |||||
| shuffle=None, sampler=None, num_shards=None, shard_id=None): | shuffle=None, sampler=None, num_shards=None, shard_id=None): | ||||
| super().__init__(num_parallel_workers) | super().__init__(num_parallel_workers) | ||||
| self.dataset_dir = dataset_dir | self.dataset_dir = dataset_dir | ||||
| self.usage = usage | |||||
| self.sampler = _select_sampler(num_samples, sampler, shuffle, num_shards, shard_id) | self.sampler = _select_sampler(num_samples, sampler, shuffle, num_shards, shard_id) | ||||
| self.num_samples = num_samples | self.num_samples = num_samples | ||||
| self.num_shards = num_shards | self.num_shards = num_shards | ||||
| @@ -4096,6 +4110,7 @@ class Cifar100Dataset(MappableDataset): | |||||
| def get_args(self): | def get_args(self): | ||||
| args = super().get_args() | args = super().get_args() | ||||
| args["dataset_dir"] = self.dataset_dir | args["dataset_dir"] = self.dataset_dir | ||||
| args["usage"] = self.usage | |||||
| args["num_samples"] = self.num_samples | args["num_samples"] = self.num_samples | ||||
| args["sampler"] = self.sampler | args["sampler"] = self.sampler | ||||
| args["num_shards"] = self.num_shards | args["num_shards"] = self.num_shards | ||||
| @@ -4111,7 +4126,7 @@ class Cifar100Dataset(MappableDataset): | |||||
| Number, number of batches. | Number, number of batches. | ||||
| """ | """ | ||||
| if self.dataset_size is None: | if self.dataset_size is None: | ||||
| num_rows = CifarOp.get_num_rows(self.dataset_dir, False) | |||||
| num_rows = CifarOp.get_num_rows(self.dataset_dir, "all" if self.usage is None else self.usage, False) | |||||
| self.dataset_size = get_num_rows(num_rows, self.num_shards) | self.dataset_size = get_num_rows(num_rows, self.num_shards) | ||||
| rows_from_sampler = self._get_sampler_dataset_size() | rows_from_sampler = self._get_sampler_dataset_size() | ||||
| @@ -4467,7 +4482,7 @@ class VOCDataset(MappableDataset): | |||||
| dataset_dir (str): Path to the root directory that contains the dataset. | dataset_dir (str): Path to the root directory that contains the dataset. | ||||
| task (str): Set the task type of reading voc data, now only support "Segmentation" or "Detection" | task (str): Set the task type of reading voc data, now only support "Segmentation" or "Detection" | ||||
| (default="Segmentation"). | (default="Segmentation"). | ||||
| mode (str): Set the data list txt file to be readed (default="train"). | |||||
| usage (str): The type of data list text file to be read (default="train"). | |||||
| class_indexing (dict, optional): A str-to-int mapping from label name to index, only valid in | class_indexing (dict, optional): A str-to-int mapping from label name to index, only valid in | ||||
| "Detection" task (default=None, the folder names will be sorted alphabetically and each | "Detection" task (default=None, the folder names will be sorted alphabetically and each | ||||
| class will be given a unique index starting from 0). | class will be given a unique index starting from 0). | ||||
| @@ -4502,24 +4517,24 @@ class VOCDataset(MappableDataset): | |||||
| >>> import mindspore.dataset as ds | >>> import mindspore.dataset as ds | ||||
| >>> dataset_dir = "/path/to/voc_dataset_directory" | >>> dataset_dir = "/path/to/voc_dataset_directory" | ||||
| >>> # 1) read VOC data for segmenatation train | >>> # 1) read VOC data for segmenatation train | ||||
| >>> voc_dataset = ds.VOCDataset(dataset_dir, task="Segmentation", mode="train") | |||||
| >>> voc_dataset = ds.VOCDataset(dataset_dir, task="Segmentation", usage="train") | |||||
| >>> # 2) read VOC data for detection train | >>> # 2) read VOC data for detection train | ||||
| >>> voc_dataset = ds.VOCDataset(dataset_dir, task="Detection", mode="train") | |||||
| >>> voc_dataset = ds.VOCDataset(dataset_dir, task="Detection", usage="train") | |||||
| >>> # 3) read all VOC dataset samples in dataset_dir with 8 threads in random order: | >>> # 3) read all VOC dataset samples in dataset_dir with 8 threads in random order: | ||||
| >>> voc_dataset = ds.VOCDataset(dataset_dir, task="Detection", mode="train", num_parallel_workers=8) | |||||
| >>> voc_dataset = ds.VOCDataset(dataset_dir, task="Detection", usage="train", num_parallel_workers=8) | |||||
| >>> # 4) read then decode all VOC dataset samples in dataset_dir in sequence: | >>> # 4) read then decode all VOC dataset samples in dataset_dir in sequence: | ||||
| >>> voc_dataset = ds.VOCDataset(dataset_dir, task="Detection", mode="train", decode=True, shuffle=False) | |||||
| >>> voc_dataset = ds.VOCDataset(dataset_dir, task="Detection", usage="train", decode=True, shuffle=False) | |||||
| >>> # in VOC dataset, if task='Segmentation', each dictionary has keys "image" and "target" | >>> # in VOC dataset, if task='Segmentation', each dictionary has keys "image" and "target" | ||||
| >>> # in VOC dataset, if task='Detection', each dictionary has keys "image" and "annotation" | >>> # in VOC dataset, if task='Detection', each dictionary has keys "image" and "annotation" | ||||
| """ | """ | ||||
| @check_vocdataset | @check_vocdataset | ||||
| def __init__(self, dataset_dir, task="Segmentation", mode="train", class_indexing=None, num_samples=None, | |||||
| def __init__(self, dataset_dir, task="Segmentation", usage="train", class_indexing=None, num_samples=None, | |||||
| num_parallel_workers=None, shuffle=None, decode=False, sampler=None, num_shards=None, shard_id=None): | num_parallel_workers=None, shuffle=None, decode=False, sampler=None, num_shards=None, shard_id=None): | ||||
| super().__init__(num_parallel_workers) | super().__init__(num_parallel_workers) | ||||
| self.dataset_dir = dataset_dir | self.dataset_dir = dataset_dir | ||||
| self.task = task | self.task = task | ||||
| self.mode = mode | |||||
| self.usage = usage | |||||
| self.class_indexing = class_indexing | self.class_indexing = class_indexing | ||||
| self.sampler = _select_sampler(num_samples, sampler, shuffle, num_shards, shard_id) | self.sampler = _select_sampler(num_samples, sampler, shuffle, num_shards, shard_id) | ||||
| self.num_samples = num_samples | self.num_samples = num_samples | ||||
| @@ -4532,7 +4547,7 @@ class VOCDataset(MappableDataset): | |||||
| args = super().get_args() | args = super().get_args() | ||||
| args["dataset_dir"] = self.dataset_dir | args["dataset_dir"] = self.dataset_dir | ||||
| args["task"] = self.task | args["task"] = self.task | ||||
| args["mode"] = self.mode | |||||
| args["usage"] = self.usage | |||||
| args["class_indexing"] = self.class_indexing | args["class_indexing"] = self.class_indexing | ||||
| args["num_samples"] = self.num_samples | args["num_samples"] = self.num_samples | ||||
| args["sampler"] = self.sampler | args["sampler"] = self.sampler | ||||
| @@ -4560,7 +4575,7 @@ class VOCDataset(MappableDataset): | |||||
| else: | else: | ||||
| class_indexing = self.class_indexing | class_indexing = self.class_indexing | ||||
| num_rows = VOCOp.get_num_rows(self.dataset_dir, self.task, self.mode, class_indexing, num_samples) | |||||
| num_rows = VOCOp.get_num_rows(self.dataset_dir, self.task, self.usage, class_indexing, num_samples) | |||||
| self.dataset_size = get_num_rows(num_rows, self.num_shards) | self.dataset_size = get_num_rows(num_rows, self.num_shards) | ||||
| rows_from_sampler = self._get_sampler_dataset_size() | rows_from_sampler = self._get_sampler_dataset_size() | ||||
| @@ -4584,7 +4599,7 @@ class VOCDataset(MappableDataset): | |||||
| else: | else: | ||||
| class_indexing = self.class_indexing | class_indexing = self.class_indexing | ||||
| return VOCOp.get_class_indexing(self.dataset_dir, self.task, self.mode, class_indexing) | |||||
| return VOCOp.get_class_indexing(self.dataset_dir, self.task, self.usage, class_indexing) | |||||
| def is_shuffled(self): | def is_shuffled(self): | ||||
| if self.shuffle_level is None: | if self.shuffle_level is None: | ||||
| @@ -4824,7 +4839,7 @@ class CelebADataset(MappableDataset): | |||||
| dataset_dir (str): Path to the root directory that contains the dataset. | dataset_dir (str): Path to the root directory that contains the dataset. | ||||
| num_parallel_workers (int, optional): Number of workers to read the data (default=value set in the config). | num_parallel_workers (int, optional): Number of workers to read the data (default=value set in the config). | ||||
| shuffle (bool, optional): Whether to perform shuffle on the dataset (default=None). | shuffle (bool, optional): Whether to perform shuffle on the dataset (default=None). | ||||
| dataset_type (str): one of 'all', 'train', 'valid' or 'test'. | |||||
| usage (str): one of 'all', 'train', 'valid' or 'test'. | |||||
| sampler (Sampler, optional): Object used to choose samples from the dataset (default=None). | sampler (Sampler, optional): Object used to choose samples from the dataset (default=None). | ||||
| decode (bool, optional): decode the images after reading (default=False). | decode (bool, optional): decode the images after reading (default=False). | ||||
| extensions (list[str], optional): List of file extensions to be | extensions (list[str], optional): List of file extensions to be | ||||
| @@ -4838,8 +4853,8 @@ class CelebADataset(MappableDataset): | |||||
| """ | """ | ||||
| @check_celebadataset | @check_celebadataset | ||||
| def __init__(self, dataset_dir, num_parallel_workers=None, shuffle=None, dataset_type='all', | |||||
| sampler=None, decode=False, extensions=None, num_samples=None, num_shards=None, shard_id=None): | |||||
| def __init__(self, dataset_dir, num_parallel_workers=None, shuffle=None, usage='all', sampler=None, decode=False, | |||||
| extensions=None, num_samples=None, num_shards=None, shard_id=None): | |||||
| super().__init__(num_parallel_workers) | super().__init__(num_parallel_workers) | ||||
| self.dataset_dir = dataset_dir | self.dataset_dir = dataset_dir | ||||
| self.sampler = _select_sampler(num_samples, sampler, shuffle, num_shards, shard_id) | self.sampler = _select_sampler(num_samples, sampler, shuffle, num_shards, shard_id) | ||||
| @@ -4847,7 +4862,7 @@ class CelebADataset(MappableDataset): | |||||
| self.decode = decode | self.decode = decode | ||||
| self.extensions = extensions | self.extensions = extensions | ||||
| self.num_samples = num_samples | self.num_samples = num_samples | ||||
| self.dataset_type = dataset_type | |||||
| self.usage = usage | |||||
| self.num_shards = num_shards | self.num_shards = num_shards | ||||
| self.shard_id = shard_id | self.shard_id = shard_id | ||||
| self.shuffle_level = shuffle | self.shuffle_level = shuffle | ||||
| @@ -4860,7 +4875,7 @@ class CelebADataset(MappableDataset): | |||||
| args["decode"] = self.decode | args["decode"] = self.decode | ||||
| args["extensions"] = self.extensions | args["extensions"] = self.extensions | ||||
| args["num_samples"] = self.num_samples | args["num_samples"] = self.num_samples | ||||
| args["dataset_type"] = self.dataset_type | |||||
| args["usage"] = self.usage | |||||
| args["num_shards"] = self.num_shards | args["num_shards"] = self.num_shards | ||||
| args["shard_id"] = self.shard_id | args["shard_id"] = self.shard_id | ||||
| return args | return args | ||||
| @@ -273,7 +273,7 @@ def create_node(node): | |||||
| elif dataset_op == 'MnistDataset': | elif dataset_op == 'MnistDataset': | ||||
| sampler = construct_sampler(node.get('sampler')) | sampler = construct_sampler(node.get('sampler')) | ||||
| pyobj = pyclass(node['dataset_dir'], node.get('num_samples'), node.get('num_parallel_workers'), | |||||
| pyobj = pyclass(node['dataset_dir'], node['usage'], node.get('num_samples'), node.get('num_parallel_workers'), | |||||
| node.get('shuffle'), sampler, node.get('num_shards'), node.get('shard_id')) | node.get('shuffle'), sampler, node.get('num_shards'), node.get('shard_id')) | ||||
| elif dataset_op == 'MindDataset': | elif dataset_op == 'MindDataset': | ||||
| @@ -296,12 +296,12 @@ def create_node(node): | |||||
| elif dataset_op == 'Cifar10Dataset': | elif dataset_op == 'Cifar10Dataset': | ||||
| sampler = construct_sampler(node.get('sampler')) | sampler = construct_sampler(node.get('sampler')) | ||||
| pyobj = pyclass(node['dataset_dir'], node.get('num_samples'), node.get('num_parallel_workers'), | |||||
| pyobj = pyclass(node['dataset_dir'], node['usage'], node.get('num_samples'), node.get('num_parallel_workers'), | |||||
| node.get('shuffle'), sampler, node.get('num_shards'), node.get('shard_id')) | node.get('shuffle'), sampler, node.get('num_shards'), node.get('shard_id')) | ||||
| elif dataset_op == 'Cifar100Dataset': | elif dataset_op == 'Cifar100Dataset': | ||||
| sampler = construct_sampler(node.get('sampler')) | sampler = construct_sampler(node.get('sampler')) | ||||
| pyobj = pyclass(node['dataset_dir'], node.get('num_samples'), node.get('num_parallel_workers'), | |||||
| pyobj = pyclass(node['dataset_dir'], node['usage'], node.get('num_samples'), node.get('num_parallel_workers'), | |||||
| node.get('shuffle'), sampler, node.get('num_shards'), node.get('shard_id')) | node.get('shuffle'), sampler, node.get('num_shards'), node.get('shard_id')) | ||||
| elif dataset_op == 'VOCDataset': | elif dataset_op == 'VOCDataset': | ||||
| @@ -27,7 +27,7 @@ from mindspore.dataset.callback import DSCallback | |||||
| from ..core.validator_helpers import parse_user_args, type_check, type_check_list, check_value, \ | from ..core.validator_helpers import parse_user_args, type_check, type_check_list, check_value, \ | ||||
| INT32_MAX, check_valid_detype, check_dir, check_file, check_sampler_shuffle_shard_options, \ | INT32_MAX, check_valid_detype, check_dir, check_file, check_sampler_shuffle_shard_options, \ | ||||
| validate_dataset_param_value, check_padding_options, check_gnn_list_or_ndarray, check_num_parallel_workers, \ | validate_dataset_param_value, check_padding_options, check_gnn_list_or_ndarray, check_num_parallel_workers, \ | ||||
| check_columns, check_pos_int32 | |||||
| check_columns, check_pos_int32, check_valid_str | |||||
| from . import datasets | from . import datasets | ||||
| from . import samplers | from . import samplers | ||||
| @@ -74,6 +74,10 @@ def check_mnist_cifar_dataset(method): | |||||
| dataset_dir = param_dict.get('dataset_dir') | dataset_dir = param_dict.get('dataset_dir') | ||||
| check_dir(dataset_dir) | check_dir(dataset_dir) | ||||
| usage = param_dict.get('usage') | |||||
| if usage is not None: | |||||
| check_valid_str(usage, ["train", "test", "all"], "usage") | |||||
| validate_dataset_param_value(nreq_param_int, param_dict, int) | validate_dataset_param_value(nreq_param_int, param_dict, int) | ||||
| validate_dataset_param_value(nreq_param_bool, param_dict, bool) | validate_dataset_param_value(nreq_param_bool, param_dict, bool) | ||||
| @@ -154,15 +158,15 @@ def check_vocdataset(method): | |||||
| task = param_dict.get('task') | task = param_dict.get('task') | ||||
| type_check(task, (str,), "task") | type_check(task, (str,), "task") | ||||
| mode = param_dict.get('mode') | |||||
| type_check(mode, (str,), "mode") | |||||
| usage = param_dict.get('usage') | |||||
| type_check(usage, (str,), "usage") | |||||
| if task == "Segmentation": | if task == "Segmentation": | ||||
| imagesets_file = os.path.join(dataset_dir, "ImageSets", "Segmentation", mode + ".txt") | |||||
| imagesets_file = os.path.join(dataset_dir, "ImageSets", "Segmentation", usage + ".txt") | |||||
| if param_dict.get('class_indexing') is not None: | if param_dict.get('class_indexing') is not None: | ||||
| raise ValueError("class_indexing is invalid in Segmentation task") | raise ValueError("class_indexing is invalid in Segmentation task") | ||||
| elif task == "Detection": | elif task == "Detection": | ||||
| imagesets_file = os.path.join(dataset_dir, "ImageSets", "Main", mode + ".txt") | |||||
| imagesets_file = os.path.join(dataset_dir, "ImageSets", "Main", usage + ".txt") | |||||
| else: | else: | ||||
| raise ValueError("Invalid task : " + task) | raise ValueError("Invalid task : " + task) | ||||
| @@ -235,9 +239,9 @@ def check_celebadataset(method): | |||||
| validate_dataset_param_value(nreq_param_list, param_dict, list) | validate_dataset_param_value(nreq_param_list, param_dict, list) | ||||
| validate_dataset_param_value(nreq_param_str, param_dict, str) | validate_dataset_param_value(nreq_param_str, param_dict, str) | ||||
| dataset_type = param_dict.get('dataset_type') | |||||
| if dataset_type is not None and dataset_type not in ('all', 'train', 'valid', 'test'): | |||||
| raise ValueError("dataset_type should be one of 'all', 'train', 'valid' or 'test'.") | |||||
| usage = param_dict.get('usage') | |||||
| if usage is not None and usage not in ('all', 'train', 'valid', 'test'): | |||||
| raise ValueError("usage should be one of 'all', 'train', 'valid' or 'test'.") | |||||
| check_sampler_shuffle_shard_options(param_dict) | check_sampler_shuffle_shard_options(param_dict) | ||||
| @@ -5,7 +5,7 @@ SET(DE_UT_SRCS | |||||
| common/cvop_common.cc | common/cvop_common.cc | ||||
| common/bboxop_common.cc | common/bboxop_common.cc | ||||
| auto_contrast_op_test.cc | auto_contrast_op_test.cc | ||||
| album_op_test.cc | |||||
| album_op_test.cc | |||||
| batch_op_test.cc | batch_op_test.cc | ||||
| bit_functions_test.cc | bit_functions_test.cc | ||||
| storage_container_test.cc | storage_container_test.cc | ||||
| @@ -62,8 +62,8 @@ SET(DE_UT_SRCS | |||||
| rescale_op_test.cc | rescale_op_test.cc | ||||
| resize_op_test.cc | resize_op_test.cc | ||||
| resize_with_bbox_op_test.cc | resize_with_bbox_op_test.cc | ||||
| rgba_to_bgr_op_test.cc | |||||
| rgba_to_rgb_op_test.cc | |||||
| rgba_to_bgr_op_test.cc | |||||
| rgba_to_rgb_op_test.cc | |||||
| schema_test.cc | schema_test.cc | ||||
| skip_op_test.cc | skip_op_test.cc | ||||
| shuffle_op_test.cc | shuffle_op_test.cc | ||||
| @@ -28,7 +28,7 @@ TEST_F(MindDataTestPipeline, TestCifar10Dataset) { | |||||
| // Create a Cifar10 Dataset | // Create a Cifar10 Dataset | ||||
| std::string folder_path = datasets_root_path_ + "/testCifar10Data/"; | std::string folder_path = datasets_root_path_ + "/testCifar10Data/"; | ||||
| std::shared_ptr<Dataset> ds = Cifar10(folder_path, RandomSampler(false, 10)); | |||||
| std::shared_ptr<Dataset> ds = Cifar10(folder_path, std::string(), RandomSampler(false, 10)); | |||||
| EXPECT_NE(ds, nullptr); | EXPECT_NE(ds, nullptr); | ||||
| // Create an iterator over the result of the above dataset | // Create an iterator over the result of the above dataset | ||||
| @@ -45,10 +45,10 @@ TEST_F(MindDataTestPipeline, TestCifar10Dataset) { | |||||
| uint64_t i = 0; | uint64_t i = 0; | ||||
| while (row.size() != 0) { | while (row.size() != 0) { | ||||
| i++; | |||||
| auto image = row["image"]; | |||||
| MS_LOG(INFO) << "Tensor image shape: " << image->shape(); | |||||
| iter->GetNextRow(&row); | |||||
| i++; | |||||
| auto image = row["image"]; | |||||
| MS_LOG(INFO) << "Tensor image shape: " << image->shape(); | |||||
| iter->GetNextRow(&row); | |||||
| } | } | ||||
| EXPECT_EQ(i, 10); | EXPECT_EQ(i, 10); | ||||
| @@ -62,7 +62,7 @@ TEST_F(MindDataTestPipeline, TestCifar100Dataset) { | |||||
| // Create a Cifar100 Dataset | // Create a Cifar100 Dataset | ||||
| std::string folder_path = datasets_root_path_ + "/testCifar100Data/"; | std::string folder_path = datasets_root_path_ + "/testCifar100Data/"; | ||||
| std::shared_ptr<Dataset> ds = Cifar100(folder_path, RandomSampler(false, 10)); | |||||
| std::shared_ptr<Dataset> ds = Cifar100(folder_path, std::string(), RandomSampler(false, 10)); | |||||
| EXPECT_NE(ds, nullptr); | EXPECT_NE(ds, nullptr); | ||||
| // Create an iterator over the result of the above dataset | // Create an iterator over the result of the above dataset | ||||
| @@ -96,7 +96,7 @@ TEST_F(MindDataTestPipeline, TestCifar100DatasetFail1) { | |||||
| MS_LOG(INFO) << "Doing MindDataTestPipeline-TestCifar100DatasetFail1."; | MS_LOG(INFO) << "Doing MindDataTestPipeline-TestCifar100DatasetFail1."; | ||||
| // Create a Cifar100 Dataset | // Create a Cifar100 Dataset | ||||
| std::shared_ptr<Dataset> ds = Cifar100("", RandomSampler(false, 10)); | |||||
| std::shared_ptr<Dataset> ds = Cifar100("", std::string(), RandomSampler(false, 10)); | |||||
| EXPECT_EQ(ds, nullptr); | EXPECT_EQ(ds, nullptr); | ||||
| } | } | ||||
| @@ -104,7 +104,7 @@ TEST_F(MindDataTestPipeline, TestCifar10DatasetFail1) { | |||||
| MS_LOG(INFO) << "Doing MindDataTestPipeline-TestCifar10DatasetFail1."; | MS_LOG(INFO) << "Doing MindDataTestPipeline-TestCifar10DatasetFail1."; | ||||
| // Create a Cifar10 Dataset | // Create a Cifar10 Dataset | ||||
| std::shared_ptr<Dataset> ds = Cifar10("", RandomSampler(false, 10)); | |||||
| std::shared_ptr<Dataset> ds = Cifar10("", std::string(), RandomSampler(false, 10)); | |||||
| EXPECT_EQ(ds, nullptr); | EXPECT_EQ(ds, nullptr); | ||||
| } | } | ||||
| @@ -113,7 +113,7 @@ TEST_F(MindDataTestPipeline, TestCifar10DatasetWithNullSampler) { | |||||
| // Create a Cifar10 Dataset | // Create a Cifar10 Dataset | ||||
| std::string folder_path = datasets_root_path_ + "/testCifar10Data/"; | std::string folder_path = datasets_root_path_ + "/testCifar10Data/"; | ||||
| std::shared_ptr<Dataset> ds = Cifar10(folder_path, nullptr); | |||||
| std::shared_ptr<Dataset> ds = Cifar10(folder_path, std::string(), nullptr); | |||||
| // Expect failure: sampler can not be nullptr | // Expect failure: sampler can not be nullptr | ||||
| EXPECT_EQ(ds, nullptr); | EXPECT_EQ(ds, nullptr); | ||||
| } | } | ||||
| @@ -123,7 +123,7 @@ TEST_F(MindDataTestPipeline, TestCifar100DatasetWithNullSampler) { | |||||
| // Create a Cifar10 Dataset | // Create a Cifar10 Dataset | ||||
| std::string folder_path = datasets_root_path_ + "/testCifar100Data/"; | std::string folder_path = datasets_root_path_ + "/testCifar100Data/"; | ||||
| std::shared_ptr<Dataset> ds = Cifar100(folder_path, nullptr); | |||||
| std::shared_ptr<Dataset> ds = Cifar100(folder_path, std::string(), nullptr); | |||||
| // Expect failure: sampler can not be nullptr | // Expect failure: sampler can not be nullptr | ||||
| EXPECT_EQ(ds, nullptr); | EXPECT_EQ(ds, nullptr); | ||||
| } | } | ||||
| @@ -133,7 +133,7 @@ TEST_F(MindDataTestPipeline, TestCifar100DatasetWithWrongSampler) { | |||||
| // Create a Cifar10 Dataset | // Create a Cifar10 Dataset | ||||
| std::string folder_path = datasets_root_path_ + "/testCifar100Data/"; | std::string folder_path = datasets_root_path_ + "/testCifar100Data/"; | ||||
| std::shared_ptr<Dataset> ds = Cifar100(folder_path, RandomSampler(false, -10)); | |||||
| std::shared_ptr<Dataset> ds = Cifar100(folder_path, std::string(), RandomSampler(false, -10)); | |||||
| // Expect failure: sampler is not construnced correctly | // Expect failure: sampler is not construnced correctly | ||||
| EXPECT_EQ(ds, nullptr); | EXPECT_EQ(ds, nullptr); | ||||
| } | } | ||||
| @@ -28,7 +28,7 @@ TEST_F(MindDataTestPipeline, TestIteratorEmptyColumn) { | |||||
| MS_LOG(INFO) << "Doing MindDataTestPipeline-TestIteratorEmptyColumn."; | MS_LOG(INFO) << "Doing MindDataTestPipeline-TestIteratorEmptyColumn."; | ||||
| // Create a Cifar10 Dataset | // Create a Cifar10 Dataset | ||||
| std::string folder_path = datasets_root_path_ + "/testCifar10Data/"; | std::string folder_path = datasets_root_path_ + "/testCifar10Data/"; | ||||
| std::shared_ptr<Dataset> ds = Cifar10(folder_path, RandomSampler(false, 5)); | |||||
| std::shared_ptr<Dataset> ds = Cifar10(folder_path, std::string(), RandomSampler(false, 5)); | |||||
| EXPECT_NE(ds, nullptr); | EXPECT_NE(ds, nullptr); | ||||
| // Create a Rename operation on ds | // Create a Rename operation on ds | ||||
| @@ -64,7 +64,7 @@ TEST_F(MindDataTestPipeline, TestIteratorOneColumn) { | |||||
| MS_LOG(INFO) << "Doing MindDataTestPipeline-TestIteratorOneColumn."; | MS_LOG(INFO) << "Doing MindDataTestPipeline-TestIteratorOneColumn."; | ||||
| // Create a Mnist Dataset | // Create a Mnist Dataset | ||||
| std::string folder_path = datasets_root_path_ + "/testMnistData/"; | std::string folder_path = datasets_root_path_ + "/testMnistData/"; | ||||
| std::shared_ptr<Dataset> ds = Mnist(folder_path, RandomSampler(false, 4)); | |||||
| std::shared_ptr<Dataset> ds = Mnist(folder_path, std::string(), RandomSampler(false, 4)); | |||||
| EXPECT_NE(ds, nullptr); | EXPECT_NE(ds, nullptr); | ||||
| // Create a Batch operation on ds | // Create a Batch operation on ds | ||||
| @@ -103,7 +103,7 @@ TEST_F(MindDataTestPipeline, TestIteratorReOrder) { | |||||
| MS_LOG(INFO) << "Doing MindDataTestPipeline-TestIteratorReOrder."; | MS_LOG(INFO) << "Doing MindDataTestPipeline-TestIteratorReOrder."; | ||||
| // Create a Cifar10 Dataset | // Create a Cifar10 Dataset | ||||
| std::string folder_path = datasets_root_path_ + "/testCifar10Data/"; | std::string folder_path = datasets_root_path_ + "/testCifar10Data/"; | ||||
| std::shared_ptr<Dataset> ds = Cifar10(folder_path, SequentialSampler(false, 4)); | |||||
| std::shared_ptr<Dataset> ds = Cifar10(folder_path, std::string(), SequentialSampler(false, 4)); | |||||
| EXPECT_NE(ds, nullptr); | EXPECT_NE(ds, nullptr); | ||||
| // Create a Take operation on ds | // Create a Take operation on ds | ||||
| @@ -160,9 +160,8 @@ TEST_F(MindDataTestPipeline, TestIteratorTwoColumns) { | |||||
| // Iterate the dataset and get each row | // Iterate the dataset and get each row | ||||
| std::vector<std::shared_ptr<Tensor>> row; | std::vector<std::shared_ptr<Tensor>> row; | ||||
| iter->GetNextRow(&row); | iter->GetNextRow(&row); | ||||
| std::vector<TensorShape> expect = {TensorShape({173673}), TensorShape({1, 4}), | |||||
| TensorShape({173673}), TensorShape({1, 4}), | |||||
| TensorShape({147025}), TensorShape({1, 4}), | |||||
| std::vector<TensorShape> expect = {TensorShape({173673}), TensorShape({1, 4}), TensorShape({173673}), | |||||
| TensorShape({1, 4}), TensorShape({147025}), TensorShape({1, 4}), | |||||
| TensorShape({211653}), TensorShape({1, 4})}; | TensorShape({211653}), TensorShape({1, 4})}; | ||||
| uint64_t i = 0; | uint64_t i = 0; | ||||
| @@ -187,7 +186,7 @@ TEST_F(MindDataTestPipeline, TestIteratorWrongColumn) { | |||||
| MS_LOG(INFO) << "Doing MindDataTestPipeline-TestIteratorOneColumn."; | MS_LOG(INFO) << "Doing MindDataTestPipeline-TestIteratorOneColumn."; | ||||
| // Create a Mnist Dataset | // Create a Mnist Dataset | ||||
| std::string folder_path = datasets_root_path_ + "/testMnistData/"; | std::string folder_path = datasets_root_path_ + "/testMnistData/"; | ||||
| std::shared_ptr<Dataset> ds = Mnist(folder_path, RandomSampler(false, 4)); | |||||
| std::shared_ptr<Dataset> ds = Mnist(folder_path, std::string(), RandomSampler(false, 4)); | |||||
| EXPECT_NE(ds, nullptr); | EXPECT_NE(ds, nullptr); | ||||
| // Pass wrong column name | // Pass wrong column name | ||||
| @@ -40,7 +40,7 @@ TEST_F(MindDataTestPipeline, TestBatchAndRepeat) { | |||||
| // Create a Mnist Dataset | // Create a Mnist Dataset | ||||
| std::string folder_path = datasets_root_path_ + "/testMnistData/"; | std::string folder_path = datasets_root_path_ + "/testMnistData/"; | ||||
| std::shared_ptr<Dataset> ds = Mnist(folder_path, RandomSampler(false, 10)); | |||||
| std::shared_ptr<Dataset> ds = Mnist(folder_path, std::string(), RandomSampler(false, 10)); | |||||
| EXPECT_NE(ds, nullptr); | EXPECT_NE(ds, nullptr); | ||||
| // Create a Repeat operation on ds | // Create a Repeat operation on ds | ||||
| @@ -82,7 +82,7 @@ TEST_F(MindDataTestPipeline, TestBucketBatchByLengthSuccess1) { | |||||
| // Create a Mnist Dataset | // Create a Mnist Dataset | ||||
| std::string folder_path = datasets_root_path_ + "/testMnistData/"; | std::string folder_path = datasets_root_path_ + "/testMnistData/"; | ||||
| std::shared_ptr<Dataset> ds = Mnist(folder_path, RandomSampler(false, 10)); | |||||
| std::shared_ptr<Dataset> ds = Mnist(folder_path, std::string(), RandomSampler(false, 10)); | |||||
| EXPECT_NE(ds, nullptr); | EXPECT_NE(ds, nullptr); | ||||
| // Create a BucketBatchByLength operation on ds | // Create a BucketBatchByLength operation on ds | ||||
| @@ -118,13 +118,12 @@ TEST_F(MindDataTestPipeline, TestBucketBatchByLengthSuccess2) { | |||||
| // Create a Mnist Dataset | // Create a Mnist Dataset | ||||
| std::string folder_path = datasets_root_path_ + "/testMnistData/"; | std::string folder_path = datasets_root_path_ + "/testMnistData/"; | ||||
| std::shared_ptr<Dataset> ds = Mnist(folder_path, RandomSampler(false, 10)); | |||||
| std::shared_ptr<Dataset> ds = Mnist(folder_path, std::string(), RandomSampler(false, 10)); | |||||
| EXPECT_NE(ds, nullptr); | EXPECT_NE(ds, nullptr); | ||||
| // Create a BucketBatchByLength operation on ds | // Create a BucketBatchByLength operation on ds | ||||
| std::map<std::string, std::pair<mindspore::dataset::TensorShape, std::shared_ptr<Tensor>>> pad_info; | std::map<std::string, std::pair<mindspore::dataset::TensorShape, std::shared_ptr<Tensor>>> pad_info; | ||||
| ds = ds->BucketBatchByLength({"image"}, {1, 2}, {1, 2, 3}, | |||||
| &BucketBatchTestFunction, pad_info, true, true); | |||||
| ds = ds->BucketBatchByLength({"image"}, {1, 2}, {1, 2, 3}, &BucketBatchTestFunction, pad_info, true, true); | |||||
| EXPECT_NE(ds, nullptr); | EXPECT_NE(ds, nullptr); | ||||
| // Create an iterator over the result of the above dataset | // Create an iterator over the result of the above dataset | ||||
| @@ -157,7 +156,7 @@ TEST_F(MindDataTestPipeline, TestBucketBatchByLengthFail1) { | |||||
| // Create a Mnist Dataset | // Create a Mnist Dataset | ||||
| std::string folder_path = datasets_root_path_ + "/testMnistData/"; | std::string folder_path = datasets_root_path_ + "/testMnistData/"; | ||||
| std::shared_ptr<Dataset> ds = Mnist(folder_path, RandomSampler(false, 10)); | |||||
| std::shared_ptr<Dataset> ds = Mnist(folder_path, std::string(), RandomSampler(false, 10)); | |||||
| EXPECT_NE(ds, nullptr); | EXPECT_NE(ds, nullptr); | ||||
| // Create a BucketBatchByLength operation on ds | // Create a BucketBatchByLength operation on ds | ||||
| @@ -172,7 +171,7 @@ TEST_F(MindDataTestPipeline, TestBucketBatchByLengthFail2) { | |||||
| // Create a Mnist Dataset | // Create a Mnist Dataset | ||||
| std::string folder_path = datasets_root_path_ + "/testMnistData/"; | std::string folder_path = datasets_root_path_ + "/testMnistData/"; | ||||
| std::shared_ptr<Dataset> ds = Mnist(folder_path, RandomSampler(false, 10)); | |||||
| std::shared_ptr<Dataset> ds = Mnist(folder_path, std::string(), RandomSampler(false, 10)); | |||||
| EXPECT_NE(ds, nullptr); | EXPECT_NE(ds, nullptr); | ||||
| // Create a BucketBatchByLength operation on ds | // Create a BucketBatchByLength operation on ds | ||||
| @@ -187,7 +186,7 @@ TEST_F(MindDataTestPipeline, TestBucketBatchByLengthFail3) { | |||||
| // Create a Mnist Dataset | // Create a Mnist Dataset | ||||
| std::string folder_path = datasets_root_path_ + "/testMnistData/"; | std::string folder_path = datasets_root_path_ + "/testMnistData/"; | ||||
| std::shared_ptr<Dataset> ds = Mnist(folder_path, RandomSampler(false, 10)); | |||||
| std::shared_ptr<Dataset> ds = Mnist(folder_path, std::string(), RandomSampler(false, 10)); | |||||
| EXPECT_NE(ds, nullptr); | EXPECT_NE(ds, nullptr); | ||||
| // Create a BucketBatchByLength operation on ds | // Create a BucketBatchByLength operation on ds | ||||
| @@ -202,7 +201,7 @@ TEST_F(MindDataTestPipeline, TestBucketBatchByLengthFail4) { | |||||
| // Create a Mnist Dataset | // Create a Mnist Dataset | ||||
| std::string folder_path = datasets_root_path_ + "/testMnistData/"; | std::string folder_path = datasets_root_path_ + "/testMnistData/"; | ||||
| std::shared_ptr<Dataset> ds = Mnist(folder_path, RandomSampler(false, 10)); | |||||
| std::shared_ptr<Dataset> ds = Mnist(folder_path, std::string(), RandomSampler(false, 10)); | |||||
| EXPECT_NE(ds, nullptr); | EXPECT_NE(ds, nullptr); | ||||
| // Create a BucketBatchByLength operation on ds | // Create a BucketBatchByLength operation on ds | ||||
| @@ -217,7 +216,7 @@ TEST_F(MindDataTestPipeline, TestBucketBatchByLengthFail5) { | |||||
| // Create a Mnist Dataset | // Create a Mnist Dataset | ||||
| std::string folder_path = datasets_root_path_ + "/testMnistData/"; | std::string folder_path = datasets_root_path_ + "/testMnistData/"; | ||||
| std::shared_ptr<Dataset> ds = Mnist(folder_path, RandomSampler(false, 10)); | |||||
| std::shared_ptr<Dataset> ds = Mnist(folder_path, std::string(), RandomSampler(false, 10)); | |||||
| EXPECT_NE(ds, nullptr); | EXPECT_NE(ds, nullptr); | ||||
| // Create a BucketBatchByLength operation on ds | // Create a BucketBatchByLength operation on ds | ||||
| @@ -232,7 +231,7 @@ TEST_F(MindDataTestPipeline, TestBucketBatchByLengthFail6) { | |||||
| // Create a Mnist Dataset | // Create a Mnist Dataset | ||||
| std::string folder_path = datasets_root_path_ + "/testMnistData/"; | std::string folder_path = datasets_root_path_ + "/testMnistData/"; | ||||
| std::shared_ptr<Dataset> ds = Mnist(folder_path, RandomSampler(false, 10)); | |||||
| std::shared_ptr<Dataset> ds = Mnist(folder_path, std::string(), RandomSampler(false, 10)); | |||||
| EXPECT_NE(ds, nullptr); | EXPECT_NE(ds, nullptr); | ||||
| // Create a BucketBatchByLength operation on ds | // Create a BucketBatchByLength operation on ds | ||||
| ds = ds->BucketBatchByLength({"image"}, {1, 2}, {1, -2, 3}); | ds = ds->BucketBatchByLength({"image"}, {1, 2}, {1, -2, 3}); | ||||
| @@ -246,7 +245,7 @@ TEST_F(MindDataTestPipeline, TestBucketBatchByLengthFail7) { | |||||
| // Create a Mnist Dataset | // Create a Mnist Dataset | ||||
| std::string folder_path = datasets_root_path_ + "/testMnistData/"; | std::string folder_path = datasets_root_path_ + "/testMnistData/"; | ||||
| std::shared_ptr<Dataset> ds = Mnist(folder_path, RandomSampler(false, 10)); | |||||
| std::shared_ptr<Dataset> ds = Mnist(folder_path, std::string(), RandomSampler(false, 10)); | |||||
| EXPECT_NE(ds, nullptr); | EXPECT_NE(ds, nullptr); | ||||
| // Create a BucketBatchByLength operation on ds | // Create a BucketBatchByLength operation on ds | ||||
| @@ -313,7 +312,7 @@ TEST_F(MindDataTestPipeline, TestConcatSuccess) { | |||||
| // Create a Cifar10 Dataset | // Create a Cifar10 Dataset | ||||
| // Column names: {"image", "label"} | // Column names: {"image", "label"} | ||||
| folder_path = datasets_root_path_ + "/testCifar10Data/"; | folder_path = datasets_root_path_ + "/testCifar10Data/"; | ||||
| std::shared_ptr<Dataset> ds2 = Cifar10(folder_path, RandomSampler(false, 9)); | |||||
| std::shared_ptr<Dataset> ds2 = Cifar10(folder_path, std::string(), RandomSampler(false, 9)); | |||||
| EXPECT_NE(ds2, nullptr); | EXPECT_NE(ds2, nullptr); | ||||
| // Create a Project operation on ds | // Create a Project operation on ds | ||||
| @@ -365,7 +364,7 @@ TEST_F(MindDataTestPipeline, TestConcatSuccess2) { | |||||
| // Create a Cifar10 Dataset | // Create a Cifar10 Dataset | ||||
| // Column names: {"image", "label"} | // Column names: {"image", "label"} | ||||
| folder_path = datasets_root_path_ + "/testCifar10Data/"; | folder_path = datasets_root_path_ + "/testCifar10Data/"; | ||||
| std::shared_ptr<Dataset> ds2 = Cifar10(folder_path, RandomSampler(false, 9)); | |||||
| std::shared_ptr<Dataset> ds2 = Cifar10(folder_path, std::string(), RandomSampler(false, 9)); | |||||
| EXPECT_NE(ds2, nullptr); | EXPECT_NE(ds2, nullptr); | ||||
| // Create a Project operation on ds | // Create a Project operation on ds | ||||
| @@ -704,11 +703,11 @@ TEST_F(MindDataTestPipeline, TestRenameSuccess) { | |||||
| } | } | ||||
| TEST_F(MindDataTestPipeline, TestRepeatDefault) { | TEST_F(MindDataTestPipeline, TestRepeatDefault) { | ||||
| MS_LOG(INFO)<< "Doing MindDataTestPipeline-TestRepeatDefault."; | |||||
| MS_LOG(INFO) << "Doing MindDataTestPipeline-TestRepeatDefault."; | |||||
| // Create an ImageFolder Dataset | // Create an ImageFolder Dataset | ||||
| std::string folder_path = datasets_root_path_ + "/testPK/data/"; | std::string folder_path = datasets_root_path_ + "/testPK/data/"; | ||||
| std::shared_ptr <Dataset> ds = ImageFolder(folder_path, true, RandomSampler(false, 10)); | |||||
| std::shared_ptr<Dataset> ds = ImageFolder(folder_path, true, RandomSampler(false, 10)); | |||||
| EXPECT_NE(ds, nullptr); | EXPECT_NE(ds, nullptr); | ||||
| // Create a Repeat operation on ds | // Create a Repeat operation on ds | ||||
| @@ -723,21 +722,21 @@ TEST_F(MindDataTestPipeline, TestRepeatDefault) { | |||||
| // Create an iterator over the result of the above dataset | // Create an iterator over the result of the above dataset | ||||
| // This will trigger the creation of the Execution Tree and launch it. | // This will trigger the creation of the Execution Tree and launch it. | ||||
| std::shared_ptr <Iterator> iter = ds->CreateIterator(); | |||||
| std::shared_ptr<Iterator> iter = ds->CreateIterator(); | |||||
| EXPECT_NE(iter, nullptr); | EXPECT_NE(iter, nullptr); | ||||
| // iterate over the dataset and get each row | // iterate over the dataset and get each row | ||||
| std::unordered_map <std::string, std::shared_ptr<Tensor>> row; | |||||
| std::unordered_map<std::string, std::shared_ptr<Tensor>> row; | |||||
| iter->GetNextRow(&row); | iter->GetNextRow(&row); | ||||
| uint64_t i = 0; | uint64_t i = 0; | ||||
| while (row.size()!= 0) { | |||||
| while (row.size() != 0) { | |||||
| // manually stop | // manually stop | ||||
| if (i == 100) { | if (i == 100) { | ||||
| break; | break; | ||||
| } | } | ||||
| i++; | i++; | ||||
| auto image = row["image"]; | auto image = row["image"]; | ||||
| MS_LOG(INFO)<< "Tensor image shape: " << image->shape(); | |||||
| MS_LOG(INFO) << "Tensor image shape: " << image->shape(); | |||||
| iter->GetNextRow(&row); | iter->GetNextRow(&row); | ||||
| } | } | ||||
| @@ -747,11 +746,11 @@ TEST_F(MindDataTestPipeline, TestRepeatDefault) { | |||||
| } | } | ||||
| TEST_F(MindDataTestPipeline, TestRepeatOne) { | TEST_F(MindDataTestPipeline, TestRepeatOne) { | ||||
| MS_LOG(INFO)<< "Doing MindDataTestPipeline-TestRepeatOne."; | |||||
| MS_LOG(INFO) << "Doing MindDataTestPipeline-TestRepeatOne."; | |||||
| // Create an ImageFolder Dataset | // Create an ImageFolder Dataset | ||||
| std::string folder_path = datasets_root_path_ + "/testPK/data/"; | std::string folder_path = datasets_root_path_ + "/testPK/data/"; | ||||
| std::shared_ptr <Dataset> ds = ImageFolder(folder_path, true, RandomSampler(false, 10)); | |||||
| std::shared_ptr<Dataset> ds = ImageFolder(folder_path, true, RandomSampler(false, 10)); | |||||
| EXPECT_NE(ds, nullptr); | EXPECT_NE(ds, nullptr); | ||||
| // Create a Repeat operation on ds | // Create a Repeat operation on ds | ||||
| @@ -766,17 +765,17 @@ TEST_F(MindDataTestPipeline, TestRepeatOne) { | |||||
| // Create an iterator over the result of the above dataset | // Create an iterator over the result of the above dataset | ||||
| // This will trigger the creation of the Execution Tree and launch it. | // This will trigger the creation of the Execution Tree and launch it. | ||||
| std::shared_ptr <Iterator> iter = ds->CreateIterator(); | |||||
| std::shared_ptr<Iterator> iter = ds->CreateIterator(); | |||||
| EXPECT_NE(iter, nullptr); | EXPECT_NE(iter, nullptr); | ||||
| // iterate over the dataset and get each row | // iterate over the dataset and get each row | ||||
| std::unordered_map <std::string, std::shared_ptr<Tensor>> row; | |||||
| std::unordered_map<std::string, std::shared_ptr<Tensor>> row; | |||||
| iter->GetNextRow(&row); | iter->GetNextRow(&row); | ||||
| uint64_t i = 0; | uint64_t i = 0; | ||||
| while (row.size()!= 0) { | |||||
| while (row.size() != 0) { | |||||
| i++; | i++; | ||||
| auto image = row["image"]; | auto image = row["image"]; | ||||
| MS_LOG(INFO)<< "Tensor image shape: " << image->shape(); | |||||
| MS_LOG(INFO) << "Tensor image shape: " << image->shape(); | |||||
| iter->GetNextRow(&row); | iter->GetNextRow(&row); | ||||
| } | } | ||||
| @@ -1013,7 +1012,7 @@ TEST_F(MindDataTestPipeline, TestTensorOpsAndMap) { | |||||
| // Create a Mnist Dataset | // Create a Mnist Dataset | ||||
| std::string folder_path = datasets_root_path_ + "/testMnistData/"; | std::string folder_path = datasets_root_path_ + "/testMnistData/"; | ||||
| std::shared_ptr<Dataset> ds = Mnist(folder_path, RandomSampler(false, 20)); | |||||
| std::shared_ptr<Dataset> ds = Mnist(folder_path, std::string(), RandomSampler(false, 20)); | |||||
| EXPECT_NE(ds, nullptr); | EXPECT_NE(ds, nullptr); | ||||
| // Create a Repeat operation on ds | // Create a Repeat operation on ds | ||||
| @@ -1060,7 +1059,6 @@ TEST_F(MindDataTestPipeline, TestTensorOpsAndMap) { | |||||
| iter->Stop(); | iter->Stop(); | ||||
| } | } | ||||
| TEST_F(MindDataTestPipeline, TestZipFail) { | TEST_F(MindDataTestPipeline, TestZipFail) { | ||||
| MS_LOG(INFO) << "Doing MindDataTestPipeline-TestZipFail."; | MS_LOG(INFO) << "Doing MindDataTestPipeline-TestZipFail."; | ||||
| // We expect this test to fail because we are the both datasets we are zipping have "image" and "label" columns | // We expect this test to fail because we are the both datasets we are zipping have "image" and "label" columns | ||||
| @@ -1128,7 +1126,7 @@ TEST_F(MindDataTestPipeline, TestZipSuccess) { | |||||
| EXPECT_NE(ds1, nullptr); | EXPECT_NE(ds1, nullptr); | ||||
| folder_path = datasets_root_path_ + "/testCifar10Data/"; | folder_path = datasets_root_path_ + "/testCifar10Data/"; | ||||
| std::shared_ptr<Dataset> ds2 = Cifar10(folder_path, RandomSampler(false, 10)); | |||||
| std::shared_ptr<Dataset> ds2 = Cifar10(folder_path, std::string(), RandomSampler(false, 10)); | |||||
| EXPECT_NE(ds2, nullptr); | EXPECT_NE(ds2, nullptr); | ||||
| // Create a Project operation on ds | // Create a Project operation on ds | ||||
| @@ -43,10 +43,11 @@ TEST_F(MindDataTestPipeline, TestCelebADataset) { | |||||
| // Check if CelebAOp read correct images/attr | // Check if CelebAOp read correct images/attr | ||||
| std::string expect_file[] = {"1.JPEG", "2.jpg"}; | std::string expect_file[] = {"1.JPEG", "2.jpg"}; | ||||
| std::vector<std::vector<uint32_t>> expect_attr_vector = | |||||
| {{0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 1, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 0, 1, 1, 0, 1, 0, | |||||
| 1, 0, 0, 1}, {0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0, | |||||
| 1, 0, 0, 0, 0, 0, 0, 0, 1}}; | |||||
| std::vector<std::vector<uint32_t>> expect_attr_vector = { | |||||
| {0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 1, | |||||
| 0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 0, 1, 1, 0, 1, 0, 1, 0, 0, 1}, | |||||
| {0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, | |||||
| 0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1}}; | |||||
| uint64_t i = 0; | uint64_t i = 0; | ||||
| while (row.size() != 0) { | while (row.size() != 0) { | ||||
| auto image = row["image"]; | auto image = row["image"]; | ||||
| @@ -132,7 +133,7 @@ TEST_F(MindDataTestPipeline, TestMnistFailWithWrongDatasetDir) { | |||||
| MS_LOG(INFO) << "Doing MindDataTestPipeline-TestMnistFailWithWrongDatasetDir."; | MS_LOG(INFO) << "Doing MindDataTestPipeline-TestMnistFailWithWrongDatasetDir."; | ||||
| // Create a Mnist Dataset | // Create a Mnist Dataset | ||||
| std::shared_ptr<Dataset> ds = Mnist("", RandomSampler(false, 10)); | |||||
| std::shared_ptr<Dataset> ds = Mnist("", std::string(), RandomSampler(false, 10)); | |||||
| EXPECT_EQ(ds, nullptr); | EXPECT_EQ(ds, nullptr); | ||||
| } | } | ||||
| @@ -141,7 +142,7 @@ TEST_F(MindDataTestPipeline, TestMnistFailWithNullSampler) { | |||||
| // Create a Mnist Dataset | // Create a Mnist Dataset | ||||
| std::string folder_path = datasets_root_path_ + "/testMnistData/"; | std::string folder_path = datasets_root_path_ + "/testMnistData/"; | ||||
| std::shared_ptr<Dataset> ds = Mnist(folder_path, nullptr); | |||||
| std::shared_ptr<Dataset> ds = Mnist(folder_path, std::string(), nullptr); | |||||
| // Expect failure: sampler can not be nullptr | // Expect failure: sampler can not be nullptr | ||||
| EXPECT_EQ(ds, nullptr); | EXPECT_EQ(ds, nullptr); | ||||
| } | } | ||||
| @@ -30,7 +30,7 @@ TEST_F(MindDataTestPipeline, TestCutMixBatchSuccess1) { | |||||
| // Create a Cifar10 Dataset | // Create a Cifar10 Dataset | ||||
| std::string folder_path = datasets_root_path_ + "/testCifar10Data/"; | std::string folder_path = datasets_root_path_ + "/testCifar10Data/"; | ||||
| int number_of_classes = 10; | int number_of_classes = 10; | ||||
| std::shared_ptr<Dataset> ds = Cifar10(folder_path, RandomSampler(false, 10)); | |||||
| std::shared_ptr<Dataset> ds = Cifar10(folder_path, std::string(), RandomSampler(false, 10)); | |||||
| EXPECT_NE(ds, nullptr); | EXPECT_NE(ds, nullptr); | ||||
| // Create objects for the tensor ops | // Create objects for the tensor ops | ||||
| @@ -38,7 +38,7 @@ TEST_F(MindDataTestPipeline, TestCutMixBatchSuccess1) { | |||||
| EXPECT_NE(hwc_to_chw, nullptr); | EXPECT_NE(hwc_to_chw, nullptr); | ||||
| // Create a Map operation on ds | // Create a Map operation on ds | ||||
| ds = ds->Map({hwc_to_chw},{"image"}); | |||||
| ds = ds->Map({hwc_to_chw}, {"image"}); | |||||
| EXPECT_NE(ds, nullptr); | EXPECT_NE(ds, nullptr); | ||||
| // Create a Batch operation on ds | // Create a Batch operation on ds | ||||
| @@ -51,10 +51,11 @@ TEST_F(MindDataTestPipeline, TestCutMixBatchSuccess1) { | |||||
| EXPECT_NE(one_hot_op, nullptr); | EXPECT_NE(one_hot_op, nullptr); | ||||
| // Create a Map operation on ds | // Create a Map operation on ds | ||||
| ds = ds->Map({one_hot_op},{"label"}); | |||||
| ds = ds->Map({one_hot_op}, {"label"}); | |||||
| EXPECT_NE(ds, nullptr); | EXPECT_NE(ds, nullptr); | ||||
| std::shared_ptr<TensorOperation> cutmix_batch_op = vision::CutMixBatch(mindspore::dataset::ImageBatchFormat::kNCHW, 1.0, 1.0); | |||||
| std::shared_ptr<TensorOperation> cutmix_batch_op = | |||||
| vision::CutMixBatch(mindspore::dataset::ImageBatchFormat::kNCHW, 1.0, 1.0); | |||||
| EXPECT_NE(cutmix_batch_op, nullptr); | EXPECT_NE(cutmix_batch_op, nullptr); | ||||
| // Create a Map operation on ds | // Create a Map operation on ds | ||||
| @@ -77,10 +78,12 @@ TEST_F(MindDataTestPipeline, TestCutMixBatchSuccess1) { | |||||
| auto label = row["label"]; | auto label = row["label"]; | ||||
| MS_LOG(INFO) << "Tensor image shape: " << image->shape(); | MS_LOG(INFO) << "Tensor image shape: " << image->shape(); | ||||
| MS_LOG(INFO) << "Label shape: " << label->shape(); | MS_LOG(INFO) << "Label shape: " << label->shape(); | ||||
| EXPECT_EQ(image->shape().AsVector().size() == 4 && batch_size == image->shape()[0] && 3 == image->shape()[1] | |||||
| && 32 == image->shape()[2] && 32 == image->shape()[3], true); | |||||
| EXPECT_EQ(image->shape().AsVector().size() == 4 && batch_size == image->shape()[0] && 3 == image->shape()[1] && | |||||
| 32 == image->shape()[2] && 32 == image->shape()[3], | |||||
| true); | |||||
| EXPECT_EQ(label->shape().AsVector().size() == 2 && batch_size == label->shape()[0] && | EXPECT_EQ(label->shape().AsVector().size() == 2 && batch_size == label->shape()[0] && | ||||
| number_of_classes == label->shape()[1], true); | |||||
| number_of_classes == label->shape()[1], | |||||
| true); | |||||
| iter->GetNextRow(&row); | iter->GetNextRow(&row); | ||||
| } | } | ||||
| @@ -95,7 +98,7 @@ TEST_F(MindDataTestPipeline, TestCutMixBatchSuccess2) { | |||||
| // Create a Cifar10 Dataset | // Create a Cifar10 Dataset | ||||
| std::string folder_path = datasets_root_path_ + "/testCifar10Data/"; | std::string folder_path = datasets_root_path_ + "/testCifar10Data/"; | ||||
| int number_of_classes = 10; | int number_of_classes = 10; | ||||
| std::shared_ptr<Dataset> ds = Cifar10(folder_path, RandomSampler(false, 10)); | |||||
| std::shared_ptr<Dataset> ds = Cifar10(folder_path, std::string(), RandomSampler(false, 10)); | |||||
| EXPECT_NE(ds, nullptr); | EXPECT_NE(ds, nullptr); | ||||
| // Create a Batch operation on ds | // Create a Batch operation on ds | ||||
| @@ -108,7 +111,7 @@ TEST_F(MindDataTestPipeline, TestCutMixBatchSuccess2) { | |||||
| EXPECT_NE(one_hot_op, nullptr); | EXPECT_NE(one_hot_op, nullptr); | ||||
| // Create a Map operation on ds | // Create a Map operation on ds | ||||
| ds = ds->Map({one_hot_op},{"label"}); | |||||
| ds = ds->Map({one_hot_op}, {"label"}); | |||||
| EXPECT_NE(ds, nullptr); | EXPECT_NE(ds, nullptr); | ||||
| std::shared_ptr<TensorOperation> cutmix_batch_op = vision::CutMixBatch(mindspore::dataset::ImageBatchFormat::kNHWC); | std::shared_ptr<TensorOperation> cutmix_batch_op = vision::CutMixBatch(mindspore::dataset::ImageBatchFormat::kNHWC); | ||||
| @@ -134,10 +137,12 @@ TEST_F(MindDataTestPipeline, TestCutMixBatchSuccess2) { | |||||
| auto label = row["label"]; | auto label = row["label"]; | ||||
| MS_LOG(INFO) << "Tensor image shape: " << image->shape(); | MS_LOG(INFO) << "Tensor image shape: " << image->shape(); | ||||
| MS_LOG(INFO) << "Label shape: " << label->shape(); | MS_LOG(INFO) << "Label shape: " << label->shape(); | ||||
| EXPECT_EQ(image->shape().AsVector().size() == 4 && batch_size == image->shape()[0] && 32 == image->shape()[1] | |||||
| && 32 == image->shape()[2] && 3 == image->shape()[3], true); | |||||
| EXPECT_EQ(image->shape().AsVector().size() == 4 && batch_size == image->shape()[0] && 32 == image->shape()[1] && | |||||
| 32 == image->shape()[2] && 3 == image->shape()[3], | |||||
| true); | |||||
| EXPECT_EQ(label->shape().AsVector().size() == 2 && batch_size == label->shape()[0] && | EXPECT_EQ(label->shape().AsVector().size() == 2 && batch_size == label->shape()[0] && | ||||
| number_of_classes == label->shape()[1], true); | |||||
| number_of_classes == label->shape()[1], | |||||
| true); | |||||
| iter->GetNextRow(&row); | iter->GetNextRow(&row); | ||||
| } | } | ||||
| @@ -151,7 +156,7 @@ TEST_F(MindDataTestPipeline, TestCutMixBatchFail1) { | |||||
| // Must fail because alpha can't be negative | // Must fail because alpha can't be negative | ||||
| // Create a Cifar10 Dataset | // Create a Cifar10 Dataset | ||||
| std::string folder_path = datasets_root_path_ + "/testCifar10Data/"; | std::string folder_path = datasets_root_path_ + "/testCifar10Data/"; | ||||
| std::shared_ptr<Dataset> ds = Cifar10(folder_path, RandomSampler(false, 10)); | |||||
| std::shared_ptr<Dataset> ds = Cifar10(folder_path, std::string(), RandomSampler(false, 10)); | |||||
| EXPECT_NE(ds, nullptr); | EXPECT_NE(ds, nullptr); | ||||
| // Create a Batch operation on ds | // Create a Batch operation on ds | ||||
| @@ -164,10 +169,11 @@ TEST_F(MindDataTestPipeline, TestCutMixBatchFail1) { | |||||
| EXPECT_NE(one_hot_op, nullptr); | EXPECT_NE(one_hot_op, nullptr); | ||||
| // Create a Map operation on ds | // Create a Map operation on ds | ||||
| ds = ds->Map({one_hot_op},{"label"}); | |||||
| ds = ds->Map({one_hot_op}, {"label"}); | |||||
| EXPECT_NE(ds, nullptr); | EXPECT_NE(ds, nullptr); | ||||
| std::shared_ptr<TensorOperation> cutmix_batch_op = vision::CutMixBatch(mindspore::dataset::ImageBatchFormat::kNHWC, -1, 0.5); | |||||
| std::shared_ptr<TensorOperation> cutmix_batch_op = | |||||
| vision::CutMixBatch(mindspore::dataset::ImageBatchFormat::kNHWC, -1, 0.5); | |||||
| EXPECT_EQ(cutmix_batch_op, nullptr); | EXPECT_EQ(cutmix_batch_op, nullptr); | ||||
| } | } | ||||
| @@ -175,7 +181,7 @@ TEST_F(MindDataTestPipeline, TestCutMixBatchFail2) { | |||||
| // Must fail because prob can't be negative | // Must fail because prob can't be negative | ||||
| // Create a Cifar10 Dataset | // Create a Cifar10 Dataset | ||||
| std::string folder_path = datasets_root_path_ + "/testCifar10Data/"; | std::string folder_path = datasets_root_path_ + "/testCifar10Data/"; | ||||
| std::shared_ptr<Dataset> ds = Cifar10(folder_path, RandomSampler(false, 10)); | |||||
| std::shared_ptr<Dataset> ds = Cifar10(folder_path, std::string(), RandomSampler(false, 10)); | |||||
| EXPECT_NE(ds, nullptr); | EXPECT_NE(ds, nullptr); | ||||
| // Create a Batch operation on ds | // Create a Batch operation on ds | ||||
| @@ -188,20 +194,19 @@ TEST_F(MindDataTestPipeline, TestCutMixBatchFail2) { | |||||
| EXPECT_NE(one_hot_op, nullptr); | EXPECT_NE(one_hot_op, nullptr); | ||||
| // Create a Map operation on ds | // Create a Map operation on ds | ||||
| ds = ds->Map({one_hot_op},{"label"}); | |||||
| ds = ds->Map({one_hot_op}, {"label"}); | |||||
| EXPECT_NE(ds, nullptr); | EXPECT_NE(ds, nullptr); | ||||
| std::shared_ptr<TensorOperation> cutmix_batch_op = vision::CutMixBatch(mindspore::dataset::ImageBatchFormat::kNHWC, | |||||
| 1, -0.5); | |||||
| std::shared_ptr<TensorOperation> cutmix_batch_op = | |||||
| vision::CutMixBatch(mindspore::dataset::ImageBatchFormat::kNHWC, 1, -0.5); | |||||
| EXPECT_EQ(cutmix_batch_op, nullptr); | EXPECT_EQ(cutmix_batch_op, nullptr); | ||||
| } | } | ||||
| TEST_F(MindDataTestPipeline, TestCutMixBatchFail3) { | TEST_F(MindDataTestPipeline, TestCutMixBatchFail3) { | ||||
| // Must fail because alpha can't be zero | // Must fail because alpha can't be zero | ||||
| // Create a Cifar10 Dataset | // Create a Cifar10 Dataset | ||||
| std::string folder_path = datasets_root_path_ + "/testCifar10Data/"; | std::string folder_path = datasets_root_path_ + "/testCifar10Data/"; | ||||
| std::shared_ptr<Dataset> ds = Cifar10(folder_path, RandomSampler(false, 10)); | |||||
| std::shared_ptr<Dataset> ds = Cifar10(folder_path, std::string(), RandomSampler(false, 10)); | |||||
| EXPECT_NE(ds, nullptr); | EXPECT_NE(ds, nullptr); | ||||
| // Create a Batch operation on ds | // Create a Batch operation on ds | ||||
| @@ -214,11 +219,11 @@ TEST_F(MindDataTestPipeline, TestCutMixBatchFail3) { | |||||
| EXPECT_NE(one_hot_op, nullptr); | EXPECT_NE(one_hot_op, nullptr); | ||||
| // Create a Map operation on ds | // Create a Map operation on ds | ||||
| ds = ds->Map({one_hot_op},{"label"}); | |||||
| ds = ds->Map({one_hot_op}, {"label"}); | |||||
| EXPECT_NE(ds, nullptr); | EXPECT_NE(ds, nullptr); | ||||
| std::shared_ptr<TensorOperation> cutmix_batch_op = vision::CutMixBatch(mindspore::dataset::ImageBatchFormat::kNHWC, | |||||
| 0.0, 0.5); | |||||
| std::shared_ptr<TensorOperation> cutmix_batch_op = | |||||
| vision::CutMixBatch(mindspore::dataset::ImageBatchFormat::kNHWC, 0.0, 0.5); | |||||
| EXPECT_EQ(cutmix_batch_op, nullptr); | EXPECT_EQ(cutmix_batch_op, nullptr); | ||||
| } | } | ||||
| @@ -371,7 +376,7 @@ TEST_F(MindDataTestPipeline, TestHwcToChw) { | |||||
| TEST_F(MindDataTestPipeline, TestMixUpBatchFail1) { | TEST_F(MindDataTestPipeline, TestMixUpBatchFail1) { | ||||
| // Create a Cifar10 Dataset | // Create a Cifar10 Dataset | ||||
| std::string folder_path = datasets_root_path_ + "/testCifar10Data/"; | std::string folder_path = datasets_root_path_ + "/testCifar10Data/"; | ||||
| std::shared_ptr<Dataset> ds = Cifar10(folder_path, RandomSampler(false, 10)); | |||||
| std::shared_ptr<Dataset> ds = Cifar10(folder_path, std::string(), RandomSampler(false, 10)); | |||||
| EXPECT_NE(ds, nullptr); | EXPECT_NE(ds, nullptr); | ||||
| // Create a Batch operation on ds | // Create a Batch operation on ds | ||||
| @@ -395,7 +400,7 @@ TEST_F(MindDataTestPipeline, TestMixUpBatchFail2) { | |||||
| // This should fail because alpha can't be zero | // This should fail because alpha can't be zero | ||||
| // Create a Cifar10 Dataset | // Create a Cifar10 Dataset | ||||
| std::string folder_path = datasets_root_path_ + "/testCifar10Data/"; | std::string folder_path = datasets_root_path_ + "/testCifar10Data/"; | ||||
| std::shared_ptr<Dataset> ds = Cifar10(folder_path, RandomSampler(false, 10)); | |||||
| std::shared_ptr<Dataset> ds = Cifar10(folder_path, std::string(), RandomSampler(false, 10)); | |||||
| EXPECT_NE(ds, nullptr); | EXPECT_NE(ds, nullptr); | ||||
| // Create a Batch operation on ds | // Create a Batch operation on ds | ||||
| @@ -418,7 +423,7 @@ TEST_F(MindDataTestPipeline, TestMixUpBatchFail2) { | |||||
| TEST_F(MindDataTestPipeline, TestMixUpBatchSuccess1) { | TEST_F(MindDataTestPipeline, TestMixUpBatchSuccess1) { | ||||
| // Create a Cifar10 Dataset | // Create a Cifar10 Dataset | ||||
| std::string folder_path = datasets_root_path_ + "/testCifar10Data/"; | std::string folder_path = datasets_root_path_ + "/testCifar10Data/"; | ||||
| std::shared_ptr<Dataset> ds = Cifar10(folder_path, RandomSampler(false, 10)); | |||||
| std::shared_ptr<Dataset> ds = Cifar10(folder_path, std::string(), RandomSampler(false, 10)); | |||||
| EXPECT_NE(ds, nullptr); | EXPECT_NE(ds, nullptr); | ||||
| // Create a Batch operation on ds | // Create a Batch operation on ds | ||||
| @@ -467,7 +472,7 @@ TEST_F(MindDataTestPipeline, TestMixUpBatchSuccess1) { | |||||
| TEST_F(MindDataTestPipeline, TestMixUpBatchSuccess2) { | TEST_F(MindDataTestPipeline, TestMixUpBatchSuccess2) { | ||||
| // Create a Cifar10 Dataset | // Create a Cifar10 Dataset | ||||
| std::string folder_path = datasets_root_path_ + "/testCifar10Data/"; | std::string folder_path = datasets_root_path_ + "/testCifar10Data/"; | ||||
| std::shared_ptr<Dataset> ds = Cifar10(folder_path, RandomSampler(false, 10)); | |||||
| std::shared_ptr<Dataset> ds = Cifar10(folder_path, std::string(), RandomSampler(false, 10)); | |||||
| EXPECT_NE(ds, nullptr); | EXPECT_NE(ds, nullptr); | ||||
| // Create a Batch operation on ds | // Create a Batch operation on ds | ||||
| @@ -871,8 +876,7 @@ TEST_F(MindDataTestPipeline, TestRandomPosterizeSuccess1) { | |||||
| EXPECT_NE(ds, nullptr); | EXPECT_NE(ds, nullptr); | ||||
| // Create objects for the tensor ops | // Create objects for the tensor ops | ||||
| std::shared_ptr<TensorOperation> posterize = | |||||
| vision::RandomPosterize({1, 4}); | |||||
| std::shared_ptr<TensorOperation> posterize = vision::RandomPosterize({1, 4}); | |||||
| EXPECT_NE(posterize, nullptr); | EXPECT_NE(posterize, nullptr); | ||||
| // Create a Map operation on ds | // Create a Map operation on ds | ||||
| @@ -1114,7 +1118,7 @@ TEST_F(MindDataTestPipeline, TestRandomRotation) { | |||||
| TEST_F(MindDataTestPipeline, TestUniformAugWithOps) { | TEST_F(MindDataTestPipeline, TestUniformAugWithOps) { | ||||
| // Create a Mnist Dataset | // Create a Mnist Dataset | ||||
| std::string folder_path = datasets_root_path_ + "/testMnistData/"; | std::string folder_path = datasets_root_path_ + "/testMnistData/"; | ||||
| std::shared_ptr<Dataset> ds = Mnist(folder_path, RandomSampler(false, 20)); | |||||
| std::shared_ptr<Dataset> ds = Mnist(folder_path, "", RandomSampler(false, 20)); | |||||
| EXPECT_NE(ds, nullptr); | EXPECT_NE(ds, nullptr); | ||||
| // Create a Repeat operation on ds | // Create a Repeat operation on ds | ||||
| @@ -42,9 +42,13 @@ std::shared_ptr<CelebAOp> Celeba(int32_t num_workers, int32_t rows_per_buffer, i | |||||
| bool decode = false, const std::string &dataset_type="all") { | bool decode = false, const std::string &dataset_type="all") { | ||||
| std::shared_ptr<CelebAOp> so; | std::shared_ptr<CelebAOp> so; | ||||
| CelebAOp::Builder builder; | CelebAOp::Builder builder; | ||||
| Status rc = builder.SetNumWorkers(num_workers).SetCelebADir(dir).SetRowsPerBuffer(rows_per_buffer) | |||||
| .SetOpConnectorSize(queue_size).SetSampler(std::move(sampler)).SetDecode(decode) | |||||
| .SetDatasetType(dataset_type).Build(&so); | |||||
| Status rc = builder.SetNumWorkers(num_workers) | |||||
| .SetCelebADir(dir) | |||||
| .SetRowsPerBuffer(rows_per_buffer) | |||||
| .SetOpConnectorSize(queue_size) | |||||
| .SetSampler(std::move(sampler)) | |||||
| .SetDecode(decode) | |||||
| .SetUsage(dataset_type).Build(&so); | |||||
| return so; | return so; | ||||
| } | } | ||||
| @@ -63,9 +63,7 @@ TEST_F(MindDataTestVOCOp, TestVOCDetection) { | |||||
| std::string task_mode("train"); | std::string task_mode("train"); | ||||
| std::shared_ptr<VOCOp> my_voc_op; | std::shared_ptr<VOCOp> my_voc_op; | ||||
| VOCOp::Builder builder; | VOCOp::Builder builder; | ||||
| Status rc = builder.SetDir(dataset_path) | |||||
| .SetTask(task_type) | |||||
| .SetMode(task_mode) | |||||
| Status rc = builder.SetDir(dataset_path).SetTask(task_type).SetUsage(task_mode) | |||||
| .Build(&my_voc_op); | .Build(&my_voc_op); | ||||
| ASSERT_TRUE(rc.IsOk()); | ASSERT_TRUE(rc.IsOk()); | ||||
| @@ -116,9 +114,7 @@ TEST_F(MindDataTestVOCOp, TestVOCSegmentation) { | |||||
| std::string task_mode("train"); | std::string task_mode("train"); | ||||
| std::shared_ptr<VOCOp> my_voc_op; | std::shared_ptr<VOCOp> my_voc_op; | ||||
| VOCOp::Builder builder; | VOCOp::Builder builder; | ||||
| Status rc = builder.SetDir(dataset_path) | |||||
| .SetTask(task_type) | |||||
| .SetMode(task_mode) | |||||
| Status rc = builder.SetDir(dataset_path).SetTask(task_type).SetUsage(task_mode) | |||||
| .Build(&my_voc_op); | .Build(&my_voc_op); | ||||
| ASSERT_TRUE(rc.IsOk()); | ASSERT_TRUE(rc.IsOk()); | ||||
| @@ -173,9 +169,8 @@ TEST_F(MindDataTestVOCOp, TestVOCClassIndex) { | |||||
| class_index["train"] = 5; | class_index["train"] = 5; | ||||
| std::shared_ptr<VOCOp> my_voc_op; | std::shared_ptr<VOCOp> my_voc_op; | ||||
| VOCOp::Builder builder; | VOCOp::Builder builder; | ||||
| Status rc = builder.SetDir(dataset_path) | |||||
| .SetTask(task_type) | |||||
| .SetMode(task_mode) | |||||
| Status rc = | |||||
| builder.SetDir(dataset_path).SetTask(task_type).SetUsage(task_mode) | |||||
| .SetClassIndex(class_index) | .SetClassIndex(class_index) | ||||
| .Build(&my_voc_op); | .Build(&my_voc_op); | ||||
| ASSERT_TRUE(rc.IsOk()); | ASSERT_TRUE(rc.IsOk()); | ||||
| @@ -42,8 +42,8 @@ def test_bounding_box_augment_with_rotation_op(plot_vis=False): | |||||
| original_seed = config_get_set_seed(0) | original_seed = config_get_set_seed(0) | ||||
| original_num_parallel_workers = config_get_set_num_parallel_workers(1) | original_num_parallel_workers = config_get_set_num_parallel_workers(1) | ||||
| dataVoc1 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", decode=True, shuffle=False) | |||||
| dataVoc2 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", decode=True, shuffle=False) | |||||
| dataVoc1 = ds.VOCDataset(DATA_DIR, task="Detection", usage="train", shuffle=False, decode=True) | |||||
| dataVoc2 = ds.VOCDataset(DATA_DIR, task="Detection", usage="train", shuffle=False, decode=True) | |||||
| # Ratio is set to 1 to apply rotation on all bounding boxes. | # Ratio is set to 1 to apply rotation on all bounding boxes. | ||||
| test_op = c_vision.BoundingBoxAugment(c_vision.RandomRotation(90), 1) | test_op = c_vision.BoundingBoxAugment(c_vision.RandomRotation(90), 1) | ||||
| @@ -81,8 +81,8 @@ def test_bounding_box_augment_with_crop_op(plot_vis=False): | |||||
| original_seed = config_get_set_seed(0) | original_seed = config_get_set_seed(0) | ||||
| original_num_parallel_workers = config_get_set_num_parallel_workers(1) | original_num_parallel_workers = config_get_set_num_parallel_workers(1) | ||||
| dataVoc1 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", decode=True, shuffle=False) | |||||
| dataVoc2 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", decode=True, shuffle=False) | |||||
| dataVoc1 = ds.VOCDataset(DATA_DIR, task="Detection", usage="train", shuffle=False, decode=True) | |||||
| dataVoc2 = ds.VOCDataset(DATA_DIR, task="Detection", usage="train", shuffle=False, decode=True) | |||||
| # Ratio is set to 0.9 to apply RandomCrop of size (50, 50) on 90% of the bounding boxes. | # Ratio is set to 0.9 to apply RandomCrop of size (50, 50) on 90% of the bounding boxes. | ||||
| test_op = c_vision.BoundingBoxAugment(c_vision.RandomCrop(50), 0.9) | test_op = c_vision.BoundingBoxAugment(c_vision.RandomCrop(50), 0.9) | ||||
| @@ -120,8 +120,8 @@ def test_bounding_box_augment_valid_ratio_c(plot_vis=False): | |||||
| original_seed = config_get_set_seed(1) | original_seed = config_get_set_seed(1) | ||||
| original_num_parallel_workers = config_get_set_num_parallel_workers(1) | original_num_parallel_workers = config_get_set_num_parallel_workers(1) | ||||
| dataVoc1 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", decode=True, shuffle=False) | |||||
| dataVoc2 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", decode=True, shuffle=False) | |||||
| dataVoc1 = ds.VOCDataset(DATA_DIR, task="Detection", usage="train", shuffle=False, decode=True) | |||||
| dataVoc2 = ds.VOCDataset(DATA_DIR, task="Detection", usage="train", shuffle=False, decode=True) | |||||
| test_op = c_vision.BoundingBoxAugment(c_vision.RandomHorizontalFlip(1), 0.9) | test_op = c_vision.BoundingBoxAugment(c_vision.RandomHorizontalFlip(1), 0.9) | ||||
| @@ -188,8 +188,8 @@ def test_bounding_box_augment_valid_edge_c(plot_vis=False): | |||||
| original_seed = config_get_set_seed(1) | original_seed = config_get_set_seed(1) | ||||
| original_num_parallel_workers = config_get_set_num_parallel_workers(1) | original_num_parallel_workers = config_get_set_num_parallel_workers(1) | ||||
| dataVoc1 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", decode=True, shuffle=False) | |||||
| dataVoc2 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", decode=True, shuffle=False) | |||||
| dataVoc1 = ds.VOCDataset(DATA_DIR, task="Detection", usage="train", shuffle=False, decode=True) | |||||
| dataVoc2 = ds.VOCDataset(DATA_DIR, task="Detection", usage="train", shuffle=False, decode=True) | |||||
| test_op = c_vision.BoundingBoxAugment(c_vision.RandomHorizontalFlip(1), 1) | test_op = c_vision.BoundingBoxAugment(c_vision.RandomHorizontalFlip(1), 1) | ||||
| @@ -232,7 +232,7 @@ def test_bounding_box_augment_invalid_ratio_c(): | |||||
| """ | """ | ||||
| logger.info("test_bounding_box_augment_invalid_ratio_c") | logger.info("test_bounding_box_augment_invalid_ratio_c") | ||||
| dataVoc2 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", decode=True, shuffle=False) | |||||
| dataVoc2 = ds.VOCDataset(DATA_DIR, task="Detection", usage="train", shuffle=False, decode=True) | |||||
| try: | try: | ||||
| # ratio range is from 0 - 1 | # ratio range is from 0 - 1 | ||||
| @@ -256,13 +256,13 @@ def test_bounding_box_augment_invalid_bounds_c(): | |||||
| test_op = c_vision.BoundingBoxAugment(c_vision.RandomHorizontalFlip(1), | test_op = c_vision.BoundingBoxAugment(c_vision.RandomHorizontalFlip(1), | ||||
| 1) | 1) | ||||
| dataVoc2 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", decode=True, shuffle=False) | |||||
| dataVoc2 = ds.VOCDataset(DATA_DIR, task="Detection", usage="train", shuffle=False, decode=True) | |||||
| check_bad_bbox(dataVoc2, test_op, InvalidBBoxType.WidthOverflow, "bounding boxes is out of bounds of the image") | check_bad_bbox(dataVoc2, test_op, InvalidBBoxType.WidthOverflow, "bounding boxes is out of bounds of the image") | ||||
| dataVoc2 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", decode=True, shuffle=False) | |||||
| dataVoc2 = ds.VOCDataset(DATA_DIR, task="Detection", usage="train", shuffle=False, decode=True) | |||||
| check_bad_bbox(dataVoc2, test_op, InvalidBBoxType.HeightOverflow, "bounding boxes is out of bounds of the image") | check_bad_bbox(dataVoc2, test_op, InvalidBBoxType.HeightOverflow, "bounding boxes is out of bounds of the image") | ||||
| dataVoc2 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", decode=True, shuffle=False) | |||||
| dataVoc2 = ds.VOCDataset(DATA_DIR, task="Detection", usage="train", shuffle=False, decode=True) | |||||
| check_bad_bbox(dataVoc2, test_op, InvalidBBoxType.NegativeXY, "min_x") | check_bad_bbox(dataVoc2, test_op, InvalidBBoxType.NegativeXY, "min_x") | ||||
| dataVoc2 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", decode=True, shuffle=False) | |||||
| dataVoc2 = ds.VOCDataset(DATA_DIR, task="Detection", usage="train", shuffle=False, decode=True) | |||||
| check_bad_bbox(dataVoc2, test_op, InvalidBBoxType.WrongShape, "4 features") | check_bad_bbox(dataVoc2, test_op, InvalidBBoxType.WrongShape, "4 features") | ||||
| @@ -20,7 +20,7 @@ DATA_DIR = "../data/dataset/testCelebAData/" | |||||
| def test_celeba_dataset_label(): | def test_celeba_dataset_label(): | ||||
| data = ds.CelebADataset(DATA_DIR, decode=True, shuffle=False) | |||||
| data = ds.CelebADataset(DATA_DIR, shuffle=False, decode=True) | |||||
| expect_labels = [ | expect_labels = [ | ||||
| [0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 1, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 0, 1, 1, 0, 1, 0, 1, | [0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 1, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 0, 1, 1, 0, 1, 0, 1, | ||||
| 0, 0, 1], | 0, 0, 1], | ||||
| @@ -85,11 +85,13 @@ def test_celeba_dataset_distribute(): | |||||
| count = count + 1 | count = count + 1 | ||||
| assert count == 1 | assert count == 1 | ||||
| def test_celeba_get_dataset_size(): | def test_celeba_get_dataset_size(): | ||||
| data = ds.CelebADataset(DATA_DIR, decode=True, shuffle=False) | |||||
| data = ds.CelebADataset(DATA_DIR, shuffle=False, decode=True) | |||||
| size = data.get_dataset_size() | size = data.get_dataset_size() | ||||
| assert size == 2 | assert size == 2 | ||||
| if __name__ == '__main__': | if __name__ == '__main__': | ||||
| test_celeba_dataset_label() | test_celeba_dataset_label() | ||||
| test_celeba_dataset_op() | test_celeba_dataset_op() | ||||
| @@ -392,6 +392,59 @@ def test_cifar100_visualize(plot=False): | |||||
| visualize_dataset(image_list, label_list) | visualize_dataset(image_list, label_list) | ||||
| def test_cifar_usage(): | |||||
| """ | |||||
| test usage of cifar | |||||
| """ | |||||
| logger.info("Test Cifar100Dataset usage flag") | |||||
| # flag, if True, test cifar10 else test cifar100 | |||||
| def test_config(usage, flag=True, cifar_path=None): | |||||
| if cifar_path is None: | |||||
| cifar_path = DATA_DIR_10 if flag else DATA_DIR_100 | |||||
| try: | |||||
| data = ds.Cifar10Dataset(cifar_path, usage=usage) if flag else ds.Cifar100Dataset(cifar_path, usage=usage) | |||||
| num_rows = 0 | |||||
| for _ in data.create_dict_iterator(): | |||||
| num_rows += 1 | |||||
| except (ValueError, TypeError, RuntimeError) as e: | |||||
| return str(e) | |||||
| return num_rows | |||||
| # test the usage of CIFAR100 | |||||
| assert test_config("train") == 10000 | |||||
| assert test_config("all") == 10000 | |||||
| assert "usage is not within the valid set of ['train', 'test', 'all']" in test_config("invalid") | |||||
| assert "Argument usage with value ['list'] is not of type (<class 'str'>,)" in test_config(["list"]) | |||||
| assert "no valid data matching the dataset API Cifar10Dataset" in test_config("test") | |||||
| # test the usage of CIFAR10 | |||||
| assert test_config("test", False) == 10000 | |||||
| assert test_config("all", False) == 10000 | |||||
| assert "no valid data matching the dataset API Cifar100Dataset" in test_config("train", False) | |||||
| assert "usage is not within the valid set of ['train', 'test', 'all']" in test_config("invalid", False) | |||||
| # change this directory to the folder that contains all cifar10 files | |||||
| all_cifar10 = None | |||||
| if all_cifar10 is not None: | |||||
| assert test_config("train", True, all_cifar10) == 50000 | |||||
| assert test_config("test", True, all_cifar10) == 10000 | |||||
| assert test_config("all", True, all_cifar10) == 60000 | |||||
| assert ds.Cifar10Dataset(all_cifar10, usage="train").get_dataset_size() == 50000 | |||||
| assert ds.Cifar10Dataset(all_cifar10, usage="test").get_dataset_size() == 10000 | |||||
| assert ds.Cifar10Dataset(all_cifar10, usage="all").get_dataset_size() == 60000 | |||||
| # change this directory to the folder that contains all cifar100 files | |||||
| all_cifar100 = None | |||||
| if all_cifar100 is not None: | |||||
| assert test_config("train", False, all_cifar100) == 50000 | |||||
| assert test_config("test", False, all_cifar100) == 10000 | |||||
| assert test_config("all", False, all_cifar100) == 60000 | |||||
| assert ds.Cifar100Dataset(all_cifar100, usage="train").get_dataset_size() == 50000 | |||||
| assert ds.Cifar100Dataset(all_cifar100, usage="test").get_dataset_size() == 10000 | |||||
| assert ds.Cifar100Dataset(all_cifar100, usage="all").get_dataset_size() == 60000 | |||||
| if __name__ == '__main__': | if __name__ == '__main__': | ||||
| test_cifar10_content_check() | test_cifar10_content_check() | ||||
| test_cifar10_basic() | test_cifar10_basic() | ||||
| @@ -405,3 +458,5 @@ if __name__ == '__main__': | |||||
| test_cifar100_pk_sampler() | test_cifar100_pk_sampler() | ||||
| test_cifar100_exception() | test_cifar100_exception() | ||||
| test_cifar100_visualize(plot=False) | test_cifar100_visualize(plot=False) | ||||
| test_cifar_usage() | |||||
| @@ -58,6 +58,14 @@ def test_mnist_dataset_size(): | |||||
| ds_total = ds.MnistDataset(MNIST_DATA_DIR) | ds_total = ds.MnistDataset(MNIST_DATA_DIR) | ||||
| assert ds_total.get_dataset_size() == 10000 | assert ds_total.get_dataset_size() == 10000 | ||||
| # test get dataset_size with the usage arg | |||||
| test_size = ds.MnistDataset(MNIST_DATA_DIR, usage="test").get_dataset_size() | |||||
| assert test_size == 10000 | |||||
| train_size = ds.MnistDataset(MNIST_DATA_DIR, usage="train").get_dataset_size() | |||||
| assert train_size == 0 | |||||
| all_size = ds.MnistDataset(MNIST_DATA_DIR, usage="all").get_dataset_size() | |||||
| assert all_size == 10000 | |||||
| ds_shard_1_0 = ds.MnistDataset(MNIST_DATA_DIR, num_shards=1, shard_id=0) | ds_shard_1_0 = ds.MnistDataset(MNIST_DATA_DIR, num_shards=1, shard_id=0) | ||||
| assert ds_shard_1_0.get_dataset_size() == 10000 | assert ds_shard_1_0.get_dataset_size() == 10000 | ||||
| @@ -86,6 +94,14 @@ def test_cifar10_dataset_size(): | |||||
| ds_total = ds.Cifar10Dataset(CIFAR10_DATA_DIR) | ds_total = ds.Cifar10Dataset(CIFAR10_DATA_DIR) | ||||
| assert ds_total.get_dataset_size() == 10000 | assert ds_total.get_dataset_size() == 10000 | ||||
| # test get_dataset_size with usage flag | |||||
| train_size = ds.Cifar10Dataset(CIFAR10_DATA_DIR, usage="train").get_dataset_size() | |||||
| assert train_size == 10000 | |||||
| test_size = ds.Cifar10Dataset(CIFAR10_DATA_DIR, usage="test").get_dataset_size() | |||||
| assert test_size == 0 | |||||
| all_size = ds.Cifar10Dataset(CIFAR10_DATA_DIR, usage="all").get_dataset_size() | |||||
| assert all_size == 10000 | |||||
| ds_shard_1_0 = ds.Cifar10Dataset(CIFAR10_DATA_DIR, num_shards=1, shard_id=0) | ds_shard_1_0 = ds.Cifar10Dataset(CIFAR10_DATA_DIR, num_shards=1, shard_id=0) | ||||
| assert ds_shard_1_0.get_dataset_size() == 10000 | assert ds_shard_1_0.get_dataset_size() == 10000 | ||||
| @@ -103,6 +119,14 @@ def test_cifar100_dataset_size(): | |||||
| ds_total = ds.Cifar100Dataset(CIFAR100_DATA_DIR) | ds_total = ds.Cifar100Dataset(CIFAR100_DATA_DIR) | ||||
| assert ds_total.get_dataset_size() == 10000 | assert ds_total.get_dataset_size() == 10000 | ||||
| # test get_dataset_size with usage flag | |||||
| train_size = ds.Cifar100Dataset(CIFAR100_DATA_DIR, usage="train").get_dataset_size() | |||||
| assert train_size == 0 | |||||
| test_size = ds.Cifar100Dataset(CIFAR100_DATA_DIR, usage="test").get_dataset_size() | |||||
| assert test_size == 10000 | |||||
| all_size = ds.Cifar100Dataset(CIFAR100_DATA_DIR, usage="all").get_dataset_size() | |||||
| assert all_size == 10000 | |||||
| ds_shard_1_0 = ds.Cifar100Dataset(CIFAR100_DATA_DIR, num_shards=1, shard_id=0) | ds_shard_1_0 = ds.Cifar100Dataset(CIFAR100_DATA_DIR, num_shards=1, shard_id=0) | ||||
| assert ds_shard_1_0.get_dataset_size() == 10000 | assert ds_shard_1_0.get_dataset_size() == 10000 | ||||
| @@ -111,3 +135,12 @@ def test_cifar100_dataset_size(): | |||||
| ds_shard_3_0 = ds.Cifar100Dataset(CIFAR100_DATA_DIR, num_shards=3, shard_id=0) | ds_shard_3_0 = ds.Cifar100Dataset(CIFAR100_DATA_DIR, num_shards=3, shard_id=0) | ||||
| assert ds_shard_3_0.get_dataset_size() == 3334 | assert ds_shard_3_0.get_dataset_size() == 3334 | ||||
| if __name__ == '__main__': | |||||
| test_imagenet_rawdata_dataset_size() | |||||
| test_imagenet_tf_file_dataset_size() | |||||
| test_mnist_dataset_size() | |||||
| test_manifest_dataset_size() | |||||
| test_cifar10_dataset_size() | |||||
| test_cifar100_dataset_size() | |||||
| @@ -229,6 +229,41 @@ def test_mnist_visualize(plot=False): | |||||
| visualize_dataset(image_list, label_list) | visualize_dataset(image_list, label_list) | ||||
| def test_mnist_usage(): | |||||
| """ | |||||
| Validate MnistDataset image readings | |||||
| """ | |||||
| logger.info("Test MnistDataset usage flag") | |||||
| def test_config(usage, mnist_path=None): | |||||
| mnist_path = DATA_DIR if mnist_path is None else mnist_path | |||||
| try: | |||||
| data = ds.MnistDataset(mnist_path, usage=usage, shuffle=False) | |||||
| num_rows = 0 | |||||
| for _ in data.create_dict_iterator(): | |||||
| num_rows += 1 | |||||
| except (ValueError, TypeError, RuntimeError) as e: | |||||
| return str(e) | |||||
| return num_rows | |||||
| assert test_config("test") == 10000 | |||||
| assert test_config("all") == 10000 | |||||
| assert " no valid data matching the dataset API MnistDataset" in test_config("train") | |||||
| assert "usage is not within the valid set of ['train', 'test', 'all']" in test_config("invalid") | |||||
| assert "Argument usage with value ['list'] is not of type (<class 'str'>,)" in test_config(["list"]) | |||||
| # change this directory to the folder that contains all mnist files | |||||
| all_files_path = None | |||||
| # the following tests on the entire datasets | |||||
| if all_files_path is not None: | |||||
| assert test_config("train", all_files_path) == 60000 | |||||
| assert test_config("test", all_files_path) == 10000 | |||||
| assert test_config("all", all_files_path) == 70000 | |||||
| assert ds.MnistDataset(all_files_path, usage="train").get_dataset_size() == 60000 | |||||
| assert ds.MnistDataset(all_files_path, usage="test").get_dataset_size() == 10000 | |||||
| assert ds.MnistDataset(all_files_path, usage="all").get_dataset_size() == 70000 | |||||
| if __name__ == '__main__': | if __name__ == '__main__': | ||||
| test_mnist_content_check() | test_mnist_content_check() | ||||
| test_mnist_basic() | test_mnist_basic() | ||||
| @@ -236,3 +271,4 @@ if __name__ == '__main__': | |||||
| test_mnist_sequential_sampler() | test_mnist_sequential_sampler() | ||||
| test_mnist_exception() | test_mnist_exception() | ||||
| test_mnist_visualize(plot=True) | test_mnist_visualize(plot=True) | ||||
| test_mnist_usage() | |||||
| @@ -21,7 +21,7 @@ TARGET_SHAPE = [680, 680, 680, 680, 642, 607, 561, 596, 612, 680] | |||||
| def test_voc_segmentation(): | def test_voc_segmentation(): | ||||
| data1 = ds.VOCDataset(DATA_DIR, task="Segmentation", mode="train", decode=True, shuffle=False) | |||||
| data1 = ds.VOCDataset(DATA_DIR, task="Segmentation", usage="train", shuffle=False, decode=True) | |||||
| num = 0 | num = 0 | ||||
| for item in data1.create_dict_iterator(num_epochs=1): | for item in data1.create_dict_iterator(num_epochs=1): | ||||
| assert item["image"].shape[0] == IMAGE_SHAPE[num] | assert item["image"].shape[0] == IMAGE_SHAPE[num] | ||||
| @@ -31,7 +31,7 @@ def test_voc_segmentation(): | |||||
| def test_voc_detection(): | def test_voc_detection(): | ||||
| data1 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", decode=True, shuffle=False) | |||||
| data1 = ds.VOCDataset(DATA_DIR, task="Detection", usage="train", shuffle=False, decode=True) | |||||
| num = 0 | num = 0 | ||||
| count = [0, 0, 0, 0, 0, 0] | count = [0, 0, 0, 0, 0, 0] | ||||
| for item in data1.create_dict_iterator(num_epochs=1): | for item in data1.create_dict_iterator(num_epochs=1): | ||||
| @@ -45,7 +45,7 @@ def test_voc_detection(): | |||||
| def test_voc_class_index(): | def test_voc_class_index(): | ||||
| class_index = {'car': 0, 'cat': 1, 'train': 5} | class_index = {'car': 0, 'cat': 1, 'train': 5} | ||||
| data1 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", class_indexing=class_index, decode=True) | |||||
| data1 = ds.VOCDataset(DATA_DIR, task="Detection", usage="train", class_indexing=class_index, decode=True) | |||||
| class_index1 = data1.get_class_indexing() | class_index1 = data1.get_class_indexing() | ||||
| assert (class_index1 == {'car': 0, 'cat': 1, 'train': 5}) | assert (class_index1 == {'car': 0, 'cat': 1, 'train': 5}) | ||||
| data1 = data1.shuffle(4) | data1 = data1.shuffle(4) | ||||
| @@ -63,7 +63,7 @@ def test_voc_class_index(): | |||||
| def test_voc_get_class_indexing(): | def test_voc_get_class_indexing(): | ||||
| data1 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", decode=True) | |||||
| data1 = ds.VOCDataset(DATA_DIR, task="Detection", usage="train", decode=True) | |||||
| class_index1 = data1.get_class_indexing() | class_index1 = data1.get_class_indexing() | ||||
| assert (class_index1 == {'car': 0, 'cat': 1, 'chair': 2, 'dog': 3, 'person': 4, 'train': 5}) | assert (class_index1 == {'car': 0, 'cat': 1, 'chair': 2, 'dog': 3, 'person': 4, 'train': 5}) | ||||
| data1 = data1.shuffle(4) | data1 = data1.shuffle(4) | ||||
| @@ -81,7 +81,7 @@ def test_voc_get_class_indexing(): | |||||
| def test_case_0(): | def test_case_0(): | ||||
| data1 = ds.VOCDataset(DATA_DIR, task="Segmentation", mode="train", decode=True) | |||||
| data1 = ds.VOCDataset(DATA_DIR, task="Segmentation", usage="train", decode=True) | |||||
| resize_op = vision.Resize((224, 224)) | resize_op = vision.Resize((224, 224)) | ||||
| @@ -99,7 +99,7 @@ def test_case_0(): | |||||
| def test_case_1(): | def test_case_1(): | ||||
| data1 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", decode=True) | |||||
| data1 = ds.VOCDataset(DATA_DIR, task="Detection", usage="train", decode=True) | |||||
| resize_op = vision.Resize((224, 224)) | resize_op = vision.Resize((224, 224)) | ||||
| @@ -116,7 +116,7 @@ def test_case_1(): | |||||
| def test_case_2(): | def test_case_2(): | ||||
| data1 = ds.VOCDataset(DATA_DIR, task="Segmentation", mode="train", decode=True) | |||||
| data1 = ds.VOCDataset(DATA_DIR, task="Segmentation", usage="train", decode=True) | |||||
| sizes = [0.5, 0.5] | sizes = [0.5, 0.5] | ||||
| randomize = False | randomize = False | ||||
| dataset1, dataset2 = data1.split(sizes=sizes, randomize=randomize) | dataset1, dataset2 = data1.split(sizes=sizes, randomize=randomize) | ||||
| @@ -134,7 +134,7 @@ def test_case_2(): | |||||
| def test_voc_exception(): | def test_voc_exception(): | ||||
| try: | try: | ||||
| data1 = ds.VOCDataset(DATA_DIR, task="InvalidTask", mode="train", decode=True) | |||||
| data1 = ds.VOCDataset(DATA_DIR, task="InvalidTask", usage="train", decode=True) | |||||
| for _ in data1.create_dict_iterator(num_epochs=1): | for _ in data1.create_dict_iterator(num_epochs=1): | ||||
| pass | pass | ||||
| assert False | assert False | ||||
| @@ -142,7 +142,7 @@ def test_voc_exception(): | |||||
| pass | pass | ||||
| try: | try: | ||||
| data2 = ds.VOCDataset(DATA_DIR, task="Segmentation", mode="train", class_indexing={"cat": 0}, decode=True) | |||||
| data2 = ds.VOCDataset(DATA_DIR, task="Segmentation", usage="train", class_indexing={"cat": 0}, decode=True) | |||||
| for _ in data2.create_dict_iterator(num_epochs=1): | for _ in data2.create_dict_iterator(num_epochs=1): | ||||
| pass | pass | ||||
| assert False | assert False | ||||
| @@ -150,7 +150,7 @@ def test_voc_exception(): | |||||
| pass | pass | ||||
| try: | try: | ||||
| data3 = ds.VOCDataset(DATA_DIR, task="Detection", mode="notexist", decode=True) | |||||
| data3 = ds.VOCDataset(DATA_DIR, task="Detection", usage="notexist", decode=True) | |||||
| for _ in data3.create_dict_iterator(num_epochs=1): | for _ in data3.create_dict_iterator(num_epochs=1): | ||||
| pass | pass | ||||
| assert False | assert False | ||||
| @@ -158,7 +158,7 @@ def test_voc_exception(): | |||||
| pass | pass | ||||
| try: | try: | ||||
| data4 = ds.VOCDataset(DATA_DIR, task="Detection", mode="xmlnotexist", decode=True) | |||||
| data4 = ds.VOCDataset(DATA_DIR, task="Detection", usage="xmlnotexist", decode=True) | |||||
| for _ in data4.create_dict_iterator(num_epochs=1): | for _ in data4.create_dict_iterator(num_epochs=1): | ||||
| pass | pass | ||||
| assert False | assert False | ||||
| @@ -166,7 +166,7 @@ def test_voc_exception(): | |||||
| pass | pass | ||||
| try: | try: | ||||
| data5 = ds.VOCDataset(DATA_DIR, task="Detection", mode="invalidxml", decode=True) | |||||
| data5 = ds.VOCDataset(DATA_DIR, task="Detection", usage="invalidxml", decode=True) | |||||
| for _ in data5.create_dict_iterator(num_epochs=1): | for _ in data5.create_dict_iterator(num_epochs=1): | ||||
| pass | pass | ||||
| assert False | assert False | ||||
| @@ -174,7 +174,7 @@ def test_voc_exception(): | |||||
| pass | pass | ||||
| try: | try: | ||||
| data6 = ds.VOCDataset(DATA_DIR, task="Detection", mode="xmlnoobject", decode=True) | |||||
| data6 = ds.VOCDataset(DATA_DIR, task="Detection", usage="xmlnoobject", decode=True) | |||||
| for _ in data6.create_dict_iterator(num_epochs=1): | for _ in data6.create_dict_iterator(num_epochs=1): | ||||
| pass | pass | ||||
| assert False | assert False | ||||
| @@ -35,6 +35,7 @@ def diff_mse(in1, in2): | |||||
| mse = (np.square(in1.astype(float) / 255 - in2.astype(float) / 255)).mean() | mse = (np.square(in1.astype(float) / 255 - in2.astype(float) / 255)).mean() | ||||
| return mse * 100 | return mse * 100 | ||||
| def test_cifar10(): | def test_cifar10(): | ||||
| """ | """ | ||||
| dataset parameter | dataset parameter | ||||
| @@ -45,7 +46,7 @@ def test_cifar10(): | |||||
| batch_size = 32 | batch_size = 32 | ||||
| limit_dataset = 100 | limit_dataset = 100 | ||||
| # apply dataset operations | # apply dataset operations | ||||
| data1 = ds.Cifar10Dataset(data_dir_10, limit_dataset) | |||||
| data1 = ds.Cifar10Dataset(data_dir_10, num_samples=limit_dataset) | |||||
| data1 = data1.repeat(num_repeat) | data1 = data1.repeat(num_repeat) | ||||
| data1 = data1.batch(batch_size, True) | data1 = data1.batch(batch_size, True) | ||||
| num_epoch = 5 | num_epoch = 5 | ||||
| @@ -139,6 +140,7 @@ def test_generator_dict_0(): | |||||
| np.testing.assert_array_equal(item["data"], golden) | np.testing.assert_array_equal(item["data"], golden) | ||||
| i = i + 1 | i = i + 1 | ||||
| def test_generator_dict_1(): | def test_generator_dict_1(): | ||||
| """ | """ | ||||
| test generator dict 1 | test generator dict 1 | ||||
| @@ -158,6 +160,7 @@ def test_generator_dict_1(): | |||||
| i = i + 1 | i = i + 1 | ||||
| assert i == 64 | assert i == 64 | ||||
| def test_generator_dict_2(): | def test_generator_dict_2(): | ||||
| """ | """ | ||||
| test generator dict 2 | test generator dict 2 | ||||
| @@ -180,6 +183,7 @@ def test_generator_dict_2(): | |||||
| assert item1 | assert item1 | ||||
| # rely on garbage collector to destroy iter1 | # rely on garbage collector to destroy iter1 | ||||
| def test_generator_dict_3(): | def test_generator_dict_3(): | ||||
| """ | """ | ||||
| test generator dict 3 | test generator dict 3 | ||||
| @@ -226,6 +230,7 @@ def test_generator_dict_4(): | |||||
| err_msg = "EOF buffer encountered. Users try to fetch data beyond the specified number of epochs." | err_msg = "EOF buffer encountered. Users try to fetch data beyond the specified number of epochs." | ||||
| assert err_msg in str(info.value) | assert err_msg in str(info.value) | ||||
| def test_generator_dict_4_1(): | def test_generator_dict_4_1(): | ||||
| """ | """ | ||||
| test generator dict 4_1 | test generator dict 4_1 | ||||
| @@ -249,6 +254,7 @@ def test_generator_dict_4_1(): | |||||
| err_msg = "EOF buffer encountered. Users try to fetch data beyond the specified number of epochs." | err_msg = "EOF buffer encountered. Users try to fetch data beyond the specified number of epochs." | ||||
| assert err_msg in str(info.value) | assert err_msg in str(info.value) | ||||
| def test_generator_dict_4_2(): | def test_generator_dict_4_2(): | ||||
| """ | """ | ||||
| test generator dict 4_2 | test generator dict 4_2 | ||||
| @@ -274,6 +280,7 @@ def test_generator_dict_4_2(): | |||||
| err_msg = "EOF buffer encountered. Users try to fetch data beyond the specified number of epochs." | err_msg = "EOF buffer encountered. Users try to fetch data beyond the specified number of epochs." | ||||
| assert err_msg in str(info.value) | assert err_msg in str(info.value) | ||||
| def test_generator_dict_5(): | def test_generator_dict_5(): | ||||
| """ | """ | ||||
| test generator dict 5 | test generator dict 5 | ||||
| @@ -305,6 +312,7 @@ def test_generator_dict_5(): | |||||
| err_msg = "EOF buffer encountered. Users try to fetch data beyond the specified number of epochs." | err_msg = "EOF buffer encountered. Users try to fetch data beyond the specified number of epochs." | ||||
| assert err_msg in str(info.value) | assert err_msg in str(info.value) | ||||
| # Test tuple iterator | # Test tuple iterator | ||||
| def test_generator_tuple_0(): | def test_generator_tuple_0(): | ||||
| @@ -323,6 +331,7 @@ def test_generator_tuple_0(): | |||||
| np.testing.assert_array_equal(item[0], golden) | np.testing.assert_array_equal(item[0], golden) | ||||
| i = i + 1 | i = i + 1 | ||||
| def test_generator_tuple_1(): | def test_generator_tuple_1(): | ||||
| """ | """ | ||||
| test generator tuple 1 | test generator tuple 1 | ||||
| @@ -342,6 +351,7 @@ def test_generator_tuple_1(): | |||||
| i = i + 1 | i = i + 1 | ||||
| assert i == 64 | assert i == 64 | ||||
| def test_generator_tuple_2(): | def test_generator_tuple_2(): | ||||
| """ | """ | ||||
| test generator tuple 2 | test generator tuple 2 | ||||
| @@ -364,6 +374,7 @@ def test_generator_tuple_2(): | |||||
| assert item1 | assert item1 | ||||
| # rely on garbage collector to destroy iter1 | # rely on garbage collector to destroy iter1 | ||||
| def test_generator_tuple_3(): | def test_generator_tuple_3(): | ||||
| """ | """ | ||||
| test generator tuple 3 | test generator tuple 3 | ||||
| @@ -442,6 +453,7 @@ def test_generator_tuple_5(): | |||||
| err_msg = "EOF buffer encountered. Users try to fetch data beyond the specified number of epochs." | err_msg = "EOF buffer encountered. Users try to fetch data beyond the specified number of epochs." | ||||
| assert err_msg in str(info.value) | assert err_msg in str(info.value) | ||||
| # Test with repeat | # Test with repeat | ||||
| def test_generator_tuple_repeat_1(): | def test_generator_tuple_repeat_1(): | ||||
| """ | """ | ||||
| @@ -536,6 +548,7 @@ def test_generator_tuple_repeat_repeat_2(): | |||||
| iter1.__next__() | iter1.__next__() | ||||
| assert "object has no attribute 'depipeline'" in str(info.value) | assert "object has no attribute 'depipeline'" in str(info.value) | ||||
| def test_generator_tuple_repeat_repeat_3(): | def test_generator_tuple_repeat_repeat_3(): | ||||
| """ | """ | ||||
| test generator tuple repeat repeat 3 | test generator tuple repeat repeat 3 | ||||
| @@ -149,7 +149,7 @@ def test_get_column_name_to_device(): | |||||
| def test_get_column_name_voc(): | def test_get_column_name_voc(): | ||||
| data = ds.VOCDataset(VOC_DIR, task="Segmentation", mode="train", decode=True, shuffle=False) | |||||
| data = ds.VOCDataset(VOC_DIR, task="Segmentation", usage="train", decode=True, shuffle=False) | |||||
| assert data.get_col_names() == ["image", "target"] | assert data.get_col_names() == ["image", "target"] | ||||
| @@ -22,7 +22,7 @@ DATA_DIR = "../data/dataset/testVOC2012" | |||||
| def test_noop_pserver(): | def test_noop_pserver(): | ||||
| os.environ['MS_ROLE'] = 'MS_PSERVER' | os.environ['MS_ROLE'] = 'MS_PSERVER' | ||||
| data1 = ds.VOCDataset(DATA_DIR, task="Segmentation", mode="train", decode=True, shuffle=False) | |||||
| data1 = ds.VOCDataset(DATA_DIR, task="Segmentation", usage="train", shuffle=False, decode=True) | |||||
| num = 0 | num = 0 | ||||
| for _ in data1.create_dict_iterator(num_epochs=1): | for _ in data1.create_dict_iterator(num_epochs=1): | ||||
| num += 1 | num += 1 | ||||
| @@ -32,7 +32,7 @@ def test_noop_pserver(): | |||||
| def test_noop_sched(): | def test_noop_sched(): | ||||
| os.environ['MS_ROLE'] = 'MS_SCHED' | os.environ['MS_ROLE'] = 'MS_SCHED' | ||||
| data1 = ds.VOCDataset(DATA_DIR, task="Segmentation", mode="train", decode=True, shuffle=False) | |||||
| data1 = ds.VOCDataset(DATA_DIR, task="Segmentation", usage="train", shuffle=False, decode=True) | |||||
| num = 0 | num = 0 | ||||
| for _ in data1.create_dict_iterator(num_epochs=1): | for _ in data1.create_dict_iterator(num_epochs=1): | ||||
| num += 1 | num += 1 | ||||
| @@ -42,8 +42,8 @@ def test_random_resized_crop_with_bbox_op_c(plot_vis=False): | |||||
| original_num_parallel_workers = config_get_set_num_parallel_workers(1) | original_num_parallel_workers = config_get_set_num_parallel_workers(1) | ||||
| # Load dataset | # Load dataset | ||||
| dataVoc1 = ds.VOCDataset(DATA_DIR_VOC, task="Detection", mode="train", decode=True, shuffle=False) | |||||
| dataVoc2 = ds.VOCDataset(DATA_DIR_VOC, task="Detection", mode="train", decode=True, shuffle=False) | |||||
| dataVoc1 = ds.VOCDataset(DATA_DIR_VOC, task="Detection", usage="train", shuffle=False, decode=True) | |||||
| dataVoc2 = ds.VOCDataset(DATA_DIR_VOC, task="Detection", usage="train", shuffle=False, decode=True) | |||||
| test_op = c_vision.RandomResizedCropWithBBox((256, 512), (0.5, 0.5), (0.5, 0.5)) | test_op = c_vision.RandomResizedCropWithBBox((256, 512), (0.5, 0.5), (0.5, 0.5)) | ||||
| @@ -108,8 +108,8 @@ def test_random_resized_crop_with_bbox_op_edge_c(plot_vis=False): | |||||
| logger.info("test_random_resized_crop_with_bbox_op_edge_c") | logger.info("test_random_resized_crop_with_bbox_op_edge_c") | ||||
| # Load dataset | # Load dataset | ||||
| dataVoc1 = ds.VOCDataset(DATA_DIR_VOC, task="Detection", mode="train", decode=True, shuffle=False) | |||||
| dataVoc2 = ds.VOCDataset(DATA_DIR_VOC, task="Detection", mode="train", decode=True, shuffle=False) | |||||
| dataVoc1 = ds.VOCDataset(DATA_DIR_VOC, task="Detection", usage="train", shuffle=False, decode=True) | |||||
| dataVoc2 = ds.VOCDataset(DATA_DIR_VOC, task="Detection", usage="train", shuffle=False, decode=True) | |||||
| test_op = c_vision.RandomResizedCropWithBBox((256, 512), (0.5, 0.5), (0.5, 0.5)) | test_op = c_vision.RandomResizedCropWithBBox((256, 512), (0.5, 0.5), (0.5, 0.5)) | ||||
| @@ -142,7 +142,7 @@ def test_random_resized_crop_with_bbox_op_invalid_c(): | |||||
| logger.info("test_random_resized_crop_with_bbox_op_invalid_c") | logger.info("test_random_resized_crop_with_bbox_op_invalid_c") | ||||
| # Load dataset, only Augmented Dataset as test will raise ValueError | # Load dataset, only Augmented Dataset as test will raise ValueError | ||||
| dataVoc2 = ds.VOCDataset(DATA_DIR_VOC, task="Detection", mode="train", decode=True, shuffle=False) | |||||
| dataVoc2 = ds.VOCDataset(DATA_DIR_VOC, task="Detection", usage="train", shuffle=False, decode=True) | |||||
| try: | try: | ||||
| # If input range of scale is not in the order of (min, max), ValueError will be raised. | # If input range of scale is not in the order of (min, max), ValueError will be raised. | ||||
| @@ -168,7 +168,7 @@ def test_random_resized_crop_with_bbox_op_invalid2_c(): | |||||
| """ | """ | ||||
| logger.info("test_random_resized_crop_with_bbox_op_invalid2_c") | logger.info("test_random_resized_crop_with_bbox_op_invalid2_c") | ||||
| # Load dataset # only loading the to AugDataset as test will fail on this | # Load dataset # only loading the to AugDataset as test will fail on this | ||||
| dataVoc2 = ds.VOCDataset(DATA_DIR_VOC, task="Detection", mode="train", decode=True, shuffle=False) | |||||
| dataVoc2 = ds.VOCDataset(DATA_DIR_VOC, task="Detection", usage="train", shuffle=False, decode=True) | |||||
| try: | try: | ||||
| # If input range of ratio is not in the order of (min, max), ValueError will be raised. | # If input range of ratio is not in the order of (min, max), ValueError will be raised. | ||||
| @@ -195,13 +195,13 @@ def test_random_resized_crop_with_bbox_op_bad_c(): | |||||
| logger.info("test_random_resized_crop_with_bbox_op_bad_c") | logger.info("test_random_resized_crop_with_bbox_op_bad_c") | ||||
| test_op = c_vision.RandomResizedCropWithBBox((256, 512), (0.5, 0.5), (0.5, 0.5)) | test_op = c_vision.RandomResizedCropWithBBox((256, 512), (0.5, 0.5), (0.5, 0.5)) | ||||
| data_voc2 = ds.VOCDataset(DATA_DIR_VOC, task="Detection", mode="train", decode=True, shuffle=False) | |||||
| data_voc2 = ds.VOCDataset(DATA_DIR_VOC, task="Detection", usage="train", shuffle=False, decode=True) | |||||
| check_bad_bbox(data_voc2, test_op, InvalidBBoxType.WidthOverflow, "bounding boxes is out of bounds of the image") | check_bad_bbox(data_voc2, test_op, InvalidBBoxType.WidthOverflow, "bounding boxes is out of bounds of the image") | ||||
| data_voc2 = ds.VOCDataset(DATA_DIR_VOC, task="Detection", mode="train", decode=True, shuffle=False) | |||||
| data_voc2 = ds.VOCDataset(DATA_DIR_VOC, task="Detection", usage="train", shuffle=False, decode=True) | |||||
| check_bad_bbox(data_voc2, test_op, InvalidBBoxType.HeightOverflow, "bounding boxes is out of bounds of the image") | check_bad_bbox(data_voc2, test_op, InvalidBBoxType.HeightOverflow, "bounding boxes is out of bounds of the image") | ||||
| data_voc2 = ds.VOCDataset(DATA_DIR_VOC, task="Detection", mode="train", decode=True, shuffle=False) | |||||
| data_voc2 = ds.VOCDataset(DATA_DIR_VOC, task="Detection", usage="train", shuffle=False, decode=True) | |||||
| check_bad_bbox(data_voc2, test_op, InvalidBBoxType.NegativeXY, "min_x") | check_bad_bbox(data_voc2, test_op, InvalidBBoxType.NegativeXY, "min_x") | ||||
| data_voc2 = ds.VOCDataset(DATA_DIR_VOC, task="Detection", mode="train", decode=True, shuffle=False) | |||||
| data_voc2 = ds.VOCDataset(DATA_DIR_VOC, task="Detection", usage="train", shuffle=False, decode=True) | |||||
| check_bad_bbox(data_voc2, test_op, InvalidBBoxType.WrongShape, "4 features") | check_bad_bbox(data_voc2, test_op, InvalidBBoxType.WrongShape, "4 features") | ||||
| @@ -39,8 +39,8 @@ def test_random_crop_with_bbox_op_c(plot_vis=False): | |||||
| logger.info("test_random_crop_with_bbox_op_c") | logger.info("test_random_crop_with_bbox_op_c") | ||||
| # Load dataset | # Load dataset | ||||
| dataVoc1 = ds.VOCDataset(DATA_DIR_VOC, task="Detection", mode="train", decode=True, shuffle=False) | |||||
| dataVoc2 = ds.VOCDataset(DATA_DIR_VOC, task="Detection", mode="train", decode=True, shuffle=False) | |||||
| dataVoc1 = ds.VOCDataset(DATA_DIR_VOC, task="Detection", usage="train", shuffle=False, decode=True) | |||||
| dataVoc2 = ds.VOCDataset(DATA_DIR_VOC, task="Detection", usage="train", shuffle=False, decode=True) | |||||
| # define test OP with values to match existing Op UT | # define test OP with values to match existing Op UT | ||||
| test_op = c_vision.RandomCropWithBBox([512, 512], [200, 200, 200, 200]) | test_op = c_vision.RandomCropWithBBox([512, 512], [200, 200, 200, 200]) | ||||
| @@ -101,8 +101,8 @@ def test_random_crop_with_bbox_op2_c(plot_vis=False): | |||||
| original_num_parallel_workers = config_get_set_num_parallel_workers(1) | original_num_parallel_workers = config_get_set_num_parallel_workers(1) | ||||
| # Load dataset | # Load dataset | ||||
| dataVoc1 = ds.VOCDataset(DATA_DIR_VOC, task="Detection", mode="train", decode=True, shuffle=False) | |||||
| dataVoc2 = ds.VOCDataset(DATA_DIR_VOC, task="Detection", mode="train", decode=True, shuffle=False) | |||||
| dataVoc1 = ds.VOCDataset(DATA_DIR_VOC, task="Detection", usage="train", shuffle=False, decode=True) | |||||
| dataVoc2 = ds.VOCDataset(DATA_DIR_VOC, task="Detection", usage="train", shuffle=False, decode=True) | |||||
| # define test OP with values to match existing Op unit - test | # define test OP with values to match existing Op unit - test | ||||
| test_op = c_vision.RandomCropWithBBox(512, [200, 200, 200, 200], fill_value=(255, 255, 255)) | test_op = c_vision.RandomCropWithBBox(512, [200, 200, 200, 200], fill_value=(255, 255, 255)) | ||||
| @@ -138,8 +138,8 @@ def test_random_crop_with_bbox_op3_c(plot_vis=False): | |||||
| logger.info("test_random_crop_with_bbox_op3_c") | logger.info("test_random_crop_with_bbox_op3_c") | ||||
| # Load dataset | # Load dataset | ||||
| dataVoc1 = ds.VOCDataset(DATA_DIR_VOC, task="Detection", mode="train", decode=True, shuffle=False) | |||||
| dataVoc2 = ds.VOCDataset(DATA_DIR_VOC, task="Detection", mode="train", decode=True, shuffle=False) | |||||
| dataVoc1 = ds.VOCDataset(DATA_DIR_VOC, task="Detection", usage="train", shuffle=False, decode=True) | |||||
| dataVoc2 = ds.VOCDataset(DATA_DIR_VOC, task="Detection", usage="train", shuffle=False, decode=True) | |||||
| # define test OP with values to match existing Op unit - test | # define test OP with values to match existing Op unit - test | ||||
| test_op = c_vision.RandomCropWithBBox(512, [200, 200, 200, 200], padding_mode=mode.Border.EDGE) | test_op = c_vision.RandomCropWithBBox(512, [200, 200, 200, 200], padding_mode=mode.Border.EDGE) | ||||
| @@ -168,8 +168,8 @@ def test_random_crop_with_bbox_op_edge_c(plot_vis=False): | |||||
| logger.info("test_random_crop_with_bbox_op_edge_c") | logger.info("test_random_crop_with_bbox_op_edge_c") | ||||
| # Load dataset | # Load dataset | ||||
| dataVoc1 = ds.VOCDataset(DATA_DIR_VOC, task="Detection", mode="train", decode=True, shuffle=False) | |||||
| dataVoc2 = ds.VOCDataset(DATA_DIR_VOC, task="Detection", mode="train", decode=True, shuffle=False) | |||||
| dataVoc1 = ds.VOCDataset(DATA_DIR_VOC, task="Detection", usage="train", shuffle=False, decode=True) | |||||
| dataVoc2 = ds.VOCDataset(DATA_DIR_VOC, task="Detection", usage="train", shuffle=False, decode=True) | |||||
| # define test OP with values to match existing Op unit - test | # define test OP with values to match existing Op unit - test | ||||
| test_op = c_vision.RandomCropWithBBox(512, [200, 200, 200, 200], padding_mode=mode.Border.EDGE) | test_op = c_vision.RandomCropWithBBox(512, [200, 200, 200, 200], padding_mode=mode.Border.EDGE) | ||||
| @@ -205,7 +205,7 @@ def test_random_crop_with_bbox_op_invalid_c(): | |||||
| logger.info("test_random_crop_with_bbox_op_invalid_c") | logger.info("test_random_crop_with_bbox_op_invalid_c") | ||||
| # Load dataset | # Load dataset | ||||
| dataVoc2 = ds.VOCDataset(DATA_DIR_VOC, task="Detection", mode="train", decode=True, shuffle=False) | |||||
| dataVoc2 = ds.VOCDataset(DATA_DIR_VOC, task="Detection", usage="train", shuffle=False, decode=True) | |||||
| try: | try: | ||||
| # define test OP with values to match existing Op unit - test | # define test OP with values to match existing Op unit - test | ||||
| @@ -231,13 +231,13 @@ def test_random_crop_with_bbox_op_bad_c(): | |||||
| logger.info("test_random_crop_with_bbox_op_bad_c") | logger.info("test_random_crop_with_bbox_op_bad_c") | ||||
| test_op = c_vision.RandomCropWithBBox([512, 512], [200, 200, 200, 200]) | test_op = c_vision.RandomCropWithBBox([512, 512], [200, 200, 200, 200]) | ||||
| data_voc2 = ds.VOCDataset(DATA_DIR_VOC, task="Detection", mode="train", decode=True, shuffle=False) | |||||
| data_voc2 = ds.VOCDataset(DATA_DIR_VOC, task="Detection", usage="train", shuffle=False, decode=True) | |||||
| check_bad_bbox(data_voc2, test_op, InvalidBBoxType.WidthOverflow, "bounding boxes is out of bounds of the image") | check_bad_bbox(data_voc2, test_op, InvalidBBoxType.WidthOverflow, "bounding boxes is out of bounds of the image") | ||||
| data_voc2 = ds.VOCDataset(DATA_DIR_VOC, task="Detection", mode="train", decode=True, shuffle=False) | |||||
| data_voc2 = ds.VOCDataset(DATA_DIR_VOC, task="Detection", usage="train", shuffle=False, decode=True) | |||||
| check_bad_bbox(data_voc2, test_op, InvalidBBoxType.HeightOverflow, "bounding boxes is out of bounds of the image") | check_bad_bbox(data_voc2, test_op, InvalidBBoxType.HeightOverflow, "bounding boxes is out of bounds of the image") | ||||
| data_voc2 = ds.VOCDataset(DATA_DIR_VOC, task="Detection", mode="train", decode=True, shuffle=False) | |||||
| data_voc2 = ds.VOCDataset(DATA_DIR_VOC, task="Detection", usage="train", shuffle=False, decode=True) | |||||
| check_bad_bbox(data_voc2, test_op, InvalidBBoxType.NegativeXY, "min_x") | check_bad_bbox(data_voc2, test_op, InvalidBBoxType.NegativeXY, "min_x") | ||||
| data_voc2 = ds.VOCDataset(DATA_DIR_VOC, task="Detection", mode="train", decode=True, shuffle=False) | |||||
| data_voc2 = ds.VOCDataset(DATA_DIR_VOC, task="Detection", usage="train", shuffle=False, decode=True) | |||||
| check_bad_bbox(data_voc2, test_op, InvalidBBoxType.WrongShape, "4 features") | check_bad_bbox(data_voc2, test_op, InvalidBBoxType.WrongShape, "4 features") | ||||
| @@ -247,7 +247,7 @@ def test_random_crop_with_bbox_op_bad_padding(): | |||||
| """ | """ | ||||
| logger.info("test_random_crop_with_bbox_op_invalid_c") | logger.info("test_random_crop_with_bbox_op_invalid_c") | ||||
| dataVoc2 = ds.VOCDataset(DATA_DIR_VOC, task="Detection", mode="train", decode=True, shuffle=False) | |||||
| dataVoc2 = ds.VOCDataset(DATA_DIR_VOC, task="Detection", usage="train", shuffle=False, decode=True) | |||||
| try: | try: | ||||
| test_op = c_vision.RandomCropWithBBox([512, 512], padding=-1) | test_op = c_vision.RandomCropWithBBox([512, 512], padding=-1) | ||||
| @@ -37,11 +37,9 @@ def test_random_horizontal_flip_with_bbox_op_c(plot_vis=False): | |||||
| logger.info("test_random_horizontal_flip_with_bbox_op_c") | logger.info("test_random_horizontal_flip_with_bbox_op_c") | ||||
| # Load dataset | # Load dataset | ||||
| dataVoc1 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", | |||||
| decode=True, shuffle=False) | |||||
| dataVoc1 = ds.VOCDataset(DATA_DIR, task="Detection", usage="train", shuffle=False, decode=True) | |||||
| dataVoc2 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", | |||||
| decode=True, shuffle=False) | |||||
| dataVoc2 = ds.VOCDataset(DATA_DIR, task="Detection", usage="train", shuffle=False, decode=True) | |||||
| test_op = c_vision.RandomHorizontalFlipWithBBox(1) | test_op = c_vision.RandomHorizontalFlipWithBBox(1) | ||||
| @@ -102,11 +100,9 @@ def test_random_horizontal_flip_with_bbox_valid_rand_c(plot_vis=False): | |||||
| original_num_parallel_workers = config_get_set_num_parallel_workers(1) | original_num_parallel_workers = config_get_set_num_parallel_workers(1) | ||||
| # Load dataset | # Load dataset | ||||
| dataVoc1 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", | |||||
| decode=True, shuffle=False) | |||||
| dataVoc1 = ds.VOCDataset(DATA_DIR, task="Detection", usage="train", shuffle=False, decode=True) | |||||
| dataVoc2 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", | |||||
| decode=True, shuffle=False) | |||||
| dataVoc2 = ds.VOCDataset(DATA_DIR, task="Detection", usage="train", shuffle=False, decode=True) | |||||
| test_op = c_vision.RandomHorizontalFlipWithBBox(0.6) | test_op = c_vision.RandomHorizontalFlipWithBBox(0.6) | ||||
| @@ -140,8 +136,8 @@ def test_random_horizontal_flip_with_bbox_valid_edge_c(plot_vis=False): | |||||
| """ | """ | ||||
| logger.info("test_horizontal_flip_with_bbox_valid_edge_c") | logger.info("test_horizontal_flip_with_bbox_valid_edge_c") | ||||
| dataVoc1 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", decode=True, shuffle=False) | |||||
| dataVoc2 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", decode=True, shuffle=False) | |||||
| dataVoc1 = ds.VOCDataset(DATA_DIR, task="Detection", usage="train", shuffle=False, decode=True) | |||||
| dataVoc2 = ds.VOCDataset(DATA_DIR, task="Detection", usage="train", shuffle=False, decode=True) | |||||
| test_op = c_vision.RandomHorizontalFlipWithBBox(1) | test_op = c_vision.RandomHorizontalFlipWithBBox(1) | ||||
| @@ -178,7 +174,7 @@ def test_random_horizontal_flip_with_bbox_invalid_prob_c(): | |||||
| """ | """ | ||||
| logger.info("test_random_horizontal_bbox_invalid_prob_c") | logger.info("test_random_horizontal_bbox_invalid_prob_c") | ||||
| dataVoc2 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", decode=True, shuffle=False) | |||||
| dataVoc2 = ds.VOCDataset(DATA_DIR, task="Detection", usage="train", shuffle=False, decode=True) | |||||
| try: | try: | ||||
| # Note: Valid range of prob should be [0.0, 1.0] | # Note: Valid range of prob should be [0.0, 1.0] | ||||
| @@ -201,13 +197,13 @@ def test_random_horizontal_flip_with_bbox_invalid_bounds_c(): | |||||
| test_op = c_vision.RandomHorizontalFlipWithBBox(1) | test_op = c_vision.RandomHorizontalFlipWithBBox(1) | ||||
| dataVoc2 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", decode=True, shuffle=False) | |||||
| dataVoc2 = ds.VOCDataset(DATA_DIR, task="Detection", usage="train", shuffle=False, decode=True) | |||||
| check_bad_bbox(dataVoc2, test_op, InvalidBBoxType.WidthOverflow, "bounding boxes is out of bounds of the image") | check_bad_bbox(dataVoc2, test_op, InvalidBBoxType.WidthOverflow, "bounding boxes is out of bounds of the image") | ||||
| dataVoc2 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", decode=True, shuffle=False) | |||||
| dataVoc2 = ds.VOCDataset(DATA_DIR, task="Detection", usage="train", shuffle=False, decode=True) | |||||
| check_bad_bbox(dataVoc2, test_op, InvalidBBoxType.HeightOverflow, "bounding boxes is out of bounds of the image") | check_bad_bbox(dataVoc2, test_op, InvalidBBoxType.HeightOverflow, "bounding boxes is out of bounds of the image") | ||||
| dataVoc2 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", decode=True, shuffle=False) | |||||
| dataVoc2 = ds.VOCDataset(DATA_DIR, task="Detection", usage="train", shuffle=False, decode=True) | |||||
| check_bad_bbox(dataVoc2, test_op, InvalidBBoxType.NegativeXY, "min_x") | check_bad_bbox(dataVoc2, test_op, InvalidBBoxType.NegativeXY, "min_x") | ||||
| dataVoc2 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", decode=True, shuffle=False) | |||||
| dataVoc2 = ds.VOCDataset(DATA_DIR, task="Detection", usage="train", shuffle=False, decode=True) | |||||
| check_bad_bbox(dataVoc2, test_op, InvalidBBoxType.WrongShape, "4 features") | check_bad_bbox(dataVoc2, test_op, InvalidBBoxType.WrongShape, "4 features") | ||||
| @@ -39,11 +39,9 @@ def test_random_resize_with_bbox_op_voc_c(plot_vis=False): | |||||
| original_seed = config_get_set_seed(123) | original_seed = config_get_set_seed(123) | ||||
| original_num_parallel_workers = config_get_set_num_parallel_workers(1) | original_num_parallel_workers = config_get_set_num_parallel_workers(1) | ||||
| # Load dataset | # Load dataset | ||||
| dataVoc1 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", | |||||
| decode=True, shuffle=False) | |||||
| dataVoc1 = ds.VOCDataset(DATA_DIR, task="Detection", usage="train", shuffle=False, decode=True) | |||||
| dataVoc2 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", | |||||
| decode=True, shuffle=False) | |||||
| dataVoc2 = ds.VOCDataset(DATA_DIR, task="Detection", usage="train", shuffle=False, decode=True) | |||||
| test_op = c_vision.RandomResizeWithBBox(100) | test_op = c_vision.RandomResizeWithBBox(100) | ||||
| @@ -120,11 +118,9 @@ def test_random_resize_with_bbox_op_edge_c(plot_vis=False): | |||||
| box has dimensions as the image itself. | box has dimensions as the image itself. | ||||
| """ | """ | ||||
| logger.info("test_random_resize_with_bbox_op_edge_c") | logger.info("test_random_resize_with_bbox_op_edge_c") | ||||
| dataVoc1 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", | |||||
| decode=True, shuffle=False) | |||||
| dataVoc1 = ds.VOCDataset(DATA_DIR, task="Detection", usage="train", shuffle=False, decode=True) | |||||
| dataVoc2 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", | |||||
| decode=True, shuffle=False) | |||||
| dataVoc2 = ds.VOCDataset(DATA_DIR, task="Detection", usage="train", shuffle=False, decode=True) | |||||
| test_op = c_vision.RandomResizeWithBBox(500) | test_op = c_vision.RandomResizeWithBBox(500) | ||||
| @@ -197,13 +193,13 @@ def test_random_resize_with_bbox_op_bad_c(): | |||||
| logger.info("test_random_resize_with_bbox_op_bad_c") | logger.info("test_random_resize_with_bbox_op_bad_c") | ||||
| test_op = c_vision.RandomResizeWithBBox((400, 300)) | test_op = c_vision.RandomResizeWithBBox((400, 300)) | ||||
| data_voc2 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", decode=True, shuffle=False) | |||||
| data_voc2 = ds.VOCDataset(DATA_DIR, task="Detection", usage="train", shuffle=False, decode=True) | |||||
| check_bad_bbox(data_voc2, test_op, InvalidBBoxType.WidthOverflow, "bounding boxes is out of bounds of the image") | check_bad_bbox(data_voc2, test_op, InvalidBBoxType.WidthOverflow, "bounding boxes is out of bounds of the image") | ||||
| data_voc2 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", decode=True, shuffle=False) | |||||
| data_voc2 = ds.VOCDataset(DATA_DIR, task="Detection", usage="train", shuffle=False, decode=True) | |||||
| check_bad_bbox(data_voc2, test_op, InvalidBBoxType.HeightOverflow, "bounding boxes is out of bounds of the image") | check_bad_bbox(data_voc2, test_op, InvalidBBoxType.HeightOverflow, "bounding boxes is out of bounds of the image") | ||||
| data_voc2 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", decode=True, shuffle=False) | |||||
| data_voc2 = ds.VOCDataset(DATA_DIR, task="Detection", usage="train", shuffle=False, decode=True) | |||||
| check_bad_bbox(data_voc2, test_op, InvalidBBoxType.NegativeXY, "min_x") | check_bad_bbox(data_voc2, test_op, InvalidBBoxType.NegativeXY, "min_x") | ||||
| data_voc2 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", decode=True, shuffle=False) | |||||
| data_voc2 = ds.VOCDataset(DATA_DIR, task="Detection", usage="train", shuffle=False, decode=True) | |||||
| check_bad_bbox(data_voc2, test_op, InvalidBBoxType.WrongShape, "4 features") | check_bad_bbox(data_voc2, test_op, InvalidBBoxType.WrongShape, "4 features") | ||||
| @@ -37,11 +37,9 @@ def test_random_vertical_flip_with_bbox_op_c(plot_vis=False): | |||||
| """ | """ | ||||
| logger.info("test_random_vertical_flip_with_bbox_op_c") | logger.info("test_random_vertical_flip_with_bbox_op_c") | ||||
| # Load dataset | # Load dataset | ||||
| dataVoc1 = ds.VOCDataset(DATA_DIR_VOC, task="Detection", mode="train", | |||||
| decode=True, shuffle=False) | |||||
| dataVoc1 = ds.VOCDataset(DATA_DIR_VOC, task="Detection", usage="train", shuffle=False, decode=True) | |||||
| dataVoc2 = ds.VOCDataset(DATA_DIR_VOC, task="Detection", mode="train", | |||||
| decode=True, shuffle=False) | |||||
| dataVoc2 = ds.VOCDataset(DATA_DIR_VOC, task="Detection", usage="train", shuffle=False, decode=True) | |||||
| test_op = c_vision.RandomVerticalFlipWithBBox(1) | test_op = c_vision.RandomVerticalFlipWithBBox(1) | ||||
| @@ -102,11 +100,9 @@ def test_random_vertical_flip_with_bbox_op_rand_c(plot_vis=False): | |||||
| original_num_parallel_workers = config_get_set_num_parallel_workers(1) | original_num_parallel_workers = config_get_set_num_parallel_workers(1) | ||||
| # Load dataset | # Load dataset | ||||
| dataVoc1 = ds.VOCDataset(DATA_DIR_VOC, task="Detection", mode="train", | |||||
| decode=True, shuffle=False) | |||||
| dataVoc1 = ds.VOCDataset(DATA_DIR_VOC, task="Detection", usage="train", shuffle=False, decode=True) | |||||
| dataVoc2 = ds.VOCDataset(DATA_DIR_VOC, task="Detection", mode="train", | |||||
| decode=True, shuffle=False) | |||||
| dataVoc2 = ds.VOCDataset(DATA_DIR_VOC, task="Detection", usage="train", shuffle=False, decode=True) | |||||
| test_op = c_vision.RandomVerticalFlipWithBBox(0.8) | test_op = c_vision.RandomVerticalFlipWithBBox(0.8) | ||||
| @@ -139,11 +135,9 @@ def test_random_vertical_flip_with_bbox_op_edge_c(plot_vis=False): | |||||
| applied on dynamically generated edge case, expected to pass | applied on dynamically generated edge case, expected to pass | ||||
| """ | """ | ||||
| logger.info("test_random_vertical_flip_with_bbox_op_edge_c") | logger.info("test_random_vertical_flip_with_bbox_op_edge_c") | ||||
| dataVoc1 = ds.VOCDataset(DATA_DIR_VOC, task="Detection", mode="train", | |||||
| decode=True, shuffle=False) | |||||
| dataVoc1 = ds.VOCDataset(DATA_DIR_VOC, task="Detection", usage="train", shuffle=False, decode=True) | |||||
| dataVoc2 = ds.VOCDataset(DATA_DIR_VOC, task="Detection", mode="train", | |||||
| decode=True, shuffle=False) | |||||
| dataVoc2 = ds.VOCDataset(DATA_DIR_VOC, task="Detection", usage="train", shuffle=False, decode=True) | |||||
| test_op = c_vision.RandomVerticalFlipWithBBox(1) | test_op = c_vision.RandomVerticalFlipWithBBox(1) | ||||
| @@ -174,8 +168,7 @@ def test_random_vertical_flip_with_bbox_op_invalid_c(): | |||||
| Test RandomVerticalFlipWithBBox Op on invalid constructor parameters, expected to raise ValueError | Test RandomVerticalFlipWithBBox Op on invalid constructor parameters, expected to raise ValueError | ||||
| """ | """ | ||||
| logger.info("test_random_vertical_flip_with_bbox_op_invalid_c") | logger.info("test_random_vertical_flip_with_bbox_op_invalid_c") | ||||
| dataVoc2 = ds.VOCDataset(DATA_DIR_VOC, task="Detection", mode="train", | |||||
| decode=True, shuffle=False) | |||||
| dataVoc2 = ds.VOCDataset(DATA_DIR_VOC, task="Detection", usage="train", shuffle=False, decode=True) | |||||
| try: | try: | ||||
| test_op = c_vision.RandomVerticalFlipWithBBox(2) | test_op = c_vision.RandomVerticalFlipWithBBox(2) | ||||
| @@ -201,13 +194,13 @@ def test_random_vertical_flip_with_bbox_op_bad_c(): | |||||
| logger.info("test_random_vertical_flip_with_bbox_op_bad_c") | logger.info("test_random_vertical_flip_with_bbox_op_bad_c") | ||||
| test_op = c_vision.RandomVerticalFlipWithBBox(1) | test_op = c_vision.RandomVerticalFlipWithBBox(1) | ||||
| data_voc2 = ds.VOCDataset(DATA_DIR_VOC, task="Detection", mode="train", decode=True, shuffle=False) | |||||
| data_voc2 = ds.VOCDataset(DATA_DIR_VOC, task="Detection", usage="train", shuffle=False, decode=True) | |||||
| check_bad_bbox(data_voc2, test_op, InvalidBBoxType.WidthOverflow, "bounding boxes is out of bounds of the image") | check_bad_bbox(data_voc2, test_op, InvalidBBoxType.WidthOverflow, "bounding boxes is out of bounds of the image") | ||||
| data_voc2 = ds.VOCDataset(DATA_DIR_VOC, task="Detection", mode="train", decode=True, shuffle=False) | |||||
| data_voc2 = ds.VOCDataset(DATA_DIR_VOC, task="Detection", usage="train", shuffle=False, decode=True) | |||||
| check_bad_bbox(data_voc2, test_op, InvalidBBoxType.HeightOverflow, "bounding boxes is out of bounds of the image") | check_bad_bbox(data_voc2, test_op, InvalidBBoxType.HeightOverflow, "bounding boxes is out of bounds of the image") | ||||
| data_voc2 = ds.VOCDataset(DATA_DIR_VOC, task="Detection", mode="train", decode=True, shuffle=False) | |||||
| data_voc2 = ds.VOCDataset(DATA_DIR_VOC, task="Detection", usage="train", shuffle=False, decode=True) | |||||
| check_bad_bbox(data_voc2, test_op, InvalidBBoxType.NegativeXY, "min_x") | check_bad_bbox(data_voc2, test_op, InvalidBBoxType.NegativeXY, "min_x") | ||||
| data_voc2 = ds.VOCDataset(DATA_DIR_VOC, task="Detection", mode="train", decode=True, shuffle=False) | |||||
| data_voc2 = ds.VOCDataset(DATA_DIR_VOC, task="Detection", usage="train", shuffle=False, decode=True) | |||||
| check_bad_bbox(data_voc2, test_op, InvalidBBoxType.WrongShape, "4 features") | check_bad_bbox(data_voc2, test_op, InvalidBBoxType.WrongShape, "4 features") | ||||
| @@ -39,11 +39,9 @@ def test_resize_with_bbox_op_voc_c(plot_vis=False): | |||||
| logger.info("test_resize_with_bbox_op_voc_c") | logger.info("test_resize_with_bbox_op_voc_c") | ||||
| # Load dataset | # Load dataset | ||||
| dataVoc1 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", | |||||
| decode=True, shuffle=False) | |||||
| dataVoc1 = ds.VOCDataset(DATA_DIR, task="Detection", usage="train", shuffle=False, decode=True) | |||||
| dataVoc2 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", | |||||
| decode=True, shuffle=False) | |||||
| dataVoc2 = ds.VOCDataset(DATA_DIR, task="Detection", usage="train", shuffle=False, decode=True) | |||||
| test_op = c_vision.ResizeWithBBox(100) | test_op = c_vision.ResizeWithBBox(100) | ||||
| @@ -110,11 +108,9 @@ def test_resize_with_bbox_op_edge_c(plot_vis=False): | |||||
| box has dimensions as the image itself. | box has dimensions as the image itself. | ||||
| """ | """ | ||||
| logger.info("test_resize_with_bbox_op_edge_c") | logger.info("test_resize_with_bbox_op_edge_c") | ||||
| dataVoc1 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", | |||||
| decode=True, shuffle=False) | |||||
| dataVoc1 = ds.VOCDataset(DATA_DIR, task="Detection", usage="train", shuffle=False, decode=True) | |||||
| dataVoc2 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", | |||||
| decode=True, shuffle=False) | |||||
| dataVoc2 = ds.VOCDataset(DATA_DIR, task="Detection", usage="train", shuffle=False, decode=True) | |||||
| test_op = c_vision.ResizeWithBBox(500) | test_op = c_vision.ResizeWithBBox(500) | ||||
| @@ -163,13 +159,13 @@ def test_resize_with_bbox_op_bad_c(): | |||||
| logger.info("test_resize_with_bbox_op_bad_c") | logger.info("test_resize_with_bbox_op_bad_c") | ||||
| test_op = c_vision.ResizeWithBBox((200, 300)) | test_op = c_vision.ResizeWithBBox((200, 300)) | ||||
| data_voc2 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", decode=True, shuffle=False) | |||||
| data_voc2 = ds.VOCDataset(DATA_DIR, task="Detection", usage="train", shuffle=False, decode=True) | |||||
| check_bad_bbox(data_voc2, test_op, InvalidBBoxType.WidthOverflow, "bounding boxes is out of bounds of the image") | check_bad_bbox(data_voc2, test_op, InvalidBBoxType.WidthOverflow, "bounding boxes is out of bounds of the image") | ||||
| data_voc2 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", decode=True, shuffle=False) | |||||
| data_voc2 = ds.VOCDataset(DATA_DIR, task="Detection", usage="train", shuffle=False, decode=True) | |||||
| check_bad_bbox(data_voc2, test_op, InvalidBBoxType.HeightOverflow, "bounding boxes is out of bounds of the image") | check_bad_bbox(data_voc2, test_op, InvalidBBoxType.HeightOverflow, "bounding boxes is out of bounds of the image") | ||||
| data_voc2 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", decode=True, shuffle=False) | |||||
| data_voc2 = ds.VOCDataset(DATA_DIR, task="Detection", usage="train", shuffle=False, decode=True) | |||||
| check_bad_bbox(data_voc2, test_op, InvalidBBoxType.NegativeXY, "min_x") | check_bad_bbox(data_voc2, test_op, InvalidBBoxType.NegativeXY, "min_x") | ||||
| data_voc2 = ds.VOCDataset(DATA_DIR, task="Detection", mode="train", decode=True, shuffle=False) | |||||
| data_voc2 = ds.VOCDataset(DATA_DIR, task="Detection", usage="train", shuffle=False, decode=True) | |||||
| check_bad_bbox(data_voc2, test_op, InvalidBBoxType.WrongShape, "4 features") | check_bad_bbox(data_voc2, test_op, InvalidBBoxType.WrongShape, "4 features") | ||||
| @@ -32,6 +32,7 @@ from mindspore.dataset.vision import Inter | |||||
| def test_imagefolder(remove_json_files=True): | def test_imagefolder(remove_json_files=True): | ||||
| """ | """ | ||||
| Test simulating resnet50 dataset pipeline. | Test simulating resnet50 dataset pipeline. | ||||
| @@ -103,7 +104,7 @@ def test_mnist_dataset(remove_json_files=True): | |||||
| data_dir = "../data/dataset/testMnistData" | data_dir = "../data/dataset/testMnistData" | ||||
| ds.config.set_seed(1) | ds.config.set_seed(1) | ||||
| data1 = ds.MnistDataset(data_dir, 100) | |||||
| data1 = ds.MnistDataset(data_dir, num_samples=100) | |||||
| one_hot_encode = c.OneHot(10) # num_classes is input argument | one_hot_encode = c.OneHot(10) # num_classes is input argument | ||||
| data1 = data1.map(input_columns="label", operations=one_hot_encode) | data1 = data1.map(input_columns="label", operations=one_hot_encode) | ||||