Merge pull request !21648 from 杨旭华/YesNoDatasettags/v1.6.0
| @@ -115,6 +115,7 @@ | |||
| #include "minddata/dataset/engine/ir/datasetops/source/tf_record_node.h" | |||
| #include "minddata/dataset/engine/ir/datasetops/source/usps_node.h" | |||
| #include "minddata/dataset/engine/ir/datasetops/source/voc_node.h" | |||
| #include "minddata/dataset/engine/ir/datasetops/source/yes_no_node.h" | |||
| #endif | |||
| namespace mindspore { | |||
| @@ -1543,6 +1544,27 @@ TFRecordDataset::TFRecordDataset(const std::vector<std::vector<char>> &dataset_f | |||
| ir_node_ = std::static_pointer_cast<DatasetNode>(ds); | |||
| } | |||
| YesNoDataset::YesNoDataset(const std::vector<char> &dataset_dir, const std::shared_ptr<Sampler> &sampler, | |||
| const std::shared_ptr<DatasetCache> &cache) { | |||
| auto sampler_obj = sampler ? sampler->Parse() : nullptr; | |||
| auto ds = std::make_shared<YesNoNode>(CharToString(dataset_dir), sampler_obj, cache); | |||
| ir_node_ = std::static_pointer_cast<DatasetNode>(ds); | |||
| } | |||
| YesNoDataset::YesNoDataset(const std::vector<char> &dataset_dir, const Sampler *sampler, | |||
| const std::shared_ptr<DatasetCache> &cache) { | |||
| auto sampler_obj = sampler ? sampler->Parse() : nullptr; | |||
| auto ds = std::make_shared<YesNoNode>(CharToString(dataset_dir), sampler_obj, cache); | |||
| ir_node_ = std::static_pointer_cast<DatasetNode>(ds); | |||
| } | |||
| YesNoDataset::YesNoDataset(const std::vector<char> &dataset_dir, const std::reference_wrapper<Sampler> sampler, | |||
| const std::shared_ptr<DatasetCache> &cache) { | |||
| auto sampler_obj = sampler.get().Parse(); | |||
| auto ds = std::make_shared<YesNoNode>(CharToString(dataset_dir), sampler_obj, cache); | |||
| ir_node_ = std::static_pointer_cast<DatasetNode>(ds); | |||
| } | |||
| #endif | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| @@ -44,6 +44,7 @@ | |||
| #include "minddata/dataset/engine/ir/datasetops/source/mnist_node.h" | |||
| #include "minddata/dataset/engine/ir/datasetops/source/random_node.h" | |||
| #include "minddata/dataset/engine/ir/datasetops/source/text_file_node.h" | |||
| #include "minddata/dataset/engine/ir/datasetops/source/yes_no_node.h" | |||
| // IR leaf nodes disabled for android | |||
| #ifndef ENABLE_ANDROID | |||
| @@ -448,5 +449,15 @@ PYBIND_REGISTER(VOCNode, 2, ([](const py::module *m) { | |||
| })); | |||
| })); | |||
| PYBIND_REGISTER(YesNoNode, 2, ([](const py::module *m) { | |||
| (void)py::class_<YesNoNode, DatasetNode, std::shared_ptr<YesNoNode>>(*m, "YesNoNode", | |||
| "to create a YesNoNode") | |||
| .def(py::init([](std::string dataset_dir, py::handle sampler) { | |||
| auto yes_no = std::make_shared<YesNoNode>(dataset_dir, toSamplerObj(sampler), nullptr); | |||
| THROW_IF_ERROR(yes_no->ValidateParams()); | |||
| return yes_no; | |||
| })); | |||
| })); | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| @@ -3,33 +3,34 @@ file(GLOB_RECURSE _CURRENT_SRC_FILES RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "*.cc" | |||
| set_property(SOURCE ${_CURRENT_SRC_FILES} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_MD) | |||
| set(DATASET_ENGINE_DATASETOPS_SOURCE_SRC_FILES | |||
| io_block.cc | |||
| image_folder_op.cc | |||
| mnist_op.cc | |||
| coco_op.cc | |||
| cifar_op.cc | |||
| random_data_op.cc | |||
| ag_news_op.cc | |||
| album_op.cc | |||
| celeba_op.cc | |||
| sbu_op.cc | |||
| text_file_op.cc | |||
| cifar_op.cc | |||
| cityscapes_op.cc | |||
| clue_op.cc | |||
| coco_op.cc | |||
| csv_op.cc | |||
| album_op.cc | |||
| usps_op.cc | |||
| mappable_leaf_op.cc | |||
| nonmappable_leaf_op.cc | |||
| cityscapes_op.cc | |||
| dbpedia_op.cc | |||
| div2k_op.cc | |||
| flickr_op.cc | |||
| qmnist_op.cc | |||
| emnist_op.cc | |||
| fake_image_op.cc | |||
| fashion_mnist_op.cc | |||
| flickr_op.cc | |||
| image_folder_op.cc | |||
| io_block.cc | |||
| lj_speech_op.cc | |||
| places365_op.cc | |||
| mappable_leaf_op.cc | |||
| mnist_op.cc | |||
| nonmappable_leaf_op.cc | |||
| photo_tour_op.cc | |||
| fashion_mnist_op.cc | |||
| ag_news_op.cc | |||
| dbpedia_op.cc | |||
| places365_op.cc | |||
| qmnist_op.cc | |||
| random_data_op.cc | |||
| sbu_op.cc | |||
| text_file_op.cc | |||
| usps_op.cc | |||
| yes_no_op.cc | |||
| ) | |||
| set(DATASET_ENGINE_DATASETOPS_SOURCE_SRC_FILES | |||
| @@ -0,0 +1,148 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "minddata/dataset/engine/datasetops/source/yes_no_op.h" | |||
| #include <algorithm> | |||
| #include <fstream> | |||
| #include <iomanip> | |||
| #include <regex> | |||
| #include <set> | |||
| #include "minddata/dataset/audio/kernels/audio_utils.h" | |||
| #include "minddata/dataset/core/config_manager.h" | |||
| #include "minddata/dataset/core/tensor_shape.h" | |||
| #include "minddata/dataset/engine/datasetops/source/sampler/sequential_sampler.h" | |||
| #include "minddata/dataset/engine/execution_tree.h" | |||
| #include "utils/file_utils.h" | |||
| #include "utils/ms_utils.h" | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| constexpr float kMaxShortVal = 32767.0; | |||
| constexpr char kExtension[] = ".wav"; | |||
| constexpr int kStrLen = 15; // the length of name. | |||
| #ifndef _WIN32 | |||
| constexpr char kSplitSymbol[] = "/"; | |||
| #else | |||
| constexpr char kSplitSymbol[] = "\\"; | |||
| #endif | |||
| YesNoOp::YesNoOp(const std::string &file_dir, int32_t num_workers, int32_t queue_size, | |||
| std::unique_ptr<DataSchema> data_schema, std::shared_ptr<SamplerRT> sampler) | |||
| : MappableLeafOp(num_workers, queue_size, std::move(sampler)), | |||
| dataset_dir_(file_dir), | |||
| data_schema_(std::move(data_schema)) {} | |||
| Status YesNoOp::PrepareData() { | |||
| auto realpath = FileUtils::GetRealPath(dataset_dir_.data()); | |||
| if (!realpath.has_value()) { | |||
| MS_LOG(ERROR) << "Get real path failed, path=" << dataset_dir_; | |||
| RETURN_STATUS_UNEXPECTED("Get real path failed, path=" + dataset_dir_); | |||
| } | |||
| Path dir(realpath.value()); | |||
| if (dir.Exists() == false || dir.IsDirectory() == false) { | |||
| RETURN_STATUS_UNEXPECTED("Invalid parameter, failed to open speech commands: " + dataset_dir_); | |||
| } | |||
| std::shared_ptr<Path::DirIterator> dir_itr = Path::DirIterator::OpenDirectory(&dir); | |||
| RETURN_UNEXPECTED_IF_NULL(dir_itr); | |||
| while (dir_itr->HasNext()) { | |||
| Path file = dir_itr->Next(); | |||
| if (file.Extension() == kExtension) { | |||
| all_wave_files_.emplace_back(file.ToString()); | |||
| } | |||
| } | |||
| CHECK_FAIL_RETURN_UNEXPECTED(!all_wave_files_.empty(), "Invalid file, no .wav files found under " + dataset_dir_); | |||
| num_rows_ = all_wave_files_.size(); | |||
| return Status::OK(); | |||
| } | |||
| void YesNoOp::Print(std::ostream &out, bool show_all) const { | |||
| if (!show_all) { | |||
| ParallelOp::Print(out, show_all); | |||
| out << "\n"; | |||
| } else { | |||
| ParallelOp::Print(out, show_all); | |||
| out << "\nNumber of rows: " << num_rows_ << "\nYesNo directory: " << dataset_dir_ << "\n\n"; | |||
| } | |||
| } | |||
| Status YesNoOp::Split(const std::string &line, std::vector<int32_t> *split_num) { | |||
| RETURN_UNEXPECTED_IF_NULL(split_num); | |||
| std::string str = line; | |||
| int dot_pos = str.find_last_of(kSplitSymbol); | |||
| std::string sub_line = line.substr(dot_pos + 1, kStrLen); // (dot_pos + 1) because the index start from 0. | |||
| std::string::size_type pos; | |||
| std::vector<std::string> split; | |||
| sub_line += "_"; // append to sub_line indicating the end of the string. | |||
| uint32_t size = sub_line.size(); | |||
| for (uint32_t index = 0; index < size;) { | |||
| pos = sub_line.find("_", index); | |||
| if (pos != index) { | |||
| std::string s = sub_line.substr(index, pos - index); | |||
| split.emplace_back(s); | |||
| } | |||
| index = pos + 1; | |||
| } | |||
| try { | |||
| for (int i = 0; i < split.size(); i++) { | |||
| split_num->emplace_back(stoi(split[i])); | |||
| } | |||
| } catch (const std::exception &e) { | |||
| MS_LOG(ERROR) << "Converting char to int confront with an error in function stoi()."; | |||
| RETURN_STATUS_UNEXPECTED("Converting char to int confront with an error in function stoi()."); | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| Status YesNoOp::LoadTensorRow(row_id_type index, TensorRow *trow) { | |||
| RETURN_UNEXPECTED_IF_NULL(trow); | |||
| std::shared_ptr<Tensor> waveform, sample_rate_scalar, label_scalar; | |||
| int32_t sample_rate; | |||
| std::string file_name = all_wave_files_[index]; | |||
| std::vector<int32_t> label; | |||
| std::vector<float> waveform_vec; | |||
| RETURN_IF_NOT_OK(Split(file_name, &label)); | |||
| RETURN_IF_NOT_OK(ReadWaveFile(file_name, &waveform_vec, &sample_rate)); | |||
| RETURN_IF_NOT_OK(Tensor::CreateFromVector(waveform_vec, &waveform)); | |||
| RETURN_IF_NOT_OK(waveform->ExpandDim(0)); | |||
| RETURN_IF_NOT_OK(Tensor::CreateScalar(sample_rate, &sample_rate_scalar)); | |||
| RETURN_IF_NOT_OK(Tensor::CreateFromVector(label, &label_scalar)); | |||
| (*trow) = TensorRow(index, {waveform, sample_rate_scalar, label_scalar}); | |||
| trow->setPath({file_name, file_name, file_name}); | |||
| return Status::OK(); | |||
| } | |||
| Status YesNoOp::CountTotalRows(int64_t *count) { | |||
| RETURN_UNEXPECTED_IF_NULL(count); | |||
| if (all_wave_files_.size() == 0) { | |||
| RETURN_IF_NOT_OK(PrepareData()); | |||
| } | |||
| *count = static_cast<int64_t>(all_wave_files_.size()); | |||
| return Status::OK(); | |||
| } | |||
| Status YesNoOp::ComputeColMap() { | |||
| if (column_name_id_map_.empty()) { | |||
| for (int32_t i = 0; i < data_schema_->NumColumns(); ++i) { | |||
| column_name_id_map_[data_schema_->Column(i).Name()] = i; | |||
| } | |||
| } else { | |||
| MS_LOG(WARNING) << "Column name map is already set!"; | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,92 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_SOURCE_YES_NO_OP_H_ | |||
| #define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_SOURCE_YES_NO_OP_H_ | |||
| #include <map> | |||
| #include <memory> | |||
| #include <string> | |||
| #include <utility> | |||
| #include <vector> | |||
| #include "minddata/dataset/core/tensor.h" | |||
| #include "minddata/dataset/engine/data_schema.h" | |||
| #include "minddata/dataset/engine/datasetops/parallel_op.h" | |||
| #include "minddata/dataset/engine/datasetops/source/mappable_leaf_op.h" | |||
| #include "minddata/dataset/engine/datasetops/source/sampler/sampler.h" | |||
| #include "minddata/dataset/util/path.h" | |||
| #include "minddata/dataset/util/queue.h" | |||
| #include "minddata/dataset/util/services.h" | |||
| #include "minddata/dataset/util/status.h" | |||
| #include "minddata/dataset/util/wait_post.h" | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| class YesNoOp : public MappableLeafOp { | |||
| public: | |||
| /// Constructor. | |||
| /// @param std::string file_dir - dir directory of YesNo. | |||
| /// @param int32_t num_workers - number of workers reading images in parallel. | |||
| /// @param int32_t queue_size - connector queue size. | |||
| /// @param std::unique_ptr<DataSchema> data_schema - the schema of the YesNo dataset. | |||
| /// @param std::shared_ptr<Sampler> sampler - sampler tells YesNoOp what to read. | |||
| YesNoOp(const std::string &file_dir, int32_t num_workers, int32_t queue_size, std::unique_ptr<DataSchema> data_schema, | |||
| std::shared_ptr<SamplerRT> sampler); | |||
| /// Destructor. | |||
| ~YesNoOp() = default; | |||
| /// A print method typically used for debugging. | |||
| /// @param std::ostream &out - out stream. | |||
| /// @param bool show_all - whether to show all information. | |||
| void Print(std::ostream &out, bool show_all) const override; | |||
| /// Op name getter. | |||
| /// @return Name of the current Op. | |||
| std::string Name() const override { return "YesNoOp"; } | |||
| /// @param int64_t *count - output rows number of YesNoDataset. | |||
| /// @return Status - The status code returned. | |||
| Status CountTotalRows(int64_t *count); | |||
| private: | |||
| /// Load a tensor row according to wave id. | |||
| /// @param row_id_type row_id - id for this tensor row. | |||
| /// @param TensorRow trow - wave & target read into this tensor row. | |||
| /// @return Status - The status code returned. | |||
| Status LoadTensorRow(row_id_type row_id, TensorRow *trow) override; | |||
| /// Get file infos by file name. | |||
| /// @param string line - file name. | |||
| /// @param vector split_num - vector of annotation. | |||
| /// @return Status - The status code returned. | |||
| Status Split(const std::string &line, std::vector<int32_t> *split_num); | |||
| /// Initialize YesNoDataset related var, calls the function to walk all files. | |||
| /// @return Status - The status code returned. | |||
| Status PrepareData(); | |||
| /// Private function for computing the assignment of the column name map. | |||
| /// @return Status - The status code returned. | |||
| Status ComputeColMap() override; | |||
| std::vector<std::string> all_wave_files_; | |||
| std::string dataset_dir_; | |||
| std::unique_ptr<DataSchema> data_schema_; | |||
| }; | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_SOURCE_YES_NO_OP_H | |||
| @@ -104,6 +104,7 @@ constexpr char kTextFileNode[] = "TextFileDataset"; | |||
| constexpr char kTFRecordNode[] = "TFRecordDataset"; | |||
| constexpr char kUSPSNode[] = "USPSDataset"; | |||
| constexpr char kVOCNode[] = "VOCDataset"; | |||
| constexpr char kYesNoNode[] = "YesNoDataset"; | |||
| Status AddShuffleOp(int64_t num_files, int64_t num_devices, int64_t num_rows, int64_t total_rows, | |||
| int32_t connector_que_size, std::shared_ptr<DatasetOp> *shuffle_op); | |||
| @@ -32,6 +32,7 @@ set(DATASET_ENGINE_IR_DATASETOPS_SOURCE_SRC_FILES | |||
| tf_record_node.cc | |||
| usps_node.cc | |||
| voc_node.cc | |||
| yes_no_node.cc | |||
| ) | |||
| if(ENABLE_PYTHON) | |||
| @@ -0,0 +1,115 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "minddata/dataset/engine/ir/datasetops/source/yes_no_node.h" | |||
| #include <utility> | |||
| #include "minddata/dataset/engine/datasetops/source/yes_no_op.h" | |||
| #include "minddata/dataset/util/status.h" | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| // Constructor for YesNoNode. | |||
| YesNoNode::YesNoNode(const std::string &dataset_dir, std::shared_ptr<SamplerObj> sampler, | |||
| std::shared_ptr<DatasetCache> cache) | |||
| : MappableSourceNode(std::move(cache)), dataset_dir_(dataset_dir), sampler_(sampler) {} | |||
| std::shared_ptr<DatasetNode> YesNoNode::Copy() { | |||
| std::shared_ptr<SamplerObj> sampler = (sampler_ == nullptr) ? nullptr : sampler_->SamplerCopy(); | |||
| auto node = std::make_shared<YesNoNode>(dataset_dir_, sampler, cache_); | |||
| return node; | |||
| } | |||
| void YesNoNode::Print(std::ostream &out) const { | |||
| out << (Name() + "(cache: " + ((cache_ != nullptr) ? "true" : "false") + ")"); | |||
| } | |||
| Status YesNoNode::ValidateParams() { | |||
| RETURN_IF_NOT_OK(DatasetNode::ValidateParams()); | |||
| RETURN_IF_NOT_OK(ValidateDatasetDirParam("YesNoNode", dataset_dir_)); | |||
| RETURN_IF_NOT_OK(ValidateDatasetSampler("YesNoNode", sampler_)); | |||
| return Status::OK(); | |||
| } | |||
| Status YesNoNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops) { | |||
| // Do internal Schema generation. | |||
| auto schema = std::make_unique<DataSchema>(); | |||
| RETURN_IF_NOT_OK( | |||
| schema->AddColumn(ColDescriptor("waveform", DataType(DataType::DE_FLOAT32), TensorImpl::kFlexible, 1))); | |||
| TensorShape sample_rate_scalar = TensorShape::CreateScalar(); | |||
| TensorShape lable_scalar = TensorShape::CreateScalar(); | |||
| RETURN_IF_NOT_OK(schema->AddColumn( | |||
| ColDescriptor("sample_rate", DataType(DataType::DE_INT32), TensorImpl::kFlexible, 0, &sample_rate_scalar))); | |||
| RETURN_IF_NOT_OK( | |||
| schema->AddColumn(ColDescriptor("label", DataType(DataType::DE_INT32), TensorImpl::kFlexible, 0, &lable_scalar))); | |||
| std::shared_ptr<SamplerRT> sampler_rt = nullptr; | |||
| RETURN_IF_NOT_OK(sampler_->SamplerBuild(&sampler_rt)); | |||
| auto op = std::make_shared<YesNoOp>(dataset_dir_, num_workers_, connector_que_size_, std::move(schema), | |||
| std::move(sampler_rt)); | |||
| op->SetTotalRepeats(GetTotalRepeats()); | |||
| op->SetNumRepeatsPerEpoch(GetNumRepeatsPerEpoch()); | |||
| node_ops->push_back(op); | |||
| return Status::OK(); | |||
| } | |||
| Status YesNoNode::GetShardId(int32_t *shard_id) { | |||
| *shard_id = sampler_->ShardId(); | |||
| return Status::OK(); | |||
| } | |||
| Status YesNoNode::GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size_getter, bool estimate, | |||
| int64_t *dataset_size) { | |||
| if (dataset_size_ > 0) { | |||
| *dataset_size = dataset_size_; | |||
| return Status::OK(); | |||
| } | |||
| int64_t num_rows, sample_size; | |||
| std::vector<std::shared_ptr<DatasetOp>> ops; | |||
| RETURN_IF_NOT_OK(Build(&ops)); | |||
| CHECK_FAIL_RETURN_UNEXPECTED(!ops.empty(), "Unable to build YesNoOp."); | |||
| auto op = std::dynamic_pointer_cast<YesNoOp>(ops.front()); | |||
| RETURN_IF_NOT_OK(op->CountTotalRows(&num_rows)); | |||
| std::shared_ptr<SamplerRT> sampler_rt = nullptr; | |||
| RETURN_IF_NOT_OK(sampler_->SamplerBuild(&sampler_rt)); | |||
| sample_size = sampler_rt->CalculateNumSamples(num_rows); | |||
| if (sample_size == -1) { | |||
| RETURN_IF_NOT_OK(size_getter->DryRun(shared_from_this(), &sample_size)); | |||
| } | |||
| *dataset_size = sample_size; | |||
| dataset_size_ = *dataset_size; | |||
| return Status::OK(); | |||
| } | |||
| Status YesNoNode::to_json(nlohmann::json *out_json) { | |||
| nlohmann::json args, sampler_args; | |||
| RETURN_IF_NOT_OK(sampler_->to_json(&sampler_args)); | |||
| args["sampler"] = sampler_args; | |||
| args["num_parallel_workers"] = num_workers_; | |||
| args["dataset_dir"] = dataset_dir_; | |||
| if (cache_ != nullptr) { | |||
| nlohmann::json cache_args; | |||
| RETURN_IF_NOT_OK(cache_->to_json(&cache_args)); | |||
| args["cache"] = cache_args; | |||
| } | |||
| *out_json = args; | |||
| return Status::OK(); | |||
| } | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,92 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_YES_NO_NODE_H | |||
| #define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_YES_NO_NODE_H | |||
| #include <memory> | |||
| #include <string> | |||
| #include <vector> | |||
| #include "minddata/dataset/engine/ir/datasetops/dataset_node.h" | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| class YesNoNode : public MappableSourceNode { | |||
| public: | |||
| /// \brief Constructor. | |||
| YesNoNode(const std::string &dataset_dir, std::shared_ptr<SamplerObj> sampler, std::shared_ptr<DatasetCache> cache); | |||
| /// \brief Destructor. | |||
| ~YesNoNode() = default; | |||
| /// \brief Node name getter. | |||
| /// \return Name of the current node. | |||
| std::string Name() const override { return "YesNoNode"; } | |||
| /// \brief Print the description. | |||
| /// \param out - The output stream to write output to. | |||
| void Print(std::ostream &out) const override; | |||
| /// \brief Copy the node to a new object. | |||
| /// \return A shared pointer to the new copy. | |||
| std::shared_ptr<DatasetNode> Copy() override; | |||
| /// \brief a base class override function to create the required runtime dataset op objects for this class. | |||
| /// \param node_ops - A vector containing shared pointer to the Dataset Ops that this object will create. | |||
| /// \return Status Status::OK() if build successfully. | |||
| Status Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops) override; | |||
| /// \brief Parameters validation. | |||
| /// \return Status Status::OK() if all the parameters are valid. | |||
| Status ValidateParams() override; | |||
| /// \brief Get the shard id of node. | |||
| /// \param[in] shard_id Shard id. | |||
| /// \return Status Status::OK() if get shard id successfully. | |||
| Status GetShardId(int32_t *shard_id) override; | |||
| /// \brief Base-class override for GetDatasetSize. | |||
| /// \param[in] size_getter Shared pointer to DatasetSizeGetter. | |||
| /// \param[in] estimate This is only supported by some of the ops and it's used to speed up the process of getting | |||
| /// dataset size at the expense of accuracy. | |||
| /// \param[out] dataset_size the size of the dataset. | |||
| /// \return Status of the function. | |||
| Status GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size_getter, bool estimate, | |||
| int64_t *dataset_size) override; | |||
| /// \brief Getter functions. | |||
| const std::string &DatasetDir() const { return dataset_dir_; } | |||
| /// \brief Get the arguments of node. | |||
| /// \param[out] out_json JSON string of all attributes. | |||
| /// \return Status of the function. | |||
| Status to_json(nlohmann::json *out_json) override; | |||
| /// \brief Sampler getter. | |||
| /// \return SamplerObj of the current node. | |||
| std::shared_ptr<SamplerObj> Sampler() override { return sampler_; } | |||
| /// \brief Sampler setter. | |||
| /// \param[in] sampler Sampler object used to choose samples from the dataset. | |||
| void SetSampler(std::shared_ptr<SamplerObj> sampler) override { sampler_ = sampler; } | |||
| private: | |||
| std::string dataset_dir_; | |||
| std::shared_ptr<SamplerObj> sampler_; | |||
| }; | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_YES_NO_NODE_H | |||
| @@ -3919,6 +3919,72 @@ inline std::shared_ptr<VOCDataset> VOC(const std::string &dataset_dir, const std | |||
| MapStringToChar(class_indexing), decode, sampler, cache, extra_metadata); | |||
| } | |||
| /// \class YesNoDataset. | |||
| /// \brief A source dataset for reading and parsing YesNo dataset. | |||
| class YesNoDataset : public Dataset { | |||
| public: | |||
| /// \brief Constructor of YesNoDataset. | |||
| /// \param[in] dataset_dir Path to the root directory that contains the dataset. | |||
| /// \param[in] sampler Shared pointer to a sampler object used to choose samples from the dataset. If sampler is not | |||
| /// given, a `RandomSampler` will be used to randomly iterate the entire dataset. | |||
| /// \param[in] cache Tensor cache to use. | |||
| YesNoDataset(const std::vector<char> &dataset_dir, const std::shared_ptr<Sampler> &sampler, | |||
| const std::shared_ptr<DatasetCache> &cache); | |||
| /// \brief Constructor of YesNoDataset. | |||
| /// \param[in] dataset_dir Path to the root directory that contains the dataset. | |||
| /// \param[in] sampler Raw pointer to a sampler object used to choose samples from the dataset. | |||
| /// \param[in] cache Tensor cache to use. | |||
| YesNoDataset(const std::vector<char> &dataset_dir, const Sampler *sampler, | |||
| const std::shared_ptr<DatasetCache> &cache); | |||
| /// \brief Constructor of YesNoDataset. | |||
| /// \param[in] dataset_dir Path to the root directory that contains the dataset. | |||
| /// \param[in] sampler Sampler object used to choose samples from the dataset. | |||
| /// \param[in] cache Tensor cache to use. | |||
| YesNoDataset(const std::vector<char> &dataset_dir, const std::reference_wrapper<Sampler> sampler, | |||
| const std::shared_ptr<DatasetCache> &cache); | |||
| /// Destructor of YesNoDataset. | |||
| ~YesNoDataset() = default; | |||
| }; | |||
| /// \brief Function to create a YesNo Dataset. | |||
| /// \note The generated dataset has three columns ["waveform", "sample_rate", "label"]. | |||
| /// \param[in] dataset_dir Path to the root directory that contains the dataset. | |||
| /// \param[in] sampler Shared pointer to a sampler object used to choose samples from the dataset. If sampler is not | |||
| /// given, a `RandomSampler` will be used to randomly iterate the entire dataset (default = RandomSampler()). | |||
| /// \param[in] cache Tensor cache to use (default=nullptr, which means no cache is used). | |||
| /// \return Shared pointer to the current Dataset. | |||
| inline std::shared_ptr<YesNoDataset> YesNo(const std::string &dataset_dir, | |||
| const std::shared_ptr<Sampler> &sampler = std::make_shared<RandomSampler>(), | |||
| const std::shared_ptr<DatasetCache> &cache = nullptr) { | |||
| return std::make_shared<YesNoDataset>(StringToChar(dataset_dir), sampler, cache); | |||
| } | |||
| /// \brief Function to create a YesNo Dataset. | |||
| /// \note The generated dataset has three columns ["waveform", "sample_rate", "label"]. | |||
| /// \param[in] dataset_dir Path to the root directory that contains the dataset. | |||
| /// \param[in] sampler Raw pointer to a sampler object used to choose samples from the dataset. | |||
| /// \param[in] cache Tensor cache to use (default=nullptr, which means no cache is used). | |||
| /// \return Shared pointer to the current Dataset. | |||
| inline std::shared_ptr<YesNoDataset> YesNo(const std::string &dataset_dir, Sampler *sampler, | |||
| const std::shared_ptr<DatasetCache> &cache = nullptr) { | |||
| return std::make_shared<YesNoDataset>(StringToChar(dataset_dir), sampler, cache); | |||
| } | |||
| /// \brief Function to create a YesNo Dataset. | |||
| /// \note The generated dataset has three columns ["waveform", "sample_rate", "label"]. | |||
| /// \param[in] dataset_dir Path to the root directory that contains the dataset. | |||
| /// \param[in] sampler Sampler object used to choose samples from the dataset. | |||
| /// \param[in] cache Tensor cache to use (default=nullptr, which means no cache is used). | |||
| /// \return Shared pointer to the current Dataset. | |||
| inline std::shared_ptr<YesNoDataset> YesNo(const std::string &dataset_dir, | |||
| const std::reference_wrapper<Sampler> sampler, | |||
| const std::shared_ptr<DatasetCache> &cache = nullptr) { | |||
| return std::make_shared<YesNoDataset>(StringToChar(dataset_dir), sampler, cache); | |||
| } | |||
| /// \brief Function to create a cache to be attached to a dataset. | |||
| /// \note The reason for providing this API is that std::string will be constrained by the | |||
| /// compiler option '_GLIBCXX_USE_CXX11_ABI' while char is free of this restriction. | |||
| @@ -57,6 +57,7 @@ class Sampler : std::enable_shared_from_this<Sampler> { | |||
| friend class TFRecordDataset; | |||
| friend class USPSDataset; | |||
| friend class VOCDataset; | |||
| friend class YesNoDataset; | |||
| friend std::shared_ptr<SamplerObj> SelectSampler(int64_t, bool, int32_t, int32_t); | |||
| public: | |||
| @@ -68,7 +68,8 @@ from .validators import check_batch, check_shuffle, check_map, check_filter, che | |||
| check_tuple_iterator, check_dict_iterator, check_schema, check_to_device_send, check_flickr_dataset, \ | |||
| check_sb_dataset, check_flowers102dataset, check_cityscapes_dataset, check_usps_dataset, check_div2k_dataset, \ | |||
| check_sbu_dataset, check_qmnist_dataset, check_emnist_dataset, check_fake_image_dataset, check_places365_dataset, \ | |||
| check_photo_tour_dataset, check_ag_news_dataset, check_dbpedia_dataset, check_lj_speech_dataset | |||
| check_photo_tour_dataset, check_ag_news_dataset, check_dbpedia_dataset, check_lj_speech_dataset, \ | |||
| check_yes_no_dataset | |||
| from ..core.config import get_callback_timeout, _init_device_info, get_enable_shared_mem, get_num_parallel_workers, \ | |||
| get_prefetch_size, get_auto_offload | |||
| from ..core.datatypes import mstype_to_detype, mstypelist_to_detypelist | |||
| @@ -8362,3 +8363,116 @@ class DIV2KDataset(MappableDataset): | |||
| def parse(self, children=None): | |||
| return cde.DIV2KNode(self.dataset_dir, self.usage, self.downgrade, self.scale, self.decode, self.sampler) | |||
| class YesNoDataset(MappableDataset): | |||
| """ | |||
| A source dataset for reading and parsing the YesNo dataset. | |||
| The generated dataset has three columns :py:obj:`[waveform, sample_rate, labels]`. | |||
| The tensor of column :py:obj:`waveform` is a vector of the float32 type. | |||
| The tensor of column :py:obj:`sample_rate` is a scalar of the int32 type. | |||
| The tensor of column :py:obj:`labels` is a scalar of the int32 type. | |||
| Args: | |||
| dataset_dir (str): Path to the root directory that contains the dataset. | |||
| num_samples (int, optional): The number of images to be included in the dataset | |||
| (default=None, will read all images). | |||
| num_parallel_workers (int, optional): Number of workers to read the data | |||
| (default=None, will use value set in the config). | |||
| shuffle (bool, optional): Whether or not to perform shuffle on the dataset | |||
| (default=None, expected order behavior shown in the table). | |||
| sampler (Sampler, optional): Object used to choose samples from the | |||
| dataset (default=None, expected order behavior shown in the table). | |||
| num_shards (int, optional): Number of shards that the dataset will be divided into (default=None). | |||
| When this argument is specified, `num_samples` reflects the maximum sample number of per shard. | |||
| shard_id (int, optional): The shard ID within `num_shards` (default=None). This argument can only | |||
| be specified when `num_shards` is also specified. | |||
| cache (DatasetCache, optional): Use tensor caching service to speed up dataset processing | |||
| (default=None, which means no cache is used). | |||
| Raises: | |||
| RuntimeError: If dataset_dir does not contain data files. | |||
| RuntimeError: If num_parallel_workers exceeds the max thread numbers. | |||
| RuntimeError: If sampler and shuffle are specified at the same time. | |||
| RuntimeError: If sampler and sharding are specified at the same time. | |||
| RuntimeError: If num_shards is specified but shard_id is None. | |||
| RuntimeError: If shard_id is specified but num_shards is None. | |||
| ValueError: If shard_id is invalid (< 0 or >= num_shards). | |||
| Note: | |||
| - This dataset can take in a `sampler`. `sampler` and `shuffle` are mutually exclusive. | |||
| The table below shows what input arguments are allowed and their expected behavior. | |||
| .. list-table:: Expected Order Behavior of Using `sampler` and `shuffle` | |||
| :widths: 25 25 50 | |||
| :header-rows: 1 | |||
| * - Parameter `sampler` | |||
| - Parameter `shuffle` | |||
| - Expected Order Behavior | |||
| * - None | |||
| - None | |||
| - random order | |||
| * - None | |||
| - True | |||
| - random order | |||
| * - None | |||
| - False | |||
| - sequential order | |||
| * - Sampler object | |||
| - None | |||
| - order defined by sampler | |||
| * - Sampler object | |||
| - True | |||
| - not allowed | |||
| * - Sampler object | |||
| - False | |||
| - not allowed | |||
| Examples: | |||
| >>> yes_no_dataset_dir = "/path/to/yes_no_dataset_directory" | |||
| >>> | |||
| >>> # Read 3 samples from YesNo dataset | |||
| >>> dataset = ds.YesNoDataset(dataset_dir=yes_no_dataset_dir, num_samples=3) | |||
| >>> | |||
| >>> # Note: In YesNo dataset, each dictionary has keys "waveform", "sample_rate", "label" | |||
| About YesNo dataset: | |||
| Yesno is an audio dataset consisting of 60 recordings of one individual saying yes or no in Hebrew; each | |||
| recording is eight words long. It was created for the Kaldi audio project by an author who wishes to | |||
| remain anonymous. | |||
| Here is the original YesNo dataset structure. | |||
| You can unzip the dataset files into this directory structure and read by MindSpore's API. | |||
| .. code-block:: | |||
| . | |||
| └── yes_no_dataset_dir | |||
| ├── 1_1_0_0_1_1_0_0.wav | |||
| ├── 1_0_0_0_1_1_0_0.wav | |||
| ├── 1_1_0_0_1_1_0_0.wav | |||
| └──.... | |||
| Citation: | |||
| .. code-block:: | |||
| @NetworkResource{Kaldi_audio_project, | |||
| author = {anonymous}, | |||
| url = "http://wwww.openslr.org/1/" | |||
| } | |||
| """ | |||
| @check_yes_no_dataset | |||
| def __init__(self, dataset_dir, num_samples=None, num_parallel_workers=None, shuffle=None, | |||
| sampler=None, num_shards=None, shard_id=None, cache=None): | |||
| super().__init__(num_parallel_workers=num_parallel_workers, sampler=sampler, num_samples=num_samples, | |||
| shuffle=shuffle, num_shards=num_shards, shard_id=shard_id, cache=cache) | |||
| self.dataset_dir = dataset_dir | |||
| def parse(self, children=None): | |||
| return cde.YesNoNode(self.dataset_dir, self.sampler) | |||
| @@ -1807,3 +1807,29 @@ def check_dbpedia_dataset(method): | |||
| return method(self, *args, **kwargs) | |||
| return new_method | |||
| def check_yes_no_dataset(method): | |||
| """A wrapper that wraps a parameter checker around the original Dataset(YesNoDataset).""" | |||
| @wraps(method) | |||
| def new_method(self, *args, **kwargs): | |||
| _, param_dict = parse_user_args(method, *args, **kwargs) | |||
| nreq_param_int = ['num_samples', 'num_parallel_workers', 'num_shards', 'shard_id'] | |||
| nreq_param_bool = ['shuffle'] | |||
| dataset_dir = param_dict.get('dataset_dir') | |||
| check_dir(dataset_dir) | |||
| validate_dataset_param_value(nreq_param_int, param_dict, int) | |||
| validate_dataset_param_value(nreq_param_bool, param_dict, bool) | |||
| check_sampler_shuffle_shard_options(param_dict) | |||
| cache = param_dict.get('cache') | |||
| check_cache_option(cache) | |||
| return method(self, *args, **kwargs) | |||
| return new_method | |||
| @@ -43,6 +43,7 @@ SET(DE_UT_SRCS | |||
| c_api_dataset_tfrecord_test.cc | |||
| c_api_dataset_usps_test.cc | |||
| c_api_dataset_voc_test.cc | |||
| c_api_dataset_yes_no_test.cc | |||
| c_api_datasets_test.cc | |||
| c_api_epoch_ctrl_test.cc | |||
| c_api_pull_based_test.cc | |||
| @@ -0,0 +1,196 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "common/common.h" | |||
| #include "minddata/dataset/include/dataset/datasets.h" | |||
| using namespace mindspore::dataset; | |||
| using mindspore::dataset::DataType; | |||
| using mindspore::dataset::Tensor; | |||
| using mindspore::dataset::TensorShape; | |||
| class MindDataTestPipeline : public UT::DatasetOpTesting { | |||
| protected: | |||
| }; | |||
| /// Feature: Test YesNo dataset. | |||
| /// Description: read data from a single file. | |||
| /// Expectation: the data is processed successfully. | |||
| TEST_F(MindDataTestPipeline, TestYesNoDataset) { | |||
| MS_LOG(INFO) << "Doing MindDataTestPipeline-TestYesNoDataset."; | |||
| // Create a YesNoDataset | |||
| std::string folder_path = datasets_root_path_ + "/testYesNoData/"; | |||
| std::shared_ptr<Dataset> ds = YesNo(folder_path, std::make_shared<RandomSampler>(false, 2)); | |||
| EXPECT_NE(ds, nullptr); | |||
| // Create an iterator over the result of the above dataset. | |||
| // This will trigger the creation of the Execution Tree and launch it. | |||
| std::shared_ptr<Iterator> iter = ds->CreateIterator(); | |||
| EXPECT_NE(iter, nullptr); | |||
| // Iterate the dataset and get each row | |||
| std::unordered_map<std::string, mindspore::MSTensor> row; | |||
| ASSERT_OK(iter->GetNextRow(&row)); | |||
| MS_LOG(INFO) << "iter->GetNextRow(&row) OK"; | |||
| EXPECT_NE(row.find("waveform"), row.end()); | |||
| EXPECT_NE(row.find("sample_rate"), row.end()); | |||
| EXPECT_NE(row.find("label"), row.end()); | |||
| uint64_t i = 0; | |||
| while (row.size() != 0) { | |||
| i++; | |||
| auto waveform = row["waveform"]; | |||
| MS_LOG(INFO) << "Tensor waveform shape: " << waveform.Shape(); | |||
| ASSERT_OK(iter->GetNextRow(&row)); | |||
| } | |||
| EXPECT_EQ(i, 2); | |||
| // Manually terminate the pipeline | |||
| iter->Stop(); | |||
| } | |||
| /// Feature: Test YesNo dataset. | |||
| /// Description: test YesNo dataset with pipeline. | |||
| /// Expectation: the data is processed successfully. | |||
| TEST_F(MindDataTestPipeline, YesNoDatasetWithPipeline) { | |||
| MS_LOG(INFO) << "Doing MindDataTestPipeline-YesNoDatasetWithPipeline."; | |||
| std::string folder_path = datasets_root_path_ + "/testYesNoData/"; | |||
| std::shared_ptr<Dataset> ds1 = YesNo(folder_path, std::make_shared<RandomSampler>(false, 1)); | |||
| std::shared_ptr<Dataset> ds2 = YesNo(folder_path, std::make_shared<RandomSampler>(false, 2)); | |||
| EXPECT_NE(ds1, nullptr); | |||
| EXPECT_NE(ds2, nullptr); | |||
| // Create two Repeat operation on ds | |||
| int32_t repeat_num = 1; | |||
| ds1 = ds1->Repeat(repeat_num); | |||
| EXPECT_NE(ds1, nullptr); | |||
| repeat_num = 2; | |||
| ds2 = ds2->Repeat(repeat_num); | |||
| EXPECT_NE(ds2, nullptr); | |||
| // Create two Project operation on ds | |||
| std::vector<std::string> column_project = {"waveform", "sample_rate", "label"}; | |||
| ds1 = ds1->Project(column_project); | |||
| EXPECT_NE(ds1, nullptr); | |||
| ds2 = ds2->Project(column_project); | |||
| EXPECT_NE(ds2, nullptr); | |||
| // Create a Concat operation on the ds | |||
| ds1 = ds1->Concat({ds2}); | |||
| EXPECT_NE(ds1, nullptr); | |||
| std::shared_ptr<Iterator> iter = ds1->CreateIterator(); | |||
| EXPECT_NE(iter, nullptr); | |||
| // Iterate the dataset and get each row | |||
| std::unordered_map<std::string, mindspore::MSTensor> row; | |||
| ASSERT_OK(iter->GetNextRow(&row)); | |||
| EXPECT_NE(row.find("waveform"), row.end()); | |||
| EXPECT_NE(row.find("sample_rate"), row.end()); | |||
| EXPECT_NE(row.find("label"), row.end()); | |||
| uint64_t i = 0; | |||
| while (row.size() != 0) { | |||
| i++; | |||
| auto waveform = row["waveform"]; | |||
| MS_LOG(INFO) << "Tensor waveform shape: " << waveform.Shape(); | |||
| ASSERT_OK(iter->GetNextRow(&row)); | |||
| } | |||
| EXPECT_EQ(i, 5); | |||
| iter->Stop(); | |||
| } | |||
| /// Feature: Test YesNo dataset. | |||
| /// Description: get the size of YesNo dataset. | |||
| /// Expectation: the data is processed successfully. | |||
| TEST_F(MindDataTestPipeline, TestYesNoGetDatasetSize) { | |||
| MS_LOG(INFO) << "Doing MindDataTestPipeline-TestYesNoGetDatasetSize."; | |||
| // Create a YesNo Dataset | |||
| std::string folder_path = datasets_root_path_ + "/testYesNoData/"; | |||
| std::shared_ptr<Dataset> ds = YesNo(folder_path); | |||
| EXPECT_NE(ds, nullptr); | |||
| EXPECT_EQ(ds->GetDatasetSize(), 3); | |||
| } | |||
| /// Feature: Test YesNo dataset. | |||
| /// Description: getter functions. | |||
| /// Expectation: the data is processed successfully. | |||
| TEST_F(MindDataTestPipeline, TestYesNoGetters) { | |||
| MS_LOG(INFO) << "Doing MindDataTestPipeline-TestYesNoMixGetter."; | |||
| // Create a YesNo Dataset | |||
| std::string folder_path = datasets_root_path_ + "/testYesNoData/"; | |||
| std::shared_ptr<Dataset> ds = YesNo(folder_path); | |||
| EXPECT_NE(ds, nullptr); | |||
| EXPECT_EQ(ds->GetDatasetSize(), 3); | |||
| std::vector<DataType> types = ToDETypes(ds->GetOutputTypes()); | |||
| std::vector<TensorShape> shapes = ToTensorShapeVec(ds->GetOutputShapes()); | |||
| std::vector<std::string> column_names = {"waveform", "sample_rate", "label"}; | |||
| EXPECT_EQ(types.size(), 3); | |||
| EXPECT_EQ(types[0].ToString(), "float32"); | |||
| EXPECT_EQ(types[1].ToString(), "int32"); | |||
| EXPECT_EQ(types[2].ToString(), "int32"); | |||
| EXPECT_EQ(shapes.size(), 3); | |||
| EXPECT_EQ(shapes[1].ToString(), "<>"); | |||
| EXPECT_EQ(shapes[2].ToString(), "<8>"); | |||
| EXPECT_EQ(ds->GetBatchSize(), 1); | |||
| EXPECT_EQ(ds->GetRepeatCount(), 1); | |||
| EXPECT_EQ(ds->GetDatasetSize(), 3); | |||
| EXPECT_EQ(ToDETypes(ds->GetOutputTypes()), types); | |||
| EXPECT_EQ(ToTensorShapeVec(ds->GetOutputShapes()), shapes); | |||
| EXPECT_EQ(ds->GetColumnNames(), column_names); | |||
| } | |||
| /// Feature: Test YesNo dataset. | |||
| /// Description: DatasetFail tests. | |||
| /// Expectation: throw error messages when certain errors occur. | |||
| TEST_F(MindDataTestPipeline, TestYesNoDatasetFail) { | |||
| MS_LOG(INFO) << "Doing MindDataTestPipeline-TestYesNoDatasetFail."; | |||
| // Create a YesNo Dataset | |||
| std::shared_ptr<Dataset> ds = YesNo("", std::make_shared<RandomSampler>(false, 1)); | |||
| EXPECT_NE(ds, nullptr); | |||
| // Create an iterator over the result of the above dataset | |||
| std::shared_ptr<Iterator> iter = ds->CreateIterator(); | |||
| // Expect failure: Invalid YesNo directory | |||
| EXPECT_EQ(iter, nullptr); | |||
| } | |||
| /// Feature: Test YesNo dataset. | |||
| /// Description: NullSamplerFail tests. | |||
| /// Expectation: Throw error messages when certain errors occur. | |||
| TEST_F(MindDataTestPipeline, TestYesNoDatasetWithNullSamplerFail) { | |||
| MS_LOG(INFO) << "Doing MindDataTestPipeline-TestYesNo10DatasetWithNullSamplerFail."; | |||
| // Create a YesNo Dataset | |||
| std::string folder_path = datasets_root_path_ + "/testYesNoData/"; | |||
| std::shared_ptr<Dataset> ds = YesNo(folder_path, nullptr); | |||
| // Expect failure: Null Sampler | |||
| EXPECT_NE(ds, nullptr); | |||
| // Create an iterator over the result of the above dataset | |||
| std::shared_ptr<Iterator> iter = ds->CreateIterator(); | |||
| // Expect failure: Null Sampler | |||
| EXPECT_EQ(iter, nullptr); | |||
| } | |||
| @@ -0,0 +1,185 @@ | |||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================== | |||
| import numpy as np | |||
| import pytest | |||
| import mindspore.dataset as ds | |||
| import mindspore.dataset.audio.transforms as audio | |||
| from mindspore import log as logger | |||
| DATA_DIR = "../data/dataset/testYesNoData/" | |||
| def test_yes_no_basic(): | |||
| """ | |||
| Feature: YesNo Dataset | |||
| Description: Read all files | |||
| Expectation: Output the amount of file | |||
| """ | |||
| logger.info("Test YesNoDataset Op") | |||
| data = ds.YesNoDataset(DATA_DIR) | |||
| num_iter = 0 | |||
| for _ in data.create_dict_iterator(num_epochs=1): | |||
| num_iter += 1 | |||
| assert num_iter == 3 | |||
| def test_yes_no_num_samples(): | |||
| """ | |||
| Feature: YesNo Dataset | |||
| Description: Test num_samples | |||
| Expectation: Get certain number of samples | |||
| """ | |||
| data = ds.YesNoDataset(DATA_DIR, num_samples=2) | |||
| num_iter = 0 | |||
| for _ in data.create_dict_iterator(num_epochs=1): | |||
| num_iter += 1 | |||
| assert num_iter == 2 | |||
| def test_yes_no_repeat(): | |||
| """ | |||
| Feature: YesNo Dataset | |||
| Description: Test repeat | |||
| Expectation: Output the amount of file | |||
| """ | |||
| data = ds.YesNoDataset(DATA_DIR, num_samples=2) | |||
| data = data.repeat(5) | |||
| num_iter = 0 | |||
| for _ in data.create_dict_iterator(num_epochs=1): | |||
| num_iter += 1 | |||
| assert num_iter == 10 | |||
| def test_yes_no_dataset_size(): | |||
| """ | |||
| Feature: YesNo Dataset | |||
| Description: Test dataset_size | |||
| Expectation: Output the size of dataset | |||
| """ | |||
| data = ds.YesNoDataset(DATA_DIR, shuffle=False) | |||
| assert data.get_dataset_size() == 3 | |||
| def test_yes_no_sequential_sampler(): | |||
| """ | |||
| Feature: YesNo Dataset | |||
| Description: Use SequentialSampler to sample data. | |||
| Expectation: The number of samplers returned by dict_iterator is equal to the requested number of samples. | |||
| """ | |||
| logger.info("Test YesNoDataset Op with SequentialSampler") | |||
| num_samples = 2 | |||
| sampler = ds.SequentialSampler(num_samples=num_samples) | |||
| data1 = ds.YesNoDataset(DATA_DIR, sampler=sampler) | |||
| data2 = ds.YesNoDataset(DATA_DIR, shuffle=False, num_samples=num_samples) | |||
| sample_rate_list1, sample_rate_list2 = [], [] | |||
| num_iter = 0 | |||
| for item1, item2 in zip(data1.create_dict_iterator(num_epochs=1), | |||
| data2.create_dict_iterator(num_epochs=1)): | |||
| sample_rate_list1.append(item1["sample_rate"]) | |||
| sample_rate_list2.append(item2["sample_rate"]) | |||
| num_iter += 1 | |||
| np.testing.assert_array_equal(sample_rate_list1, sample_rate_list2) | |||
| assert num_iter == num_samples | |||
| def test_yes_no_exception(): | |||
| """ | |||
| Feature: Error tests | |||
| Description: Throw error messages when certain errors occur | |||
| Expectation: Output error message | |||
| """ | |||
| logger.info("Test error cases for YesNoDataset") | |||
| error_msg_1 = "sampler and shuffle cannot be specified at the same time" | |||
| with pytest.raises(RuntimeError, match=error_msg_1): | |||
| ds.YesNoDataset(DATA_DIR, shuffle=False, sampler=ds.PKSampler(3)) | |||
| error_msg_2 = "sampler and sharding cannot be specified at the same time" | |||
| with pytest.raises(RuntimeError, match=error_msg_2): | |||
| ds.YesNoDataset(DATA_DIR, sampler=ds.PKSampler(3), | |||
| num_shards=2, shard_id=0) | |||
| error_msg_3 = "num_shards is specified and currently requires shard_id as well" | |||
| with pytest.raises(RuntimeError, match=error_msg_3): | |||
| ds.YesNoDataset(DATA_DIR, num_shards=10) | |||
| error_msg_4 = "shard_id is specified but num_shards is not" | |||
| with pytest.raises(RuntimeError, match=error_msg_4): | |||
| ds.YesNoDataset(DATA_DIR, shard_id=0) | |||
| error_msg_5 = "Input shard_id is not within the required interval" | |||
| with pytest.raises(ValueError, match=error_msg_5): | |||
| ds.YesNoDataset(DATA_DIR, num_shards=5, shard_id=-1) | |||
| with pytest.raises(ValueError, match=error_msg_5): | |||
| ds.YesNoDataset(DATA_DIR, num_shards=5, shard_id=5) | |||
| with pytest.raises(ValueError, match=error_msg_5): | |||
| ds.YesNoDataset(DATA_DIR, num_shards=2, shard_id=5) | |||
| error_msg_6 = "num_parallel_workers exceeds" | |||
| with pytest.raises(ValueError, match=error_msg_6): | |||
| ds.YesNoDataset(DATA_DIR, shuffle=False, num_parallel_workers=0) | |||
| with pytest.raises(ValueError, match=error_msg_6): | |||
| ds.YesNoDataset(DATA_DIR, shuffle=False, num_parallel_workers=256) | |||
| with pytest.raises(ValueError, match=error_msg_6): | |||
| ds.YesNoDataset(DATA_DIR, shuffle=False, num_parallel_workers=-2) | |||
| error_msg_7 = "Argument shard_id" | |||
| with pytest.raises(TypeError, match=error_msg_7): | |||
| ds.YesNoDataset(DATA_DIR, num_shards=2, shard_id="0") | |||
| def exception_func(item): | |||
| raise Exception("Error occur!") | |||
| error_msg_8 = "The corresponding data files" | |||
| with pytest.raises(RuntimeError, match=error_msg_8): | |||
| data = ds.YesNoDataset(DATA_DIR) | |||
| data = data.map(operations=exception_func, input_columns=[ | |||
| "waveform"], num_parallel_workers=1) | |||
| for _ in data.__iter__(): | |||
| pass | |||
| with pytest.raises(RuntimeError, match=error_msg_8): | |||
| data = ds.YesNoDataset(DATA_DIR) | |||
| data = data.map(operations=exception_func, input_columns=[ | |||
| "sample_rate"], num_parallel_workers=1) | |||
| for _ in data.__iter__(): | |||
| pass | |||
| def test_yes_no_pipeline(): | |||
| """ | |||
| Feature: Pipeline test | |||
| Description: Read a sample | |||
| Expectation: The amount of each function are equal | |||
| """ | |||
| # Original waveform | |||
| dataset = ds.YesNoDataset(DATA_DIR, num_samples=1) | |||
| band_biquad_op = audio.BandBiquad(8000, 200.0) | |||
| # Filtered waveform by bandbiquad | |||
| dataset = dataset.map(input_columns=["waveform"], operations=band_biquad_op, num_parallel_workers=2) | |||
| num_iter = 0 | |||
| for _ in dataset.create_dict_iterator(num_epochs=1, output_numpy=True): | |||
| num_iter += 1 | |||
| assert num_iter == 1 | |||
| if __name__ == '__main__': | |||
| test_yes_no_basic() | |||
| test_yes_no_num_samples() | |||
| test_yes_no_repeat() | |||
| test_yes_no_dataset_size() | |||
| test_yes_no_sequential_sampler() | |||
| test_yes_no_exception() | |||
| test_yes_no_pipeline() | |||