| @@ -102,6 +102,7 @@ | |||
| #include "minddata/dataset/engine/ir/datasetops/source/dbpedia_node.h" | |||
| #include "minddata/dataset/engine/ir/datasetops/source/div2k_node.h" | |||
| #include "minddata/dataset/engine/ir/datasetops/source/emnist_node.h" | |||
| #include "minddata/dataset/engine/ir/datasetops/source/en_wik9_node.h" | |||
| #include "minddata/dataset/engine/ir/datasetops/source/fake_image_node.h" | |||
| #include "minddata/dataset/engine/ir/datasetops/source/fashion_mnist_node.h" | |||
| #include "minddata/dataset/engine/ir/datasetops/source/flickr_node.h" | |||
| @@ -1151,6 +1152,12 @@ EMnistDataset::EMnistDataset(const std::vector<char> &dataset_dir, const std::ve | |||
| ir_node_ = std::static_pointer_cast<DatasetNode>(ds); | |||
| } | |||
| EnWik9Dataset::EnWik9Dataset(const std::vector<char> &dataset_dir, int64_t num_samples, ShuffleMode shuffle, | |||
| int32_t num_shards, int32_t shard_id, const std::shared_ptr<DatasetCache> &cache) { | |||
| auto ds = std::make_shared<EnWik9Node>(CharToString(dataset_dir), num_samples, shuffle, num_shards, shard_id, cache); | |||
| ir_node_ = std::static_pointer_cast<DatasetNode>(ds); | |||
| } | |||
| FakeImageDataset::FakeImageDataset(int32_t num_images, const std::vector<int32_t> &image_size, int32_t num_classes, | |||
| int32_t base_seed, const std::shared_ptr<Sampler> &sampler, | |||
| const std::shared_ptr<DatasetCache> &cache) { | |||
| @@ -39,6 +39,7 @@ | |||
| #include "minddata/dataset/engine/ir/datasetops/source/dbpedia_node.h" | |||
| #include "minddata/dataset/engine/ir/datasetops/source/div2k_node.h" | |||
| #include "minddata/dataset/engine/ir/datasetops/source/emnist_node.h" | |||
| #include "minddata/dataset/engine/ir/datasetops/source/en_wik9_node.h" | |||
| #include "minddata/dataset/engine/ir/datasetops/source/fake_image_node.h" | |||
| #include "minddata/dataset/engine/ir/datasetops/source/fashion_mnist_node.h" | |||
| #include "minddata/dataset/engine/ir/datasetops/source/flickr_node.h" | |||
| @@ -248,6 +249,18 @@ PYBIND_REGISTER(EMnistNode, 2, ([](const py::module *m) { | |||
| })); | |||
| })); | |||
| PYBIND_REGISTER(EnWik9Node, 2, ([](const py::module *m) { | |||
| (void)py::class_<EnWik9Node, DatasetNode, std::shared_ptr<EnWik9Node>>(*m, "EnWik9Node", | |||
| "to create an EnWik9Node") | |||
| .def(py::init([](std::string dataset_dir, int32_t num_samples, int32_t shuffle, int32_t num_shards, | |||
| int32_t shard_id) { | |||
| std::shared_ptr<EnWik9Node> en_wik9 = std::make_shared<EnWik9Node>( | |||
| dataset_dir, num_samples, toShuffleMode(shuffle), num_shards, shard_id, nullptr); | |||
| THROW_IF_ERROR(en_wik9->ValidateParams()); | |||
| return en_wik9; | |||
| })); | |||
| })); | |||
| PYBIND_REGISTER(FakeImageNode, 2, ([](const py::module *m) { | |||
| (void)py::class_<FakeImageNode, DatasetNode, std::shared_ptr<FakeImageNode>>( | |||
| *m, "FakeImageNode", "to create a FakeImageNode") | |||
| @@ -17,6 +17,7 @@ set(DATASET_ENGINE_DATASETOPS_SOURCE_SRC_FILES | |||
| dbpedia_op.cc | |||
| div2k_op.cc | |||
| emnist_op.cc | |||
| en_wik9_op.cc | |||
| fake_image_op.cc | |||
| fashion_mnist_op.cc | |||
| flickr_op.cc | |||
| @@ -0,0 +1,118 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "minddata/dataset/engine/datasetops/source/en_wik9_op.h" | |||
| #include <fstream> | |||
| #include <utility> | |||
| #include "utils/file_utils.h" | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| EnWik9Op::EnWik9Op(int32_t num_workers, int64_t total_rows, int32_t worker_connector_size, | |||
| std::unique_ptr<DataSchema> data_schema, const std::vector<std::string> &file_list, | |||
| int32_t op_connector_size, bool shuffle_files, int32_t num_devices, int32_t device_id) | |||
| : TextFileOp(num_workers, total_rows, worker_connector_size, std::move(data_schema), file_list, op_connector_size, | |||
| shuffle_files, num_devices, device_id) {} | |||
| // A print method typically used for debugging. | |||
| void EnWik9Op::Print(std::ostream &out, bool show_all) const { | |||
| if (!show_all) { | |||
| // Call the super class for displaying any common 1-liner info. | |||
| ParallelOp::Print(out, show_all); | |||
| // Then show any custom derived-internal 1-liner info for this op. | |||
| out << "\n"; | |||
| } else { | |||
| // Call the super class for displaying any common detailed info. | |||
| ParallelOp::Print(out, show_all); | |||
| // Then show any custom derived-internal stuff. | |||
| out << "\nRow count: " << total_rows_ << "\nDevice id: " << device_id_ << "\nNumber of devices: " << num_devices_ | |||
| << "\nShuffle files: " << ((shuffle_files_) ? "yes" : "no") << "\nEnWik9 file path:\n"; | |||
| for (size_t i = 0; i < text_files_list_.size(); ++i) { | |||
| // Print the name of per file path. | |||
| out << " " << text_files_list_[i]; | |||
| } | |||
| out << "\nData Schema:\n"; | |||
| out << *data_schema_ << "\n\n"; | |||
| } | |||
| } | |||
| Status EnWik9Op::LoadFile(const std::string &file, int64_t start_offset, int64_t end_offset, int32_t worker_id) { | |||
| auto realpath = FileUtils::GetRealPath(file.data()); | |||
| if (!realpath.has_value()) { | |||
| MS_LOG(ERROR) << "Invalid file path, " << file << " does not exist."; | |||
| RETURN_STATUS_UNEXPECTED("Invalid file path, " + file + " does not exist."); | |||
| } | |||
| std::ifstream handle(realpath.value()); | |||
| if (!handle.is_open()) { | |||
| RETURN_STATUS_UNEXPECTED("Invalid file, failed to open file: " + file + | |||
| ". Check if the file is damaged or permission denied."); | |||
| } | |||
| int64_t rows_total = 0; | |||
| std::string line; | |||
| while (getline(handle, line)) { | |||
| if (line.empty()) { | |||
| line = ""; | |||
| } | |||
| // If read to the end offset of this file, break. | |||
| if (rows_total >= end_offset) { | |||
| break; | |||
| } | |||
| // Skip line before start offset. | |||
| if (rows_total < start_offset) { | |||
| rows_total++; | |||
| continue; | |||
| } | |||
| TensorRow tRow(1, nullptr); | |||
| tRow.setPath({file}); | |||
| RETURN_IF_NOT_OK(LoadTensor(line, &tRow)); | |||
| RETURN_IF_NOT_OK(jagged_rows_connector_->Add(worker_id, std::move(tRow))); | |||
| rows_total++; | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| int64_t EnWik9Op::CountTotalRows(const std::string &file) { | |||
| auto realpath = FileUtils::GetRealPath(file.data()); | |||
| if (!realpath.has_value()) { | |||
| MS_LOG(ERROR) << "Invalid file, " << file << " does not exist."; | |||
| return 0; | |||
| } | |||
| std::ifstream handle(realpath.value()); | |||
| if (!handle.is_open()) { | |||
| MS_LOG(ERROR) << "Invalid file, failed to open file: " << file | |||
| << ". Check if the file is damaged or permission denied."; | |||
| return 0; | |||
| } | |||
| std::string line; | |||
| int64_t count = 0; | |||
| while (getline(handle, line)) { | |||
| count++; | |||
| } | |||
| return count; | |||
| } | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,77 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_SOURCE_EN_WIK9_OP_H_ | |||
| #define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_SOURCE_EN_WIK9_OP_H_ | |||
| #include <memory> | |||
| #include <string> | |||
| #include <vector> | |||
| #include "minddata/dataset/engine/datasetops/source/text_file_op.h" | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| class EnWik9Op : public TextFileOp { | |||
| public: | |||
| /// \brief Constructor. | |||
| /// \param[in] num_workers The number of worker threads reading data from enwiki files. | |||
| /// \param[in] total_rows The number of rows to read. | |||
| /// \param[in] worker_connector_size Size of each internal queue. | |||
| /// \param[in] data_schema The data schema object. | |||
| /// \param[in] files_list List of file paths for the dataset files. | |||
| /// \param[in] op_connector_size Size of each queue in the connector that the child operator pulls from. | |||
| /// \param[in] shuffle_files Whether or not to shuffle the files before reading data. | |||
| /// \param[in] num_devices The number of devices. | |||
| /// \param[in] device_id Id of device. | |||
| EnWik9Op(int32_t num_workers, int64_t total_rows, int32_t worker_connector_size, | |||
| std::unique_ptr<DataSchema> data_schema, const std::vector<std::string> &file_list, | |||
| int32_t op_connector_size, bool shuffle_files, int32_t num_devices, int32_t device_id); | |||
| /// \brief Default destructor. | |||
| ~EnWik9Op() = default; | |||
| /// \brief A print method typically used for debugging. | |||
| /// \param[out] out The output stream to write output to. | |||
| /// \param[in] show_all A bool to control if you want to show all info or just a summary. | |||
| void Print(std::ostream &out, bool show_all) const override; | |||
| /// \brief Op name getter. | |||
| /// \return Name of the current Op. | |||
| std::string Name() const override { return "EnWik9Op"; } | |||
| /// \brief DatasetName name getter. | |||
| /// \param[in] upper A bool to control if you need upper DatasetName. | |||
| /// \return DatasetName of the current Op. | |||
| virtual std::string DatasetName(bool upper = false) const { return upper ? "EnWik9" : "enwik9"; } | |||
| /// \brief Reads a text file and loads the data into multiple TensorRows. | |||
| /// \param[in] file The file to read. | |||
| /// \param[in] start_offset - the start offset of file. | |||
| /// \param[in] end_offset - the end offset of file. | |||
| /// \param[in] The id of the worker that is executing this function. | |||
| /// \return Status The error code returned. | |||
| Status LoadFile(const std::string &file, int64_t start_offset, int64_t end_offset, int32_t worker_id) override; | |||
| private: | |||
| /// \brief Count number of rows in each file. | |||
| /// \param[in] file Txt file name. | |||
| /// \return int64_t The total number of rows in file. | |||
| int64_t CountTotalRows(const std::string &file); | |||
| }; | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_SOURCE_EN_WIK9_OP_H_ | |||
| @@ -248,6 +248,5 @@ Status TextFileOp::ComputeColMap() { | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| @@ -114,7 +114,7 @@ class TextFileOp : public NonMappableLeafOp { | |||
| // Count number of rows in each file. | |||
| // @param file - txt file name. | |||
| // @return int64_t - the total number of rows in file. | |||
| int64_t CountTotalRows(const std::string &file); | |||
| virtual int64_t CountTotalRows(const std::string &file); | |||
| std::vector<std::string> text_files_list_; | |||
| std::unique_ptr<DataSchema> data_schema_; | |||
| @@ -91,6 +91,7 @@ constexpr char kCSVNode[] = "CSVDataset"; | |||
| constexpr char kDBpediaNode[] = "DBpediaDataset"; | |||
| constexpr char kDIV2KNode[] = "DIV2KDataset"; | |||
| constexpr char kEMnistNode[] = "EMnistDataset"; | |||
| constexpr char kEnWik9Node[] = "EnWik9Dataset"; | |||
| constexpr char kFakeImageNode[] = "FakeImageDataset"; | |||
| constexpr char kFashionMnistNode[] = "FashionMnistDataset"; | |||
| constexpr char kFlickrNode[] = "FlickrDataset"; | |||
| @@ -18,6 +18,7 @@ set(DATASET_ENGINE_IR_DATASETOPS_SOURCE_SRC_FILES | |||
| dbpedia_node.cc | |||
| div2k_node.cc | |||
| emnist_node.cc | |||
| en_wik9_node.cc | |||
| fake_image_node.cc | |||
| fashion_mnist_node.cc | |||
| flickr_node.cc | |||
| @@ -0,0 +1,174 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "minddata/dataset/engine/ir/datasetops/source/en_wik9_node.h" | |||
| #include <algorithm> | |||
| #include <memory> | |||
| #include <string> | |||
| #include <utility> | |||
| #include <vector> | |||
| #include "minddata/dataset/engine/datasetops/source/en_wik9_op.h" | |||
| #include "minddata/dataset/util/status.h" | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| // Constructor for EnWik9Node | |||
| EnWik9Node::EnWik9Node(const std::string &dataset_dir, int32_t num_samples, ShuffleMode shuffle, int32_t num_shards, | |||
| int32_t shard_id, std::shared_ptr<DatasetCache> cache) | |||
| : NonMappableSourceNode(std::move(cache)), | |||
| num_samples_(num_samples), | |||
| shuffle_(shuffle), | |||
| num_shards_(num_shards), | |||
| shard_id_(shard_id), | |||
| dataset_dir_(dataset_dir) { | |||
| // Update the num_shards_ in global context. this number is only used for now by auto_num_worker_pass. User discretion | |||
| // is advised. Auto_num_worker_pass is currently an experimental feature which can still work if the num_shards_ isn't | |||
| // 100% correct. The reason behind is for now, PreBuildSampler doesn't offer a way to return num_shards. Once | |||
| // PreBuildSampler is phased out, this can be cleaned up. | |||
| GlobalContext::config_manager()->set_num_shards_for_auto_num_workers(num_shards_); | |||
| DirToPath(dataset_dir_); | |||
| } | |||
| std::shared_ptr<DatasetNode> EnWik9Node::Copy() { | |||
| auto node = std::make_shared<EnWik9Node>(dataset_dir_, num_samples_, shuffle_, num_shards_, shard_id_, cache_); | |||
| return node; | |||
| } | |||
| void EnWik9Node::Print(std::ostream &out) const { | |||
| out << (Name() + "(cache: " + ((cache_ != nullptr) ? "true" : "false") + | |||
| ", num_shards: " + std::to_string(num_shards_) + ", shard_id: " + std::to_string(shard_id_) + ")"); | |||
| } | |||
| Status EnWik9Node::ValidateParams() { | |||
| RETURN_IF_NOT_OK(DatasetNode::ValidateParams()); | |||
| RETURN_IF_NOT_OK(ValidateDatasetDirParam("EnWik9Dataset", dataset_dir_)); | |||
| RETURN_IF_NOT_OK(ValidateEnum("EnWik9Dataset", "ShuffleMode", shuffle_, | |||
| {ShuffleMode::kFalse, ShuffleMode::kFiles, ShuffleMode::kGlobal})); | |||
| RETURN_IF_NOT_OK(ValidateScalar("EnWik9Dataset", "num_samples", num_samples_, {0}, false)); | |||
| RETURN_IF_NOT_OK(ValidateDatasetShardParams("EnWik9Dataset", num_shards_, shard_id_)); | |||
| return Status::OK(); | |||
| } | |||
| void EnWik9Node::DirToPath(const std::string &dataset_dir) { | |||
| Path train_prefix("enwik9"); | |||
| Path dir(dataset_dir); | |||
| Path temp_path = dir / train_prefix; | |||
| src_target_file_list_.push_back(temp_path.ToString()); | |||
| } | |||
| // Function to build EnWik9Node | |||
| Status EnWik9Node::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops) { | |||
| bool shuffle_files = (shuffle_ == ShuffleMode::kGlobal || shuffle_ == ShuffleMode::kFiles); | |||
| // Do internal Schema generation. | |||
| auto schema = std::make_unique<DataSchema>(); | |||
| RETURN_IF_NOT_OK(schema->AddColumn(ColDescriptor("text", DataType(DataType::DE_UINT8), TensorImpl::kFlexible, 1))); | |||
| // Create and initialize EnWik9Op | |||
| std::shared_ptr<EnWik9Op> en_wik9_op = | |||
| std::make_shared<EnWik9Op>(num_workers_, num_samples_, worker_connector_size_, std::move(schema), | |||
| src_target_file_list_, connector_que_size_, shuffle_files, num_shards_, shard_id_); | |||
| RETURN_IF_NOT_OK(en_wik9_op->Init()); | |||
| // If a global shuffle is used for EnWik9, it will inject a shuffle op over the EnWik9. | |||
| // But, if there is a cache in the tree, we do not need the global shuffle and the shuffle op should not be built. | |||
| // This is achieved in the cache transform pass where we call MakeSimpleProducer to reset EnWik9's shuffle | |||
| // option to false. | |||
| if (shuffle_ == ShuffleMode::kGlobal) { | |||
| // Inject ShuffleOp | |||
| std::shared_ptr<DatasetOp> shuffle_op = nullptr; | |||
| int64_t num_rows = 0; | |||
| // First, get the number of rows in the dataset | |||
| RETURN_IF_NOT_OK(EnWik9Op::CountAllFileRows(src_target_file_list_, &num_rows)); | |||
| // Add the shuffle op after this op | |||
| RETURN_IF_NOT_OK( | |||
| AddShuffleOp(src_target_file_list_.size(), num_shards_, num_rows, 0, connector_que_size_, &shuffle_op)); | |||
| shuffle_op->SetTotalRepeats(GetTotalRepeats()); | |||
| shuffle_op->SetNumRepeatsPerEpoch(GetNumRepeatsPerEpoch()); | |||
| node_ops->push_back(shuffle_op); | |||
| } | |||
| en_wik9_op->SetTotalRepeats(GetTotalRepeats()); | |||
| en_wik9_op->SetNumRepeatsPerEpoch(GetNumRepeatsPerEpoch()); | |||
| // Add EnWik9Op | |||
| node_ops->push_back(en_wik9_op); | |||
| return Status::OK(); | |||
| } | |||
| // Get the shard id of node | |||
| Status EnWik9Node::GetShardId(int32_t *shard_id) { | |||
| *shard_id = shard_id_; | |||
| return Status::OK(); | |||
| } | |||
| // Get Dataset size | |||
| Status EnWik9Node::GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size_getter, bool estimate, | |||
| int64_t *dataset_size) { | |||
| if (dataset_size_ > 0) { | |||
| *dataset_size = dataset_size_; | |||
| return Status::OK(); | |||
| } | |||
| int64_t num_rows, sample_size = num_samples_; | |||
| RETURN_IF_NOT_OK(EnWik9Op::CountAllFileRows(src_target_file_list_, &num_rows)); | |||
| num_rows = static_cast<int64_t>(ceil(num_rows / (1.0 * num_shards_))); | |||
| *dataset_size = sample_size > 0 ? std::min(num_rows, sample_size) : num_rows; | |||
| dataset_size_ = *dataset_size; | |||
| return Status::OK(); | |||
| } | |||
| Status EnWik9Node::to_json(nlohmann::json *out_json) { | |||
| nlohmann::json args; | |||
| args["num_parallel_workers"] = num_workers_; | |||
| args["dataset_dir"] = dataset_dir_; | |||
| args["num_samples"] = num_samples_; | |||
| args["shuffle"] = shuffle_; | |||
| args["num_shards"] = num_shards_; | |||
| args["shard_id"] = shard_id_; | |||
| if (cache_ != nullptr) { | |||
| nlohmann::json cache_args; | |||
| RETURN_IF_NOT_OK(cache_->to_json(&cache_args)); | |||
| args["cache"] = cache_args; | |||
| } | |||
| *out_json = args; | |||
| return Status::OK(); | |||
| } | |||
| // Note: The following two functions are common among NonMappableSourceNode and should be promoted to its parent class. | |||
| // EnWik9 by itself is a non-mappable dataset that does not support sampling. | |||
| // However, if a cache operator is injected at some other place higher in the tree, that cache can | |||
| // inherit this sampler from the leaf, providing sampling support from the caching layer. | |||
| // That is why we setup the sampler for a leaf node that does not use sampling. | |||
| Status EnWik9Node::SetupSamplerForCache(std::shared_ptr<SamplerObj> *sampler) { | |||
| bool shuffle_files = (shuffle_ == ShuffleMode::kGlobal || shuffle_ == ShuffleMode::kFiles); | |||
| *sampler = SelectSampler(num_samples_, shuffle_files, num_shards_, shard_id_); | |||
| return Status::OK(); | |||
| } | |||
| // If a cache has been added into the ascendant tree over this EnWik9 node, then the cache will be executing | |||
| // a sampler for fetching the data. As such, any options in the EnWik9 node need to be reset to its defaults so | |||
| // that this EnWik9 node will produce the full set of data into the cache. | |||
| Status EnWik9Node::MakeSimpleProducer() { | |||
| shard_id_ = 0; | |||
| num_shards_ = 1; | |||
| shuffle_ = ShuffleMode::kFalse; | |||
| num_samples_ = 0; | |||
| return Status::OK(); | |||
| } | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,136 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_EN_WIK9_NODE_H_ | |||
| #define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_EN_WIK9_NODE_H_ | |||
| #include <memory> | |||
| #include <string> | |||
| #include <vector> | |||
| #include "minddata/dataset/engine/ir/datasetops/dataset_node.h" | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| /// \class EnWik9Node. | |||
| /// \brief A Dataset derived class to represent EnWik9 dataset. | |||
| class EnWik9Node : public NonMappableSourceNode { | |||
| public: | |||
| /// \brief Constructor. | |||
| /// \param[in] dataset_dir The directory of dataset. | |||
| /// \param[in] num_samples The number of samples that users want to get. | |||
| /// \param[in] shuffle Decide the dataset shuffle pattern. | |||
| /// \param[in] num_shards The number of shards that users want to part. | |||
| /// \param[in] shard_id The id of shard. | |||
| /// \param[in] cache Tensor cache to use. | |||
| EnWik9Node(const std::string &dataset_dir, int32_t num_samples, ShuffleMode shuffle, int32_t num_shards, | |||
| int32_t shard_id, std::shared_ptr<DatasetCache> cache); | |||
| /// \brief Destructor. | |||
| ~EnWik9Node() = default; | |||
| /// \brief Node name getter. | |||
| /// \return Name of the current node. | |||
| std::string Name() const override { return kEnWik9Node; } | |||
| /// \brief Print the description. | |||
| /// \param[out] out The output stream to write output to. | |||
| void Print(std::ostream &out) const override; | |||
| /// \brief Copy the node to a new object. | |||
| /// \return A shared pointer to the new copy. | |||
| std::shared_ptr<DatasetNode> Copy() override; | |||
| /// \brief a base class override function to create the required runtime dataset op objects for this class. | |||
| /// \param[in] node_ops A vector containing shared pointer to the Dataset Ops that this object will create. | |||
| /// \return Status Status::OK() if build successfully. | |||
| Status Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops) override; | |||
| /// \brief Parameters validation. | |||
| /// \return Status Status::OK() if all the parameters are valid. | |||
| Status ValidateParams() override; | |||
| /// \brief Get the shard id of node. | |||
| /// \param[in] shard_id Id of this shard. | |||
| /// \return Status Status::OK() if get shard id successfully. | |||
| Status GetShardId(int32_t *shard_id) override; | |||
| /// \brief Base-class override for GetDatasetSize. | |||
| /// \param[in] size_getter Shared pointer to DatasetSizeGetter. | |||
| /// \param[in] estimate This is only supported by some of the ops and it's used to speed up the process of getting | |||
| /// dataset size at the expense of accuracy. | |||
| /// \param[out] dataset_size the size of the dataset. | |||
| /// \return Status of the function. | |||
| Status GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size_getter, bool estimate, | |||
| int64_t *dataset_size) override; | |||
| /// \brief Getter functions. | |||
| /// \return Directory of dataset. | |||
| const std::string &DatasetDir() const { return dataset_dir_; } | |||
| // \brief Getter functions. | |||
| /// \return The number of samples. | |||
| int32_t NumSamples() const { return num_samples_; } | |||
| // \brief Getter functions. | |||
| /// \return The number of shards. | |||
| int32_t NumShards() const { return num_shards_; } | |||
| // \brief Getter functions. | |||
| /// \return Id of shard. | |||
| int32_t ShardId() const { return shard_id_; } | |||
| // \brief Getter functions. | |||
| /// \return Shuffle pattern. | |||
| ShuffleMode Shuffle() const { return shuffle_; } | |||
| /// \brief Get the arguments of node. | |||
| /// \param[out] out_json JSON string of all attributes. | |||
| /// \return Status of the function. | |||
| Status to_json(nlohmann::json *out_json) override; | |||
| /// \brief EnWik9 by itself is a non-mappable dataset that does not support sampling. | |||
| /// However, if a cache operator is injected at some other place higher in the tree, that cache can | |||
| /// inherit this sampler from the leaf, providing sampling support from the caching layer. | |||
| /// That is why we setup the sampler for a leaf node that does not use sampling. | |||
| /// Note: This function is common among NonMappableSourceNode and should be promoted to its parent class. | |||
| /// \param[in] sampler The sampler to setup. | |||
| /// \return Status of the function. | |||
| Status SetupSamplerForCache(std::shared_ptr<SamplerObj> *sampler) override; | |||
| /// \brief If a cache has been added into the ascendant tree over this EnWik9 node, then the cache will be executing. | |||
| /// a sampler for fetching the data. As such, any options in the EnWik9 node need to be reset to its defaults | |||
| /// so that this EnWik9 node will produce the full set of data into the cache. | |||
| /// Note: This function is common among NonMappableSourceNode and should be promoted to its parent class. | |||
| /// \return Status of the function. | |||
| Status MakeSimpleProducer() override; | |||
| /// \brief Change file's directory into file's path, and put it into a list. | |||
| /// \param[in] dataset_dir Directory of enwik9 dataset. | |||
| /// \return A list of read file names. | |||
| void DirToPath(const std::string &dataset_dir); | |||
| private: | |||
| std::string dataset_dir_; // dataset of file. | |||
| int32_t num_samples_; // the number of samples. | |||
| int32_t num_shards_; // the number of shards. | |||
| int32_t shard_id_; // the id of shard. | |||
| ShuffleMode shuffle_; // a object of ShuffleMode, which belongs to num. | |||
| std::vector<std::string> src_target_file_list_; // file list; | |||
| }; | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_EN_WIK9_NODE_H_ | |||
| @@ -2342,6 +2342,51 @@ inline std::shared_ptr<EMnistDataset> MS_APIEMnist(const std::string &dataset_di | |||
| cache); | |||
| } | |||
| /// \class EnWik9Dataset | |||
| /// \brief A source dataset for reading and parsing EnWik9 dataset. | |||
| class MS_API EnWik9Dataset : public Dataset { | |||
| public: | |||
| /// \brief Function to create a EnWik9Dataset. | |||
| /// \param[in] dataset_dir Path to the root directory that contains the dataset. | |||
| /// \param[in] num_samples The number of samples to be included in the dataset. | |||
| /// \param[in] shuffle The mode for shuffling data every epoch. | |||
| /// Can be any of: | |||
| /// ShuffleMode.kFalse - No shuffling is performed. | |||
| /// ShuffleMode.kFiles - Shuffle files only. | |||
| /// ShuffleMode.kGlobal - Shuffle both the files and samples. | |||
| /// \param[in] num_shards Number of shards that the dataset should be divided into. | |||
| /// \param[in] shard_id The shard ID within num_shards. This argument should be | |||
| /// specified only when num_shards is also specified. | |||
| /// \param[in] cache Tensor cache to use. | |||
| EnWik9Dataset(const std::vector<char> &dataset_dir, int64_t num_samples, ShuffleMode shuffle, int32_t num_shards, | |||
| int32_t shard_id, const std::shared_ptr<DatasetCache> &cache); | |||
| /// Destructor of EnWik9Dataset. | |||
| ~EnWik9Dataset() = default; | |||
| }; | |||
| /// \brief Function to create a EnWik9Dataset. | |||
| /// \note The generated dataset has one column ['text']. | |||
| /// \param[in] dataset_dir Path to the root directory that contains the dataset. | |||
| /// \param[in] num_samples The number of samples to be included in the dataset | |||
| /// (Default = 0, means all samples). | |||
| /// \param[in] shuffle The mode for shuffling data every epoch (Default=ShuffleMode.kGlobal). | |||
| /// Can be any of: | |||
| /// ShuffleMode.kFalse - No shuffling is performed. | |||
| /// ShuffleMode.kFiles - Shuffle files only. | |||
| /// ShuffleMode.kGlobal - Shuffle both the files and samples. | |||
| /// \param[in] num_shards Number of shards that the dataset should be divided into (Default = 1). | |||
| /// \param[in] shard_id The shard ID within num_shards. This argument should be | |||
| /// specified only when num_shards is also specified (Default = 0). | |||
| /// \param[in] cache Tensor cache to use (default=nullptr, which means no cache is used). | |||
| /// \return Shared pointer to the EnWik9Dataset. | |||
| inline std::shared_ptr<EnWik9Dataset> MS_API EnWik9(const std::string &dataset_dir, int64_t num_samples = 0, | |||
| ShuffleMode shuffle = ShuffleMode::kGlobal, int32_t num_shards = 1, | |||
| int32_t shard_id = 0, | |||
| const std::shared_ptr<DatasetCache> &cache = nullptr) { | |||
| return std::make_shared<EnWik9Dataset>(StringToChar(dataset_dir), num_samples, shuffle, num_shards, shard_id, cache); | |||
| } | |||
| /// \class FakeImageDataset | |||
| /// \brief A source dataset for generating fake images. | |||
| class MS_API FakeImageDataset : public Dataset { | |||
| @@ -76,7 +76,8 @@ from .validators import check_batch, check_shuffle, check_map, check_filter, che | |||
| check_stl10_dataset, check_yelp_review_dataset, check_penn_treebank_dataset, check_iwslt2016_dataset, \ | |||
| check_iwslt2017_dataset, check_sogou_news_dataset, check_yahoo_answers_dataset, check_udpos_dataset, \ | |||
| check_conll2000_dataset, check_amazon_review_dataset, check_semeion_dataset, check_caltech101_dataset, \ | |||
| check_caltech256_dataset, check_wiki_text_dataset, check_imdb_dataset, check_wider_face_dataset | |||
| check_caltech256_dataset, check_wiki_text_dataset, check_imdb_dataset, check_wider_face_dataset, \ | |||
| check_en_wik9_dataset | |||
| from ..core.config import get_callback_timeout, _init_device_info, get_enable_shared_mem, get_num_parallel_workers, \ | |||
| get_prefetch_size | |||
| from ..core.datatypes import mstype_to_detype, mstypelist_to_detypelist | |||
| @@ -10805,6 +10806,82 @@ class STL10Dataset(MappableDataset): | |||
| return cde.STL10Node(self.dataset_dir, self.usage, self.sampler) | |||
| class EnWik9Dataset(SourceDataset): | |||
| """ | |||
| A source dataset that reads and parses EnWik9 dataset. | |||
| The generated dataset has one column :py:obj:`[text]` with type string. | |||
| Args: | |||
| dataset_dir (str): Path to the root directory that contains the dataset. | |||
| num_samples (int, optional): The number of samples to be included in the dataset | |||
| (default=None, will include all samples). | |||
| num_parallel_workers (int, optional): Number of workers to read the data | |||
| (default=None, number set in the config). | |||
| shuffle (Union[bool, Shuffle level], optional): Perform reshuffling of the data every epoch | |||
| (default=True). | |||
| If shuffle is False, no shuffling will be performed; | |||
| If shuffle is True, the behavior is the same as setting shuffle to be Shuffle.GLOBAL | |||
| Otherwise, there are two levels of shuffling: | |||
| - Shuffle.GLOBAL: Shuffle both the files and samples. | |||
| - Shuffle.FILES: Shuffle files only. | |||
| num_shards (int, optional): Number of shards that the dataset will be divided into (default=None). | |||
| When this argument is specified, `num_samples` reflects the maximum sample number of per shard. | |||
| shard_id (int, optional): The shard ID within num_shards (default=None). This | |||
| argument can only be specified when num_shards is also specified. | |||
| cache (DatasetCache, optional): Use tensor caching service to speed up dataset processing | |||
| (default=None, which means no cache is used). | |||
| Examples: | |||
| >>> en_wik9_dataset_dir = "/path/to/en_wik9_dataset" | |||
| >>> dataset2 = ds.EnWik9Dataset(dataset_dir=en_wik9_dataset_dir, num_samples=2, | |||
| ... shuffle=True) | |||
| About EnWik9 dataset: | |||
| The data of EnWik9 is UTF-8 encoded XML consisting primarily of English text. It contains 243,426 article titles, | |||
| of which 85,560 are #REDIRECT to fix broken links, and the rest are regular articles. | |||
| The data is UTF-8 clean. All characters are in the range U'0000 to U'10FFFF with valid encodings of 1 to | |||
| 4 bytes. The byte values 0xC0, 0xC1, and 0xF5-0xFF never occur. Also, in the Wikipedia dumps, | |||
| there are no control characters in the range 0x00-0x1F except for 0x09 (tab) and 0x0A (linefeed). | |||
| Linebreaks occur only on paragraph boundaries, so they always have a semantic purpose. | |||
| You can unzip the dataset files into the following directory structure and read by MindSpore's API. | |||
| .. code-block:: | |||
| . | |||
| └── EnWik9 | |||
| ├── enwik9 | |||
| Citation: | |||
| .. code-block:: | |||
| @NetworkResource{Hutter_prize, | |||
| author = {English Wikipedia}, | |||
| url = "https://cs.fit.edu/~mmahoney/compression/textdata.html", | |||
| month = {March}, | |||
| year = {2006} | |||
| } | |||
| """ | |||
| @check_en_wik9_dataset | |||
| def __init__(self, dataset_dir, num_samples=None, num_parallel_workers=None, shuffle=True, | |||
| num_shards=None, shard_id=None, cache=None): | |||
| super().__init__(num_parallel_workers=num_parallel_workers, num_samples=num_samples, shuffle=shuffle, | |||
| num_shards=num_shards, shard_id=shard_id, cache=cache) | |||
| self.dataset_dir = dataset_dir | |||
| def parse(self, children=None): | |||
| return cde.EnWik9Node(self.dataset_dir, self.num_samples, self.shuffle_flag, self.num_shards, | |||
| self.shard_id) | |||
| class YahooAnswersDataset(SourceDataset): | |||
| """ | |||
| A source dataset that reads and parses the YahooAnswers dataset. | |||
| @@ -2499,3 +2499,26 @@ def check_wiki_text_dataset(method): | |||
| return method(self, *args, **kwargs) | |||
| return new_method | |||
| def check_en_wik9_dataset(method): | |||
| """Wrapper method to check the parameters of EnWik9 dataset.""" | |||
| @wraps(method) | |||
| def new_method(self, *args, **kwargs): | |||
| _, param_dict = parse_user_args(method, *args, **kwargs) | |||
| nreq_param_int = ['num_samples', 'num_parallel_workers', 'num_shards', 'shard_id'] | |||
| dataset_dir = param_dict.get('dataset_dir') | |||
| check_dir(dataset_dir) | |||
| validate_dataset_param_value(nreq_param_int, param_dict, int) | |||
| check_sampler_shuffle_shard_options(param_dict) | |||
| cache = param_dict.get('cache') | |||
| check_cache_option(cache) | |||
| return method(self, *args, **kwargs) | |||
| return new_method | |||
| @@ -28,6 +28,7 @@ SET(DE_UT_SRCS | |||
| c_api_dataset_dbpedia_test.cc | |||
| c_api_dataset_div2k_test.cc | |||
| c_api_dataset_emnist_test.cc | |||
| c_api_dataset_en_wik9_test.cc | |||
| c_api_dataset_fake_image_test.cc | |||
| c_api_dataset_fashion_mnist_test.cc | |||
| c_api_dataset_flickr_test.cc | |||
| @@ -0,0 +1,427 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "common/common.h" | |||
| #include "minddata/dataset/core/global_context.h" | |||
| #include "minddata/dataset/include/dataset/datasets.h" | |||
| using namespace mindspore::dataset; | |||
| using mindspore::dataset::ShuffleMode; | |||
| class MindDataTestPipeline : public UT::DatasetOpTesting { | |||
| protected: | |||
| }; | |||
| /// Feature: EnWik9Dataset | |||
| /// Description: test EnWik9Dataset in pipeline mode | |||
| /// Expectation: the number of samples is proper | |||
| TEST_F(MindDataTestPipeline, TestEnWik9DatasetBasic) { | |||
| MS_LOG(INFO) << "Doing MindDataTestPipeline-TestEnWik9DatasetBasic."; | |||
| // Test EnWik9 Dataset with single text file and many default inputs. | |||
| // Set configuration. | |||
| uint32_t original_seed = GlobalContext::config_manager()->seed(); | |||
| uint32_t original_num_parallel_workers = GlobalContext::config_manager()->num_parallel_workers(); | |||
| MS_LOG(DEBUG) << "ORIGINAL seed: " << original_seed << ", num_parallel_workers: " << original_num_parallel_workers; | |||
| GlobalContext::config_manager()->set_seed(987); | |||
| GlobalContext::config_manager()->set_num_parallel_workers(4); | |||
| // Create a EnWik9 Dataset, with single enwik9 file. | |||
| // Note: /testEnWik9Dataset/enwik9 has 13 rows. | |||
| // Use 2 samples. | |||
| // Use defaults for other input parameters. | |||
| std::string tf_file = datasets_root_path_ + "/testEnWik9Dataset"; | |||
| std::shared_ptr<Dataset> ds = EnWik9(tf_file, 2); | |||
| EXPECT_NE(ds, nullptr); | |||
| // Create an iterator over the result of the above dataset. | |||
| // This will trigger the creation of the Execution Tree and launch it. | |||
| std::shared_ptr<Iterator> iter = ds->CreateIterator(); | |||
| EXPECT_NE(iter, nullptr); | |||
| // Iterate the dataset and get each row. | |||
| std::unordered_map<std::string, mindspore::MSTensor> row; | |||
| ASSERT_OK(iter->GetNextRow(&row)); | |||
| EXPECT_NE(row.find("text"), row.end()); | |||
| std::vector<std::string> expected_result = {" <title>MindSpore</title>", " <page>"}; | |||
| uint64_t i = 0; | |||
| while (row.size() != 0) { | |||
| auto text = row["text"]; | |||
| MS_LOG(INFO) << "Tensor text shape: " << text.Shape(); | |||
| std::shared_ptr<Tensor> de_text; | |||
| ASSERT_OK(Tensor::CreateFromMSTensor(text, &de_text)); | |||
| std::string_view sv; | |||
| ASSERT_OK(de_text->GetItemAt(&sv, {})); | |||
| std::string ss(sv); | |||
| MS_LOG(INFO) << "Text length: " << ss.length() << ", Text: " << ss.substr(0, 50); | |||
| // Compare against expected result. | |||
| EXPECT_STREQ(ss.c_str(), expected_result[i].c_str()); | |||
| i++; | |||
| ASSERT_OK(iter->GetNextRow(&row)); | |||
| } | |||
| // Expect 2 samples. | |||
| EXPECT_EQ(i, 2); | |||
| // Manually terminate the pipeline. | |||
| iter->Stop(); | |||
| // Restore configuration. | |||
| GlobalContext::config_manager()->set_seed(original_seed); | |||
| GlobalContext::config_manager()->set_num_parallel_workers(original_num_parallel_workers); | |||
| } | |||
| /// Feature: EnWik9Dataset | |||
| /// Description: test EnWik9Dataset in pipeline mode | |||
| /// Expectation: the number of samples is proper | |||
| TEST_F(MindDataTestPipeline, TestEnWik9DatasetBasicAndRepeat) { | |||
| MS_LOG(INFO) << "Doing MindDataTestPipeline-TestEnWik9DatasetBasicAndRepeat."; | |||
| // Test EnWik9 Dataset with single enwik9 file and many default inputs. | |||
| // Set configuration. | |||
| uint32_t original_seed = GlobalContext::config_manager()->seed(); | |||
| uint32_t original_num_parallel_workers = GlobalContext::config_manager()->num_parallel_workers(); | |||
| MS_LOG(DEBUG) << "ORIGINAL seed: " << original_seed << ", num_parallel_workers: " << original_num_parallel_workers; | |||
| GlobalContext::config_manager()->set_seed(987); | |||
| GlobalContext::config_manager()->set_num_parallel_workers(4); | |||
| // Create two EnWik9 Dataset, with single enwik9 file. | |||
| // Note: /testEnWik9Dataset/enwik9 has 13 rows. | |||
| // Use 2 samples. | |||
| // Use defaults for other input parameters. | |||
| std::string tf_file = datasets_root_path_ + "/testEnWik9Dataset"; | |||
| std::shared_ptr<Dataset> ds1 = EnWik9(tf_file, 2); | |||
| std::shared_ptr<Dataset> ds2 = EnWik9(tf_file, 2); | |||
| EXPECT_NE(ds1, nullptr); | |||
| EXPECT_NE(ds2, nullptr); | |||
| // Create two Repeat operation on ds. | |||
| int32_t repeat_num = 2; | |||
| ds1 = ds1->Repeat(repeat_num); | |||
| EXPECT_NE(ds1, nullptr); | |||
| repeat_num = 3; | |||
| ds2 = ds2->Repeat(repeat_num); | |||
| EXPECT_NE(ds2, nullptr); | |||
| // Create a Concat operation on the ds. | |||
| ds1 = ds1->Concat({ds2}); | |||
| EXPECT_NE(ds1, nullptr); | |||
| // Create an iterator over the result of the above dataset. | |||
| // This will trigger the creation of the Execution Tree and launch it. | |||
| std::shared_ptr<Iterator> iter = ds1->CreateIterator(); | |||
| EXPECT_NE(iter, nullptr); | |||
| // Iterate the dataset and get each row. | |||
| std::unordered_map<std::string, mindspore::MSTensor> row; | |||
| ASSERT_OK(iter->GetNextRow(&row)); | |||
| EXPECT_NE(row.find("text"), row.end()); | |||
| std::vector<std::string> expected_result = {" <page>", " <title>MindSpore</title>"}; | |||
| uint64_t i = 0; | |||
| while (row.size() != 0) { | |||
| auto text = row["text"]; | |||
| MS_LOG(INFO) << "Tensor text shape: " << text.Shape(); | |||
| i++; | |||
| ASSERT_OK(iter->GetNextRow(&row)); | |||
| } | |||
| // Expect 10 samples. | |||
| EXPECT_EQ(i, 10); | |||
| // Manually terminate the pipeline. | |||
| iter->Stop(); | |||
| // Restore configuration. | |||
| GlobalContext::config_manager()->set_seed(original_seed); | |||
| GlobalContext::config_manager()->set_num_parallel_workers(original_num_parallel_workers); | |||
| } | |||
| /// Feature: EnWik9Dataset | |||
| /// Description: test EnWik9Dataset in pipeline mode | |||
| /// Expectation: the number of samples is proper | |||
| TEST_F(MindDataTestPipeline, TestEnWik9Getters) { | |||
| MS_LOG(INFO) << "Doing MindDataTestPipeline-TestEnWik9Getters."; | |||
| // Test EnWik9 Dataset with single text file and many default inputs. | |||
| // Set configuration. | |||
| uint32_t original_seed = GlobalContext::config_manager()->seed(); | |||
| uint32_t original_num_parallel_workers = GlobalContext::config_manager()->num_parallel_workers(); | |||
| MS_LOG(DEBUG) << "ORIGINAL seed: " << original_seed << ", num_parallel_workers: " << original_num_parallel_workers; | |||
| GlobalContext::config_manager()->set_seed(987); | |||
| GlobalContext::config_manager()->set_num_parallel_workers(4); | |||
| // Create a EnWik9 Dataset, with single enwik9 file. | |||
| // Note: /testEnWik9Dataset/enwik9 has 3 rows. | |||
| // Use 2 samples. | |||
| // Use defaults for other input parameters. | |||
| std::string tf_file = datasets_root_path_ + "/testEnWik9Dataset"; | |||
| std::shared_ptr<Dataset> ds = EnWik9(tf_file, 2); | |||
| EXPECT_NE(ds, nullptr); | |||
| std::vector<std::string> column_names = {"text"}; | |||
| EXPECT_EQ(ds->GetDatasetSize(), 2); | |||
| EXPECT_EQ(ds->GetColumnNames(), column_names); | |||
| ds = EnWik9(tf_file, 0); | |||
| EXPECT_NE(ds, nullptr); | |||
| EXPECT_EQ(ds->GetDatasetSize(), 13); | |||
| // Restore configuration. | |||
| GlobalContext::config_manager()->set_seed(original_seed); | |||
| GlobalContext::config_manager()->set_num_parallel_workers(original_num_parallel_workers); | |||
| } | |||
| /// Feature: EnWik9Dataset | |||
| /// Description: test EnWik9Dataset in pipeline mode | |||
| /// Expectation: the argument named dataset_file is incorrect | |||
| TEST_F(MindDataTestPipeline, TestEnWik9DatasetFailNoExistentPath) { | |||
| MS_LOG(INFO) << "Doing MindDataTestPipeline-TestEnWik9DatasetFailNoExistentPath."; | |||
| // Create a EnWik9 Dataset. | |||
| // with non-existent dataset_files input. | |||
| std::string tf_file = datasets_root_path_ + "/testEnWik9Dataset"; | |||
| std::shared_ptr<Dataset> ds = EnWik9("/NotExist", 0, ShuffleMode::kFalse); | |||
| EXPECT_NE(ds, nullptr); | |||
| // Create an iterator over the result of the above dataset. | |||
| std::shared_ptr<Iterator> iter = ds->CreateIterator(); | |||
| // Expect failure: specified dataset_files does not exist | |||
| EXPECT_EQ(iter, nullptr); | |||
| } | |||
| /// Feature: EnWik9Dataset | |||
| /// Description: test EnWik9Dataset in pipeline mode | |||
| /// Expectation: the data of samples is proper | |||
| TEST_F(MindDataTestPipeline, TestEnWik9DatasetShuffleFalse1A) { | |||
| MS_LOG(INFO) << "Doing MindDataTestPipeline-TestEnWik9DatasetShuffleFalse1A."; | |||
| // Test EnWik9 Dataset with two enwik9 files and no shuffle, num_parallel_workers=1. | |||
| // Set configuration. | |||
| uint32_t original_seed = GlobalContext::config_manager()->seed(); | |||
| uint32_t original_num_parallel_workers = GlobalContext::config_manager()->num_parallel_workers(); | |||
| MS_LOG(DEBUG) << "ORIGINAL seed: " << original_seed << ", num_parallel_workers: " << original_num_parallel_workers; | |||
| GlobalContext::config_manager()->set_seed(654); | |||
| GlobalContext::config_manager()->set_num_parallel_workers(1); | |||
| // Create a EnWik9 Dataset, with one enwik9 file, /testEnWik9Dataset/enwik9. | |||
| // Note: /testEnWik9Dataset/enwik9 has 13 rows. | |||
| // Use default of all samples | |||
| std::string tf_file = datasets_root_path_ + "/testEnWik9Dataset"; | |||
| std::shared_ptr<Dataset> ds = EnWik9(tf_file, 0, ShuffleMode::kFalse); | |||
| EXPECT_NE(ds, nullptr); | |||
| // Create an iterator over the result of the above dataset. | |||
| // This will trigger the creation of the Execution Tree and launch it. | |||
| std::shared_ptr<Iterator> iter = ds->CreateIterator(); | |||
| EXPECT_NE(iter, nullptr); | |||
| // Iterate the dataset and get each row. | |||
| std::unordered_map<std::string, mindspore::MSTensor> row; | |||
| ASSERT_OK(iter->GetNextRow(&row)); | |||
| EXPECT_NE(row.find("text"), row.end()); | |||
| std::vector<std::string> expected_result = {" <page>", | |||
| " <title>MindSpore</title>", | |||
| " <id>1</id>", | |||
| " <revision>", | |||
| " <id>234</id>", | |||
| " <timestamp>2020-01-01T00:00:00Z</timestamp>", | |||
| " <contributor>", | |||
| " <username>MS</username>", | |||
| " <id>567</id>", | |||
| " </contributor>", | |||
| " <text xml:space=\"preserve\">666</text>", | |||
| " </revision>", | |||
| " </page>"}; | |||
| uint64_t i = 0; | |||
| while (row.size() != 0) { | |||
| auto text = row["text"]; | |||
| MS_LOG(INFO) << "Tensor text shape: " << text.Shape(); | |||
| std::shared_ptr<Tensor> de_text; | |||
| ASSERT_OK(Tensor::CreateFromMSTensor(text, &de_text)); | |||
| std::string_view sv; | |||
| ASSERT_OK(de_text->GetItemAt(&sv, {})); | |||
| std::string ss(sv); | |||
| MS_LOG(INFO) << "Text length: " << ss.length() << ", Text: " << ss.substr(0, 50); | |||
| // Compare against expected result. | |||
| EXPECT_STREQ(ss.c_str(), expected_result[i].c_str()); | |||
| i++; | |||
| ASSERT_OK(iter->GetNextRow(&row)); | |||
| } | |||
| // Expect 13 samples. | |||
| EXPECT_EQ(i, 13); | |||
| // Manually terminate the pipeline. | |||
| iter->Stop(); | |||
| // Restore configuration. | |||
| GlobalContext::config_manager()->set_seed(original_seed); | |||
| GlobalContext::config_manager()->set_num_parallel_workers(original_num_parallel_workers); | |||
| } | |||
| /// Feature: EnWik9Dataset | |||
| /// Description: test EnWik9Dataset in pipeline mode | |||
| /// Expectation: the data of samples is proper | |||
| TEST_F(MindDataTestPipeline, TestEnWik9DatasetShuffleFalse4Shard) { | |||
| MS_LOG(INFO) << "Doing MindDataTestPipeline-TestEnWik9DatasetShuffleFalse4Shard."; | |||
| // Test EnWik9 Dataset with one enwik9 files and no shuffle, num_parallel_workers=4, shard coverage. | |||
| // Set configuration. | |||
| uint32_t original_seed = GlobalContext::config_manager()->seed(); | |||
| uint32_t original_num_parallel_workers = GlobalContext::config_manager()->num_parallel_workers(); | |||
| MS_LOG(DEBUG) << "ORIGINAL seed: " << original_seed << ", num_parallel_workers: " << original_num_parallel_workers; | |||
| GlobalContext::config_manager()->set_seed(654); | |||
| GlobalContext::config_manager()->set_num_parallel_workers(4); | |||
| // Create a EnWik9 Dataset, with one enwik9 file. | |||
| // Note: /testEnWik9Dataset/enwik9 has 13 rows. | |||
| // Set shuffle to file shuffle, num_shards=2, shard_id=0 | |||
| std::string tf_file = datasets_root_path_ + "/testEnWik9Dataset"; | |||
| std::shared_ptr<Dataset> ds = EnWik9(tf_file, 0, ShuffleMode::kFalse, 2, 0); | |||
| EXPECT_NE(ds, nullptr); | |||
| // Create an iterator over the result of the above dataset. | |||
| // This will trigger the creation of the Execution Tree and launch it. | |||
| std::shared_ptr<Iterator> iter = ds->CreateIterator(); | |||
| EXPECT_NE(iter, nullptr); | |||
| // Iterate the dataset and get each row | |||
| std::unordered_map<std::string, mindspore::MSTensor> row; | |||
| ASSERT_OK(iter->GetNextRow(&row)); | |||
| EXPECT_NE(row.find("text"), row.end()); | |||
| std::vector<std::string> expected_result = {" <page>", | |||
| " <title>MindSpore</title>", | |||
| " <id>1</id>", | |||
| " <revision>", | |||
| " <id>234</id>", | |||
| " <timestamp>2020-01-01T00:00:00Z</timestamp>", | |||
| " <contributor>", | |||
| " <username>MS</username>", | |||
| " <id>567</id>", | |||
| " </contributor>", | |||
| " <text xml:space=\"preserve\">666</text>", | |||
| " </revision>", | |||
| " </page>"}; | |||
| uint64_t i = 0; | |||
| while (row.size() != 0) { | |||
| auto text = row["text"]; | |||
| MS_LOG(INFO) << "Tensor text shape: " << text.Shape(); | |||
| std::shared_ptr<Tensor> de_text; | |||
| ASSERT_OK(Tensor::CreateFromMSTensor(text, &de_text)); | |||
| std::string_view sv; | |||
| ASSERT_OK(de_text->GetItemAt(&sv, {})); | |||
| std::string ss(sv); | |||
| MS_LOG(INFO) << "Text length: " << ss.length() << ", Text: " << ss.substr(0, 50); | |||
| // Compare against expected result. | |||
| EXPECT_STREQ(ss.c_str(), expected_result[i].c_str()); | |||
| i++; | |||
| ASSERT_OK(iter->GetNextRow(&row)); | |||
| } | |||
| // Expect 7 samples for this shard. | |||
| EXPECT_EQ(i, 7); | |||
| // Manually terminate the pipeline. | |||
| iter->Stop(); | |||
| // Restore configuration. | |||
| GlobalContext::config_manager()->set_seed(original_seed); | |||
| GlobalContext::config_manager()->set_num_parallel_workers(original_num_parallel_workers); | |||
| } | |||
| /// Feature: EnWik9Dataset | |||
| /// Description: test EnWik9Dataset in pipeline mode | |||
| /// Expectation: the data of samples is proper | |||
| TEST_F(MindDataTestPipeline, TestEnWik9DatasetShuffleGlobal1A) { | |||
| MS_LOG(INFO) << "Doing MindDataTestPipeline-TestEnWik9DatasetShuffleGlobal1A."; | |||
| // Test EnWik9 Dataset with one enwik9 file, global shuffle, num_parallel_workers=1. | |||
| // Set configuration. | |||
| uint32_t original_seed = GlobalContext::config_manager()->seed(); | |||
| uint32_t original_num_parallel_workers = GlobalContext::config_manager()->num_parallel_workers(); | |||
| MS_LOG(DEBUG) << "ORIGINAL seed: " << original_seed << ", num_parallel_workers: " << original_num_parallel_workers; | |||
| GlobalContext::config_manager()->set_seed(246); | |||
| GlobalContext::config_manager()->set_num_parallel_workers(1); | |||
| // Create a EnWik9 Dataset, with one enwik9 file. | |||
| // Note: /testEnWik9Dataset/enwik9 has 13 rows. | |||
| // Set shuffle to global shuffle. | |||
| std::string tf_file = datasets_root_path_ + "/testEnWik9Dataset"; | |||
| std::shared_ptr<Dataset> ds = EnWik9(tf_file, 0, ShuffleMode::kGlobal); | |||
| EXPECT_NE(ds, nullptr); | |||
| // Create an iterator over the result of the above dataset. | |||
| // This will trigger the creation of the Execution Tree and launch it. | |||
| std::shared_ptr<Iterator> iter = ds->CreateIterator(); | |||
| EXPECT_NE(iter, nullptr); | |||
| // Iterate the dataset and get each row. | |||
| std::unordered_map<std::string, mindspore::MSTensor> row; | |||
| ASSERT_OK(iter->GetNextRow(&row)); | |||
| EXPECT_NE(row.find("text"), row.end()); | |||
| std::vector<std::string> expected_result = {" </contributor>", | |||
| " <page>", | |||
| " <contributor>", | |||
| " <username>MS</username>", | |||
| " <title>MindSpore</title>", | |||
| " <timestamp>2020-01-01T00:00:00Z</timestamp>", | |||
| " <text xml:space=\"preserve\">666</text>", | |||
| " <revision>", | |||
| " <id>567</id>", | |||
| " </revision>", | |||
| " </page>", | |||
| " <id>234</id>", | |||
| " <id>1</id>"}; | |||
| uint64_t i = 0; | |||
| while (row.size() != 0) { | |||
| auto text = row["text"]; | |||
| MS_LOG(INFO) << "Tensor text shape: " << text.Shape(); | |||
| std::shared_ptr<Tensor> de_text; | |||
| ASSERT_OK(Tensor::CreateFromMSTensor(text, &de_text)); | |||
| std::string_view sv; | |||
| ASSERT_OK(de_text->GetItemAt(&sv, {})); | |||
| std::string ss(sv); | |||
| MS_LOG(INFO) << "Text length: " << ss.length() << ", Text: " << ss.substr(0, 50); | |||
| // Compare against expected result. | |||
| EXPECT_STREQ(ss.c_str(), expected_result[i].c_str()); | |||
| i++; | |||
| ASSERT_OK(iter->GetNextRow(&row)); | |||
| } | |||
| // Expect 13 samples. | |||
| EXPECT_EQ(i, 13); | |||
| // Manually terminate the pipeline. | |||
| iter->Stop(); | |||
| // Restore configuration. | |||
| GlobalContext::config_manager()->set_seed(original_seed); | |||
| GlobalContext::config_manager()->set_num_parallel_workers(original_num_parallel_workers); | |||
| } | |||
| @@ -0,0 +1,13 @@ | |||
| <page> | |||
| <title>MindSpore</title> | |||
| <id>1</id> | |||
| <revision> | |||
| <id>234</id> | |||
| <timestamp>2020-01-01T00:00:00Z</timestamp> | |||
| <contributor> | |||
| <username>MS</username> | |||
| <id>567</id> | |||
| </contributor> | |||
| <text xml:space="preserve">666</text> | |||
| </revision> | |||
| </page> | |||
| @@ -0,0 +1,296 @@ | |||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================== | |||
| import pytest | |||
| import mindspore.dataset as ds | |||
| from mindspore import log as logger | |||
| from util import config_get_set_num_parallel_workers, config_get_set_seed | |||
| DATA_FILE = "../data/dataset/testEnWik9Dataset" | |||
| def test_enwik9_total_rows_dataset_num_samples_none(): | |||
| """ | |||
| Feature: EnWik9Dataset | |||
| Description: test the function while param num_samples = 0 | |||
| Expectation: the number of samples is 13 | |||
| """ | |||
| # Do not provide a num_samples argument, so it would be None by default. | |||
| data = ds.EnWik9Dataset(DATA_FILE) | |||
| count = 0 | |||
| for i in data.create_dict_iterator(num_epochs=1, output_numpy=True): | |||
| logger.info("{}".format(i["text"])) | |||
| count += 1 | |||
| assert count == 13 | |||
| def test_enwik9_total_rows_dataset_shuffle_false_parallel_worker_two(): | |||
| """ | |||
| Feature: EnWik9Dataset | |||
| Description: test the function while param shuffle = False | |||
| Expectation: the samples is ordered | |||
| """ | |||
| original_num_parallel_workers = config_get_set_num_parallel_workers(2) | |||
| original_seed = config_get_set_seed(987) | |||
| data = ds.EnWik9Dataset(DATA_FILE, shuffle=False) | |||
| count = 0 | |||
| line = [" <page>", | |||
| " <title>MindSpore</title>", | |||
| " <id>1</id>", | |||
| " <revision>", | |||
| " <id>234</id>", | |||
| " <timestamp>2020-01-01T00:00:00Z</timestamp>", | |||
| " <contributor>", | |||
| " <username>MS</username>", | |||
| " <id>567</id>", | |||
| " </contributor>", | |||
| " <text xml:space=\"preserve\">666</text>", | |||
| " </revision>", | |||
| " </page>"] | |||
| for i in data.create_dict_iterator(num_epochs=1, output_numpy=True): | |||
| strs = i["text"].item().decode("utf8") | |||
| assert strs == line[count] | |||
| count += 1 | |||
| assert count == 13 | |||
| # Restore configuration. | |||
| ds.config.set_num_parallel_workers(original_num_parallel_workers) | |||
| ds.config.set_seed(original_seed) | |||
| def test_enwik9_total_rows_dataset_shuffle_false_parallel_worker_one(): | |||
| """ | |||
| Feature: EnWik9Dataset | |||
| Description: test the function while param shuffle = False | |||
| Expectation: the samples is ordered | |||
| """ | |||
| original_num_parallel_workers = config_get_set_num_parallel_workers(1) | |||
| original_seed = config_get_set_seed(987) | |||
| data = ds.EnWik9Dataset(DATA_FILE, shuffle=False) | |||
| count = 0 | |||
| line = [" <page>", | |||
| " <title>MindSpore</title>", | |||
| " <id>1</id>", | |||
| " <revision>", | |||
| " <id>234</id>", | |||
| " <timestamp>2020-01-01T00:00:00Z</timestamp>", | |||
| " <contributor>", | |||
| " <username>MS</username>", | |||
| " <id>567</id>", | |||
| " </contributor>", | |||
| " <text xml:space=\"preserve\">666</text>", | |||
| " </revision>", | |||
| " </page>"] | |||
| for i in data.create_dict_iterator(num_epochs=1, output_numpy=True): | |||
| strs = i["text"].item().decode("utf8") | |||
| assert strs == line[count] | |||
| count += 1 | |||
| assert count == 13 | |||
| # Restore configuration. | |||
| ds.config.set_num_parallel_workers(original_num_parallel_workers) | |||
| ds.config.set_seed(original_seed) | |||
| def test_enwik9_total_rows_dataset_shuffle_true_parallel_worker_two(): | |||
| """ | |||
| Feature: EnWik9Dataset | |||
| Description: test the function while param shuffle = True | |||
| Expectation: the samples is disorder | |||
| """ | |||
| original_num_parallel_workers = config_get_set_num_parallel_workers(2) | |||
| original_seed = config_get_set_seed(135) | |||
| data = ds.EnWik9Dataset(DATA_FILE, shuffle=True) | |||
| count = 0 | |||
| line = [" <username>MS</username>", | |||
| " <title>MindSpore</title>", | |||
| " <id>234</id>", | |||
| " </revision>", | |||
| " </contributor>", | |||
| " <revision>", | |||
| " <id>567</id>", | |||
| " <timestamp>2020-01-01T00:00:00Z</timestamp>", | |||
| " <id>1</id>", | |||
| " </page>", | |||
| " <page>", | |||
| " <text xml:space=\"preserve\">666</text>", | |||
| " <contributor>"] | |||
| for i in data.create_dict_iterator(num_epochs=1, output_numpy=True): | |||
| strs = i["text"].item().decode("utf8") | |||
| assert strs == line[count] | |||
| count += 1 | |||
| assert count == 13 | |||
| # Restore configuration. | |||
| ds.config.set_num_parallel_workers(original_num_parallel_workers) | |||
| ds.config.set_seed(original_seed) | |||
| def test_enwik9_total_rows_dataset_shuffle_true_parallel_worker_one(): | |||
| """ | |||
| Feature: EnWik9Dataset | |||
| Description: test the function while param shuffle = True | |||
| Expectation: the samples is disorder | |||
| """ | |||
| original_num_parallel_workers = config_get_set_num_parallel_workers(1) | |||
| original_seed = config_get_set_seed(135) | |||
| data = ds.EnWik9Dataset(DATA_FILE, shuffle=True) | |||
| count = 0 | |||
| line = [" <username>MS</username>", | |||
| " <title>MindSpore</title>", | |||
| " <id>234</id>", | |||
| " </revision>", | |||
| " </contributor>", | |||
| " <revision>", | |||
| " <id>567</id>", | |||
| " <timestamp>2020-01-01T00:00:00Z</timestamp>", | |||
| " <id>1</id>", | |||
| " </page>", | |||
| " <page>", | |||
| " <text xml:space=\"preserve\">666</text>", | |||
| " <contributor>"] | |||
| for i in data.create_dict_iterator(num_epochs=1, output_numpy=True): | |||
| strs = i["text"].item().decode("utf8") | |||
| assert strs == line[count] | |||
| count += 1 | |||
| assert count == 13 | |||
| # Restore configuration. | |||
| ds.config.set_num_parallel_workers(original_num_parallel_workers) | |||
| ds.config.set_seed(original_seed) | |||
| def test_enwik9_dataset_num_samples(): | |||
| """ | |||
| Feature: EnWik9Dataset | |||
| Description: test param num_samples, while it = 2 | |||
| Expectation: the number of samples = 2 | |||
| """ | |||
| data = ds.EnWik9Dataset(DATA_FILE, num_samples=2) | |||
| count = 0 | |||
| for _ in data.create_dict_iterator(num_epochs=1, output_numpy=True): | |||
| count += 1 | |||
| assert count == 2 | |||
| def test_enwik9_dataset_distribution(): | |||
| """ | |||
| Feature: EnWik9Dataset | |||
| Description: test distribution of the dataset | |||
| Expectation: count = 7 | |||
| """ | |||
| data = ds.EnWik9Dataset(DATA_FILE, num_shards=2, shard_id=1) | |||
| count = 0 | |||
| for _ in data.create_dict_iterator(num_epochs=1, output_numpy=True): | |||
| count += 1 | |||
| assert count == 7 | |||
| def test_enwik9_total_rows_dataset_repeat(): | |||
| """ | |||
| Feature: EnWik9Dataset | |||
| Description: test the function whie the samples are repeat | |||
| Expectation: count = 26 | |||
| """ | |||
| data = ds.EnWik9Dataset(DATA_FILE, shuffle=False) | |||
| data = data.repeat(2) | |||
| count = 0 | |||
| line = [" <page>", | |||
| " <title>MindSpore</title>", | |||
| " <id>1</id>", | |||
| " <revision>", | |||
| " <id>234</id>", | |||
| " <timestamp>2020-01-01T00:00:00Z</timestamp>", | |||
| " <contributor>", | |||
| " <username>MS</username>", | |||
| " <id>567</id>", | |||
| " </contributor>", | |||
| " <text xml:space=\"preserve\">666</text>", | |||
| " </revision>", | |||
| " </page>", | |||
| " <page>", | |||
| " <title>MindSpore</title>", | |||
| " <id>1</id>", | |||
| " <revision>", | |||
| " <id>234</id>", | |||
| " <timestamp>2020-01-01T00:00:00Z</timestamp>", | |||
| " <contributor>", | |||
| " <username>MS</username>", | |||
| " <id>567</id>", | |||
| " </contributor>", | |||
| " <text xml:space=\"preserve\">666</text>", | |||
| " </revision>", | |||
| " </page>"] | |||
| for i in data.create_dict_iterator(num_epochs=1, output_numpy=True): | |||
| strs = i["text"].item().decode("utf8") | |||
| assert strs == line[count] | |||
| count += 1 | |||
| assert count == 26 | |||
| def test_enwik9_total_rows_dataset_get_datasetsize(): | |||
| """ | |||
| Feature: EnWik9Dataset | |||
| Description: test the function, get_dataset_size() | |||
| Expectation: size = 13 | |||
| """ | |||
| data = ds.EnWik9Dataset(DATA_FILE) | |||
| size = data.get_dataset_size() | |||
| assert size == 13 | |||
| def test_enwik9_total_rows_dataset_to_device(): | |||
| """ | |||
| Feature: EnWik9Dataset | |||
| Description: test the function, to_device() | |||
| Expectation: size = 13 | |||
| """ | |||
| data = ds.EnWik9Dataset(DATA_FILE, shuffle=False) | |||
| data = data.to_device() | |||
| data.send() | |||
| def test_enwik9_dataset_exceptions(): | |||
| """ | |||
| Feature: EnWik9Dataset | |||
| Description: test the errors which appear possibly | |||
| Expectation: the errors are expected correctly | |||
| """ | |||
| with pytest.raises(ValueError) as error_info: | |||
| _ = ds.EnWik9Dataset("does/not/exist/") | |||
| assert "does not exist or is not a directory or permission denied" in str(error_info.value) | |||
| with pytest.raises(ValueError) as error_info: | |||
| _ = ds.EnWik9Dataset("") | |||
| assert "The folder does not exist or is not a directory or permission denied" in str(error_info.value) | |||
| def exception_func(item): | |||
| raise Exception("Error occur!") | |||
| with pytest.raises(RuntimeError) as error_info: | |||
| data = ds.EnWik9Dataset(DATA_FILE) | |||
| data = data.map(operations=exception_func, input_columns=["text"], num_parallel_workers=1) | |||
| for _ in data.__iter__(): | |||
| pass | |||
| assert "map operation: [PyFunc] failed. The corresponding data files" in str(error_info.value) | |||
| if __name__ == "__main__": | |||
| test_enwik9_total_rows_dataset_num_samples_none() | |||
| test_enwik9_total_rows_dataset_shuffle_false_parallel_worker_two() | |||
| test_enwik9_total_rows_dataset_shuffle_false_parallel_worker_one() | |||
| test_enwik9_total_rows_dataset_shuffle_true_parallel_worker_two() | |||
| test_enwik9_total_rows_dataset_shuffle_true_parallel_worker_one() | |||
| test_enwik9_dataset_num_samples() | |||
| test_enwik9_dataset_distribution() | |||
| test_enwik9_total_rows_dataset_repeat() | |||
| test_enwik9_total_rows_dataset_get_datasetsize() | |||
| test_enwik9_total_rows_dataset_to_device() | |||
| test_enwik9_dataset_exceptions() | |||