| @@ -110,6 +110,7 @@ | |||
| #include "minddata/dataset/engine/ir/datasetops/source/mnist_node.h" | |||
| #ifndef ENABLE_ANDROID | |||
| #include "minddata/dataset/engine/ir/datasetops/source/multi30k_node.h" | |||
| #include "minddata/dataset/engine/ir/datasetops/source/omniglot_node.h" | |||
| #include "minddata/dataset/engine/ir/datasetops/source/penn_treebank_node.h" | |||
| #include "minddata/dataset/engine/ir/datasetops/source/photo_tour_node.h" | |||
| #include "minddata/dataset/engine/ir/datasetops/source/places365_node.h" | |||
| @@ -1675,6 +1676,31 @@ Multi30kDataset::Multi30kDataset(const std::vector<char> &dataset_dir, const std | |||
| ir_node_ = std::static_pointer_cast<Multi30kNode>(ds); | |||
| } | |||
| OmniglotDataset::OmniglotDataset(const std::vector<char> &dataset_dir, bool background, bool decode, | |||
| const std::shared_ptr<Sampler> &sampler, const std::shared_ptr<DatasetCache> &cache) { | |||
| // Create logical representation of OmniglotDataset. | |||
| auto sampler_obj = sampler ? sampler->Parse() : nullptr; | |||
| auto ds = std::make_shared<OmniglotNode>(CharToString(dataset_dir), background, decode, sampler_obj, cache); | |||
| ir_node_ = std::static_pointer_cast<DatasetNode>(ds); | |||
| } | |||
| OmniglotDataset::OmniglotDataset(const std::vector<char> &dataset_dir, bool background, bool decode, | |||
| const Sampler *sampler, const std::shared_ptr<DatasetCache> &cache) { | |||
| // Create logical representation of OmniglotDataset. | |||
| auto sampler_obj = sampler ? sampler->Parse() : nullptr; | |||
| auto ds = std::make_shared<OmniglotNode>(CharToString(dataset_dir), background, decode, sampler_obj, cache); | |||
| ir_node_ = std::static_pointer_cast<DatasetNode>(ds); | |||
| } | |||
| OmniglotDataset::OmniglotDataset(const std::vector<char> &dataset_dir, bool background, bool decode, | |||
| const std::reference_wrapper<Sampler> &sampler, | |||
| const std::shared_ptr<DatasetCache> &cache) { | |||
| // Create logical representation of OmniglotDataset. | |||
| auto sampler_obj = sampler.get().Parse(); | |||
| auto ds = std::make_shared<OmniglotNode>(CharToString(dataset_dir), background, decode, sampler_obj, cache); | |||
| ir_node_ = std::static_pointer_cast<DatasetNode>(ds); | |||
| } | |||
| PennTreebankDataset::PennTreebankDataset(const std::vector<char> &dataset_dir, const std::vector<char> &usage, | |||
| int64_t num_samples, ShuffleMode shuffle, int32_t num_shards, int32_t shard_id, | |||
| const std::shared_ptr<DatasetCache> &cache) { | |||
| @@ -75,6 +75,7 @@ | |||
| #include "minddata/dataset/engine/ir/datasetops/source/manifest_node.h" | |||
| #include "minddata/dataset/engine/ir/datasetops/source/minddata_node.h" | |||
| #include "minddata/dataset/engine/ir/datasetops/source/multi30k_node.h" | |||
| #include "minddata/dataset/engine/ir/datasetops/source/omniglot_node.h" | |||
| #include "minddata/dataset/engine/ir/datasetops/source/photo_tour_node.h" | |||
| #include "minddata/dataset/engine/ir/datasetops/source/places365_node.h" | |||
| #include "minddata/dataset/engine/ir/datasetops/source/qmnist_node.h" | |||
| @@ -544,6 +545,18 @@ PYBIND_REGISTER(Multi30kNode, 2, ([](const py::module *m) { | |||
| })); | |||
| })); | |||
| PYBIND_REGISTER(OmniglotNode, 2, ([](const py::module *m) { | |||
| (void)py::class_<OmniglotNode, DatasetNode, std::shared_ptr<OmniglotNode>>( | |||
| *m, "OmniglotNode", "to create an OmniglotNode") | |||
| .def(py::init([](const std::string &dataset_dir, bool background, bool decode, | |||
| const py::handle &sampler) { | |||
| auto omniglot = | |||
| std::make_shared<OmniglotNode>(dataset_dir, background, decode, toSamplerObj(sampler), nullptr); | |||
| THROW_IF_ERROR(omniglot->ValidateParams()); | |||
| return omniglot; | |||
| })); | |||
| })); | |||
| PYBIND_REGISTER(PennTreebankNode, 2, ([](const py::module *m) { | |||
| (void)py::class_<PennTreebankNode, DatasetNode, std::shared_ptr<PennTreebankNode>>( | |||
| *m, "PennTreebankNode", "to create a PennTreebankNode") | |||
| @@ -37,6 +37,7 @@ set(DATASET_ENGINE_DATASETOPS_SOURCE_SRC_FILES | |||
| mnist_op.cc | |||
| multi30k_op.cc | |||
| nonmappable_leaf_op.cc | |||
| omniglot_op.cc | |||
| penn_treebank_op.cc | |||
| photo_tour_op.cc | |||
| places365_op.cc | |||
| @@ -0,0 +1,132 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "minddata/dataset/engine/datasetops/source/omniglot_op.h" | |||
| #include <fstream> | |||
| #include <unordered_set> | |||
| #include "minddata/dataset/core/config_manager.h" | |||
| #include "minddata/dataset/core/tensor_shape.h" | |||
| #include "minddata/dataset/engine/datasetops/source/sampler/sequential_sampler.h" | |||
| #include "minddata/dataset/engine/execution_tree.h" | |||
| #include "utils/ms_utils.h" | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| OmniglotOp::OmniglotOp(int32_t num_wkrs, const std::string &file_dir, int32_t queue_size, bool background, | |||
| bool do_decode, std::unique_ptr<DataSchema> data_schema, | |||
| const std::shared_ptr<SamplerRT> &sampler) | |||
| : ImageFolderOp(num_wkrs, file_dir, queue_size, false, do_decode, {}, {}, std::move(data_schema), | |||
| std::move(sampler)) { | |||
| Path dir(file_dir); | |||
| if (background) { | |||
| folder_path_ = (dir / "images_background").ToString(); | |||
| } else { | |||
| folder_path_ = (dir / "images_evaluation").ToString(); | |||
| } | |||
| } | |||
| void OmniglotOp::Print(std::ostream &out, bool show_all) const { | |||
| if (!show_all) { | |||
| // Call the super class for displaying any common 1-liner info. | |||
| ParallelOp::Print(out, show_all); | |||
| // Then show any custom derived-internal 1-liner info for this op. | |||
| out << "\n"; | |||
| } else { | |||
| // Call the super class for displaying any common detailed info. | |||
| ParallelOp::Print(out, show_all); | |||
| // Then show any custom derived-internal stuff. | |||
| out << "\nNumber of rows: " << num_rows_ << "\nOmniglot directory: " << folder_path_ | |||
| << "\nDecode: " << (decode_ ? "yes" : "no") << "\n\n"; | |||
| } | |||
| } | |||
| // This helper function walks all folder_paths, and send each foldername to folder_name_queue_. | |||
| Status OmniglotOp::RecursiveWalkFolder(Path *dir) { | |||
| RETURN_UNEXPECTED_IF_NULL(dir); | |||
| std::queue<std::string> folder_paths; | |||
| return WalkDir(dir, &folder_paths, folder_name_queue_.get(), dirname_offset_, false); | |||
| } | |||
| Status OmniglotOp::WalkDir(Path *dir, std::queue<std::string> *folder_paths, Queue<std::string> *folder_name_queue, | |||
| uint64_t dirname_offset, bool std_queue) { | |||
| RETURN_UNEXPECTED_IF_NULL(dir); | |||
| RETURN_UNEXPECTED_IF_NULL(folder_paths); | |||
| RETURN_UNEXPECTED_IF_NULL(folder_name_queue); | |||
| std::shared_ptr<Path::DirIterator> dir_itr = Path::DirIterator::OpenDirectory(dir); | |||
| CHECK_FAIL_RETURN_UNEXPECTED(dir_itr != nullptr, "Invalid path, failed to open omniglot image dir: " + | |||
| (*dir).ToString() + ", permission denied."); | |||
| while (dir_itr->HasNext()) { | |||
| Path subdir = dir_itr->Next(); | |||
| if (subdir.IsDirectory()) { | |||
| std::shared_ptr<Path::DirIterator> dir_itr_sec = Path::DirIterator::OpenDirectory(&subdir); | |||
| CHECK_FAIL_RETURN_UNEXPECTED(dir_itr_sec != nullptr, "Invalid path, failed to open omniglot image dir: " + | |||
| subdir.ToString() + ", permission denied."); | |||
| while (dir_itr_sec->HasNext()) { | |||
| Path subsubdir = dir_itr_sec->Next(); | |||
| if (subsubdir.IsDirectory()) { | |||
| if (std_queue) { | |||
| folder_paths->push(subsubdir.ToString()); | |||
| } else { | |||
| RETURN_IF_NOT_OK(folder_name_queue->EmplaceBack(subsubdir.ToString().substr(dirname_offset))); | |||
| } | |||
| } | |||
| } | |||
| } | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| Status OmniglotOp::CountRowsAndClasses(const std::string &path, int64_t *num_rows, int64_t *num_classes) { | |||
| Path dir(path); | |||
| CHECK_FAIL_RETURN_UNEXPECTED(dir.Exists() && dir.IsDirectory(), | |||
| "Invalid parameter, input path is invalid or not set, path: " + path); | |||
| CHECK_FAIL_RETURN_UNEXPECTED(num_classes != nullptr || num_rows != nullptr, | |||
| "[Internal ERROR] num_class and num_rows are null."); | |||
| int64_t row_cnt = 0; | |||
| std::queue<std::string> folder_paths; | |||
| Queue<std::string> tmp_queue = Queue<std::string>(1); | |||
| RETURN_IF_NOT_OK(WalkDir(&dir, &folder_paths, &tmp_queue, 0, true)); | |||
| if (num_classes != nullptr) { | |||
| *num_classes = folder_paths.size(); | |||
| } | |||
| RETURN_OK_IF_TRUE(num_rows == nullptr); | |||
| while (!folder_paths.empty()) { | |||
| Path subdir(folder_paths.front()); | |||
| auto dir_itr = Path::DirIterator::OpenDirectory(&subdir); | |||
| while (dir_itr->HasNext()) { | |||
| ++row_cnt; | |||
| } | |||
| folder_paths.pop(); | |||
| } | |||
| (*num_rows) = row_cnt; | |||
| return Status::OK(); | |||
| } | |||
| // Get number of classes | |||
| Status OmniglotOp::GetNumClasses(int64_t *num_classes) { | |||
| RETURN_UNEXPECTED_IF_NULL(num_classes); | |||
| if (num_classes_ > 0) { | |||
| *num_classes = num_classes_; | |||
| return Status::OK(); | |||
| } | |||
| RETURN_IF_NOT_OK(CountRowsAndClasses(folder_path_, nullptr, num_classes)); | |||
| num_classes_ = *num_classes; | |||
| return Status::OK(); | |||
| } | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,98 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_SOURCE_OMNIGLOT_OP_H_ | |||
| #define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_SOURCE_OMNIGLOT_OP_H_ | |||
| #include <memory> | |||
| #include <queue> | |||
| #include <string> | |||
| #include <utility> | |||
| #include "minddata/dataset/engine/datasetops/source/image_folder_op.h" | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| // Forward declares. | |||
| template <typename T> | |||
| class Queue; | |||
| using ImageLabelPair = std::shared_ptr<std::pair<std::string, int32_t>>; | |||
| using FolderImagesPair = std::shared_ptr<std::pair<std::string, std::queue<ImageLabelPair>>>; | |||
| class OmniglotOp : public ImageFolderOp { | |||
| public: | |||
| /// Constructor | |||
| /// @param num_wkrs - Num of workers reading images in parallel. | |||
| /// @param file_dir - Directory of ImageNetFolder. | |||
| /// @param queue_size - Connector queue size. | |||
| /// @param background - Use the background dataset or the evaluation dataset. | |||
| /// @param do_decode - Decode the images after reading. | |||
| /// @param data_schema - Schema of Omniglot dataset. | |||
| /// @param sampler - Sampler tells OmniglotOp what to read. | |||
| OmniglotOp(int32_t num_wkrs, const std::string &file_dir, int32_t queue_size, bool background, bool do_decode, | |||
| std::unique_ptr<DataSchema> data_schema, const std::shared_ptr<SamplerRT> &sampler); | |||
| /// Destructor. | |||
| ~OmniglotOp() = default; | |||
| /// A print method typically used for debugging. | |||
| /// @param out - The output stream to write output to. | |||
| /// @param show_all - A bool to control if you want to show all info or just a summary. | |||
| void Print(std::ostream &out, bool show_all) const override; | |||
| /// This is the common function to walk one directory. | |||
| /// @param dir - The directory path | |||
| /// @param folder_path - The queue in CountRowsAndClasses function. | |||
| /// @param folder_name_queue - The queue in base class. | |||
| /// @param dirname_offset - The offset of path of directory using in RecursiveWalkFolder function. | |||
| /// @param std_queue - A bool to use folder_path or the foler_name_queue. | |||
| /// @return Status - The error code returned. | |||
| static Status WalkDir(Path *dir, std::queue<std::string> *folder_paths, Queue<std::string> *folder_name_queue, | |||
| uint64_t dirname_offset, bool std_queue); | |||
| /// This function is a hack! It is to return the num_class and num_rows. The result | |||
| /// returned by this function may not be consistent with what omniglot_op is going to return | |||
| /// use this at your own risk! | |||
| /// @param path - The folder path | |||
| /// @param num_rows - The point to the number of rows | |||
| /// @param num_classes - The point to the number of classes | |||
| /// @return Status - the error code returned. | |||
| static Status CountRowsAndClasses(const std::string &path, int64_t *num_rows, int64_t *num_classes); | |||
| /// Op name getter | |||
| /// @return std::string - Name of the current Op. | |||
| std::string Name() const override { return "OmniglotOp"; } | |||
| /// DatasetName name getter | |||
| /// @param upper - A bool to control if you want to return uppercase or lowercase Op name. | |||
| /// @return std::string - DatasetName of the current Op | |||
| std::string DatasetName(bool upper = false) const { return upper ? "Omniglot" : "omniglot"; } | |||
| /// Base-class override for GetNumClasses. | |||
| /// @param num_classes - the number of classes. | |||
| /// @return Status - the error code returned. | |||
| Status GetNumClasses(int64_t *num_classes) override; | |||
| private: | |||
| // Walk the folder | |||
| /// @param dir - The folder path | |||
| /// @return Status - the error code returned. | |||
| Status RecursiveWalkFolder(Path *dir) override; | |||
| }; | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_SOURCE_OMNIGLOT_OP_H_ | |||
| @@ -112,6 +112,7 @@ constexpr char kManifestNode[] = "ManifestDataset"; | |||
| constexpr char kMindDataNode[] = "MindDataDataset"; | |||
| constexpr char kMnistNode[] = "MnistDataset"; | |||
| constexpr char kMulti30kNode[] = "Multi30kDataset"; | |||
| constexpr char kOmniglotNode[] = "OmniglotDataset"; | |||
| constexpr char kPennTreebankNode[] = "PennTreebankDataset"; | |||
| constexpr char kPhotoTourNode[] = "PhotoTourDataset"; | |||
| constexpr char kPlaces365Node[] = "Places365Dataset"; | |||
| @@ -38,6 +38,7 @@ set(DATASET_ENGINE_IR_DATASETOPS_SOURCE_SRC_FILES | |||
| minddata_node.cc | |||
| mnist_node.cc | |||
| multi30k_node.cc | |||
| omniglot_node.cc | |||
| penn_treebank_node.cc | |||
| photo_tour_node.cc | |||
| places365_node.cc | |||
| @@ -0,0 +1,130 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "minddata/dataset/engine/ir/datasetops/source/omniglot_node.h" | |||
| #include <map> | |||
| #include <memory> | |||
| #include <set> | |||
| #include <string> | |||
| #include <utility> | |||
| #include <vector> | |||
| #include "minddata/dataset/engine/datasetops/source/omniglot_op.h" | |||
| #include "minddata/dataset/util/status.h" | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| OmniglotNode::OmniglotNode(const std::string &dataset_dir, bool background, bool decode, | |||
| const std::shared_ptr<SamplerObj> &sampler, const std::shared_ptr<DatasetCache> &cache) | |||
| : MappableSourceNode(std::move(cache)), | |||
| dataset_dir_(dataset_dir), | |||
| background_(background), | |||
| decode_(decode), | |||
| sampler_(sampler) {} | |||
| std::shared_ptr<DatasetNode> OmniglotNode::Copy() { | |||
| std::shared_ptr<SamplerObj> sampler = (sampler_ == nullptr) ? nullptr : sampler_->SamplerCopy(); | |||
| auto node = std::make_shared<OmniglotNode>(dataset_dir_, background_, decode_, sampler, cache_); | |||
| return node; | |||
| } | |||
| void OmniglotNode::Print(std::ostream &out) const { | |||
| out << (Name() + "(path: " + dataset_dir_ + ", background: " + (background_ ? "true" : "false") + | |||
| ", decode: " + (decode_ ? "true" : "false") + ")"); | |||
| } | |||
| Status OmniglotNode::ValidateParams() { | |||
| RETURN_IF_NOT_OK(DatasetNode::ValidateParams()); | |||
| RETURN_IF_NOT_OK(ValidateDatasetDirParam("OmniglotDataset", dataset_dir_)); | |||
| RETURN_IF_NOT_OK(ValidateDatasetSampler("OmniglotDataset", sampler_)); | |||
| return Status::OK(); | |||
| } | |||
| Status OmniglotNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops) { | |||
| // Do internal Schema generation. | |||
| // This arg is exist in OmniglotOp, but not externalized (in Python API). | |||
| std::unique_ptr<DataSchema> schema = std::make_unique<DataSchema>(); | |||
| TensorShape scalar = TensorShape::CreateScalar(); | |||
| RETURN_IF_NOT_OK(schema->AddColumn(ColDescriptor("image", DataType(DataType::DE_UINT8), TensorImpl::kFlexible, 1))); | |||
| RETURN_IF_NOT_OK( | |||
| schema->AddColumn(ColDescriptor("label", DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 0, &scalar))); | |||
| std::shared_ptr<SamplerRT> sampler_rt = nullptr; | |||
| RETURN_IF_NOT_OK(sampler_->SamplerBuild(&sampler_rt)); | |||
| auto op = std::make_shared<OmniglotOp>(num_workers_, dataset_dir_, connector_que_size_, background_, decode_, | |||
| std::move(schema), std::move(sampler_rt)); | |||
| op->SetTotalRepeats(GetTotalRepeats()); | |||
| op->SetNumRepeatsPerEpoch(GetNumRepeatsPerEpoch()); | |||
| node_ops->push_back(op); | |||
| return Status::OK(); | |||
| } | |||
| // Get the shard id of node | |||
| Status OmniglotNode::GetShardId(int32_t *shard_id) { | |||
| RETURN_UNEXPECTED_IF_NULL(shard_id); | |||
| *shard_id = sampler_->ShardId(); | |||
| return Status::OK(); | |||
| } | |||
| // Get Dataset size | |||
| Status OmniglotNode::GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size_getter, bool estimate, | |||
| int64_t *dataset_size) { | |||
| RETURN_UNEXPECTED_IF_NULL(dataset_size); | |||
| if (dataset_size_ > 0) { | |||
| *dataset_size = dataset_size_; | |||
| return Status::OK(); | |||
| } | |||
| int64_t sample_size, num_rows; | |||
| Path dataset_path(dataset_dir_); | |||
| if (background_) { | |||
| dataset_path = dataset_path / "images_background"; | |||
| } else { | |||
| dataset_path = dataset_path / "images_evaluation"; | |||
| } | |||
| std::string path_str = dataset_path.ToString(); | |||
| RETURN_IF_NOT_OK(OmniglotOp::CountRowsAndClasses(path_str, &num_rows, nullptr)); | |||
| std::shared_ptr<SamplerRT> sampler_rt = nullptr; | |||
| RETURN_IF_NOT_OK(sampler_->SamplerBuild(&sampler_rt)); | |||
| sample_size = sampler_rt->CalculateNumSamples(num_rows); | |||
| if (sample_size == -1) { | |||
| RETURN_IF_NOT_OK(size_getter->DryRun(shared_from_this(), &sample_size)); | |||
| } | |||
| *dataset_size = sample_size; | |||
| dataset_size_ = *dataset_size; | |||
| return Status::OK(); | |||
| } | |||
| Status OmniglotNode::to_json(nlohmann::json *out_json) { | |||
| nlohmann::json args, sampler_args; | |||
| RETURN_UNEXPECTED_IF_NULL(out_json); | |||
| RETURN_IF_NOT_OK(sampler_->to_json(&sampler_args)); | |||
| args["num_parallel_workers"] = num_workers_; | |||
| args["dataset_dir"] = dataset_dir_; | |||
| args["background"] = background_; | |||
| args["decode"] = decode_; | |||
| args["sampler"] = sampler_args; | |||
| if (cache_ != nullptr) { | |||
| nlohmann::json cache_args; | |||
| RETURN_IF_NOT_OK(cache_->to_json(&cache_args)); | |||
| args["cache"] = cache_args; | |||
| } | |||
| *out_json = args; | |||
| return Status::OK(); | |||
| } | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,103 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_OMNIGLOT_NODE_H_ | |||
| #define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_OMNIGLOT_NODE_H_ | |||
| #include <map> | |||
| #include <memory> | |||
| #include <set> | |||
| #include <string> | |||
| #include <vector> | |||
| #include "mindspore/ccsrc/minddata/dataset/engine/ir/cache/dataset_cache.h" | |||
| #include "minddata/dataset/engine/ir/datasetops/dataset_node.h" | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| /// \class OmniglotNode | |||
| /// \brief A Dataset derived class to represent Omniglot dataset. | |||
| class OmniglotNode : public MappableSourceNode { | |||
| public: | |||
| /// \brief Constructor | |||
| OmniglotNode(const std::string &dataset_dir, bool background, bool decode, const std::shared_ptr<SamplerObj> &sampler, | |||
| const std::shared_ptr<DatasetCache> &cache); | |||
| /// \brief Destructor | |||
| ~OmniglotNode() = default; | |||
| /// \brief Node name getter. | |||
| /// \return Name of the current node. | |||
| std::string Name() const override { return kOmniglotNode; } | |||
| /// \brief Print the description. | |||
| /// \param out The output stream to write output to. | |||
| void Print(std::ostream &out) const override; | |||
| /// \brief Copy the node to a new object. | |||
| /// \return A shared pointer to the new copy. | |||
| std::shared_ptr<DatasetNode> Copy() override; | |||
| /// \brief a base class override function to create the required runtime dataset op objects for this class. | |||
| /// \param node_ops A vector containing shared pointer to the Dataset Ops that this object will create. | |||
| /// \return Status Status::OK() if build successfully. | |||
| Status Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops) override; | |||
| /// \brief Parameters validation | |||
| /// \return Status Status::OK() if all the parameters are valid. | |||
| Status ValidateParams() override; | |||
| /// \brief Get the shard id of node | |||
| /// \param[in] shard_id The shard ID within num_shards. | |||
| /// \return Status Status::OK() if get shard id successfully. | |||
| Status GetShardId(int32_t *shard_id) override; | |||
| /// \brief Base-class override for GetDatasetSize. | |||
| /// \param[in] size_getter Shared pointer to DatasetSizeGetter. | |||
| /// \param[in] estimate This is only supported by some of the ops and it's used to speed up the process of getting | |||
| /// dataset size at the expense of accuracy. | |||
| /// \param[out] dataset_size the size of the dataset. | |||
| /// \return Status of the function. | |||
| Status GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size_getter, bool estimate, | |||
| int64_t *dataset_size) override; | |||
| /// \brief Getter functions | |||
| const std::string &DatasetDir() const { return dataset_dir_; } | |||
| bool Background() const { return background_; } | |||
| bool Decode() const { return decode_; } | |||
| /// \brief Get the arguments of node. | |||
| /// \param[out] out_json JSON string of all attributes. | |||
| /// \return Status of the function. | |||
| Status to_json(nlohmann::json *out_json) override; | |||
| /// \brief Sampler getter. | |||
| /// \return SamplerObj of the current node. | |||
| std::shared_ptr<SamplerObj> Sampler() override { return sampler_; } | |||
| /// \brief Sampler setter. | |||
| /// \param[in] sampler Specify sampler. | |||
| void SetSampler(std::shared_ptr<SamplerObj> sampler) override { sampler_ = sampler; } | |||
| private: | |||
| std::string dataset_dir_; | |||
| bool background_; | |||
| bool decode_; | |||
| std::shared_ptr<SamplerObj> sampler_; | |||
| }; | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_OMNIGLOT_NODE_H_ | |||
| @@ -4318,6 +4318,99 @@ inline std::shared_ptr<Multi30kDataset> MS_API Multi30k(const std::string &datas | |||
| shard_id, cache); | |||
| } | |||
| /// \class OmniglotDataset | |||
| /// \brief A source dataset for reading and parsing Omniglot dataset. | |||
| class MS_API OmniglotDataset : public Dataset { | |||
| public: | |||
| /// \brief Constructor of OmniglotDataset. | |||
| /// \param[in] dataset_dir Path to the root directory that contains the dataset. | |||
| /// \param[in] background A flag to use background dataset or evaluation dataset. | |||
| /// \param[in] decode Decode the images after reading. | |||
| /// \param[in] sampler Shared pointer to a sampler object used to choose samples from the dataset. | |||
| /// \param[in] cache Tensor cache to use. | |||
| OmniglotDataset(const std::vector<char> &dataset_dir, bool background, bool decode, | |||
| const std::shared_ptr<Sampler> &sampler, const std::shared_ptr<DatasetCache> &cache); | |||
| /// \brief Constructor of OmniglotDataset. | |||
| /// \param[in] dataset_dir Path to the root directory that contains the dataset. | |||
| /// \param[in] background A flag to use background dataset or evaluation dataset. | |||
| /// \param[in] decode Decode the images after reading. | |||
| /// \param[in] sampler Raw pointer to a sampler object used to choose samples from the dataset. | |||
| /// \param[in] cache Tensor cache to use. | |||
| OmniglotDataset(const std::vector<char> &dataset_dir, bool background, bool decode, const Sampler *sampler, | |||
| const std::shared_ptr<DatasetCache> &cache); | |||
| /// \brief Constructor of OmniglotDataset. | |||
| /// \param[in] dataset_dir Path to the root directory that contains the dataset. | |||
| /// \param[in] background A flag to use background dataset or evaluation dataset. | |||
| /// \param[in] decode Decode the images after reading. | |||
| /// \param[in] sampler Sampler object used to choose samples from the dataset. | |||
| /// \param[in] cache Tensor cache to use. | |||
| OmniglotDataset(const std::vector<char> &dataset_dir, bool background, bool decode, | |||
| const std::reference_wrapper<Sampler> &sampler, const std::shared_ptr<DatasetCache> &cache); | |||
| /// Destructor of OmniglotDataset. | |||
| ~OmniglotDataset() = default; | |||
| }; | |||
| /// \brief Function to create an OmniglotDataset. | |||
| /// \notes The generated dataset has two columns ["image", "label"]. | |||
| /// \param[in] dataset_dir Path to the root directory that contains the dataset. | |||
| /// \param[in] background A flag to use background dataset or evaluation dataset (Default=true). | |||
| /// \param[in] decode Decode the images after reading (Default=false). | |||
| /// \param[in] sampler Shared pointer to a sampler object used to choose samples from the dataset. If sampler is not | |||
| /// given, a `RandomSampler` will be used to randomly iterate the entire dataset (default = RandomSampler()). | |||
| /// \param[in] cache Tensor cache to use (default=nullptr, which means no cache is used). | |||
| /// \return Shared pointer to the current OmniglotDataset. | |||
| /// \par Example | |||
| /// \code | |||
| /// /* Define dataset path and MindData object */ | |||
| /// std::string folder_path = "/path/to/omniglot_dataset_directory"; | |||
| /// std::shared_ptr<Dataset> ds = Omniglot(folder_path, true, false, std::make_shared<RandomSampler>(false, 5)); | |||
| /// | |||
| /// /* Create iterator to read dataset */ | |||
| /// std::shared_ptr<Iterator> iter = ds->CreateIterator(); | |||
| /// std::unordered_map<std::string, mindspore::MSTensor> row; | |||
| /// iter->GetNextRow(&row); | |||
| /// | |||
| /// /* Note: In Omniglot dataset, each dictionary has keys "image" and "label" */ | |||
| /// auto image = row["image"]; | |||
| /// \endcode | |||
| inline std::shared_ptr<OmniglotDataset> MS_API | |||
| Omniglot(const std::string &dataset_dir, bool background = true, bool decode = false, | |||
| const std::shared_ptr<Sampler> &sampler = std::make_shared<RandomSampler>(), | |||
| const std::shared_ptr<DatasetCache> &cache = nullptr) { | |||
| return std::make_shared<OmniglotDataset>(StringToChar(dataset_dir), background, decode, sampler, cache); | |||
| } | |||
| /// \brief Function to create an OmniglotDataset. | |||
| /// \notes The generated dataset has two columns ["image", "label"]. | |||
| /// \param[in] dataset_dir Path to the root directory that contains the dataset. | |||
| /// \param[in] background A flag to use background dataset or evaluation dataset. | |||
| /// \param[in] decode Decode the images after reading. | |||
| /// \param[in] sampler Raw pointer to a sampler object used to choose samples from the dataset. | |||
| /// \param[in] cache Tensor cache to use (default=nullptr, which means no cache is used). | |||
| /// \return Shared pointer to the current OmniglotDataset. | |||
| inline std::shared_ptr<OmniglotDataset> MS_API Omniglot(const std::string &dataset_dir, bool background, bool decode, | |||
| const Sampler *sampler, | |||
| const std::shared_ptr<DatasetCache> &cache = nullptr) { | |||
| return std::make_shared<OmniglotDataset>(StringToChar(dataset_dir), background, decode, sampler, cache); | |||
| } | |||
| /// \brief Function to create an OmniglotDataset. | |||
| /// \notes The generated dataset has two columns ["image", "label"]. | |||
| /// \param[in] dataset_dir Path to the root directory that contains the dataset. | |||
| /// \param[in] background A flag to use background dataset or evaluation dataset. | |||
| /// \param[in] decode Decode the images after reading. | |||
| /// \param[in] sampler Sampler object used to choose samples from the dataset. | |||
| /// \param[in] cache Tensor cache to use (default=nullptr, which means no cache is used). | |||
| /// \return Shared pointer to the current OmniglotDataset. | |||
| inline std::shared_ptr<OmniglotDataset> MS_API Omniglot(const std::string &dataset_dir, bool background, bool decode, | |||
| const std::reference_wrapper<Sampler> &sampler, | |||
| const std::shared_ptr<DatasetCache> &cache = nullptr) { | |||
| return std::make_shared<OmniglotDataset>(StringToChar(dataset_dir), background, decode, sampler, cache); | |||
| } | |||
| /// \class PennTreebankDataset | |||
| /// \brief A source dataset for reading and parsing PennTreebank dataset. | |||
| class MS_API PennTreebankDataset : public Dataset { | |||
| @@ -59,6 +59,7 @@ class MS_API Sampler : std::enable_shared_from_this<Sampler> { | |||
| friend class ManifestDataset; | |||
| friend class MindDataDataset; | |||
| friend class MnistDataset; | |||
| friend class OmniglotDataset; | |||
| friend class PhotoTourDataset; | |||
| friend class Places365Dataset; | |||
| friend class QMnistDataset; | |||
| @@ -56,6 +56,7 @@ __all__ = ["Caltech101Dataset", # Vision | |||
| "LSUNDataset", # Vision | |||
| "ManifestDataset", # Vision | |||
| "MnistDataset", # Vision | |||
| "OmniglotDataset", # Vision | |||
| "PhotoTourDataset", # Vision | |||
| "Places365Dataset", # Vision | |||
| "QMnistDataset", # Vision | |||
| @@ -38,7 +38,8 @@ from .validators import check_imagefolderdataset, check_kittidataset,\ | |||
| check_usps_dataset, check_div2k_dataset, check_random_dataset, \ | |||
| check_sbu_dataset, check_qmnist_dataset, check_emnist_dataset, check_fake_image_dataset, check_places365_dataset, \ | |||
| check_photo_tour_dataset, check_svhn_dataset, check_stl10_dataset, check_semeion_dataset, \ | |||
| check_caltech101_dataset, check_caltech256_dataset, check_wider_face_dataset, check_lfw_dataset, check_lsun_dataset | |||
| check_caltech101_dataset, check_caltech256_dataset, check_wider_face_dataset, check_lfw_dataset, \ | |||
| check_lsun_dataset, check_omniglotdataset | |||
| from ..core.validator_helpers import replace_none | |||
| @@ -3090,6 +3091,135 @@ class MnistDataset(MappableDataset, VisionBaseDataset): | |||
| return cde.MnistNode(self.dataset_dir, self.usage, self.sampler) | |||
| class OmniglotDataset(MappableDataset): | |||
| """ | |||
| A source dataset that reads and parses the Omniglot dataset. | |||
| The generated dataset has two columns :py:obj:`[image, label]`. | |||
| The tensor of column :py:obj:`image` is of the uint8 type. | |||
| The tensor of column :py:obj:`label` is a scalar of the uint32 type. | |||
| Args: | |||
| dataset_dir (str): Path to the root directory that contains the dataset. | |||
| background(bool, optional): Use the background dataset or the evaluation dataset | |||
| (default=None, will use the background dataset). | |||
| num_samples (int, optional): The number of images to be included in the dataset | |||
| (default=None, all images). | |||
| num_parallel_workers (int, optional): Number of workers to read the data | |||
| (default=None, set in the config). | |||
| shuffle (bool, optional): Whether or not to perform shuffle on the dataset | |||
| (default=None, expected order behavior shown in the table). | |||
| decode (bool, optional): Decode the images after reading (default=False). | |||
| sampler (Sampler, optional): Object used to choose samples from the | |||
| dataset (default=None, expected order behavior shown in the table). | |||
| num_shards (int, optional): Number of shards that the dataset will be divided | |||
| into (default=None). When this argument is specified, 'num_samples' reflects | |||
| the max sample number of per shard. | |||
| shard_id (int, optional): The shard ID within num_shards (default=None). This | |||
| argument can only be specified when num_shards is also specified. | |||
| cache (DatasetCache, optional): Use tensor caching service to speed up dataset processing | |||
| (default=None, which means no cache is used). | |||
| Raises: | |||
| RuntimeError: If `sampler` and `shuffle` are specified at the same time. | |||
| RuntimeError: If `sampler` and `sharding` are specified at the same time. | |||
| RuntimeError: If `num_shards` is specified but `shard_id` is None. | |||
| RuntimeError: If `shard_id` is specified but `num_shards` is None. | |||
| ValueError: If `shard_id` is invalid (< 0 or >= `num_shards`). | |||
| Note: | |||
| - This dataset can take in a sampler. `sampler` and `shuffle` are mutually exclusive. | |||
| The table below shows what input arguments are allowed and their expected behavior. | |||
| .. list-table:: Expected Order Behavior of Using `sampler` and `shuffle` | |||
| :widths: 25 25 50 | |||
| :header-rows: 1 | |||
| * - Parameter `sampler` | |||
| - Parameter `shuffle` | |||
| - Expected Order Behavior | |||
| * - None | |||
| - None | |||
| - random order | |||
| * - None | |||
| - True | |||
| - random order | |||
| * - None | |||
| - False | |||
| - sequential order | |||
| * - Sampler object | |||
| - None | |||
| - order defined by sampler | |||
| * - Sampler object | |||
| - True | |||
| - not allowed | |||
| * - Sampler object | |||
| - False | |||
| - not allowed | |||
| Examples: | |||
| >>> omniglot_dataset_dir = "/path/to/omniglot_dataset_directory" | |||
| >>> dataset = ds.OmniglotDataset(dataset_dir=omniglot_dataset_dir, | |||
| ... num_parallel_workers=8) | |||
| About Omniglot dataset: | |||
| The Omniglot dataset is designed for developing more human-like learning algorithms. Omniglot is a large dataset | |||
| of hand-written characters with 1623 characters and 20 examples for each character. These characters are collected | |||
| based upon 50 alphabets from different countries. It contains both images and strokes data. Stroke data are | |||
| coordinates with time in milliseconds. | |||
| You can unzip the original Omniglot dataset files into this directory structure and read by MindSpore's API. | |||
| .. code-block:: | |||
| . | |||
| └── omniglot_dataset_directory | |||
| ├── images_background/ | |||
| │ ├── character_class1/ | |||
| ├ ├──── 01.jpg | |||
| │ ├──── 02.jpg | |||
| │ ├── character_class2/ | |||
| ├ ├──── 01.jpg | |||
| │ ├──── 02.jpg | |||
| │ ├── ... | |||
| ├── images_evaluation/ | |||
| │ ├── character_class1/ | |||
| ├ ├──── 01.jpg | |||
| │ ├──── 02.jpg | |||
| │ ├── character_class2/ | |||
| ├ ├──── 01.jpg | |||
| │ ├──── 02.jpg | |||
| │ ├── ... | |||
| Citation: | |||
| .. code-block:: | |||
| @article{lake2015human, | |||
| title={Human-level concept learning through probabilistic program induction}, | |||
| author={Lake, Brenden M and Salakhutdinov, Ruslan and Tenenbaum, Joshua B}, | |||
| journal={Science}, | |||
| volume={350}, | |||
| number={6266}, | |||
| pages={1332--1338}, | |||
| year={2015}, | |||
| publisher={American Association for the Advancement of Science} | |||
| } | |||
| """ | |||
| @check_omniglotdataset | |||
| def __init__(self, dataset_dir, background=None, num_samples=None, num_parallel_workers=None, shuffle=None, | |||
| decode=False, sampler=None, num_shards=None, shard_id=None, cache=None): | |||
| super().__init__(num_parallel_workers=num_parallel_workers, sampler=sampler, num_samples=num_samples, | |||
| shuffle=shuffle, num_shards=num_shards, shard_id=shard_id, cache=cache) | |||
| self.dataset_dir = dataset_dir | |||
| self.background = replace_none(background, True) | |||
| self.decode = replace_none(decode, False) | |||
| def parse(self, children=None): | |||
| return cde.OmniglotNode(self.dataset_dir, self.background, self.decode, self.sampler) | |||
| class PhotoTourDataset(MappableDataset, VisionBaseDataset): | |||
| """ | |||
| A source dataset that reads and parses the PhotoTour dataset. | |||
| @@ -358,6 +358,30 @@ def check_mnist_cifar_dataset(method): | |||
| return new_method | |||
| def check_omniglotdataset(method): | |||
| """A wrapper that wraps a parameter checker around the original Dataset(OmniglotDataset).""" | |||
| @wraps(method) | |||
| def new_method(self, *args, **kwargs): | |||
| _, param_dict = parse_user_args(method, *args, **kwargs) | |||
| nreq_param_int = ['num_samples', 'num_parallel_workers', 'num_shards', 'shard_id'] | |||
| nreq_param_bool = ['shuffle', 'background', 'decode'] | |||
| dataset_dir = param_dict.get('dataset_dir') | |||
| check_dir(dataset_dir) | |||
| validate_dataset_param_value(nreq_param_int, param_dict, int) | |||
| validate_dataset_param_value(nreq_param_bool, param_dict, bool) | |||
| check_sampler_shuffle_shard_options(param_dict) | |||
| cache = param_dict.get('cache') | |||
| check_cache_option(cache) | |||
| return method(self, *args, **kwargs) | |||
| return new_method | |||
| def check_photo_tour_dataset(method): | |||
| """A wrapper that wraps a parameter checker around the original Dataset(PhotoTourDataset).""" | |||
| @@ -46,6 +46,7 @@ SET(DE_UT_SRCS | |||
| c_api_dataset_manifest_test.cc | |||
| c_api_dataset_minddata_test.cc | |||
| c_api_dataset_multi30k_test.cc | |||
| c_api_dataset_omniglot_test.cc | |||
| c_api_dataset_ops_test.cc | |||
| c_api_dataset_penn_treebank_test.cc | |||
| c_api_dataset_photo_tour_test.cc | |||
| @@ -0,0 +1,297 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "common/common.h" | |||
| #include "minddata/dataset/include/dataset/datasets.h" | |||
| using namespace mindspore::dataset; | |||
| using mindspore::dataset::DataType; | |||
| using mindspore::dataset::Tensor; | |||
| using mindspore::dataset::TensorShape; | |||
| class MindDataTestPipeline : public UT::DatasetOpTesting { | |||
| protected: | |||
| }; | |||
| /// Feature: OmniglotDataset | |||
| /// Description: test Omniglot | |||
| /// Expectation: get correct Omniglot dataset | |||
| TEST_F(MindDataTestPipeline, TestOmniglotBackgroundDataset) { | |||
| MS_LOG(INFO) << "Doing MindDataTestPipeline-TestOmniglotBackgroundDataset."; | |||
| // Create a Omniglot Dataset. | |||
| std::string folder_path = datasets_root_path_ + "/testOmniglot"; | |||
| std::shared_ptr<Dataset> ds = Omniglot(folder_path, true, false, std::make_shared<RandomSampler>(false, 5)); | |||
| EXPECT_NE(ds, nullptr); | |||
| // Create an iterator over the result of the above dataset. | |||
| // This will trigger the creation of the Execution Tree and launch it. | |||
| std::shared_ptr<Iterator> iter = ds->CreateIterator(); | |||
| EXPECT_NE(iter, nullptr); | |||
| // Iterate the dataset and get each row. | |||
| std::unordered_map<std::string, mindspore::MSTensor> row; | |||
| ASSERT_OK(iter->GetNextRow(&row)); | |||
| EXPECT_NE(row.find("image"), row.end()); | |||
| EXPECT_NE(row.find("label"), row.end()); | |||
| uint64_t i = 0; | |||
| while (row.size() != 0) { | |||
| i++; | |||
| auto image = row["image"]; | |||
| MS_LOG(INFO) << "Tensor image shape: " << image.Shape(); | |||
| ASSERT_OK(iter->GetNextRow(&row)); | |||
| } | |||
| EXPECT_EQ(i, 4); | |||
| // Manually terminate the pipeline. | |||
| iter->Stop(); | |||
| } | |||
| /// Feature: OmniglotDataset | |||
| /// Description: test Omniglot | |||
| /// Expectation: get correct Omniglot dataset | |||
| TEST_F(MindDataTestPipeline, TestOmniglotEvaluationDataset) { | |||
| MS_LOG(INFO) << "Doing MindDataTestPipeline-TestOmniglotEvaluationDataset."; | |||
| // Create a Omniglot Dataset. | |||
| std::string folder_path = datasets_root_path_ + "/testOmniglot"; | |||
| std::shared_ptr<Dataset> ds = Omniglot(folder_path, false, false, std::make_shared<RandomSampler>(false, 5)); | |||
| EXPECT_NE(ds, nullptr); | |||
| // Create an iterator over the result of the above dataset. | |||
| // This will trigger the creation of the Execution Tree and launch it. | |||
| std::shared_ptr<Iterator> iter = ds->CreateIterator(); | |||
| EXPECT_NE(iter, nullptr); | |||
| // Iterate the dataset and get each row. | |||
| std::unordered_map<std::string, mindspore::MSTensor> row; | |||
| ASSERT_OK(iter->GetNextRow(&row)); | |||
| EXPECT_NE(row.find("image"), row.end()); | |||
| EXPECT_NE(row.find("label"), row.end()); | |||
| uint64_t i = 0; | |||
| while (row.size() != 0) { | |||
| i++; | |||
| auto image = row["image"]; | |||
| MS_LOG(INFO) << "Tensor image shape: " << image.Shape(); | |||
| ASSERT_OK(iter->GetNextRow(&row)); | |||
| } | |||
| EXPECT_EQ(i, 4); | |||
| // Manually terminate the pipeline. | |||
| iter->Stop(); | |||
| } | |||
| /// Feature: OmniglotDataset | |||
| /// Description: test Omniglot | |||
| /// Expectation: get correct Omniglot dataset | |||
| TEST_F(MindDataTestPipeline, TestOmniglotBackgroundDatasetWithPipeline) { | |||
| MS_LOG(INFO) << "Doing MindDataTestPipeline-TestOmniglotBackgroundDatasetWithPipeline."; | |||
| // Create two Omniglot Dataset. | |||
| std::string folder_path = datasets_root_path_ + "/testOmniglot"; | |||
| std::shared_ptr<Dataset> ds1 = Omniglot(folder_path, true, false, std::make_shared<RandomSampler>(false, 5)); | |||
| std::shared_ptr<Dataset> ds2 = Omniglot(folder_path, true, false, std::make_shared<RandomSampler>(false, 5)); | |||
| EXPECT_NE(ds1, nullptr); | |||
| EXPECT_NE(ds2, nullptr); | |||
| // Create two Repeat operation on ds. | |||
| int32_t repeat_num = 1; | |||
| ds1 = ds1->Repeat(repeat_num); | |||
| EXPECT_NE(ds1, nullptr); | |||
| repeat_num = 1; | |||
| ds2 = ds2->Repeat(repeat_num); | |||
| EXPECT_NE(ds2, nullptr); | |||
| // Create two Project operation on ds. | |||
| std::vector<std::string> column_project = {"image", "label"}; | |||
| ds1 = ds1->Project(column_project); | |||
| EXPECT_NE(ds1, nullptr); | |||
| ds2 = ds2->Project(column_project); | |||
| EXPECT_NE(ds2, nullptr); | |||
| // Create a Concat operation on the ds. | |||
| ds1 = ds1->Concat({ds2}); | |||
| EXPECT_NE(ds1, nullptr); | |||
| // Create an iterator over the result of the above dataset. | |||
| // This will trigger the creation of the Execution Tree and launch it. | |||
| std::shared_ptr<Iterator> iter = ds1->CreateIterator(); | |||
| EXPECT_NE(iter, nullptr); | |||
| // Iterate the dataset and get each row. | |||
| std::unordered_map<std::string, mindspore::MSTensor> row; | |||
| ASSERT_OK(iter->GetNextRow(&row)); | |||
| EXPECT_NE(row.find("image"), row.end()); | |||
| EXPECT_NE(row.find("label"), row.end()); | |||
| uint64_t i = 0; | |||
| while (row.size() != 0) { | |||
| i++; | |||
| auto image = row["image"]; | |||
| MS_LOG(INFO) << "Tensor image shape: " << image.Shape(); | |||
| ASSERT_OK(iter->GetNextRow(&row)); | |||
| } | |||
| EXPECT_EQ(i, 8); | |||
| // Manually terminate the pipeline. | |||
| iter->Stop(); | |||
| } | |||
| /// Feature: OmniglotDataset | |||
| /// Description: test Omniglot | |||
| /// Expectation: get correct Omniglot dataset | |||
| TEST_F(MindDataTestPipeline, TestOmniglotBackgroundGetDatasetSize) { | |||
| MS_LOG(INFO) << "Doing MindDataTestPipeline-TestGetOmniglotBackgroundSize."; | |||
| // Create a Omniglot Dataset. | |||
| std::string folder_path = datasets_root_path_ + "/testOmniglot"; | |||
| std::shared_ptr<Dataset> ds = Omniglot(folder_path, true, false); | |||
| EXPECT_NE(ds, nullptr); | |||
| EXPECT_EQ(ds->GetDatasetSize(), 4); | |||
| } | |||
| /// Feature: OmniglotDataset | |||
| /// Description: test Omniglot | |||
| /// Expectation: get correct Omniglot dataset | |||
| TEST_F(MindDataTestPipeline, TestOmniglotEvaluationGetDatasetSize) { | |||
| MS_LOG(INFO) << "Doing MindDataTestPipeline-TestGetOmniglotEvaluationDatasetSize."; | |||
| // Create a Omniglot Dataset. | |||
| std::string folder_path = datasets_root_path_ + "/testOmniglot"; | |||
| std::shared_ptr<Dataset> ds = Omniglot(folder_path, false, false); | |||
| EXPECT_NE(ds, nullptr); | |||
| EXPECT_EQ(ds->GetDatasetSize(), 4); | |||
| } | |||
| /// Feature: OmniglotDataset | |||
| /// Description: test Omniglot | |||
| /// Expectation: get correct Omniglot dataset | |||
| TEST_F(MindDataTestPipeline, TestOmniglotBackgroundDatasetGetters) { | |||
| MS_LOG(INFO) << "Doing MindDataTestPipeline-TestOmniglotBackgroundDatasetGetters."; | |||
| // Create a Omniglot Dataset. | |||
| std::string folder_path = datasets_root_path_ + "/testOmniglot"; | |||
| std::shared_ptr<Dataset> ds = Omniglot(folder_path, true, true); | |||
| EXPECT_NE(ds, nullptr); | |||
| EXPECT_EQ(ds->GetDatasetSize(), 4); | |||
| std::vector<DataType> types = ToDETypes(ds->GetOutputTypes()); | |||
| std::vector<TensorShape> shapes = ToTensorShapeVec(ds->GetOutputShapes()); | |||
| std::vector<std::string> column_names = {"image", "label"}; | |||
| int64_t num_classes = ds->GetNumClasses(); | |||
| EXPECT_EQ(types.size(), 2); | |||
| EXPECT_EQ(types[0].ToString(), "uint8"); | |||
| EXPECT_EQ(types[1].ToString(), "int32"); | |||
| EXPECT_EQ(shapes.size(), 2); | |||
| EXPECT_EQ(shapes[1].ToString(), "<>"); | |||
| EXPECT_EQ(num_classes, 2); | |||
| EXPECT_EQ(ds->GetBatchSize(), 1); | |||
| EXPECT_EQ(ds->GetRepeatCount(), 1); | |||
| EXPECT_EQ(ds->GetDatasetSize(), 4); | |||
| EXPECT_EQ(ToDETypes(ds->GetOutputTypes()), types); | |||
| EXPECT_EQ(ToTensorShapeVec(ds->GetOutputShapes()), shapes); | |||
| EXPECT_EQ(ds->GetNumClasses(), 2); | |||
| EXPECT_EQ(ds->GetColumnNames(), column_names); | |||
| EXPECT_EQ(ds->GetDatasetSize(), 4); | |||
| EXPECT_EQ(ToDETypes(ds->GetOutputTypes()), types); | |||
| EXPECT_EQ(ToTensorShapeVec(ds->GetOutputShapes()), shapes); | |||
| EXPECT_EQ(ds->GetBatchSize(), 1); | |||
| EXPECT_EQ(ds->GetRepeatCount(), 1); | |||
| EXPECT_EQ(ds->GetNumClasses(), 2); | |||
| EXPECT_EQ(ds->GetDatasetSize(), 4); | |||
| } | |||
| /// Feature: OmniglotDataset | |||
| /// Description: test Omniglot | |||
| /// Expectation: get correct Omniglot dataset | |||
| TEST_F(MindDataTestPipeline, TestOmniglotEvaluationDatasetGetters) { | |||
| MS_LOG(INFO) << "Doing MindDataTestPipeline-TestOmniglotTestDatasetGetters."; | |||
| // Create a Omniglot Test Dataset. | |||
| std::string folder_path = datasets_root_path_ + "/testOmniglot"; | |||
| std::shared_ptr<Dataset> ds = Omniglot(folder_path, false, true); | |||
| EXPECT_NE(ds, nullptr); | |||
| EXPECT_EQ(ds->GetDatasetSize(), 4); | |||
| std::vector<DataType> types = ToDETypes(ds->GetOutputTypes()); | |||
| std::vector<TensorShape> shapes = ToTensorShapeVec(ds->GetOutputShapes()); | |||
| std::vector<std::string> column_names = {"image", "label"}; | |||
| int64_t num_classes = ds->GetNumClasses(); | |||
| EXPECT_EQ(types.size(), 2); | |||
| EXPECT_EQ(types[0].ToString(), "uint8"); | |||
| EXPECT_EQ(types[1].ToString(), "int32"); | |||
| EXPECT_EQ(shapes.size(), 2); | |||
| EXPECT_EQ(shapes[1].ToString(), "<>"); | |||
| EXPECT_EQ(num_classes, 2); | |||
| EXPECT_EQ(ds->GetBatchSize(), 1); | |||
| EXPECT_EQ(ds->GetRepeatCount(), 1); | |||
| EXPECT_EQ(ToDETypes(ds->GetOutputTypes()), types); | |||
| EXPECT_EQ(ToTensorShapeVec(ds->GetOutputShapes()), shapes); | |||
| EXPECT_EQ(ds->GetNumClasses(), 2); | |||
| EXPECT_EQ(ds->GetColumnNames(), column_names); | |||
| EXPECT_EQ(ToDETypes(ds->GetOutputTypes()), types); | |||
| EXPECT_EQ(ToTensorShapeVec(ds->GetOutputShapes()), shapes); | |||
| EXPECT_EQ(ds->GetBatchSize(), 1); | |||
| EXPECT_EQ(ds->GetRepeatCount(), 1); | |||
| EXPECT_EQ(ds->GetNumClasses(), 2); | |||
| } | |||
| /// Feature: TestOmniglotDatasetFail | |||
| /// Description: test invalid num_images of Omniglot | |||
| /// Expectation: throw exception correctly | |||
| TEST_F(MindDataTestPipeline, TestOmniglotDatasetFail) { | |||
| MS_LOG(INFO) << "Doing MindDataTestPipeline-TestOmniglotDatasetFail."; | |||
| // Create a Omniglot Dataset. | |||
| std::shared_ptr<Dataset> ds = Omniglot("", true, false, std::make_shared<RandomSampler>(false, 5)); | |||
| EXPECT_NE(ds, nullptr); | |||
| // Create an iterator over the result of the above dataset. | |||
| std::shared_ptr<Iterator> iter = ds->CreateIterator(); | |||
| // Expect failure: invalid Omniglot input. | |||
| EXPECT_EQ(iter, nullptr); | |||
| } | |||
| /// Feature: TestOmniglotDatasetWithNullSampler | |||
| /// Description: test null sampler of Omniglot | |||
| /// Expectation: throw exception correctly | |||
| TEST_F(MindDataTestPipeline, TestOmniglotDatasetWithNullSamplerFail) { | |||
| MS_LOG(INFO) << "Doing MindDataTestPipeline-TestOmniglotDatasetWithNullSamplerFail."; | |||
| // Create a Omniglot Dataset. | |||
| std::string folder_path = datasets_root_path_ + "/testOmniglot"; | |||
| std::shared_ptr<Dataset> ds = Omniglot(folder_path, true, false, nullptr); | |||
| EXPECT_NE(ds, nullptr); | |||
| // Create an iterator over the result of the above dataset. | |||
| std::shared_ptr<Iterator> iter = ds->CreateIterator(); | |||
| // Expect failure: invalid Omniglot input, sampler cannot be nullptr | |||
| EXPECT_EQ(iter, nullptr); | |||
| } | |||
| @@ -0,0 +1,485 @@ | |||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================== | |||
| """ | |||
| Test Omniglot dataset operators | |||
| """ | |||
| import mindspore.dataset as ds | |||
| import mindspore.dataset.vision.c_transforms as vision | |||
| from mindspore import log as logger | |||
| DATA_DIR = "../data/dataset/testOmniglot" | |||
| def test_omniglot_basic(): | |||
| """ | |||
| Feature: load_omniglot. | |||
| Description: load OmniglotDataset. | |||
| Expectation: get data of OmniglotDataset. | |||
| """ | |||
| logger.info("Test Case basic") | |||
| # define parameters. | |||
| repeat_count = 1 | |||
| # apply dataset operations. | |||
| data1 = ds.OmniglotDataset(DATA_DIR) | |||
| data1 = data1.repeat(repeat_count) | |||
| num_iter = 0 | |||
| count = [0, 0, 0, 0] | |||
| BASIC_EXPECTED_SHAPE = {"82386": 1, "61235": 1, "159109": 2} | |||
| ACTUAL_SHAPE = {"82386": 0, "61235": 0, "159109": 0} | |||
| # each data is a dictionary. | |||
| for item in data1.create_dict_iterator(num_epochs=1, output_numpy=True): | |||
| # in this example, each dictionary has keys "image" and "label". | |||
| ACTUAL_SHAPE[str(item["image"].shape[0])] += 1 | |||
| count[item["label"]] += 1 | |||
| num_iter += 1 | |||
| logger.info("Number of data in data1: {}".format(num_iter)) | |||
| assert num_iter == 4 | |||
| assert count == [2, 2, 0, 0] | |||
| assert ACTUAL_SHAPE == BASIC_EXPECTED_SHAPE | |||
| def test_omniglot_num_samples(): | |||
| """ | |||
| Feature: load_omniglot. | |||
| Description: load OmniglotDataset. | |||
| Expectation: get data of OmniglotDataset. | |||
| """ | |||
| logger.info("Test Case numSamples") | |||
| # define parameters. | |||
| repeat_count = 1 | |||
| # apply dataset operations. | |||
| data1 = ds.OmniglotDataset(DATA_DIR, num_samples=8, num_parallel_workers=2) | |||
| data1 = data1.repeat(repeat_count) | |||
| num_iter = 0 | |||
| # each data is a dictionary. | |||
| for _ in data1.create_dict_iterator(num_epochs=1, output_numpy=True): | |||
| num_iter += 1 | |||
| logger.info("Number of data in data1: {}".format(num_iter)) | |||
| assert num_iter == 4 | |||
| random_sampler = ds.RandomSampler(num_samples=3, replacement=True) | |||
| data1 = ds.OmniglotDataset(DATA_DIR, | |||
| num_parallel_workers=2, | |||
| sampler=random_sampler) | |||
| num_iter = 0 | |||
| for _ in data1.create_dict_iterator(num_epochs=1, output_numpy=True): | |||
| num_iter += 1 | |||
| assert num_iter == 3 | |||
| random_sampler = ds.RandomSampler(num_samples=3, replacement=False) | |||
| data1 = ds.OmniglotDataset(DATA_DIR, | |||
| num_parallel_workers=2, | |||
| sampler=random_sampler) | |||
| num_iter = 0 | |||
| for _ in data1.create_dict_iterator(num_epochs=1, output_numpy=True): | |||
| num_iter += 1 | |||
| assert num_iter == 3 | |||
| def test_omniglot_num_shards(): | |||
| """ | |||
| Feature: load_omniglot. | |||
| Description: load OmniglotDataset. | |||
| Expectation: get data of OmniglotDataset. | |||
| """ | |||
| logger.info("Test Case numShards") | |||
| # define parameters. | |||
| repeat_count = 1 | |||
| # apply dataset operations. | |||
| data1 = ds.OmniglotDataset(DATA_DIR, num_shards=4, shard_id=2) | |||
| data1 = data1.repeat(repeat_count) | |||
| num_iter = 0 | |||
| # each data is a dictionary. | |||
| for item in data1.create_dict_iterator(num_epochs=1, output_numpy=True): | |||
| # in this example, each dictionary has keys "image" and "label". | |||
| assert item["image"].shape[0] == 82386 | |||
| assert item["label"] == 1 | |||
| num_iter += 1 | |||
| logger.info("Number of data in data1: {}".format(num_iter)) | |||
| assert num_iter == 1 | |||
| def test_omniglot_shard_id(): | |||
| """ | |||
| Feature: load_omniglot. | |||
| Description: load OmniglotDataset. | |||
| Expectation: get data of OmniglotDataset. | |||
| """ | |||
| logger.info("Test Case withShardID") | |||
| # define parameters. | |||
| repeat_count = 1 | |||
| # apply dataset operations. | |||
| data1 = ds.OmniglotDataset(DATA_DIR, num_shards=4, shard_id=1) | |||
| data1 = data1.repeat(repeat_count) | |||
| num_iter = 0 | |||
| # each data is a dictionary. | |||
| for item in data1.create_dict_iterator(num_epochs=1, output_numpy=True): | |||
| # in this example, each dictionary has keys "image" and "label". | |||
| assert item["image"].shape[0] == 159109 | |||
| assert item["label"] == 0 | |||
| num_iter += 1 | |||
| logger.info("Number of data in data1: {}".format(num_iter)) | |||
| assert num_iter == 1 | |||
| def test_omniglot_no_shuffle(): | |||
| """ | |||
| Feature: load_omniglot. | |||
| Description: load OmniglotDataset. | |||
| Expectation: get data of OmniglotDataset. | |||
| """ | |||
| logger.info("Test Case noShuffle") | |||
| # define parameters. | |||
| repeat_count = 1 | |||
| # apply dataset operations. | |||
| data1 = ds.OmniglotDataset(DATA_DIR, shuffle=False) | |||
| data1 = data1.repeat(repeat_count) | |||
| num_iter = 0 | |||
| count = [0, 0, 0, 0] | |||
| SHAPE = [159109, 159109, 82386, 61235] | |||
| # each data is a dictionary. | |||
| for item in data1.create_dict_iterator(num_epochs=1, output_numpy=True): | |||
| # in this example, each dictionary has keys "image" and "label". | |||
| assert item["image"].shape[0] == SHAPE[num_iter] | |||
| count[item["label"]] += 1 | |||
| num_iter += 1 | |||
| assert num_iter == 4 | |||
| assert count == [2, 2, 0, 0] | |||
| def test_omniglot_extra_shuffle(): | |||
| """ | |||
| Feature: load_omniglot. | |||
| Description: load OmniglotDataset. | |||
| Expectation: get data of OmniglotDataset. | |||
| """ | |||
| logger.info("Test Case extraShuffle") | |||
| # define parameters. | |||
| repeat_count = 2 | |||
| # apply dataset operations. | |||
| data1 = ds.OmniglotDataset(DATA_DIR, shuffle=True) | |||
| data1 = data1.shuffle(buffer_size=5) | |||
| data1 = data1.repeat(repeat_count) | |||
| num_iter = 0 | |||
| count = [0, 0, 0, 0] | |||
| EXPECTED_SHAPE = {"82386": 2, "61235": 2, "159109": 4} | |||
| ACTUAL_SHAPE = {"82386": 0, "61235": 0, "159109": 0} | |||
| # each data is a dictionary. | |||
| for item in data1.create_dict_iterator(num_epochs=1, output_numpy=True): | |||
| # in this example, each dictionary has keys "image" and "label". | |||
| ACTUAL_SHAPE[str(item["image"].shape[0])] += 1 | |||
| count[item["label"]] += 1 | |||
| num_iter += 1 | |||
| logger.info("Number of data in data1: {}".format(num_iter)) | |||
| assert num_iter == 8 | |||
| assert count == [4, 4, 0, 0] | |||
| assert ACTUAL_SHAPE == EXPECTED_SHAPE | |||
| def test_omniglot_decode(): | |||
| """ | |||
| Feature: load_omniglot. | |||
| Description: load OmniglotDataset. | |||
| Expectation: get data of OmniglotDataset. | |||
| """ | |||
| logger.info("Test Case decode") | |||
| # define parameters. | |||
| repeat_count = 1 | |||
| # apply dataset operations. | |||
| data1 = ds.OmniglotDataset(DATA_DIR, decode=True) | |||
| data1 = data1.repeat(repeat_count) | |||
| num_iter = 0 | |||
| # each data is a dictionary. | |||
| for _ in data1.create_dict_iterator(num_epochs=1, output_numpy=True): | |||
| num_iter += 1 | |||
| logger.info("Number of data in data1: {}".format(num_iter)) | |||
| assert num_iter == 4 | |||
| def test_sequential_sampler(): | |||
| """ | |||
| Feature: load_omniglot. | |||
| Description: load OmniglotDataset. | |||
| Expectation: get data of OmniglotDataset. | |||
| """ | |||
| logger.info("Test Case SequentialSampler") | |||
| # define parameters. | |||
| repeat_count = 1 | |||
| # apply dataset operations. | |||
| sampler = ds.SequentialSampler(num_samples=8) | |||
| data1 = ds.OmniglotDataset(DATA_DIR, sampler=sampler) | |||
| data_seq = data1.repeat(repeat_count) | |||
| num_iter = 0 | |||
| count = [0, 0, 0, 0] | |||
| SHAPE = [159109, 159109, 82386, 61235] | |||
| # each data is a dictionary. | |||
| for item in data_seq.create_dict_iterator(num_epochs=1, output_numpy=True): | |||
| # in this example, each dictionary has keys "image" and "label". | |||
| assert item["image"].shape[0] == SHAPE[num_iter] | |||
| count[item["label"]] += 1 | |||
| num_iter += 1 | |||
| assert num_iter == 4 | |||
| assert count == [2, 2, 0, 0] | |||
| def test_random_sampler(): | |||
| """ | |||
| Feature: load_omniglot. | |||
| Description: load OmniglotDataset. | |||
| Expectation: get data of OmniglotDataset. | |||
| """ | |||
| logger.info("Test Case RandomSampler") | |||
| # define parameters. | |||
| repeat_count = 1 | |||
| # apply dataset operations. | |||
| sampler = ds.RandomSampler() | |||
| data1 = ds.OmniglotDataset(DATA_DIR, sampler=sampler) | |||
| data1 = data1.repeat(repeat_count) | |||
| num_iter = 0 | |||
| count = [0, 0, 0, 0] | |||
| RANDOM_EXPECTED_SHAPE = {"82386": 1, "61235": 1, "159109": 2} | |||
| ACTUAL_SHAPE = {"82386": 0, "61235": 0, "159109": 0} | |||
| # each data is a dictionary. | |||
| for item in data1.create_dict_iterator(num_epochs=1, output_numpy=True): | |||
| # in this example, each dictionary has keys "image" and "label". | |||
| ACTUAL_SHAPE[str(item["image"].shape[0])] += 1 | |||
| count[item["label"]] += 1 | |||
| num_iter += 1 | |||
| logger.info("Number of data in data1: {}".format(num_iter)) | |||
| assert num_iter == 4 | |||
| assert count == [2, 2, 0, 0] | |||
| assert ACTUAL_SHAPE == RANDOM_EXPECTED_SHAPE | |||
| def test_distributed_sampler(): | |||
| """ | |||
| Feature: load_omniglot. | |||
| Description: load OmniglotDataset. | |||
| Expectation: get data of OmniglotDataset. | |||
| """ | |||
| logger.info("Test Case DistributedSampler") | |||
| # define parameters. | |||
| repeat_count = 1 | |||
| # apply dataset operations. | |||
| sampler = ds.DistributedSampler(4, 1) | |||
| data1 = ds.OmniglotDataset(DATA_DIR, sampler=sampler) | |||
| data1 = data1.repeat(repeat_count) | |||
| num_iter = 0 | |||
| # each data is a dictionary. | |||
| for item in data1.create_dict_iterator(num_epochs=1, output_numpy=True): | |||
| # in this example, each dictionary has keys "image" and "label". | |||
| assert item["image"].shape[0] == 159109 | |||
| assert item["label"] == 0 | |||
| num_iter += 1 | |||
| logger.info("Number of data in data1: {}".format(num_iter)) | |||
| assert num_iter == 1 | |||
| def test_pk_sampler(): | |||
| """ | |||
| Feature: load_omniglot. | |||
| Description: load OmniglotDataset. | |||
| Expectation: get data of OmniglotDataset. | |||
| """ | |||
| logger.info("Test Case PKSampler") | |||
| # define parameters. | |||
| repeat_count = 1 | |||
| # apply dataset operations. | |||
| sampler = ds.PKSampler(1) | |||
| data1 = ds.OmniglotDataset(DATA_DIR, sampler=sampler) | |||
| data1 = data1.repeat(repeat_count) | |||
| num_iter = 0 | |||
| # each data is a dictionary. | |||
| for _ in data1.create_dict_iterator(num_epochs=1, output_numpy=True): | |||
| num_iter += 1 | |||
| logger.info("Number of data in data1: {}".format(num_iter)) | |||
| assert num_iter == 2 | |||
| def test_chained_sampler(): | |||
| """ | |||
| Feature: load_omniglot. | |||
| Description: load OmniglotDataset. | |||
| Expectation: get data of OmniglotDataset. | |||
| """ | |||
| logger.info( | |||
| "Test Case Chained Sampler - Random and Sequential, with repeat") | |||
| # Create chained sampler, random and sequential. | |||
| sampler = ds.RandomSampler() | |||
| child_sampler = ds.SequentialSampler() | |||
| sampler.add_child(child_sampler) | |||
| # Create OmniglotDataset with sampler. | |||
| data1 = ds.OmniglotDataset(DATA_DIR, sampler=sampler) | |||
| data1 = data1.repeat(count=3) | |||
| # Verify dataset size. | |||
| data1_size = data1.get_dataset_size() | |||
| logger.info("dataset size is: {}".format(data1_size)) | |||
| assert data1_size == 12 | |||
| # Verify number of iterations. | |||
| num_iter = 0 | |||
| # each data is a dictionary. | |||
| for _ in data1.create_dict_iterator(num_epochs=1, output_numpy=True): | |||
| num_iter += 1 | |||
| logger.info("Number of data in data1: {}".format(num_iter)) | |||
| assert num_iter == 12 | |||
| def test_omniglot_evaluation(): | |||
| """ | |||
| Feature: load_omniglot. | |||
| Description: load OmniglotDataset. | |||
| Expectation: get data of OmniglotDataset. | |||
| """ | |||
| logger.info("Test Case usage") | |||
| # apply dataset operations. | |||
| data1 = ds.OmniglotDataset(DATA_DIR, background=False, num_samples=6) | |||
| num_iter = 0 | |||
| # each data is a dictionary. | |||
| for _ in data1.create_dict_iterator(num_epochs=1, output_numpy=True): | |||
| num_iter += 1 | |||
| logger.info("Number of data in data1: {}".format(num_iter)) | |||
| assert num_iter == 4 | |||
| def test_omniglot_zip(): | |||
| """ | |||
| Feature: load_omniglot. | |||
| Description: load OmniglotDataset. | |||
| Expectation: get data of OmniglotDataset. | |||
| """ | |||
| logger.info("Test Case zip") | |||
| # define parameters. | |||
| repeat_count = 2 | |||
| # apply dataset operations. | |||
| data1 = ds.OmniglotDataset(DATA_DIR, num_samples=8) | |||
| data2 = ds.OmniglotDataset(DATA_DIR, num_samples=8) | |||
| data1 = data1.repeat(repeat_count) | |||
| # rename dataset2 for no conflict. | |||
| data2 = data2.rename(input_columns=["image", "label"], | |||
| output_columns=["image1", "label1"]) | |||
| data3 = ds.zip((data1, data2)) | |||
| num_iter = 0 | |||
| # each data is a dictionary. | |||
| for _ in data3.create_dict_iterator(num_epochs=1, output_numpy=True): | |||
| num_iter += 1 | |||
| logger.info("Number of data in data1: {}".format(num_iter)) | |||
| assert num_iter == 4 | |||
| def test_omniglot_exception(): | |||
| """ | |||
| Feature: test_omniglot_exception. | |||
| Description: test error cases for OmniglotDataset. | |||
| Expectation: raise exception. | |||
| """ | |||
| logger.info("Test omniglot exception") | |||
| def exception_func(item): | |||
| raise Exception("Error occur!") | |||
| def exception_func2(image, label): | |||
| raise Exception("Error occur!") | |||
| try: | |||
| data = ds.OmniglotDataset(DATA_DIR) | |||
| data = data.map(operations=exception_func, | |||
| input_columns=["image"], | |||
| num_parallel_workers=1) | |||
| for _ in data.__iter__(): | |||
| pass | |||
| assert False | |||
| except RuntimeError as e: | |||
| assert "map operation: [PyFunc] failed. The corresponding data files" in str( | |||
| e) | |||
| try: | |||
| data = ds.OmniglotDataset(DATA_DIR) | |||
| data = data.map(operations=exception_func2, | |||
| input_columns=["image", "label"], | |||
| output_columns=["image", "label", "label1"], | |||
| column_order=["image", "label", "label1"], | |||
| num_parallel_workers=1) | |||
| for _ in data.__iter__(): | |||
| pass | |||
| assert False | |||
| except RuntimeError as e: | |||
| assert "map operation: [PyFunc] failed. The corresponding data files" in str(e) | |||
| try: | |||
| data = ds.OmniglotDataset(DATA_DIR) | |||
| data = data.map(operations=vision.Decode(), input_columns=["image"], num_parallel_workers=1) | |||
| data = data.map(operations=exception_func, input_columns=["image"], num_parallel_workers=1) | |||
| for _ in data.__iter__(): | |||
| pass | |||
| assert False | |||
| except RuntimeError as e: | |||
| assert "map operation: [PyFunc] failed. The corresponding data files" in str(e) | |||
| if __name__ == '__main__': | |||
| test_omniglot_basic() | |||
| test_omniglot_num_samples() | |||
| test_sequential_sampler() | |||
| test_random_sampler() | |||
| test_distributed_sampler() | |||
| test_chained_sampler() | |||
| test_pk_sampler() | |||
| test_omniglot_num_shards() | |||
| test_omniglot_shard_id() | |||
| test_omniglot_no_shuffle() | |||
| test_omniglot_extra_shuffle() | |||
| test_omniglot_decode() | |||
| test_omniglot_evaluation() | |||
| test_omniglot_zip() | |||
| test_omniglot_exception() | |||