| @@ -26,6 +26,7 @@ | |||||
| #include "minddata/dataset/engine/datasetops/source/coco_op.h" | #include "minddata/dataset/engine/datasetops/source/coco_op.h" | ||||
| #include "minddata/dataset/engine/datasetops/source/image_folder_op.h" | #include "minddata/dataset/engine/datasetops/source/image_folder_op.h" | ||||
| #include "minddata/dataset/engine/datasetops/source/mnist_op.h" | #include "minddata/dataset/engine/datasetops/source/mnist_op.h" | ||||
| #include "minddata/dataset/engine/datasetops/source/text_file_op.h" | |||||
| #include "minddata/dataset/engine/datasetops/source/voc_op.h" | #include "minddata/dataset/engine/datasetops/source/voc_op.h" | ||||
| // Dataset operator headers (in alphabetical order) | // Dataset operator headers (in alphabetical order) | ||||
| #include "minddata/dataset/engine/datasetops/batch_op.h" | #include "minddata/dataset/engine/datasetops/batch_op.h" | ||||
| @@ -95,6 +96,7 @@ Dataset::Dataset() { | |||||
| num_workers_ = cfg->num_parallel_workers(); | num_workers_ = cfg->num_parallel_workers(); | ||||
| rows_per_buffer_ = cfg->rows_per_buffer(); | rows_per_buffer_ = cfg->rows_per_buffer(); | ||||
| connector_que_size_ = cfg->op_connector_size(); | connector_que_size_ = cfg->op_connector_size(); | ||||
| worker_connector_size_ = cfg->worker_connector_size(); | |||||
| } | } | ||||
| // FUNCTIONS TO CREATE DATASETS FOR LEAF-NODE DATASETS | // FUNCTIONS TO CREATE DATASETS FOR LEAF-NODE DATASETS | ||||
| @@ -140,7 +142,7 @@ std::shared_ptr<CocoDataset> Coco(const std::string &dataset_dir, const std::str | |||||
| std::shared_ptr<ImageFolderDataset> ImageFolder(std::string dataset_dir, bool decode, | std::shared_ptr<ImageFolderDataset> ImageFolder(std::string dataset_dir, bool decode, | ||||
| std::shared_ptr<SamplerObj> sampler, std::set<std::string> extensions, | std::shared_ptr<SamplerObj> sampler, std::set<std::string> extensions, | ||||
| std::map<std::string, int32_t> class_indexing) { | std::map<std::string, int32_t> class_indexing) { | ||||
| // This arg is exist in ImageFolderOp, but not externalized (in Python API). The default value is false. | |||||
| // This arg exists in ImageFolderOp, but not externalized (in Python API). The default value is false. | |||||
| bool recursive = false; | bool recursive = false; | ||||
| // Create logical representation of ImageFolderDataset. | // Create logical representation of ImageFolderDataset. | ||||
| @@ -163,6 +165,16 @@ std::shared_ptr<ConcatDataset> operator+(const std::shared_ptr<Dataset> &dataset | |||||
| const std::shared_ptr<Dataset> &datasets2) { | const std::shared_ptr<Dataset> &datasets2) { | ||||
| std::shared_ptr<ConcatDataset> ds = std::make_shared<ConcatDataset>(std::vector({datasets1, datasets2})); | std::shared_ptr<ConcatDataset> ds = std::make_shared<ConcatDataset>(std::vector({datasets1, datasets2})); | ||||
| // Call derived class validation method. | |||||
| return ds->ValidateParams() ? ds : nullptr; | |||||
| } | |||||
| // Function to create a TextFileDataset. | |||||
| std::shared_ptr<TextFileDataset> TextFile(std::vector<std::string> dataset_files, int32_t num_samples, | |||||
| ShuffleMode shuffle, int32_t num_shards, int32_t shard_id) { | |||||
| auto ds = std::make_shared<TextFileDataset>(dataset_files, num_samples, shuffle, num_shards, shard_id); | |||||
| // Call derived class validation method. | |||||
| return ds->ValidateParams() ? ds : nullptr; | return ds->ValidateParams() ? ds : nullptr; | ||||
| } | } | ||||
| @@ -340,6 +352,34 @@ std::shared_ptr<SamplerObj> CreateDefaultSampler() { | |||||
| return std::make_shared<RandomSamplerObj>(replacement, num_samples); | return std::make_shared<RandomSamplerObj>(replacement, num_samples); | ||||
| } | } | ||||
| // Helper function to compute a default shuffle size | |||||
| int64_t ComputeShuffleSize(int64_t num_files, int64_t num_devices, int64_t num_rows, int64_t total_rows) { | |||||
| const int64_t average_files_multiplier = 4; | |||||
| const int64_t shuffle_max = 10000; | |||||
| int64_t avg_rows_per_file = 0; | |||||
| int64_t shuffle_size = 0; | |||||
| // Adjust the num rows per shard if sharding was given | |||||
| if (num_devices > 0) { | |||||
| if (num_rows % num_devices == 0) { | |||||
| num_rows = num_rows / num_devices; | |||||
| } else { | |||||
| num_rows = (num_rows / num_devices) + 1; | |||||
| } | |||||
| } | |||||
| // Cap based on total rows directive. Some ops do not have this and give value of 0. | |||||
| if (total_rows > 0) { | |||||
| num_rows = std::min(num_rows, total_rows); | |||||
| } | |||||
| // get the average per file | |||||
| avg_rows_per_file = num_rows / num_files; | |||||
| shuffle_size = std::max(avg_rows_per_file * average_files_multiplier, shuffle_max); | |||||
| return shuffle_size; | |||||
| } | |||||
| // Helper function to validate dataset params | // Helper function to validate dataset params | ||||
| bool ValidateCommonDatasetParams(std::string dataset_dir) { | bool ValidateCommonDatasetParams(std::string dataset_dir) { | ||||
| if (dataset_dir.empty()) { | if (dataset_dir.empty()) { | ||||
| @@ -613,6 +653,87 @@ std::vector<std::shared_ptr<DatasetOp>> MnistDataset::Build() { | |||||
| return node_ops; | return node_ops; | ||||
| } | } | ||||
| // Constructor for TextFileDataset | |||||
| TextFileDataset::TextFileDataset(std::vector<std::string> dataset_files, int32_t num_samples, ShuffleMode shuffle, | |||||
| int32_t num_shards, int32_t shard_id) | |||||
| : dataset_files_(dataset_files), | |||||
| num_samples_(num_samples), | |||||
| shuffle_(shuffle), | |||||
| num_shards_(num_shards), | |||||
| shard_id_(shard_id) {} | |||||
| bool TextFileDataset::ValidateParams() { | |||||
| if (dataset_files_.empty()) { | |||||
| MS_LOG(ERROR) << "TextFileDataset: dataset_files is not specified."; | |||||
| return false; | |||||
| } | |||||
| for (auto file : dataset_files_) { | |||||
| std::ifstream handle(file); | |||||
| if (!handle.is_open()) { | |||||
| MS_LOG(ERROR) << "TextFileDataset: Failed to open file: " << file; | |||||
| return false; | |||||
| } | |||||
| } | |||||
| if (num_samples_ < 0) { | |||||
| MS_LOG(ERROR) << "TextFileDataset: Invalid number of samples: " << num_samples_; | |||||
| return false; | |||||
| } | |||||
| if (num_shards_ <= 0) { | |||||
| MS_LOG(ERROR) << "TextFileDataset: Invalid num_shards: " << num_shards_; | |||||
| return false; | |||||
| } | |||||
| if (shard_id_ < 0 || shard_id_ >= num_shards_) { | |||||
| MS_LOG(ERROR) << "TextFileDataset: Invalid input, shard_id: " << shard_id_ << ", num_shards: " << num_shards_; | |||||
| return false; | |||||
| } | |||||
| return true; | |||||
| } | |||||
| // Function to build TextFileDataset | |||||
| std::vector<std::shared_ptr<DatasetOp>> TextFileDataset::Build() { | |||||
| // A vector containing shared pointer to the Dataset Ops that this object will create | |||||
| std::vector<std::shared_ptr<DatasetOp>> node_ops; | |||||
| bool shuffle_files = (shuffle_ == ShuffleMode::kGlobal || shuffle_ == ShuffleMode::kFiles); | |||||
| // Do internal Schema generation. | |||||
| auto schema = std::make_unique<DataSchema>(); | |||||
| RETURN_EMPTY_IF_ERROR( | |||||
| schema->AddColumn(ColDescriptor("text", DataType(DataType::DE_UINT8), TensorImpl::kFlexible, 1))); | |||||
| // Create and initalize TextFileOp | |||||
| std::shared_ptr<TextFileOp> text_file_op = std::make_shared<TextFileOp>( | |||||
| num_workers_, rows_per_buffer_, num_samples_, worker_connector_size_, std::move(schema), dataset_files_, | |||||
| connector_que_size_, shuffle_files, num_shards_, shard_id_, std::move(nullptr)); | |||||
| RETURN_EMPTY_IF_ERROR(text_file_op->Init()); | |||||
| if (shuffle_ == ShuffleMode::kGlobal) { | |||||
| // Inject ShuffleOp | |||||
| std::shared_ptr<DatasetOp> shuffle_op = nullptr; | |||||
| int64_t shuffle_size = 0; | |||||
| int64_t num_rows = 0; | |||||
| // First, get the number of rows in the dataset and then compute the shuffle size | |||||
| RETURN_EMPTY_IF_ERROR(TextFileOp::CountAllFileRows(dataset_files_, &num_rows)); | |||||
| shuffle_size = ComputeShuffleSize(dataset_files_.size(), num_shards_, num_rows, 0); | |||||
| MS_LOG(INFO) << "TextFileDataset::Build - num_rows: " << num_rows << ", shuffle_size: " << shuffle_size; | |||||
| // Add the shuffle op after this op | |||||
| shuffle_op = std::make_shared<ShuffleOp>(shuffle_size, GetSeed(), connector_que_size_, true, rows_per_buffer_); | |||||
| node_ops.push_back(shuffle_op); | |||||
| } | |||||
| // Add TextFileOp | |||||
| node_ops.push_back(text_file_op); | |||||
| return node_ops; | |||||
| } | |||||
| // Constructor for VOCDataset | // Constructor for VOCDataset | ||||
| VOCDataset::VOCDataset(const std::string &dataset_dir, const std::string &task, const std::string &mode, | VOCDataset::VOCDataset(const std::string &dataset_dir, const std::string &task, const std::string &mode, | ||||
| const std::map<std::string, int32_t> &class_index, bool decode, | const std::map<std::string, int32_t> &class_index, bool decode, | ||||
| @@ -35,6 +35,9 @@ enum class DatasetType { kUnknown, kArrow, kTf }; | |||||
| // Possible flavours of Tensor implementations | // Possible flavours of Tensor implementations | ||||
| enum class TensorImpl { kNone, kFlexible, kCv, kNP }; | enum class TensorImpl { kNone, kFlexible, kCv, kNP }; | ||||
| // Possible values for shuffle | |||||
| enum class ShuffleMode { kFalse = 0, kFiles = 1, kGlobal = 2 }; | |||||
| // Possible values for Border types | // Possible values for Border types | ||||
| enum class BorderType { kConstant = 0, kEdge = 1, kReflect = 2, kSymmetric = 3 }; | enum class BorderType { kConstant = 0, kEdge = 1, kReflect = 2, kSymmetric = 3 }; | ||||
| @@ -23,6 +23,7 @@ | |||||
| #include <map> | #include <map> | ||||
| #include <utility> | #include <utility> | ||||
| #include <string> | #include <string> | ||||
| #include "minddata/dataset/core/constants.h" | |||||
| #include "minddata/dataset/include/tensor.h" | #include "minddata/dataset/include/tensor.h" | ||||
| #include "minddata/dataset/include/iterator.h" | #include "minddata/dataset/include/iterator.h" | ||||
| #include "minddata/dataset/include/samplers.h" | #include "minddata/dataset/include/samplers.h" | ||||
| @@ -47,6 +48,7 @@ class Cifar100Dataset; | |||||
| class CocoDataset; | class CocoDataset; | ||||
| class ImageFolderDataset; | class ImageFolderDataset; | ||||
| class MnistDataset; | class MnistDataset; | ||||
| class TextFileDataset; | |||||
| class VOCDataset; | class VOCDataset; | ||||
| // Dataset Op classes (in alphabetical order) | // Dataset Op classes (in alphabetical order) | ||||
| class BatchDataset; | class BatchDataset; | ||||
| @@ -83,7 +85,7 @@ std::shared_ptr<CelebADataset> CelebA(const std::string &dataset_dir, const std: | |||||
| std::shared_ptr<Cifar10Dataset> Cifar10(const std::string &dataset_dir, std::shared_ptr<SamplerObj> sampler = nullptr); | std::shared_ptr<Cifar10Dataset> Cifar10(const std::string &dataset_dir, std::shared_ptr<SamplerObj> sampler = nullptr); | ||||
| /// \brief Function to create a Cifar100 Dataset | /// \brief Function to create a Cifar100 Dataset | ||||
| /// \notes The generated dataset has two columns ['image', 'coarse_label', 'fine_label'] | |||||
| /// \notes The generated dataset has three columns ['image', 'coarse_label', 'fine_label'] | |||||
| /// \param[in] dataset_dir Path to the root directory that contains the dataset | /// \param[in] dataset_dir Path to the root directory that contains the dataset | ||||
| /// \param[in] sampler Object used to choose samples from the dataset. If sampler is `nullptr`, A `RandomSampler` | /// \param[in] sampler Object used to choose samples from the dataset. If sampler is `nullptr`, A `RandomSampler` | ||||
| /// will be used to randomly iterate the entire dataset | /// will be used to randomly iterate the entire dataset | ||||
| @@ -143,6 +145,25 @@ std::shared_ptr<MnistDataset> Mnist(std::string dataset_dir, std::shared_ptr<Sam | |||||
| std::shared_ptr<ConcatDataset> operator+(const std::shared_ptr<Dataset> &datasets1, | std::shared_ptr<ConcatDataset> operator+(const std::shared_ptr<Dataset> &datasets1, | ||||
| const std::shared_ptr<Dataset> &datasets2); | const std::shared_ptr<Dataset> &datasets2); | ||||
| /// \brief Function to create a TextFileDataset | |||||
| /// \notes The generated dataset has one column ['text'] | |||||
| /// \param[in] dataset_files List of files to be read to search for a pattern of files. The list | |||||
| /// will be sorted in a lexicographical order. | |||||
| /// \param[in] num_samples The number of samples to be included in the dataset. | |||||
| /// (Default = 0 means all samples.) | |||||
| /// \param[in] shuffle The mode for shuffling data every epoch. (Default=ShuffleMode.kGlobal) | |||||
| /// Can be any of: | |||||
| /// ShuffleMode.kFalse - No shuffling is performed. | |||||
| /// ShuffleMode.kFiles - Shuffle files only. | |||||
| /// ShuffleMode.kGlobal - Shuffle both the files and samples. | |||||
| /// \param[in] num_shards Number of shards that the dataset should be divided into. (Default = 1) | |||||
| /// \param[in] shard_id The shard ID within num_shards. This argument should be | |||||
| /// specified only when num_shards is also specified. (Default = 0) | |||||
| /// \return Shared pointer to the current TextFileDataset | |||||
| std::shared_ptr<TextFileDataset> TextFile(std::vector<std::string> dataset_files, int32_t num_samples = 0, | |||||
| ShuffleMode shuffle = ShuffleMode::kGlobal, int32_t num_shards = 1, | |||||
| int32_t shard_id = 0); | |||||
| /// \brief Function to create a VOCDataset | /// \brief Function to create a VOCDataset | ||||
| /// \notes The generated dataset has multi-columns : | /// \notes The generated dataset has multi-columns : | ||||
| /// - task='Detection', column: [['image', dtype=uint8], ['bbox', dtype=float32], ['label', dtype=uint32], | /// - task='Detection', column: [['image', dtype=uint8], ['bbox', dtype=float32], ['label', dtype=uint32], | ||||
| @@ -289,10 +310,14 @@ class Dataset : public std::enable_shared_from_this<Dataset> { | |||||
| int32_t num_workers_; | int32_t num_workers_; | ||||
| int32_t rows_per_buffer_; | int32_t rows_per_buffer_; | ||||
| int32_t connector_que_size_; | int32_t connector_que_size_; | ||||
| int32_t worker_connector_size_; | |||||
| }; | }; | ||||
| /* ####################################### Derived Dataset classes ################################# */ | /* ####################################### Derived Dataset classes ################################# */ | ||||
| // DERIVED DATASET CLASSES FOR LEAF-NODE DATASETS | |||||
| // (In alphabetical order) | |||||
| class CelebADataset : public Dataset { | class CelebADataset : public Dataset { | ||||
| public: | public: | ||||
| /// \brief Constructor | /// \brief Constructor | ||||
| @@ -318,6 +343,8 @@ class CelebADataset : public Dataset { | |||||
| std::set<std::string> extensions_; | std::set<std::string> extensions_; | ||||
| std::shared_ptr<SamplerObj> sampler_; | std::shared_ptr<SamplerObj> sampler_; | ||||
| }; | }; | ||||
| // DERIVED DATASET CLASSES FOR LEAF-NODE DATASETS | |||||
| // (In alphabetical order) | |||||
| class Cifar10Dataset : public Dataset { | class Cifar10Dataset : public Dataset { | ||||
| public: | public: | ||||
| @@ -435,6 +462,33 @@ class MnistDataset : public Dataset { | |||||
| std::shared_ptr<SamplerObj> sampler_; | std::shared_ptr<SamplerObj> sampler_; | ||||
| }; | }; | ||||
| /// \class TextFileDataset | |||||
| /// \brief A Dataset derived class to represent TextFile dataset | |||||
| class TextFileDataset : public Dataset { | |||||
| public: | |||||
| /// \brief Constructor | |||||
| TextFileDataset(std::vector<std::string> dataset_files, int32_t num_samples, ShuffleMode shuffle, int32_t num_shards, | |||||
| int32_t shard_id); | |||||
| /// \brief Destructor | |||||
| ~TextFileDataset() = default; | |||||
| /// \brief a base class override function to create the required runtime dataset op objects for this class | |||||
| /// \return The list of shared pointers to the newly created DatasetOps | |||||
| std::vector<std::shared_ptr<DatasetOp>> Build() override; | |||||
| /// \brief Parameters validation | |||||
| /// \return bool true if all the params are valid | |||||
| bool ValidateParams() override; | |||||
| private: | |||||
| std::vector<std::string> dataset_files_; | |||||
| int32_t num_samples_; | |||||
| int32_t num_shards_; | |||||
| int32_t shard_id_; | |||||
| ShuffleMode shuffle_; | |||||
| }; | |||||
| class VOCDataset : public Dataset { | class VOCDataset : public Dataset { | ||||
| public: | public: | ||||
| /// \brief Constructor | /// \brief Constructor | ||||
| @@ -467,6 +521,9 @@ class VOCDataset : public Dataset { | |||||
| std::shared_ptr<SamplerObj> sampler_; | std::shared_ptr<SamplerObj> sampler_; | ||||
| }; | }; | ||||
| // DERIVED DATASET CLASSES FOR DATASET OPS | |||||
| // (In alphabetical order) | |||||
| class BatchDataset : public Dataset { | class BatchDataset : public Dataset { | ||||
| public: | public: | ||||
| /// \brief Constructor | /// \brief Constructor | ||||
| @@ -5012,7 +5012,7 @@ class CSVDataset(SourceDataset): | |||||
| class TextFileDataset(SourceDataset): | class TextFileDataset(SourceDataset): | ||||
| """ | """ | ||||
| A source dataset that reads and parses datasets stored on disk in text format. | A source dataset that reads and parses datasets stored on disk in text format. | ||||
| The generated dataset has one columns ['text']. | |||||
| The generated dataset has one column ['text']. | |||||
| Args: | Args: | ||||
| dataset_files (Union[str, list[str]]): String or list of files to be read or glob strings to search for a | dataset_files (Union[str, list[str]]): String or list of files to be read or glob strings to search for a | ||||
| @@ -97,6 +97,7 @@ SET(DE_UT_SRCS | |||||
| c_api_dataset_ops_test.cc | c_api_dataset_ops_test.cc | ||||
| c_api_dataset_cifar_test.cc | c_api_dataset_cifar_test.cc | ||||
| c_api_dataset_coco_test.cc | c_api_dataset_coco_test.cc | ||||
| c_api_dataset_filetext_test.cc | |||||
| c_api_dataset_voc_test.cc | c_api_dataset_voc_test.cc | ||||
| c_api_datasets_test.cc | c_api_datasets_test.cc | ||||
| c_api_dataset_iterator_test.cc | c_api_dataset_iterator_test.cc | ||||
| @@ -0,0 +1,596 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include <fstream> | |||||
| #include <iostream> | |||||
| #include <memory> | |||||
| #include <vector> | |||||
| #include <string> | |||||
| #include "utils/log_adapter.h" | |||||
| #include "utils/ms_utils.h" | |||||
| #include "common/common.h" | |||||
| #include "gtest/gtest.h" | |||||
| #include "./securec.h" | |||||
| #include "minddata/dataset/core/client.h" | |||||
| #include "minddata/dataset/core/config_manager.h" | |||||
| #include "minddata/dataset/core/constants.h" | |||||
| #include "minddata/dataset/core/global_context.h" | |||||
| #include "minddata/dataset/core/tensor.h" | |||||
| #include "minddata/dataset/core/tensor_shape.h" | |||||
| #include "minddata/dataset/include/datasets.h" | |||||
| #include "minddata/dataset/include/iterator.h" | |||||
| #include "minddata/dataset/include/samplers.h" | |||||
| #include "minddata/dataset/include/status.h" | |||||
| #include "minddata/dataset/include/transforms.h" | |||||
| using namespace mindspore::dataset; | |||||
| using namespace mindspore::dataset::api; | |||||
| using mindspore::LogStream; | |||||
| using mindspore::dataset::DataType; | |||||
| using mindspore::dataset::ShuffleMode; | |||||
| using mindspore::dataset::Status; | |||||
| using mindspore::dataset::Tensor; | |||||
| using mindspore::dataset::TensorImpl; | |||||
| using mindspore::dataset::TensorShape; | |||||
| using mindspore::ExceptionType::NoExceptionType; | |||||
| using mindspore::MsLogLevel::ERROR; | |||||
| class MindDataTestPipeline : public UT::DatasetOpTesting { | |||||
| protected: | |||||
| }; | |||||
| TEST_F(MindDataTestPipeline, TestTextFileDatasetBasic) { | |||||
| MS_LOG(INFO) << "Doing MindDataTestPipeline-TestTextFileDatasetBasic."; | |||||
| // Test TextFile Dataset with single text file and many default inputs | |||||
| // Set configuration | |||||
| uint32_t original_seed = GlobalContext::config_manager()->seed(); | |||||
| uint32_t original_num_parallel_workers = GlobalContext::config_manager()->num_parallel_workers(); | |||||
| MS_LOG(DEBUG) << "ORIGINAL seed: " << original_seed << ", num_parallel_workers: " << original_num_parallel_workers; | |||||
| GlobalContext::config_manager()->set_seed(987); | |||||
| GlobalContext::config_manager()->set_num_parallel_workers(4); | |||||
| // Create a TextFile Dataset, with single text file | |||||
| // Note: 1.txt has 3 rows | |||||
| // Use 2 samples | |||||
| // Use defaults for other input parameters | |||||
| std::string tf_file1 = datasets_root_path_ + "/testTextFileDataset/1.txt"; | |||||
| std::shared_ptr<Dataset> ds = TextFile({tf_file1}, 2); | |||||
| EXPECT_NE(ds, nullptr); | |||||
| // Create an iterator over the result of the above dataset. | |||||
| // This will trigger the creation of the Execution Tree and launch it. | |||||
| std::shared_ptr<Iterator> iter = ds->CreateIterator(); | |||||
| EXPECT_NE(iter, nullptr); | |||||
| // Iterate the dataset and get each row | |||||
| std::unordered_map<std::string, std::shared_ptr<Tensor>> row; | |||||
| iter->GetNextRow(&row); | |||||
| EXPECT_NE(row.find("text"), row.end()); | |||||
| std::vector<std::string> expected_result = {"Be happy every day.", "This is a text file."}; | |||||
| uint64_t i = 0; | |||||
| while (row.size() != 0) { | |||||
| auto text = row["text"]; | |||||
| MS_LOG(INFO) << "Tensor text shape: " << text->shape(); | |||||
| std::string_view sv; | |||||
| text->GetItemAt(&sv, {0}); | |||||
| std::string ss(sv); | |||||
| MS_LOG(INFO) << "Text length: " << ss.length() << ", Text: " << ss.substr(0, 50); | |||||
| // Compare against expected result | |||||
| EXPECT_STREQ(ss.c_str(), expected_result[i].c_str()); | |||||
| i++; | |||||
| iter->GetNextRow(&row); | |||||
| } | |||||
| // Expect 2 samples | |||||
| EXPECT_EQ(i, 2); | |||||
| // Manually terminate the pipeline | |||||
| iter->Stop(); | |||||
| // Restore configuration | |||||
| GlobalContext::config_manager()->set_seed(original_seed); | |||||
| GlobalContext::config_manager()->set_num_parallel_workers(original_num_parallel_workers); | |||||
| } | |||||
| TEST_F(MindDataTestPipeline, TestTextFileDatasetShuffleFalse1) { | |||||
| MS_LOG(INFO) << "Doing MindDataTestPipeline-TestTextFileDatasetShuffleFalse1."; | |||||
| // Test TextFile Dataset with two text files and no shuffle, num_parallel_workers=1 | |||||
| // Set configuration | |||||
| uint32_t original_seed = GlobalContext::config_manager()->seed(); | |||||
| uint32_t original_num_parallel_workers = GlobalContext::config_manager()->num_parallel_workers(); | |||||
| MS_LOG(DEBUG) << "ORIGINAL seed: " << original_seed << ", num_parallel_workers: " << original_num_parallel_workers; | |||||
| GlobalContext::config_manager()->set_seed(654); | |||||
| GlobalContext::config_manager()->set_num_parallel_workers(1); | |||||
| // Create a TextFile Dataset, with two text files | |||||
| // Note: 1.txt has 3 rows | |||||
| // Note: 2.txt has 2 rows | |||||
| // Use default of all samples | |||||
| std::string tf_file1 = datasets_root_path_ + "/testTextFileDataset/1.txt"; | |||||
| std::string tf_file2 = datasets_root_path_ + "/testTextFileDataset/2.txt"; | |||||
| std::shared_ptr<Dataset> ds = TextFile({tf_file1, tf_file2}, 0, ShuffleMode::kFalse); | |||||
| EXPECT_NE(ds, nullptr); | |||||
| // Create an iterator over the result of the above dataset. | |||||
| // This will trigger the creation of the Execution Tree and launch it. | |||||
| std::shared_ptr<Iterator> iter = ds->CreateIterator(); | |||||
| EXPECT_NE(iter, nullptr); | |||||
| // Iterate the dataset and get each row | |||||
| std::unordered_map<std::string, std::shared_ptr<Tensor>> row; | |||||
| iter->GetNextRow(&row); | |||||
| EXPECT_NE(row.find("text"), row.end()); | |||||
| std::vector<std::string> expected_result = {"This is a text file.", "Be happy every day.", "Good luck to everyone.", | |||||
| "Another file.", "End of file."}; | |||||
| uint64_t i = 0; | |||||
| while (row.size() != 0) { | |||||
| auto text = row["text"]; | |||||
| MS_LOG(INFO) << "Tensor text shape: " << text->shape(); | |||||
| std::string_view sv; | |||||
| text->GetItemAt(&sv, {0}); | |||||
| std::string ss(sv); | |||||
| MS_LOG(INFO) << "Text length: " << ss.length() << ", Text: " << ss.substr(0, 50); | |||||
| // Compare against expected result | |||||
| EXPECT_STREQ(ss.c_str(), expected_result[i].c_str()); | |||||
| i++; | |||||
| iter->GetNextRow(&row); | |||||
| } | |||||
| // Expect 2 + 3 = 5 samples | |||||
| EXPECT_EQ(i, 5); | |||||
| // Manually terminate the pipeline | |||||
| iter->Stop(); | |||||
| // Restore configuration | |||||
| GlobalContext::config_manager()->set_seed(original_seed); | |||||
| GlobalContext::config_manager()->set_num_parallel_workers(original_num_parallel_workers); | |||||
| } | |||||
| TEST_F(MindDataTestPipeline, TestTextFileDatasetShuffleFalse4Shard) { | |||||
| MS_LOG(INFO) << "Doing MindDataTestPipeline-TestTextFileDatasetShuffleFalse4Shard."; | |||||
| // Test TextFile Dataset with two text files and no shuffle, num_parallel_workers=4, shard coverage | |||||
| // Set configuration | |||||
| uint32_t original_seed = GlobalContext::config_manager()->seed(); | |||||
| uint32_t original_num_parallel_workers = GlobalContext::config_manager()->num_parallel_workers(); | |||||
| MS_LOG(DEBUG) << "ORIGINAL seed: " << original_seed << ", num_parallel_workers: " << original_num_parallel_workers; | |||||
| GlobalContext::config_manager()->set_seed(654); | |||||
| GlobalContext::config_manager()->set_num_parallel_workers(4); | |||||
| // Create a TextFile Dataset, with two text files | |||||
| // Note: 1.txt has 3 rows | |||||
| // Note: 2.txt has 2 rows | |||||
| // Set shuffle to file shuffle, num_shards=2, shard_id=0 | |||||
| std::string tf_file1 = datasets_root_path_ + "/testTextFileDataset/1.txt"; | |||||
| std::string tf_file2 = datasets_root_path_ + "/testTextFileDataset/2.txt"; | |||||
| std::shared_ptr<Dataset> ds = TextFile({tf_file1, tf_file2}, 0, ShuffleMode::kFalse, 2, 0); | |||||
| EXPECT_NE(ds, nullptr); | |||||
| // Create an iterator over the result of the above dataset. | |||||
| // This will trigger the creation of the Execution Tree and launch it. | |||||
| std::shared_ptr<Iterator> iter = ds->CreateIterator(); | |||||
| EXPECT_NE(iter, nullptr); | |||||
| // Iterate the dataset and get each row | |||||
| std::unordered_map<std::string, std::shared_ptr<Tensor>> row; | |||||
| iter->GetNextRow(&row); | |||||
| EXPECT_NE(row.find("text"), row.end()); | |||||
| std::vector<std::string> expected_result = {"This is a text file.", "Be happy every day.", "Good luck to everyone."}; | |||||
| uint64_t i = 0; | |||||
| while (row.size() != 0) { | |||||
| auto text = row["text"]; | |||||
| MS_LOG(INFO) << "Tensor text shape: " << text->shape(); | |||||
| std::string_view sv; | |||||
| text->GetItemAt(&sv, {0}); | |||||
| std::string ss(sv); | |||||
| MS_LOG(INFO) << "Text length: " << ss.length() << ", Text: " << ss.substr(0, 50); | |||||
| // Compare against expected result | |||||
| EXPECT_STREQ(ss.c_str(), expected_result[i].c_str()); | |||||
| i++; | |||||
| iter->GetNextRow(&row); | |||||
| } | |||||
| // Expect 3 samples for this shard | |||||
| EXPECT_EQ(i, 3); | |||||
| // Manually terminate the pipeline | |||||
| iter->Stop(); | |||||
| // Restore configuration | |||||
| GlobalContext::config_manager()->set_seed(original_seed); | |||||
| GlobalContext::config_manager()->set_num_parallel_workers(original_num_parallel_workers); | |||||
| } | |||||
| TEST_F(MindDataTestPipeline, TestTextFileDatasetShuffleGlobal1A) { | |||||
| MS_LOG(INFO) << "Doing MindDataTestPipeline-TestTextFileDatasetShuffleGlobal1A."; | |||||
| // Test TextFile Dataset with 1 text file, global shuffle, num_parallel_workers=1 | |||||
| // Set configuration | |||||
| uint32_t original_seed = GlobalContext::config_manager()->seed(); | |||||
| uint32_t original_num_parallel_workers = GlobalContext::config_manager()->num_parallel_workers(); | |||||
| MS_LOG(DEBUG) << "ORIGINAL seed: " << original_seed << ", num_parallel_workers: " << original_num_parallel_workers; | |||||
| GlobalContext::config_manager()->set_seed(246); | |||||
| GlobalContext::config_manager()->set_num_parallel_workers(1); | |||||
| // Create a TextFile Dataset, with two text files | |||||
| // Note: 1.txt has 3 rows | |||||
| // Set shuffle to global shuffle | |||||
| std::string tf_file1 = datasets_root_path_ + "/testTextFileDataset/1.txt"; | |||||
| std::shared_ptr<Dataset> ds = TextFile({tf_file1}, 0, ShuffleMode::kGlobal); | |||||
| EXPECT_NE(ds, nullptr); | |||||
| // Create an iterator over the result of the above dataset. | |||||
| // This will trigger the creation of the Execution Tree and launch it. | |||||
| std::shared_ptr<Iterator> iter = ds->CreateIterator(); | |||||
| EXPECT_NE(iter, nullptr); | |||||
| // Iterate the dataset and get each row | |||||
| std::unordered_map<std::string, std::shared_ptr<Tensor>> row; | |||||
| iter->GetNextRow(&row); | |||||
| EXPECT_NE(row.find("text"), row.end()); | |||||
| std::vector<std::string> expected_result = {"Good luck to everyone.", "This is a text file.", "Be happy every day."}; | |||||
| uint64_t i = 0; | |||||
| while (row.size() != 0) { | |||||
| auto text = row["text"]; | |||||
| MS_LOG(INFO) << "Tensor text shape: " << text->shape(); | |||||
| std::string_view sv; | |||||
| text->GetItemAt(&sv, {0}); | |||||
| std::string ss(sv); | |||||
| MS_LOG(INFO) << "Text length: " << ss.length() << ", Text: " << ss.substr(0, 50); | |||||
| // Compare against expected result | |||||
| EXPECT_STREQ(ss.c_str(), expected_result[i].c_str()); | |||||
| i++; | |||||
| iter->GetNextRow(&row); | |||||
| } | |||||
| // Expect 3 samples | |||||
| EXPECT_EQ(i, 3); | |||||
| // Manually terminate the pipeline | |||||
| iter->Stop(); | |||||
| // Restore configuration | |||||
| GlobalContext::config_manager()->set_seed(original_seed); | |||||
| GlobalContext::config_manager()->set_num_parallel_workers(original_num_parallel_workers); | |||||
| } | |||||
| TEST_F(MindDataTestPipeline, TestTextFileDatasetShuffleGlobal1B) { | |||||
| MS_LOG(INFO) << "Doing MindDataTestPipeline-TestTextFileDatasetShuffleGlobal1B."; | |||||
| // Test TextFile Dataset with 2 text files, global shuffle, num_parallel_workers=1 | |||||
| // Set configuration | |||||
| uint32_t original_seed = GlobalContext::config_manager()->seed(); | |||||
| uint32_t original_num_parallel_workers = GlobalContext::config_manager()->num_parallel_workers(); | |||||
| MS_LOG(DEBUG) << "ORIGINAL seed: " << original_seed << ", num_parallel_workers: " << original_num_parallel_workers; | |||||
| GlobalContext::config_manager()->set_seed(246); | |||||
| GlobalContext::config_manager()->set_num_parallel_workers(1); | |||||
| // Create a TextFile Dataset, with two text files | |||||
| // Note: 1.txt has 3 rows | |||||
| // Note: 2.txt has 2 rows | |||||
| // Set shuffle to global shuffle | |||||
| std::string tf_file1 = datasets_root_path_ + "/testTextFileDataset/1.txt"; | |||||
| std::string tf_file2 = datasets_root_path_ + "/testTextFileDataset/2.txt"; | |||||
| std::shared_ptr<Dataset> ds = TextFile({tf_file1, tf_file2}, 0, ShuffleMode::kGlobal); | |||||
| EXPECT_NE(ds, nullptr); | |||||
| // Create an iterator over the result of the above dataset. | |||||
| // This will trigger the creation of the Execution Tree and launch it. | |||||
| std::shared_ptr<Iterator> iter = ds->CreateIterator(); | |||||
| EXPECT_NE(iter, nullptr); | |||||
| // Iterate the dataset and get each row | |||||
| std::unordered_map<std::string, std::shared_ptr<Tensor>> row; | |||||
| iter->GetNextRow(&row); | |||||
| EXPECT_NE(row.find("text"), row.end()); | |||||
| std::vector<std::string> expected_result = {"Another file.", "Good luck to everyone.", "This is a text file.", | |||||
| "End of file.", "Be happy every day."}; | |||||
| uint64_t i = 0; | |||||
| while (row.size() != 0) { | |||||
| auto text = row["text"]; | |||||
| MS_LOG(INFO) << "Tensor text shape: " << text->shape(); | |||||
| std::string_view sv; | |||||
| text->GetItemAt(&sv, {0}); | |||||
| std::string ss(sv); | |||||
| MS_LOG(INFO) << "Text length: " << ss.length() << ", Text: " << ss.substr(0, 50); | |||||
| // Compare against expected result | |||||
| EXPECT_STREQ(ss.c_str(), expected_result[i].c_str()); | |||||
| i++; | |||||
| iter->GetNextRow(&row); | |||||
| } | |||||
| // Expect 2 + 3 = 5 samples | |||||
| EXPECT_EQ(i, 5); | |||||
| // Manually terminate the pipeline | |||||
| iter->Stop(); | |||||
| // Restore configuration | |||||
| GlobalContext::config_manager()->set_seed(original_seed); | |||||
| GlobalContext::config_manager()->set_num_parallel_workers(original_num_parallel_workers); | |||||
| } | |||||
| TEST_F(MindDataTestPipeline, TestTextFileDatasetShuffleGlobal4) { | |||||
| MS_LOG(INFO) << "Doing MindDataTestPipeline-TestTextFileDatasetShuffleGlobal4."; | |||||
| // Test TextFile Dataset with 2 text files, global shuffle, num_parallel_workers=4 | |||||
| // Set configuration | |||||
| uint32_t original_seed = GlobalContext::config_manager()->seed(); | |||||
| uint32_t original_num_parallel_workers = GlobalContext::config_manager()->num_parallel_workers(); | |||||
| MS_LOG(DEBUG) << "ORIGINAL seed: " << original_seed << ", num_parallel_workers: " << original_num_parallel_workers; | |||||
| GlobalContext::config_manager()->set_seed(246); | |||||
| GlobalContext::config_manager()->set_num_parallel_workers(4); | |||||
| // Create a TextFile Dataset, with two text files | |||||
| // Note: 1.txt has 3 rows | |||||
| // Note: 2.txt has 2 rows | |||||
| // Set shuffle to global shuffle | |||||
| std::string tf_file1 = datasets_root_path_ + "/testTextFileDataset/1.txt"; | |||||
| std::string tf_file2 = datasets_root_path_ + "/testTextFileDataset/2.txt"; | |||||
| std::shared_ptr<Dataset> ds = TextFile({tf_file1, tf_file2}, 0, ShuffleMode::kGlobal); | |||||
| EXPECT_NE(ds, nullptr); | |||||
| // Create an iterator over the result of the above dataset. | |||||
| // This will trigger the creation of the Execution Tree and launch it. | |||||
| std::shared_ptr<Iterator> iter = ds->CreateIterator(); | |||||
| EXPECT_NE(iter, nullptr); | |||||
| // Iterate the dataset and get each row | |||||
| std::unordered_map<std::string, std::shared_ptr<Tensor>> row; | |||||
| iter->GetNextRow(&row); | |||||
| EXPECT_NE(row.find("text"), row.end()); | |||||
| std::vector<std::string> expected_result = {"Another file.", "Good luck to everyone.", "End of file.", | |||||
| "This is a text file.", "Be happy every day."}; | |||||
| uint64_t i = 0; | |||||
| while (row.size() != 0) { | |||||
| auto text = row["text"]; | |||||
| MS_LOG(INFO) << "Tensor text shape: " << text->shape(); | |||||
| std::string_view sv; | |||||
| text->GetItemAt(&sv, {0}); | |||||
| std::string ss(sv); | |||||
| MS_LOG(INFO) << "Text length: " << ss.length() << ", Text: " << ss.substr(0, 50); | |||||
| // Compare against expected result | |||||
| EXPECT_STREQ(ss.c_str(), expected_result[i].c_str()); | |||||
| i++; | |||||
| iter->GetNextRow(&row); | |||||
| } | |||||
| // Expect 2 + 3 = 5 samples | |||||
| EXPECT_EQ(i, 5); | |||||
| // Manually terminate the pipeline | |||||
| iter->Stop(); | |||||
| // Restore configuration | |||||
| GlobalContext::config_manager()->set_seed(original_seed); | |||||
| GlobalContext::config_manager()->set_num_parallel_workers(original_num_parallel_workers); | |||||
| } | |||||
| TEST_F(MindDataTestPipeline, TestTextFileDatasetShuffleFiles1) { | |||||
| MS_LOG(INFO) << "Doing MindDataTestPipeline-TestTextFileDatasetShuffleFiles1."; | |||||
| // Test TextFile Dataset with files shuffle, num_parallel_workers=1 | |||||
| // Set configuration | |||||
| uint32_t original_seed = GlobalContext::config_manager()->seed(); | |||||
| uint32_t original_num_parallel_workers = GlobalContext::config_manager()->num_parallel_workers(); | |||||
| MS_LOG(DEBUG) << "ORIGINAL seed: " << original_seed << ", num_parallel_workers: " << original_num_parallel_workers; | |||||
| GlobalContext::config_manager()->set_seed(135); | |||||
| GlobalContext::config_manager()->set_num_parallel_workers(1); | |||||
| // Create a TextFile Dataset, with two text files | |||||
| // Note: 1.txt has 3 rows | |||||
| // Note: 2.txt has 2 rows | |||||
| // Use default of all samples | |||||
| // Set shuffle to files shuffle | |||||
| std::string tf_file1 = datasets_root_path_ + "/testTextFileDataset/1.txt"; | |||||
| std::string tf_file2 = datasets_root_path_ + "/testTextFileDataset/2.txt"; | |||||
| std::shared_ptr<Dataset> ds = TextFile({tf_file1, tf_file2}, 0, ShuffleMode::kFiles); | |||||
| EXPECT_NE(ds, nullptr); | |||||
| // Create an iterator over the result of the above dataset. | |||||
| // This will trigger the creation of the Execution Tree and launch it. | |||||
| std::shared_ptr<Iterator> iter = ds->CreateIterator(); | |||||
| EXPECT_NE(iter, nullptr); | |||||
| // Iterate the dataset and get each row | |||||
| std::unordered_map<std::string, std::shared_ptr<Tensor>> row; | |||||
| iter->GetNextRow(&row); | |||||
| EXPECT_NE(row.find("text"), row.end()); | |||||
| std::vector<std::string> expected_result = { | |||||
| "This is a text file.", "Be happy every day.", "Good luck to everyone.", "Another file.", "End of file.", | |||||
| }; | |||||
| uint64_t i = 0; | |||||
| while (row.size() != 0) { | |||||
| auto text = row["text"]; | |||||
| MS_LOG(INFO) << "Tensor text shape: " << text->shape(); | |||||
| std::string_view sv; | |||||
| text->GetItemAt(&sv, {0}); | |||||
| std::string ss(sv); | |||||
| MS_LOG(INFO) << "Text length: " << ss.length() << ", Text: " << ss.substr(0, 50); | |||||
| // Compare against expected result | |||||
| EXPECT_STREQ(ss.c_str(), expected_result[i].c_str()); | |||||
| i++; | |||||
| iter->GetNextRow(&row); | |||||
| } | |||||
| // Expect 2 + 3 = 5 samples | |||||
| EXPECT_EQ(i, 5); | |||||
| // Manually terminate the pipeline | |||||
| iter->Stop(); | |||||
| // Restore configuration | |||||
| GlobalContext::config_manager()->set_seed(original_seed); | |||||
| GlobalContext::config_manager()->set_num_parallel_workers(original_num_parallel_workers); | |||||
| } | |||||
| TEST_F(MindDataTestPipeline, TestTextFileDatasetShuffleFiles4) { | |||||
| MS_LOG(INFO) << "Doing MindDataTestPipeline-TestTextFileDatasetShuffleFiles4."; | |||||
| // Test TextFile Dataset with files shuffle, num_parallel_workers=4 | |||||
| // Set configuration | |||||
| uint32_t original_seed = GlobalContext::config_manager()->seed(); | |||||
| uint32_t original_num_parallel_workers = GlobalContext::config_manager()->num_parallel_workers(); | |||||
| MS_LOG(DEBUG) << "ORIGINAL seed: " << original_seed << ", num_parallel_workers: " << original_num_parallel_workers; | |||||
| GlobalContext::config_manager()->set_seed(135); | |||||
| GlobalContext::config_manager()->set_num_parallel_workers(4); | |||||
| // Create a TextFile Dataset, with two text files | |||||
| // Note: 1.txt has 3 rows | |||||
| // Note: 2.txt has 2 rows | |||||
| // Use default of all samples | |||||
| // Set shuffle to files shuffle | |||||
| std::string tf_file1 = datasets_root_path_ + "/testTextFileDataset/1.txt"; | |||||
| std::string tf_file2 = datasets_root_path_ + "/testTextFileDataset/2.txt"; | |||||
| std::shared_ptr<Dataset> ds = TextFile({tf_file1, tf_file2}, 0, ShuffleMode::kFiles); | |||||
| EXPECT_NE(ds, nullptr); | |||||
| // Create an iterator over the result of the above dataset. | |||||
| // This will trigger the creation of the Execution Tree and launch it. | |||||
| std::shared_ptr<Iterator> iter = ds->CreateIterator(); | |||||
| EXPECT_NE(iter, nullptr); | |||||
| // Iterate the dataset and get each row | |||||
| std::unordered_map<std::string, std::shared_ptr<Tensor>> row; | |||||
| iter->GetNextRow(&row); | |||||
| EXPECT_NE(row.find("text"), row.end()); | |||||
| std::vector<std::string> expected_result = {"This is a text file.", "Another file.", "Be happy every day.", | |||||
| "End of file.", "Good luck to everyone."}; | |||||
| uint64_t i = 0; | |||||
| while (row.size() != 0) { | |||||
| auto text = row["text"]; | |||||
| MS_LOG(INFO) << "Tensor text shape: " << text->shape(); | |||||
| std::string_view sv; | |||||
| text->GetItemAt(&sv, {0}); | |||||
| std::string ss(sv); | |||||
| MS_LOG(INFO) << "Text length: " << ss.length() << ", Text: " << ss.substr(0, 50); | |||||
| // Compare against expected result | |||||
| EXPECT_STREQ(ss.c_str(), expected_result[i].c_str()); | |||||
| i++; | |||||
| iter->GetNextRow(&row); | |||||
| } | |||||
| // Expect 2 + 3 = 5 samples | |||||
| EXPECT_EQ(i, 5); | |||||
| // Manually terminate the pipeline | |||||
| iter->Stop(); | |||||
| // Restore configuration | |||||
| GlobalContext::config_manager()->set_seed(original_seed); | |||||
| GlobalContext::config_manager()->set_num_parallel_workers(original_num_parallel_workers); | |||||
| } | |||||
| TEST_F(MindDataTestPipeline, TestTextFileDatasetFail1) { | |||||
| MS_LOG(INFO) << "Doing MindDataTestPipeline-TestTextFileDatasetFail1."; | |||||
| // Attempt to create a TextFile Dataset | |||||
| // with invalid samplers=-1 | |||||
| std::string tf_file1 = datasets_root_path_ + "/testTextFileDataset/1.txt"; | |||||
| std::shared_ptr<Dataset> ds = TextFile({tf_file1}, -1); | |||||
| // Expect failure: Number of samples cannot be negative | |||||
| EXPECT_EQ(ds, nullptr); | |||||
| } | |||||
| TEST_F(MindDataTestPipeline, TestTextFileDatasetFail2) { | |||||
| MS_LOG(INFO) << "Doing MindDataTestPipeline-TestTextFileDatasetFail2."; | |||||
| // Attempt to create a TextFile Dataset | |||||
| // with wrongful empty dataset_files input | |||||
| std::shared_ptr<Dataset> ds = TextFile({}); | |||||
| // Expect failure: dataset_files is not specified | |||||
| EXPECT_EQ(ds, nullptr); | |||||
| } | |||||
| TEST_F(MindDataTestPipeline, TestTextFileDatasetFail3) { | |||||
| MS_LOG(INFO) << "Doing MindDataTestPipeline-TestTextFileDatasetFail3."; | |||||
| // Attempt to create a TextFile Dataset | |||||
| // with non-existent dataset_files input | |||||
| std::shared_ptr<Dataset> ds = TextFile({"notexist.txt"}, 0, ShuffleMode::kFalse); | |||||
| // Expect failure: specified dataset_files does not exist | |||||
| EXPECT_EQ(ds, nullptr); | |||||
| } | |||||
| TEST_F(MindDataTestPipeline, TestTextFileDatasetFail4) { | |||||
| MS_LOG(INFO) << "Doing MindDataTestPipeline-TestTextFileDatasetFail4."; | |||||
| // Attempt to create a TextFile Dataset | |||||
| // with empty string dataset_files input | |||||
| std::shared_ptr<Dataset> ds = TextFile({""}, 0, ShuffleMode::kFiles); | |||||
| // Expect failure: specified dataset_files does not exist | |||||
| EXPECT_EQ(ds, nullptr); | |||||
| } | |||||
| TEST_F(MindDataTestPipeline, TestTextFileDatasetFail5) { | |||||
| MS_LOG(INFO) << "Doing MindDataTestPipeline-TestTextFileDatasetFail5."; | |||||
| // Attempt to create a TextFile Dataset | |||||
| // with invalid num_shards=0 value | |||||
| std::string tf_file1 = datasets_root_path_ + "/testTextFileDataset/1.txt"; | |||||
| std::shared_ptr<Dataset> ds = TextFile({tf_file1}, 1, ShuffleMode::kFalse, 0); | |||||
| // Expect failure: Number of shards cannot be <=0 | |||||
| EXPECT_EQ(ds, nullptr); | |||||
| } | |||||
| TEST_F(MindDataTestPipeline, TestTextFileDatasetFail6) { | |||||
| MS_LOG(INFO) << "Doing MindDataTestPipeline-TestTextFileDatasetFail6."; | |||||
| // Attempt to create a TextFile Dataset | |||||
| // with invalid shard_id=-1 value | |||||
| std::string tf_file1 = datasets_root_path_ + "/testTextFileDataset/1.txt"; | |||||
| std::shared_ptr<Dataset> ds = TextFile({tf_file1}, 0, ShuffleMode::kFiles, -1); | |||||
| // Expect failure: shard_id cannot be negative | |||||
| EXPECT_EQ(ds, nullptr); | |||||
| } | |||||
| TEST_F(MindDataTestPipeline, TestTextFileDatasetFail7) { | |||||
| MS_LOG(INFO) << "Doing MindDataTestPipeline-TestTextFileDatasetFail7."; | |||||
| // Attempt to create a TextFile Dataset | |||||
| // with invalid shard_id=2 and num_shards=2 combination | |||||
| std::string tf_file1 = datasets_root_path_ + "/testTextFileDataset/1.txt"; | |||||
| std::shared_ptr<Dataset> ds = TextFile({tf_file1}, 0, ShuffleMode::kGlobal, 2, 2); | |||||
| // Expect failure: Cannot have shard_id >= num_shards | |||||
| EXPECT_EQ(ds, nullptr); | |||||
| } | |||||
| @@ -89,6 +89,23 @@ TEST_F(MindDataTestTextFileOp, TestTextFileBasic) { | |||||
| ASSERT_EQ(row_count, 3); | ASSERT_EQ(row_count, 3); | ||||
| } | } | ||||
| TEST_F(MindDataTestTextFileOp, TestTextFileFileNotExist) { | |||||
| // Start with an empty execution tree | |||||
| auto tree = std::make_shared<ExecutionTree>(); | |||||
| std::string dataset_path = datasets_root_path_ + "/does/not/exist/0.txt"; | |||||
| std::shared_ptr<TextFileOp> op; | |||||
| TextFileOp::Builder builder; | |||||
| builder.SetTextFilesList({dataset_path}) | |||||
| .SetRowsPerBuffer(16) | |||||
| .SetNumWorkers(16) | |||||
| .SetOpConnectorSize(2); | |||||
| Status rc = builder.Build(&op); | |||||
| ASSERT_TRUE(rc.IsOk()); | |||||
| } | |||||
| TEST_F(MindDataTestTextFileOp, TestTotalRows) { | TEST_F(MindDataTestTextFileOp, TestTotalRows) { | ||||
| std::string tf_file1 = datasets_root_path_ + "/testTextFileDataset/1.txt"; | std::string tf_file1 = datasets_root_path_ + "/testTextFileDataset/1.txt"; | ||||
| std::string tf_file2 = datasets_root_path_ + "/testTextFileDataset/2.txt"; | std::string tf_file2 = datasets_root_path_ + "/testTextFileDataset/2.txt"; | ||||
| @@ -110,3 +127,14 @@ TEST_F(MindDataTestTextFileOp, TestTotalRows) { | |||||
| ASSERT_EQ(total_rows, 5); | ASSERT_EQ(total_rows, 5); | ||||
| files.clear(); | files.clear(); | ||||
| } | } | ||||
| TEST_F(MindDataTestTextFileOp, TestTotalRowsFileNotExist) { | |||||
| std::string tf_file1 = datasets_root_path_ + "/does/not/exist/0.txt"; | |||||
| std::vector<std::string> files; | |||||
| files.push_back(tf_file1); | |||||
| int64_t total_rows = 0; | |||||
| TextFileOp::CountAllFileRows(files, &total_rows); | |||||
| ASSERT_EQ(total_rows, 0); | |||||
| } | |||||
| @@ -12,9 +12,10 @@ | |||||
| # See the License for the specific language governing permissions and | # See the License for the specific language governing permissions and | ||||
| # limitations under the License. | # limitations under the License. | ||||
| # ============================================================================== | # ============================================================================== | ||||
| import pytest | |||||
| import mindspore.dataset as ds | import mindspore.dataset as ds | ||||
| from mindspore import log as logger | from mindspore import log as logger | ||||
| from util import config_get_set_num_parallel_workers | |||||
| from util import config_get_set_num_parallel_workers, config_get_set_seed | |||||
| DATA_FILE = "../data/dataset/testTextFileDataset/1.txt" | DATA_FILE = "../data/dataset/testTextFileDataset/1.txt" | ||||
| @@ -39,10 +40,54 @@ def test_textline_dataset_all_file(): | |||||
| assert count == 5 | assert count == 5 | ||||
| def test_textline_dataset_totext(): | |||||
| def test_textline_dataset_num_samples_zero(): | |||||
| data = ds.TextFileDataset(DATA_FILE, num_samples=0) | |||||
| count = 0 | |||||
| for i in data.create_dict_iterator(): | |||||
| logger.info("{}".format(i["text"])) | |||||
| count += 1 | |||||
| assert count == 3 | |||||
| def test_textline_dataset_shuffle_false4(): | |||||
| original_num_parallel_workers = config_get_set_num_parallel_workers(4) | original_num_parallel_workers = config_get_set_num_parallel_workers(4) | ||||
| original_seed = config_get_set_seed(987) | |||||
| data = ds.TextFileDataset(DATA_ALL_FILE, shuffle=False) | |||||
| count = 0 | |||||
| line = ["This is a text file.", "Another file.", | |||||
| "Be happy every day.", "End of file.", "Good luck to everyone."] | |||||
| for i in data.create_dict_iterator(): | |||||
| strs = i["text"].item().decode("utf8") | |||||
| assert strs == line[count] | |||||
| count += 1 | |||||
| assert count == 5 | |||||
| # Restore configuration | |||||
| ds.config.set_num_parallel_workers(original_num_parallel_workers) | |||||
| ds.config.set_seed(original_seed) | |||||
| def test_textline_dataset_shuffle_false1(): | |||||
| original_num_parallel_workers = config_get_set_num_parallel_workers(1) | |||||
| original_seed = config_get_set_seed(987) | |||||
| data = ds.TextFileDataset(DATA_ALL_FILE, shuffle=False) | data = ds.TextFileDataset(DATA_ALL_FILE, shuffle=False) | ||||
| count = 0 | count = 0 | ||||
| line = ["This is a text file.", "Be happy every day.", "Good luck to everyone.", | |||||
| "Another file.", "End of file."] | |||||
| for i in data.create_dict_iterator(): | |||||
| strs = i["text"].item().decode("utf8") | |||||
| assert strs == line[count] | |||||
| count += 1 | |||||
| assert count == 5 | |||||
| # Restore configuration | |||||
| ds.config.set_num_parallel_workers(original_num_parallel_workers) | |||||
| ds.config.set_seed(original_seed) | |||||
| def test_textline_dataset_shuffle_files4(): | |||||
| original_num_parallel_workers = config_get_set_num_parallel_workers(4) | |||||
| original_seed = config_get_set_seed(135) | |||||
| data = ds.TextFileDataset(DATA_ALL_FILE, shuffle=ds.Shuffle.FILES) | |||||
| count = 0 | |||||
| line = ["This is a text file.", "Another file.", | line = ["This is a text file.", "Another file.", | ||||
| "Be happy every day.", "End of file.", "Good luck to everyone."] | "Be happy every day.", "End of file.", "Good luck to everyone."] | ||||
| for i in data.create_dict_iterator(): | for i in data.create_dict_iterator(): | ||||
| @@ -50,8 +95,60 @@ def test_textline_dataset_totext(): | |||||
| assert strs == line[count] | assert strs == line[count] | ||||
| count += 1 | count += 1 | ||||
| assert count == 5 | assert count == 5 | ||||
| # Restore configuration num_parallel_workers | |||||
| # Restore configuration | |||||
| ds.config.set_num_parallel_workers(original_num_parallel_workers) | ds.config.set_num_parallel_workers(original_num_parallel_workers) | ||||
| ds.config.set_seed(original_seed) | |||||
| def test_textline_dataset_shuffle_files1(): | |||||
| original_num_parallel_workers = config_get_set_num_parallel_workers(1) | |||||
| original_seed = config_get_set_seed(135) | |||||
| data = ds.TextFileDataset(DATA_ALL_FILE, shuffle=ds.Shuffle.FILES) | |||||
| count = 0 | |||||
| line = ["This is a text file.", "Be happy every day.", "Good luck to everyone.", | |||||
| "Another file.", "End of file."] | |||||
| for i in data.create_dict_iterator(): | |||||
| strs = i["text"].item().decode("utf8") | |||||
| assert strs == line[count] | |||||
| count += 1 | |||||
| assert count == 5 | |||||
| # Restore configuration | |||||
| ds.config.set_num_parallel_workers(original_num_parallel_workers) | |||||
| ds.config.set_seed(original_seed) | |||||
| def test_textline_dataset_shuffle_global4(): | |||||
| original_num_parallel_workers = config_get_set_num_parallel_workers(4) | |||||
| original_seed = config_get_set_seed(246) | |||||
| data = ds.TextFileDataset(DATA_ALL_FILE, shuffle=ds.Shuffle.GLOBAL) | |||||
| count = 0 | |||||
| line = ["Another file.", "Good luck to everyone.", "End of file.", | |||||
| "This is a text file.", "Be happy every day."] | |||||
| for i in data.create_dict_iterator(): | |||||
| strs = i["text"].item().decode("utf8") | |||||
| assert strs == line[count] | |||||
| count += 1 | |||||
| assert count == 5 | |||||
| # Restore configuration | |||||
| ds.config.set_num_parallel_workers(original_num_parallel_workers) | |||||
| ds.config.set_seed(original_seed) | |||||
| def test_textline_dataset_shuffle_global1(): | |||||
| original_num_parallel_workers = config_get_set_num_parallel_workers(1) | |||||
| original_seed = config_get_set_seed(246) | |||||
| data = ds.TextFileDataset(DATA_ALL_FILE, shuffle=ds.Shuffle.GLOBAL) | |||||
| count = 0 | |||||
| line = ["Another file.", "Good luck to everyone.", "This is a text file.", | |||||
| "End of file.", "Be happy every day."] | |||||
| for i in data.create_dict_iterator(): | |||||
| strs = i["text"].item().decode("utf8") | |||||
| assert strs == line[count] | |||||
| count += 1 | |||||
| assert count == 5 | |||||
| # Restore configuration | |||||
| ds.config.set_num_parallel_workers(original_num_parallel_workers) | |||||
| ds.config.set_seed(original_seed) | |||||
| def test_textline_dataset_num_samples(): | def test_textline_dataset_num_samples(): | ||||
| @@ -94,11 +191,33 @@ def test_textline_dataset_to_device(): | |||||
| data = data.to_device() | data = data.to_device() | ||||
| data.send() | data.send() | ||||
| def test_textline_dataset_exceptions(): | |||||
| with pytest.raises(ValueError) as error_info: | |||||
| _ = ds.TextFileDataset(DATA_FILE, num_samples=-1) | |||||
| assert "Input num_samples is not within the required interval" in str(error_info.value) | |||||
| with pytest.raises(ValueError) as error_info: | |||||
| _ = ds.TextFileDataset("does/not/exist/no.txt") | |||||
| assert "The following patterns did not match any files" in str(error_info.value) | |||||
| with pytest.raises(ValueError) as error_info: | |||||
| _ = ds.TextFileDataset("") | |||||
| assert "The following patterns did not match any files" in str(error_info.value) | |||||
| if __name__ == "__main__": | if __name__ == "__main__": | ||||
| test_textline_dataset_one_file() | test_textline_dataset_one_file() | ||||
| test_textline_dataset_all_file() | test_textline_dataset_all_file() | ||||
| test_textline_dataset_totext() | |||||
| test_textline_dataset_num_samples_zero() | |||||
| test_textline_dataset_shuffle_false4() | |||||
| test_textline_dataset_shuffle_false1() | |||||
| test_textline_dataset_shuffle_files4() | |||||
| test_textline_dataset_shuffle_files1() | |||||
| test_textline_dataset_shuffle_global4() | |||||
| test_textline_dataset_shuffle_global1() | |||||
| test_textline_dataset_num_samples() | test_textline_dataset_num_samples() | ||||
| test_textline_dataset_distribution() | test_textline_dataset_distribution() | ||||
| test_textline_dataset_repeat() | test_textline_dataset_repeat() | ||||
| test_textline_dataset_get_datasetsize() | test_textline_dataset_get_datasetsize() | ||||
| test_textline_dataset_to_device() | |||||
| test_textline_dataset_exceptions() | |||||