Merge pull request !22421 from TR-nbu/GTZANDatasetfeature/build-system-rewrite
| @@ -92,6 +92,7 @@ | |||
| #include "minddata/dataset/engine/ir/datasetops/source/fake_image_node.h" | |||
| #include "minddata/dataset/engine/ir/datasetops/source/fashion_mnist_node.h" | |||
| #include "minddata/dataset/engine/ir/datasetops/source/flickr_node.h" | |||
| #include "minddata/dataset/engine/ir/datasetops/source/gtzan_node.h" | |||
| #include "minddata/dataset/engine/ir/datasetops/source/image_folder_node.h" | |||
| #include "minddata/dataset/engine/ir/datasetops/source/imdb_node.h" | |||
| #include "minddata/dataset/engine/ir/datasetops/source/iwslt2016_node.h" | |||
| @@ -1235,6 +1236,27 @@ FlickrDataset::FlickrDataset(const std::vector<char> &dataset_dir, const std::ve | |||
| ir_node_ = std::static_pointer_cast<DatasetNode>(ds); | |||
| } | |||
| GTZANDataset::GTZANDataset(const std::vector<char> &dataset_dir, const std::vector<char> &usage, | |||
| const std::shared_ptr<Sampler> &sampler, const std::shared_ptr<DatasetCache> &cache) { | |||
| auto sampler_obj = sampler ? sampler->Parse() : nullptr; | |||
| auto ds = std::make_shared<GTZANNode>(CharToString(dataset_dir), CharToString(usage), sampler_obj, cache); | |||
| ir_node_ = std::static_pointer_cast<DatasetNode>(ds); | |||
| } | |||
| GTZANDataset::GTZANDataset(const std::vector<char> &dataset_dir, const std::vector<char> &usage, const Sampler *sampler, | |||
| const std::shared_ptr<DatasetCache> &cache) { | |||
| auto sampler_obj = sampler ? sampler->Parse() : nullptr; | |||
| auto ds = std::make_shared<GTZANNode>(CharToString(dataset_dir), CharToString(usage), sampler_obj, cache); | |||
| ir_node_ = std::static_pointer_cast<DatasetNode>(ds); | |||
| } | |||
| GTZANDataset::GTZANDataset(const std::vector<char> &dataset_dir, const std::vector<char> &usage, | |||
| const std::reference_wrapper<Sampler> &sampler, const std::shared_ptr<DatasetCache> &cache) { | |||
| auto sampler_obj = sampler.get().Parse(); | |||
| auto ds = std::make_shared<GTZANNode>(CharToString(dataset_dir), CharToString(usage), sampler_obj, cache); | |||
| ir_node_ = std::static_pointer_cast<DatasetNode>(ds); | |||
| } | |||
| ImageFolderDataset::ImageFolderDataset(const std::vector<char> &dataset_dir, bool decode, | |||
| const std::shared_ptr<Sampler> &sampler, | |||
| const std::set<std::vector<char>> &extensions, | |||
| @@ -43,6 +43,7 @@ | |||
| #include "minddata/dataset/engine/ir/datasetops/source/fashion_mnist_node.h" | |||
| #include "minddata/dataset/engine/ir/datasetops/source/flickr_node.h" | |||
| #include "minddata/dataset/engine/ir/datasetops/source/generator_node.h" | |||
| #include "minddata/dataset/engine/ir/datasetops/source/gtzan_node.h" | |||
| #include "minddata/dataset/engine/ir/datasetops/source/image_folder_node.h" | |||
| #include "minddata/dataset/engine/ir/datasetops/source/imdb_node.h" | |||
| #include "minddata/dataset/engine/ir/datasetops/source/iwslt2016_node.h" | |||
| @@ -323,6 +324,17 @@ PYBIND_REGISTER(GeneratorNode, 2, ([](const py::module *m) { | |||
| })); | |||
| })); | |||
| PYBIND_REGISTER(GTZANNode, 2, ([](const py::module *m) { | |||
| (void)py::class_<GTZANNode, DatasetNode, std::shared_ptr<GTZANNode>>(*m, "GTZANNode", | |||
| "to create a GTZANNode") | |||
| .def( | |||
| py::init([](const std::string &dataset_dir, const std::string &usage, const py::handle &sampler) { | |||
| auto gtzan = std::make_shared<GTZANNode>(dataset_dir, usage, toSamplerObj(sampler), nullptr); | |||
| THROW_IF_ERROR(gtzan->ValidateParams()); | |||
| return gtzan; | |||
| })); | |||
| })); | |||
| PYBIND_REGISTER(ImageFolderNode, 2, ([](const py::module *m) { | |||
| (void)py::class_<ImageFolderNode, DatasetNode, std::shared_ptr<ImageFolderNode>>( | |||
| *m, "ImageFolderNode", "to create an ImageFolderNode") | |||
| @@ -21,6 +21,7 @@ set(DATASET_ENGINE_DATASETOPS_SOURCE_SRC_FILES | |||
| fake_image_op.cc | |||
| fashion_mnist_op.cc | |||
| flickr_op.cc | |||
| gtzan_op.cc | |||
| image_folder_op.cc | |||
| imdb_op.cc | |||
| iwslt_op.cc | |||
| @@ -0,0 +1,336 @@ | |||
| /** | |||
| * Copyright 2022 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "minddata/dataset/engine/datasetops/source/gtzan_op.h" | |||
| #include <fstream> | |||
| #include <iomanip> | |||
| #include <set> | |||
| #include "minddata/dataset/audio/kernels/audio_utils.h" | |||
| #include "minddata/dataset/core/config_manager.h" | |||
| #include "minddata/dataset/core/tensor_shape.h" | |||
| #include "minddata/dataset/engine/datasetops/source/sampler/sequential_sampler.h" | |||
| #include "minddata/dataset/engine/execution_tree.h" | |||
| #include "utils/file_utils.h" | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| const std::vector<std::string> genres = { | |||
| "blues", "classical", "country", "disco", "hiphop", "jazz", "metal", "pop", "reggae", "rock", | |||
| }; | |||
| const std::vector<std::string> filtered_test = { | |||
| "blues.00012", "blues.00013", "blues.00014", "blues.00015", "blues.00016", "blues.00017", | |||
| "blues.00018", "blues.00019", "blues.00020", "blues.00021", "blues.00022", "blues.00023", | |||
| "blues.00024", "blues.00025", "blues.00026", "blues.00027", "blues.00028", "blues.00061", | |||
| "blues.00062", "blues.00063", "blues.00064", "blues.00065", "blues.00066", "blues.00067", | |||
| "blues.00068", "blues.00069", "blues.00070", "blues.00071", "blues.00072", "blues.00098", | |||
| "blues.00099", "classical.00011", "classical.00012", "classical.00013", "classical.00014", "classical.00015", | |||
| "classical.00016", "classical.00017", "classical.00018", "classical.00019", "classical.00020", "classical.00021", | |||
| "classical.00022", "classical.00023", "classical.00024", "classical.00025", "classical.00026", "classical.00027", | |||
| "classical.00028", "classical.00029", "classical.00034", "classical.00035", "classical.00036", "classical.00037", | |||
| "classical.00038", "classical.00039", "classical.00040", "classical.00041", "classical.00049", "classical.00077", | |||
| "classical.00078", "classical.00079", "country.00030", "country.00031", "country.00032", "country.00033", | |||
| "country.00034", "country.00035", "country.00036", "country.00037", "country.00038", "country.00039", | |||
| "country.00040", "country.00043", "country.00044", "country.00046", "country.00047", "country.00048", | |||
| "country.00050", "country.00051", "country.00053", "country.00054", "country.00055", "country.00056", | |||
| "country.00057", "country.00058", "country.00059", "country.00060", "country.00061", "country.00062", | |||
| "country.00063", "country.00064", "disco.00001", "disco.00021", "disco.00058", "disco.00062", | |||
| "disco.00063", "disco.00064", "disco.00065", "disco.00066", "disco.00069", "disco.00076", | |||
| "disco.00077", "disco.00078", "disco.00079", "disco.00080", "disco.00081", "disco.00082", | |||
| "disco.00083", "disco.00084", "disco.00085", "disco.00086", "disco.00087", "disco.00088", | |||
| "disco.00091", "disco.00092", "disco.00093", "disco.00094", "disco.00096", "disco.00097", | |||
| "disco.00099", "hiphop.00000", "hiphop.00026", "hiphop.00027", "hiphop.00030", "hiphop.00040", | |||
| "hiphop.00043", "hiphop.00044", "hiphop.00045", "hiphop.00051", "hiphop.00052", "hiphop.00053", | |||
| "hiphop.00054", "hiphop.00062", "hiphop.00063", "hiphop.00064", "hiphop.00065", "hiphop.00066", | |||
| "hiphop.00067", "hiphop.00068", "hiphop.00069", "hiphop.00070", "hiphop.00071", "hiphop.00072", | |||
| "hiphop.00073", "hiphop.00074", "hiphop.00075", "hiphop.00099", "jazz.00073", "jazz.00074", | |||
| "jazz.00075", "jazz.00076", "jazz.00077", "jazz.00078", "jazz.00079", "jazz.00080", | |||
| "jazz.00081", "jazz.00082", "jazz.00083", "jazz.00084", "jazz.00085", "jazz.00086", | |||
| "jazz.00087", "jazz.00088", "jazz.00089", "jazz.00090", "jazz.00091", "jazz.00092", | |||
| "jazz.00093", "jazz.00094", "jazz.00095", "jazz.00096", "jazz.00097", "jazz.00098", | |||
| "jazz.00099", "metal.00012", "metal.00013", "metal.00014", "metal.00015", "metal.00022", | |||
| "metal.00023", "metal.00025", "metal.00026", "metal.00027", "metal.00028", "metal.00029", | |||
| "metal.00030", "metal.00031", "metal.00032", "metal.00033", "metal.00038", "metal.00039", | |||
| "metal.00067", "metal.00070", "metal.00073", "metal.00074", "metal.00075", "metal.00078", | |||
| "metal.00083", "metal.00085", "metal.00087", "metal.00088", "pop.00000", "pop.00001", | |||
| "pop.00013", "pop.00014", "pop.00043", "pop.00063", "pop.00064", "pop.00065", | |||
| "pop.00066", "pop.00069", "pop.00070", "pop.00071", "pop.00072", "pop.00073", | |||
| "pop.00074", "pop.00075", "pop.00076", "pop.00077", "pop.00078", "pop.00079", | |||
| "pop.00082", "pop.00088", "pop.00089", "pop.00090", "pop.00091", "pop.00092", | |||
| "pop.00093", "pop.00094", "pop.00095", "pop.00096", "reggae.00034", "reggae.00035", | |||
| "reggae.00036", "reggae.00037", "reggae.00038", "reggae.00039", "reggae.00040", "reggae.00046", | |||
| "reggae.00047", "reggae.00048", "reggae.00052", "reggae.00053", "reggae.00064", "reggae.00065", | |||
| "reggae.00066", "reggae.00067", "reggae.00068", "reggae.00071", "reggae.00079", "reggae.00082", | |||
| "reggae.00083", "reggae.00084", "reggae.00087", "reggae.00088", "reggae.00089", "reggae.00090", | |||
| "rock.00010", "rock.00011", "rock.00012", "rock.00013", "rock.00014", "rock.00015", | |||
| "rock.00027", "rock.00028", "rock.00029", "rock.00030", "rock.00031", "rock.00032", | |||
| "rock.00033", "rock.00034", "rock.00035", "rock.00036", "rock.00037", "rock.00039", | |||
| "rock.00040", "rock.00041", "rock.00042", "rock.00043", "rock.00044", "rock.00045", | |||
| "rock.00046", "rock.00047", "rock.00048", "rock.00086", "rock.00087", "rock.00088", | |||
| "rock.00089", "rock.00090", | |||
| }; | |||
| const std::vector<std::string> filtered_train = { | |||
| "blues.00029", "blues.00030", "blues.00031", "blues.00032", "blues.00033", "blues.00034", | |||
| "blues.00035", "blues.00036", "blues.00037", "blues.00038", "blues.00039", "blues.00040", | |||
| "blues.00041", "blues.00042", "blues.00043", "blues.00044", "blues.00045", "blues.00046", | |||
| "blues.00047", "blues.00048", "blues.00049", "blues.00073", "blues.00074", "blues.00075", | |||
| "blues.00076", "blues.00077", "blues.00078", "blues.00079", "blues.00080", "blues.00081", | |||
| "blues.00082", "blues.00083", "blues.00084", "blues.00085", "blues.00086", "blues.00087", | |||
| "blues.00088", "blues.00089", "blues.00090", "blues.00091", "blues.00092", "blues.00093", | |||
| "blues.00094", "blues.00095", "blues.00096", "blues.00097", "classical.00030", "classical.00031", | |||
| "classical.00032", "classical.00033", "classical.00043", "classical.00044", "classical.00045", "classical.00046", | |||
| "classical.00047", "classical.00048", "classical.00050", "classical.00051", "classical.00052", "classical.00053", | |||
| "classical.00054", "classical.00055", "classical.00056", "classical.00057", "classical.00058", "classical.00059", | |||
| "classical.00060", "classical.00061", "classical.00062", "classical.00063", "classical.00064", "classical.00065", | |||
| "classical.00066", "classical.00067", "classical.00080", "classical.00081", "classical.00082", "classical.00083", | |||
| "classical.00084", "classical.00085", "classical.00086", "classical.00087", "classical.00088", "classical.00089", | |||
| "classical.00090", "classical.00091", "classical.00092", "classical.00093", "classical.00094", "classical.00095", | |||
| "classical.00096", "classical.00097", "classical.00098", "classical.00099", "country.00019", "country.00020", | |||
| "country.00021", "country.00022", "country.00023", "country.00024", "country.00025", "country.00026", | |||
| "country.00028", "country.00029", "country.00065", "country.00066", "country.00067", "country.00068", | |||
| "country.00069", "country.00070", "country.00071", "country.00072", "country.00073", "country.00074", | |||
| "country.00075", "country.00076", "country.00077", "country.00078", "country.00079", "country.00080", | |||
| "country.00081", "country.00082", "country.00083", "country.00084", "country.00085", "country.00086", | |||
| "country.00087", "country.00088", "country.00089", "country.00090", "country.00091", "country.00092", | |||
| "country.00093", "country.00094", "country.00095", "country.00096", "country.00097", "country.00098", | |||
| "country.00099", "disco.00005", "disco.00015", "disco.00016", "disco.00017", "disco.00018", | |||
| "disco.00019", "disco.00020", "disco.00022", "disco.00023", "disco.00024", "disco.00025", | |||
| "disco.00026", "disco.00027", "disco.00028", "disco.00029", "disco.00030", "disco.00031", | |||
| "disco.00032", "disco.00033", "disco.00034", "disco.00035", "disco.00036", "disco.00037", | |||
| "disco.00039", "disco.00040", "disco.00041", "disco.00042", "disco.00043", "disco.00044", | |||
| "disco.00045", "disco.00047", "disco.00049", "disco.00053", "disco.00054", "disco.00056", | |||
| "disco.00057", "disco.00059", "disco.00061", "disco.00070", "disco.00073", "disco.00074", | |||
| "disco.00089", "hiphop.00002", "hiphop.00003", "hiphop.00004", "hiphop.00005", "hiphop.00006", | |||
| "hiphop.00007", "hiphop.00008", "hiphop.00009", "hiphop.00010", "hiphop.00011", "hiphop.00012", | |||
| "hiphop.00013", "hiphop.00014", "hiphop.00015", "hiphop.00016", "hiphop.00017", "hiphop.00018", | |||
| "hiphop.00019", "hiphop.00020", "hiphop.00021", "hiphop.00022", "hiphop.00023", "hiphop.00024", | |||
| "hiphop.00025", "hiphop.00028", "hiphop.00029", "hiphop.00031", "hiphop.00032", "hiphop.00033", | |||
| "hiphop.00034", "hiphop.00035", "hiphop.00036", "hiphop.00037", "hiphop.00038", "hiphop.00041", | |||
| "hiphop.00042", "hiphop.00055", "hiphop.00056", "hiphop.00057", "hiphop.00058", "hiphop.00059", | |||
| "hiphop.00060", "hiphop.00061", "hiphop.00077", "hiphop.00078", "hiphop.00079", "hiphop.00080", | |||
| "jazz.00000", "jazz.00001", "jazz.00011", "jazz.00012", "jazz.00013", "jazz.00014", | |||
| "jazz.00015", "jazz.00016", "jazz.00017", "jazz.00018", "jazz.00019", "jazz.00020", | |||
| "jazz.00021", "jazz.00022", "jazz.00023", "jazz.00024", "jazz.00041", "jazz.00047", | |||
| "jazz.00048", "jazz.00049", "jazz.00050", "jazz.00051", "jazz.00052", "jazz.00053", | |||
| "jazz.00054", "jazz.00055", "jazz.00056", "jazz.00057", "jazz.00058", "jazz.00059", | |||
| "jazz.00060", "jazz.00061", "jazz.00062", "jazz.00063", "jazz.00064", "jazz.00065", | |||
| "jazz.00066", "jazz.00067", "jazz.00068", "jazz.00069", "jazz.00070", "jazz.00071", | |||
| "jazz.00072", "metal.00002", "metal.00003", "metal.00005", "metal.00021", "metal.00024", | |||
| "metal.00035", "metal.00046", "metal.00047", "metal.00048", "metal.00049", "metal.00050", | |||
| "metal.00051", "metal.00052", "metal.00053", "metal.00054", "metal.00055", "metal.00056", | |||
| "metal.00057", "metal.00059", "metal.00060", "metal.00061", "metal.00062", "metal.00063", | |||
| "metal.00064", "metal.00065", "metal.00066", "metal.00069", "metal.00071", "metal.00072", | |||
| "metal.00079", "metal.00080", "metal.00084", "metal.00086", "metal.00089", "metal.00090", | |||
| "metal.00091", "metal.00092", "metal.00093", "metal.00094", "metal.00095", "metal.00096", | |||
| "metal.00097", "metal.00098", "metal.00099", "pop.00002", "pop.00003", "pop.00004", | |||
| "pop.00005", "pop.00006", "pop.00007", "pop.00008", "pop.00009", "pop.00011", | |||
| "pop.00012", "pop.00016", "pop.00017", "pop.00018", "pop.00019", "pop.00020", | |||
| "pop.00023", "pop.00024", "pop.00025", "pop.00026", "pop.00027", "pop.00028", | |||
| "pop.00029", "pop.00031", "pop.00032", "pop.00033", "pop.00034", "pop.00035", | |||
| "pop.00036", "pop.00038", "pop.00039", "pop.00040", "pop.00041", "pop.00042", | |||
| "pop.00044", "pop.00046", "pop.00049", "pop.00050", "pop.00080", "pop.00097", | |||
| "pop.00098", "pop.00099", "reggae.00000", "reggae.00001", "reggae.00002", "reggae.00004", | |||
| "reggae.00006", "reggae.00009", "reggae.00011", "reggae.00012", "reggae.00014", "reggae.00015", | |||
| "reggae.00016", "reggae.00017", "reggae.00018", "reggae.00019", "reggae.00020", "reggae.00021", | |||
| "reggae.00022", "reggae.00023", "reggae.00024", "reggae.00025", "reggae.00026", "reggae.00027", | |||
| "reggae.00028", "reggae.00029", "reggae.00030", "reggae.00031", "reggae.00032", "reggae.00042", | |||
| "reggae.00043", "reggae.00044", "reggae.00045", "reggae.00049", "reggae.00050", "reggae.00051", | |||
| "reggae.00054", "reggae.00055", "reggae.00056", "reggae.00057", "reggae.00058", "reggae.00059", | |||
| "reggae.00060", "reggae.00063", "reggae.00069", "rock.00000", "rock.00001", "rock.00002", | |||
| "rock.00003", "rock.00004", "rock.00005", "rock.00006", "rock.00007", "rock.00008", | |||
| "rock.00009", "rock.00016", "rock.00017", "rock.00018", "rock.00019", "rock.00020", | |||
| "rock.00021", "rock.00022", "rock.00023", "rock.00024", "rock.00025", "rock.00026", | |||
| "rock.00057", "rock.00058", "rock.00059", "rock.00060", "rock.00061", "rock.00062", | |||
| "rock.00063", "rock.00064", "rock.00065", "rock.00066", "rock.00067", "rock.00068", | |||
| "rock.00069", "rock.00070", "rock.00091", "rock.00092", "rock.00093", "rock.00094", | |||
| "rock.00095", "rock.00096", "rock.00097", "rock.00098", "rock.00099", | |||
| }; | |||
| const std::vector<std::string> filtered_valid = { | |||
| "blues.00000", "blues.00001", "blues.00002", "blues.00003", "blues.00004", "blues.00005", | |||
| "blues.00006", "blues.00007", "blues.00008", "blues.00009", "blues.00010", "blues.00011", | |||
| "blues.00050", "blues.00051", "blues.00052", "blues.00053", "blues.00054", "blues.00055", | |||
| "blues.00056", "blues.00057", "blues.00058", "blues.00059", "blues.00060", "classical.00000", | |||
| "classical.00001", "classical.00002", "classical.00003", "classical.00004", "classical.00005", "classical.00006", | |||
| "classical.00007", "classical.00008", "classical.00009", "classical.00010", "classical.00068", "classical.00069", | |||
| "classical.00070", "classical.00071", "classical.00072", "classical.00073", "classical.00074", "classical.00075", | |||
| "classical.00076", "country.00000", "country.00001", "country.00002", "country.00003", "country.00004", | |||
| "country.00005", "country.00006", "country.00007", "country.00009", "country.00010", "country.00011", | |||
| "country.00012", "country.00013", "country.00014", "country.00015", "country.00016", "country.00017", | |||
| "country.00018", "country.00027", "country.00041", "country.00042", "country.00045", "country.00049", | |||
| "disco.00000", "disco.00002", "disco.00003", "disco.00004", "disco.00006", "disco.00007", | |||
| "disco.00008", "disco.00009", "disco.00010", "disco.00011", "disco.00012", "disco.00013", | |||
| "disco.00014", "disco.00046", "disco.00048", "disco.00052", "disco.00067", "disco.00068", | |||
| "disco.00072", "disco.00075", "disco.00090", "disco.00095", "hiphop.00081", "hiphop.00082", | |||
| "hiphop.00083", "hiphop.00084", "hiphop.00085", "hiphop.00086", "hiphop.00087", "hiphop.00088", | |||
| "hiphop.00089", "hiphop.00090", "hiphop.00091", "hiphop.00092", "hiphop.00093", "hiphop.00094", | |||
| "hiphop.00095", "hiphop.00096", "hiphop.00097", "hiphop.00098", "jazz.00002", "jazz.00003", | |||
| "jazz.00004", "jazz.00005", "jazz.00006", "jazz.00007", "jazz.00008", "jazz.00009", | |||
| "jazz.00010", "jazz.00025", "jazz.00026", "jazz.00027", "jazz.00028", "jazz.00029", | |||
| "jazz.00030", "jazz.00031", "jazz.00032", "metal.00000", "metal.00001", "metal.00006", | |||
| "metal.00007", "metal.00008", "metal.00009", "metal.00010", "metal.00011", "metal.00016", | |||
| "metal.00017", "metal.00018", "metal.00019", "metal.00020", "metal.00036", "metal.00037", | |||
| "metal.00068", "metal.00076", "metal.00077", "metal.00081", "metal.00082", "pop.00010", | |||
| "pop.00053", "pop.00055", "pop.00058", "pop.00059", "pop.00060", "pop.00061", | |||
| "pop.00062", "pop.00081", "pop.00083", "pop.00084", "pop.00085", "pop.00086", | |||
| "reggae.00061", "reggae.00062", "reggae.00070", "reggae.00072", "reggae.00074", "reggae.00076", | |||
| "reggae.00077", "reggae.00078", "reggae.00085", "reggae.00092", "reggae.00093", "reggae.00094", | |||
| "reggae.00095", "reggae.00096", "reggae.00097", "reggae.00098", "reggae.00099", "rock.00038", | |||
| "rock.00049", "rock.00050", "rock.00051", "rock.00052", "rock.00053", "rock.00054", | |||
| "rock.00055", "rock.00056", "rock.00071", "rock.00072", "rock.00073", "rock.00074", | |||
| "rock.00075", "rock.00076", "rock.00077", "rock.00078", "rock.00079", "rock.00080", | |||
| "rock.00081", "rock.00082", "rock.00083", "rock.00084", "rock.00085", | |||
| }; | |||
| GTZANOp::GTZANOp(const std::string &usage, int32_t num_workers, const std::string &folder_path, int32_t queue_size, | |||
| std::unique_ptr<DataSchema> data_schema, std::shared_ptr<SamplerRT> sampler) | |||
| : MappableLeafOp(num_workers, queue_size, std::move(sampler)), | |||
| usage_(usage), | |||
| folder_path_(folder_path), | |||
| data_schema_(std::move(data_schema)) {} | |||
| Status GTZANOp::LoadTensorRow(row_id_type row_id, TensorRow *trow) { | |||
| RETURN_UNEXPECTED_IF_NULL(trow); | |||
| const uint32_t sample_rate = 22050; | |||
| std::shared_ptr<Tensor> waveform, rate, label; | |||
| RETURN_IF_NOT_OK(ReadAudio(audio_names_[row_id].first, &waveform)); | |||
| RETURN_IF_NOT_OK(Tensor::CreateScalar(sample_rate, &rate)); | |||
| RETURN_IF_NOT_OK(Tensor::CreateScalar(audio_names_[row_id].second, &label)); | |||
| (*trow) = TensorRow(row_id, {std::move(waveform), std::move(rate), std::move(label)}); | |||
| trow->setPath({audio_names_[row_id].first, audio_names_[row_id].first, audio_names_[row_id].first}); | |||
| return Status::OK(); | |||
| } | |||
| void GTZANOp::Print(std::ostream &out, bool show_all) const { | |||
| if (!show_all) { | |||
| ParallelOp::Print(out, show_all); | |||
| out << "\n"; | |||
| return; | |||
| } | |||
| ParallelOp::Print(out, show_all); | |||
| out << "\nNumber of rows: " << num_rows_ << "\nGTZAN directory: " << folder_path_ << "\n\n"; | |||
| } | |||
| Status GTZANOp::CountTotalRows(const std::string &dir, const std::string &usage, int64_t *count) { | |||
| RETURN_UNEXPECTED_IF_NULL(count); | |||
| *count = 0; | |||
| const int64_t num_samples = 0; | |||
| const int64_t start_index = 0; | |||
| auto sampler = std::make_shared<SequentialSamplerRT>(start_index, num_samples); | |||
| auto schema = std::make_unique<DataSchema>(); | |||
| RETURN_IF_NOT_OK(schema->AddColumn(ColDescriptor("waveform", DataType(DataType::DE_FLOAT32), TensorImpl::kCv, 1))); | |||
| TensorShape scalar_rate = TensorShape::CreateScalar(); | |||
| RETURN_IF_NOT_OK(schema->AddColumn( | |||
| ColDescriptor("sample_rate", DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 0, &scalar_rate))); | |||
| TensorShape scalar_label = TensorShape::CreateScalar(); | |||
| RETURN_IF_NOT_OK( | |||
| schema->AddColumn(ColDescriptor("label", DataType(DataType::DE_STRING), TensorImpl::kFlexible, 0, &scalar_label))); | |||
| std::shared_ptr<ConfigManager> cfg = GlobalContext::config_manager(); | |||
| int32_t num_workers = cfg->num_parallel_workers(); | |||
| int32_t op_connect_size = cfg->op_connector_size(); | |||
| auto op = std::make_shared<GTZANOp>(usage, num_workers, dir, op_connect_size, std::move(schema), std::move(sampler)); | |||
| RETURN_IF_NOT_OK(op->PrepareData()); | |||
| *count = op->audio_names_.size(); | |||
| return Status::OK(); | |||
| } | |||
| Status GTZANOp::ComputeColMap() { | |||
| if (column_name_id_map_.empty()) { | |||
| for (int32_t i = 0; i < data_schema_->NumColumns(); ++i) { | |||
| column_name_id_map_[data_schema_->Column(i).Name()] = i; | |||
| } | |||
| } else { | |||
| MS_LOG(WARNING) << "Column name map is already set!"; | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| Status GTZANOp::ReadAudio(const std::string &audio_dir, std::shared_ptr<Tensor> *waveform) { | |||
| RETURN_UNEXPECTED_IF_NULL(waveform); | |||
| const int32_t kWavFileSampleRate = 22050; | |||
| int32_t sample_rate = 0; | |||
| std::vector<float> waveform_vec; | |||
| RETURN_IF_NOT_OK(ReadWaveFile(audio_dir, &waveform_vec, &sample_rate)); | |||
| CHECK_FAIL_RETURN_UNEXPECTED(sample_rate == kWavFileSampleRate, | |||
| "Invalid file, sampling rate of GTZAN wav file must be 22050, file path: " + audio_dir); | |||
| RETURN_IF_NOT_OK(Tensor::CreateFromVector(waveform_vec, waveform)); | |||
| RETURN_IF_NOT_OK((*waveform)->ExpandDim(0)); | |||
| return Status::OK(); | |||
| } | |||
| Status GTZANOp::PrepareData() { | |||
| auto realpath = FileUtils::GetRealPath(folder_path_.data()); | |||
| if (!realpath.has_value()) { | |||
| MS_LOG(ERROR) << "Invalid file path, GTZAN Dataset dir: " << folder_path_ << " does not exist."; | |||
| RETURN_STATUS_UNEXPECTED("Invalid file path, GTZAN Dataset dir: " + folder_path_ + " does not exist."); | |||
| } | |||
| Path dir(folder_path_); | |||
| if (usage_ == "all") { | |||
| for (std::string sub_directory : genres) { | |||
| Path full_dir = dir / sub_directory; | |||
| if (!full_dir.Exists() || !full_dir.IsDirectory()) { | |||
| continue; | |||
| } | |||
| auto dir_it = Path::DirIterator::OpenDirectory(&full_dir); | |||
| if (dir_it != nullptr) { | |||
| while (dir_it->HasNext()) { | |||
| Path file = dir_it->Next(); | |||
| std::string file_name = file.ToString(); | |||
| auto pos = file_name.find_last_of('.'); | |||
| std::string name = file_name.substr(0, pos), temp_ext = file_name.substr(pos); | |||
| if (temp_ext == ".wav" && name.find('.') != std::string::npos) { | |||
| audio_names_.push_back({file.ToString(), sub_directory}); | |||
| } else { | |||
| MS_LOG(WARNING) << "Invalid file, invalid file name or file type: " << file.ToString() << "."; | |||
| } | |||
| } | |||
| } else { | |||
| MS_LOG(WARNING) << "Invalid file path, unable to open directory: " << full_dir.ToString() << "."; | |||
| } | |||
| } | |||
| } else { | |||
| const std::vector<std::string> *files_point = nullptr; | |||
| if (usage_ == "test") { | |||
| files_point = &filtered_test; | |||
| } else if (usage_ == "train") { | |||
| files_point = &filtered_train; | |||
| } else { | |||
| files_point = &filtered_valid; | |||
| } | |||
| std::string ext = ".wav"; | |||
| for (auto sub_file_name : *files_point) { | |||
| auto pos = sub_file_name.find_first_of('.'); | |||
| std::string cls = sub_file_name.substr(0, pos); | |||
| Path full_dir = dir / cls / (sub_file_name + ext); | |||
| if (full_dir.Exists()) { | |||
| audio_names_.push_back({full_dir.ToString(), cls}); | |||
| } else { | |||
| MS_LOG(WARNING) << "The audio file is lost, file name= " << (sub_file_name + ext); | |||
| } | |||
| } | |||
| } | |||
| num_rows_ = audio_names_.size(); | |||
| CHECK_FAIL_RETURN_UNEXPECTED(num_rows_ > 0, "Invalid data, no valid data found in path:" + folder_path_); | |||
| return Status::OK(); | |||
| } | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,97 @@ | |||
| /** | |||
| * Copyright 2022 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_SOURCE_GTZAN_OP_H_ | |||
| #define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_SOURCE_GTZAN_OP_H_ | |||
| #include <algorithm> | |||
| #include <map> | |||
| #include <memory> | |||
| #include <string> | |||
| #include <utility> | |||
| #include <vector> | |||
| #include "minddata/dataset/core/tensor.h" | |||
| #include "minddata/dataset/engine/data_schema.h" | |||
| #include "minddata/dataset/engine/datasetops/parallel_op.h" | |||
| #include "minddata/dataset/engine/datasetops/source/mappable_leaf_op.h" | |||
| #include "minddata/dataset/engine/datasetops/source/sampler/sampler.h" | |||
| #include "minddata/dataset/util/path.h" | |||
| #include "minddata/dataset/util/queue.h" | |||
| #include "minddata/dataset/util/status.h" | |||
| #include "minddata/dataset/util/wait_post.h" | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| class GTZANOp : public MappableLeafOp { | |||
| public: | |||
| /// \brief Constructor | |||
| /// \param[in] usage Usage of this dataset, can be 'train', 'valid', 'test', or 'all'. | |||
| /// \param[in] num_workers Number of workers reading audios in parallel. | |||
| /// \param[in] folder_path Dir directory of GTZAN. | |||
| /// \param[in] queue_size Connector queue size. | |||
| /// \param[in] data_schema The schema of the GTZAN dataset. | |||
| /// \param[in] sampler Sampler tells GTZANOp what to read. | |||
| GTZANOp(const std::string &usage, int32_t num_workers, const std::string &folder_path, int32_t queue_size, | |||
| std::unique_ptr<DataSchema> data_schema, std::shared_ptr<SamplerRT> sampler); | |||
| /// \Destructor. | |||
| ~GTZANOp() = default; | |||
| /// \A print method typically used for debugging. | |||
| /// \param[out] out Output stream. | |||
| /// \param[in] show_all Whether to show all information. | |||
| void Print(std::ostream &out, bool show_all) const override; | |||
| /// \Function to count the number of samples in the GTZAN dataset. | |||
| /// \param[in] dir Path to the GTZAN directory. | |||
| /// \param[in] usage Choose the subset of GTZAN dataset. | |||
| /// \param[out] count Output arg that will hold the actual dataset size. | |||
| /// \return Status The status code returned. | |||
| static Status CountTotalRows(const std::string &dir, const std::string &usage, int64_t *count); | |||
| /// \Op name getter. | |||
| /// \return Name of the current Op. | |||
| std::string Name() const override { return "GTZANOp"; } | |||
| private: | |||
| /// \Load a tensor row according to a pair. | |||
| /// \param[in] row_id Id for this tensor row. | |||
| /// \param[out] row Audio & label read into this tensor row. | |||
| /// \return Status The status code returned. | |||
| Status LoadTensorRow(row_id_type row_id, TensorRow *row) override; | |||
| /// \Parse a audio file. | |||
| /// \param[in] audio_dir Audio file path. | |||
| /// \param[out] waveform The output waveform tensor. | |||
| /// \return Status The status code returned. | |||
| Status ReadAudio(const std::string &audio_dir, std::shared_ptr<Tensor> *waveform); | |||
| /// \Prepare data. | |||
| /// \return Status The status code returned. | |||
| Status PrepareData(); | |||
| /// \Private function for computing the assignment of the column name map. | |||
| /// \return Status The status code returned. | |||
| Status ComputeColMap() override; | |||
| const std::string usage_; | |||
| std::string folder_path_; | |||
| std::unique_ptr<DataSchema> data_schema_; | |||
| std::vector<std::pair<std::string, std::string>> audio_names_; | |||
| }; | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_SOURCE_GTZAN_OP_H_ | |||
| @@ -96,6 +96,7 @@ constexpr char kFakeImageNode[] = "FakeImageDataset"; | |||
| constexpr char kFashionMnistNode[] = "FashionMnistDataset"; | |||
| constexpr char kFlickrNode[] = "FlickrDataset"; | |||
| constexpr char kGeneratorNode[] = "GeneratorDataset"; | |||
| constexpr char kGTZANNode[] = "GTZANDataset"; | |||
| constexpr char kImageFolderNode[] = "ImageFolderDataset"; | |||
| constexpr char kIMDBNode[] = "IMDBDataset"; | |||
| constexpr char kIWSLT2016Node[] = "IWSLT2016Dataset"; | |||
| @@ -22,6 +22,7 @@ set(DATASET_ENGINE_IR_DATASETOPS_SOURCE_SRC_FILES | |||
| fake_image_node.cc | |||
| fashion_mnist_node.cc | |||
| flickr_node.cc | |||
| gtzan_node.cc | |||
| image_folder_node.cc | |||
| imdb_node.cc | |||
| iwslt2016_node.cc | |||
| @@ -0,0 +1,110 @@ | |||
| /** | |||
| * Copyright 2022 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "minddata/dataset/engine/ir/datasetops/source/gtzan_node.h" | |||
| #include <memory> | |||
| #include <string> | |||
| #include <utility> | |||
| #include <vector> | |||
| #include "minddata/dataset/engine/datasetops/source/gtzan_op.h" | |||
| #include "minddata/dataset/util/status.h" | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| GTZANNode::GTZANNode(const std::string &dataset_dir, const std::string &usage, std::shared_ptr<SamplerObj> sampler, | |||
| std::shared_ptr<DatasetCache> cache) | |||
| : MappableSourceNode(std::move(cache)), dataset_dir_(dataset_dir), usage_(usage), sampler_(sampler) {} | |||
| void GTZANNode::Print(std::ostream &out) const { out << Name(); } | |||
| std::shared_ptr<DatasetNode> GTZANNode::Copy() { | |||
| std::shared_ptr<SamplerObj> sampler = (sampler_ == nullptr) ? nullptr : sampler_->SamplerCopy(); | |||
| auto node = std::make_shared<GTZANNode>(dataset_dir_, usage_, sampler, cache_); | |||
| return node; | |||
| } | |||
| Status GTZANNode::ValidateParams() { | |||
| RETURN_IF_NOT_OK(DatasetNode::ValidateParams()); | |||
| RETURN_IF_NOT_OK(ValidateDatasetDirParam("GTZANDataset", dataset_dir_)); | |||
| RETURN_IF_NOT_OK(ValidateDatasetSampler("GTZANDataset", sampler_)); | |||
| RETURN_IF_NOT_OK(ValidateStringValue("GTZANDataset", usage_, {"train", "valid", "test", "all"})); | |||
| return Status::OK(); | |||
| } | |||
| Status GTZANNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops) { | |||
| // Do internal Schema generation. | |||
| auto schema = std::make_unique<DataSchema>(); | |||
| RETURN_IF_NOT_OK(schema->AddColumn(ColDescriptor("waveform", DataType(DataType::DE_FLOAT64), TensorImpl::kCv, 1))); | |||
| TensorShape scalar_rate = TensorShape::CreateScalar(); | |||
| RETURN_IF_NOT_OK(schema->AddColumn( | |||
| ColDescriptor("sample_rate", DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 0, &scalar_rate))); | |||
| TensorShape scalar_label = TensorShape::CreateScalar(); | |||
| RETURN_IF_NOT_OK( | |||
| schema->AddColumn(ColDescriptor("label", DataType(DataType::DE_STRING), TensorImpl::kFlexible, 0, &scalar_label))); | |||
| std::shared_ptr<SamplerRT> sampler_rt = nullptr; | |||
| RETURN_IF_NOT_OK(sampler_->SamplerBuild(&sampler_rt)); | |||
| auto op = std::make_shared<GTZANOp>(usage_, num_workers_, dataset_dir_, connector_que_size_, std::move(schema), | |||
| std::move(sampler_rt)); | |||
| op->SetTotalRepeats(GetTotalRepeats()); | |||
| op->SetNumRepeatsPerEpoch(GetNumRepeatsPerEpoch()); | |||
| node_ops->push_back(op); | |||
| return Status::OK(); | |||
| } | |||
| // Get the shard id of node. | |||
| Status GTZANNode::GetShardId(int32_t *shard_id) { | |||
| *shard_id = sampler_->ShardId(); | |||
| return Status::OK(); | |||
| } | |||
| // Get Dataset size. | |||
| Status GTZANNode::GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size_getter, bool estimate, | |||
| int64_t *dataset_size) { | |||
| if (dataset_size_ > 0) { | |||
| *dataset_size = dataset_size_; | |||
| return Status::OK(); | |||
| } | |||
| int64_t num_rows, sample_size; | |||
| RETURN_IF_NOT_OK(GTZANOp::CountTotalRows(dataset_dir_, usage_, &num_rows)); | |||
| std::shared_ptr<SamplerRT> sampler_rt = nullptr; | |||
| RETURN_IF_NOT_OK(sampler_->SamplerBuild(&sampler_rt)); | |||
| sample_size = sampler_rt->CalculateNumSamples(num_rows); | |||
| if (sample_size == -1) { | |||
| RETURN_IF_NOT_OK(size_getter->DryRun(shared_from_this(), &sample_size)); | |||
| } | |||
| *dataset_size = sample_size; | |||
| dataset_size_ = *dataset_size; | |||
| return Status::OK(); | |||
| } | |||
| Status GTZANNode::to_json(nlohmann::json *out_json) { | |||
| nlohmann::json args, sampler_args; | |||
| RETURN_IF_NOT_OK(sampler_->to_json(&sampler_args)); | |||
| args["sampler"] = sampler_args; | |||
| args["num_parallel_workers"] = num_workers_; | |||
| args["dataset_dir"] = dataset_dir_; | |||
| args["usage"] = usage_; | |||
| if (cache_ != nullptr) { | |||
| nlohmann::json cache_args; | |||
| RETURN_IF_NOT_OK(cache_->to_json(&cache_args)); | |||
| args["cache"] = cache_args; | |||
| } | |||
| *out_json = args; | |||
| return Status::OK(); | |||
| } | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,95 @@ | |||
| /** | |||
| * Copyright 2022 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_GTZAN_NODE_H_ | |||
| #define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_GTZAN_NODE_H_ | |||
| #include <memory> | |||
| #include <string> | |||
| #include <vector> | |||
| #include "minddata/dataset/engine/ir/datasetops/dataset_node.h" | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| class GTZANNode : public MappableSourceNode { | |||
| public: | |||
| /// \brief Constructor | |||
| GTZANNode(const std::string &dataset_dir, const std::string &usage, std::shared_ptr<SamplerObj> sampler, | |||
| std::shared_ptr<DatasetCache> cache); | |||
| /// \brief Destructor | |||
| ~GTZANNode() = default; | |||
| /// \brief Node name getter. | |||
| /// \return Name of the current node. | |||
| std::string Name() const override { return "kGTZANNode"; } | |||
| /// \brief Print the description. | |||
| /// \param out The output stream to write output to. | |||
| void Print(std::ostream &out) const override; | |||
| /// \brief Copy the node to a new object. | |||
| /// \return A shared pointer to the new copy. | |||
| std::shared_ptr<DatasetNode> Copy() override; | |||
| /// \brief a base class override function to create the required runtime dataset op objects for this class. | |||
| /// \param node_ops A vector containing shared pointer to the Dataset Ops that this object will create. | |||
| /// \return Status Status::OK() if build successfully. | |||
| Status Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops) override; | |||
| /// \brief Parameters validation. | |||
| /// \return Status Status::OK() if all the parameters are valid. | |||
| Status ValidateParams() override; | |||
| /// \brief Get the shard id of node. | |||
| /// \param[in] shard_id The shard ID within num_shards. | |||
| /// \return Status Status::OK() if get shard id successfully. | |||
| Status GetShardId(int32_t *shard_id) override; | |||
| /// \brief Base-class override for GetDatasetSize. | |||
| /// \param[in] size_getter Shared pointer to DatasetSizeGetter. | |||
| /// \param[in] estimate This is only supported by some of the ops and it's used to speed up the process of getting | |||
| /// dataset size at the expense of accuracy. | |||
| /// \param[out] dataset_size the size of the dataset. | |||
| /// \return Status of the function. | |||
| Status GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size_getter, bool estimate, | |||
| int64_t *dataset_size) override; | |||
| /// \brief Getter functions. | |||
| const std::string &DatasetDir() const { return dataset_dir_; } | |||
| const std::string &Usage() const { return usage_; } | |||
| /// \brief Get the arguments of node. | |||
| /// \param[out] out_json JSON string of all attributes. | |||
| /// \return Status of the function. | |||
| Status to_json(nlohmann::json *out_json) override; | |||
| /// \brief Sampler getter. | |||
| /// \return SamplerObj of the current node. | |||
| std::shared_ptr<SamplerObj> Sampler() override { return sampler_; } | |||
| /// \brief Sampler setter. | |||
| /// \param[in] sampler Tells GTZANOp what to read. | |||
| void SetSampler(std::shared_ptr<SamplerObj> sampler) override { sampler_ = sampler; } | |||
| private: | |||
| std::string dataset_dir_; | |||
| std::string usage_; | |||
| std::shared_ptr<SamplerObj> sampler_; | |||
| }; | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_GTZAN_NODE_H_ | |||
| @@ -2647,6 +2647,93 @@ inline std::shared_ptr<FlickrDataset> MS_API Flickr(const std::string &dataset_d | |||
| cache); | |||
| } | |||
| /// \class GTZANDataset | |||
| /// \brief A source dataset for reading and parsing GTZAN dataset. | |||
| class MS_API GTZANDataset : public Dataset { | |||
| public: | |||
| /// \brief Constructor of GTZANDataset. | |||
| /// \param[in] dataset_dir Path to the root directory that contains the dataset. | |||
| /// \param[in] usage Part of dataset of GTZAN, can be "train", "valid", "test", or "all" (default = "all"). | |||
| /// \param[in] sampler Shared pointer to a sampler object used to choose samples from the dataset. | |||
| /// \param[in] cache Tensor cache to use. | |||
| GTZANDataset(const std::vector<char> &dataset_dir, const std::vector<char> &usage, | |||
| const std::shared_ptr<Sampler> &sampler, const std::shared_ptr<DatasetCache> &cache); | |||
| /// \brief Constructor of GTZANDataset. | |||
| /// \param[in] dataset_dir Path to the root directory that contains the dataset. | |||
| /// \param[in] usage Part of dataset of GTZAN, can be "train", "valid", "test", or "all". | |||
| /// \param[in] sampler Raw pointer to a sampler object used to choose samples from the dataset. | |||
| /// \param[in] cache Tensor cache to use. | |||
| GTZANDataset(const std::vector<char> &dataset_dir, const std::vector<char> &usage, const Sampler *sampler, | |||
| const std::shared_ptr<DatasetCache> &cache); | |||
| /// \brief Constructor of GTZANDataset. | |||
| /// \param[in] dataset_dir Path to the root directory that contains the dataset. | |||
| /// \param[in] usage Part of dataset of GTZAN, can be "train", "valid", "test", or "all". | |||
| /// \param[in] sampler Sampler object used to choose samples from the dataset. | |||
| /// \param[in] cache Tensor cache to use. | |||
| GTZANDataset(const std::vector<char> &dataset_dir, const std::vector<char> &usage, | |||
| const std::reference_wrapper<Sampler> &sampler, const std::shared_ptr<DatasetCache> &cache); | |||
| /// \brief Destructor of GTZANDataset. | |||
| ~GTZANDataset() = default; | |||
| }; | |||
| /// \brief Function to create a GTZANDataset. | |||
| /// \note The generated dataset has three columns ["waveform", "sample_rate", "label"]. | |||
| /// \param[in] dataset_dir Path to the root directory that contains the dataset. | |||
| /// \param[in] usage Part of dataset of GTZAN, can be "train", "valid", "test", or "all" (default = "all"). | |||
| /// \param[in] sampler Shared pointer to a sampler object used to choose samples from the dataset. If sampler is not | |||
| /// given, a `RandomSampler` will be used to randomly iterate the entire dataset (default = RandomSampler()). | |||
| /// \param[in] cache Tensor cache to use (default=nullptr, which means no cache is used). | |||
| /// \return Shared pointer to the GTZANDataset. | |||
| /// \code | |||
| /// /* Define dataset path and MindData object */ | |||
| /// std::string folder_path = "/path/to/gtzan_dataset_directory"; | |||
| /// std::shared_ptr<Dataset> ds = | |||
| /// GTZANDataset(folder_path, usage = "all", std::make_shared<RandomSampler>(false, 10)); | |||
| /// | |||
| /// /* Create iterator to read dataset */ | |||
| /// std::shared_ptr<Iterator> iter = ds->CreateIterator(); | |||
| /// std::unordered_map<std::string, mindspore::MSTensor> row; | |||
| /// iter->GetNextRow(&row); | |||
| /// | |||
| /// /* Note: In GTZAN dataset, each data dictionary has keys "waveform", "sample_rate" and "label" */ | |||
| /// auto waveform = row["waveform"]; | |||
| /// \endcode | |||
| inline std::shared_ptr<GTZANDataset> MS_API | |||
| GTZAN(const std::string &dataset_dir, const std::string &usage = "all", | |||
| const std::shared_ptr<Sampler> &sampler = std::make_shared<RandomSampler>(), | |||
| const std::shared_ptr<DatasetCache> &cache = nullptr) { | |||
| return std::make_shared<GTZANDataset>(StringToChar(dataset_dir), StringToChar(usage), sampler, cache); | |||
| } | |||
| /// \brief Function to create a GTZANDataset. | |||
| /// \note The generated dataset has three columns ["waveform", "sample_rate", "label"]. | |||
| /// \param[in] dataset_dir Path to the root directory that contains the dataset. | |||
| /// \param[in] usage Part of dataset of GTZAN, can be "train", "valid", "test", or "all". | |||
| /// \param[in] sampler Raw pointer to a sampler object used to choose samples from the dataset. | |||
| /// \param[in] cache Tensor cache to use (default=nullptr, which means no cache is used). | |||
| /// \return Shared pointer to the GTZANDataset. | |||
| inline std::shared_ptr<GTZANDataset> MS_API GTZAN(const std::string &dataset_dir, const std::string &usage, | |||
| const Sampler *sampler, | |||
| const std::shared_ptr<DatasetCache> &cache = nullptr) { | |||
| return std::make_shared<GTZANDataset>(StringToChar(dataset_dir), StringToChar(usage), sampler, cache); | |||
| } | |||
| /// \brief Function to create a GTZANDataset. | |||
| /// \note The generated dataset has three columns ["waveform", "sample_rate", "label"]. | |||
| /// \param[in] dataset_dir Path to the root directory that contains the dataset. | |||
| /// \param[in] usage Part of dataset of GTZAN, can be "train", "valid", "test", or "all". | |||
| /// \param[in] sampler Raw pointer to a sampler object used to choose samples from the dataset. | |||
| /// \param[in] cache Tensor cache to use (default=nullptr, which means no cache is used). | |||
| /// \return Shared pointer to the GTZANDataset. | |||
| inline std::shared_ptr<GTZANDataset> MS_API GTZAN(const std::string &dataset_dir, const std::string &usage, | |||
| const std::reference_wrapper<Sampler> sampler, | |||
| const std::shared_ptr<DatasetCache> &cache = nullptr) { | |||
| return std::make_shared<GTZANDataset>(StringToChar(dataset_dir), StringToChar(usage), sampler, cache); | |||
| } | |||
| /// \class ImageFolderDataset | |||
| /// \brief A source dataset that reads images from a tree of directories. | |||
| class MS_API ImageFolderDataset : public Dataset { | |||
| @@ -45,6 +45,7 @@ class MS_API Sampler : std::enable_shared_from_this<Sampler> { | |||
| friend class FakeImageDataset; | |||
| friend class FashionMnistDataset; | |||
| friend class FlickrDataset; | |||
| friend class GTZANDataset; | |||
| friend class ImageFolderDataset; | |||
| friend class IMDBDataset; | |||
| friend class KMnistDataset; | |||
| @@ -81,6 +81,7 @@ __all__ = ["Caltech101Dataset", # Vision | |||
| "WikiTextDataset", # Text | |||
| "YahooAnswersDataset", # Text | |||
| "YelpReviewDataset", # Text | |||
| "GTZANDataset", # Audio | |||
| "LJSpeechDataset", # Audio | |||
| "SpeechCommandsDataset", # Audio | |||
| "TedliumDataset", # Audio | |||
| @@ -26,12 +26,142 @@ After declaring the dataset object, you can further apply dataset operations | |||
| import mindspore._c_dataengine as cde | |||
| from .datasets import AudioBaseDataset, MappableDataset | |||
| from .validators import check_lj_speech_dataset, check_yes_no_dataset, check_speech_commands_dataset, \ | |||
| check_tedlium_dataset | |||
| from .validators import check_gtzan_dataset, check_lj_speech_dataset, check_speech_commands_dataset, check_tedlium_dataset, \ | |||
| check_yes_no_dataset | |||
| from ..core.validator_helpers import replace_none | |||
| class GTZANDataset(MappableDataset, AudioBaseDataset): | |||
| """ | |||
| A source dataset that reads and parses GTZAN dataset. | |||
| The generated dataset has three columns: :py:obj:`["waveform", "sample_rate", "label"]`. | |||
| The tensor of column :py:obj:`waveform` is of the float32 type. | |||
| The tensor of column :py:obj:`sample_rate` is of a scalar of uint32 type. | |||
| The tensor of column :py:obj:`label` is of a scalar of string type. | |||
| Args: | |||
| dataset_dir (str): Path to the root directory that contains the dataset. | |||
| usage (str, optional): Usage of this dataset, can be "train", "valid", "test" or "all" | |||
| (default=None, all samples). | |||
| num_samples (int, optional): The number of audio to be included in the dataset | |||
| (default=None, will read all audio). | |||
| num_parallel_workers (int, optional): Number of workers to read the data | |||
| (default=None, will use value set in the config). | |||
| shuffle (bool, optional): Whether or not to perform shuffle on the dataset | |||
| (default=None, expected order behavior shown in the table). | |||
| sampler (Sampler, optional): Object used to choose samples from the | |||
| dataset (default=None, expected order behavior shown in the table). | |||
| num_shards (int, optional): Number of shards that the dataset will be divided into (default=None). | |||
| When this argument is specified, `num_samples` reflects the max sample number of per shard. | |||
| shard_id (int, optional): The shard ID within `num_shards` (default=None). This | |||
| argument can only be specified when `num_shards` is also specified. | |||
| cache (DatasetCache, optional): Use tensor caching service to speed up dataset processing | |||
| (default=None, which means no cache is used). | |||
| Raises: | |||
| RuntimeError: If source raises an exception during execution. | |||
| RuntimeError: If dataset_dir does not contain data files. | |||
| RuntimeError: If num_parallel_workers exceeds the max thread numbers. | |||
| RuntimeError: If sampler and shuffle are specified at the same time. | |||
| RuntimeError: If sampler and sharding are specified at the same time. | |||
| RuntimeError: If num_shards is specified but shard_id is None. | |||
| RuntimeError: If shard_id is specified but num_shards is None. | |||
| ValueError: If shard_id is invalid (< 0 or >= num_shards). | |||
| Note: | |||
| - GTZAN doesn't support PKSampler. | |||
| - This dataset can take in a `sampler`. `sampler` and `shuffle` are mutually exclusive. | |||
| The table below shows what input arguments are allowed and their expected behavior. | |||
| .. list-table:: Expected Order Behavior of Using 'sampler' and 'shuffle' | |||
| :widths: 25 25 50 | |||
| :header-rows: 1 | |||
| * - Parameter `sampler` | |||
| - Parameter `shuffle` | |||
| - Expected Order Behavior | |||
| * - None | |||
| - None | |||
| - random order | |||
| * - None | |||
| - True | |||
| - random order | |||
| * - None | |||
| - False | |||
| - sequential order | |||
| * - Sampler object | |||
| - None | |||
| - order defined by sampler | |||
| * - Sampler object | |||
| - True | |||
| - not allowed | |||
| * - Sampler object | |||
| - False | |||
| - not allowed | |||
| Examples: | |||
| >>> gtzan_dataset_directory = "/path/to/gtzan_dataset_directory" | |||
| >>> | |||
| >>> # 1) Read 500 samples (audio files) in gtzan_dataset_directory | |||
| >>> dataset = ds.GTZANDataset(gtzan_dataset_directory, usage="all", num_samples=500) | |||
| >>> | |||
| >>> # 2) Read all samples (audio files) in gtzan_dataset_directory | |||
| >>> dataset = ds.GTZANDataset(gtzan_dataset_directory) | |||
| About GTZAN dataset: | |||
| The GTZAN dataset appears in at least 100 published works and is the most commonly used | |||
| public dataset for evaluation in machine listening research for music genre recognition. | |||
| It consists of 1000 audio tracks, each of which is 30 seconds long. It contains 10 genres (blues, | |||
| classical, country, disco, hiphop, jazz, metal, pop, reggae and reggae), each of which is | |||
| represented by 100 tracks. The tracks are all 22050Hz Mono 16-bit audio files in .wav format. | |||
| You can construct the following directory structure from GTZAN dataset and read by MindSpore's API. | |||
| .. code-block:: | |||
| . | |||
| └── gtzan_dataset_directory | |||
| ├── blues | |||
| │ ├──blues.00000.wav | |||
| │ ├──blues.00001.wav | |||
| │ ├──blues.00002.wav | |||
| │ ├──... | |||
| ├── disco | |||
| │ ├──disco.00000.wav | |||
| │ ├──disco.00001.wav | |||
| │ ├──disco.00002.wav | |||
| │ └──... | |||
| └──... | |||
| Citation: | |||
| .. code-block:: | |||
| @misc{tzanetakis_essl_cook_2001, | |||
| author = "Tzanetakis, George and Essl, Georg and Cook, Perry", | |||
| title = "Automatic Musical Genre Classification Of Audio Signals", | |||
| url = "http://ismir2001.ismir.net/pdf/tzanetakis.pdf", | |||
| publisher = "The International Society for Music Information Retrieval", | |||
| year = "2001" | |||
| } | |||
| """ | |||
| @check_gtzan_dataset | |||
| def __init__(self, dataset_dir, usage=None, num_samples=None, num_parallel_workers=None, shuffle=None, | |||
| sampler=None, num_shards=None, shard_id=None, cache=None): | |||
| super().__init__(num_parallel_workers=num_parallel_workers, sampler=sampler, num_samples=num_samples, | |||
| shuffle=shuffle, num_shards=num_shards, shard_id=shard_id, cache=cache) | |||
| self.dataset_dir = dataset_dir | |||
| self.usage = replace_none(usage, "all") | |||
| def parse(self, children=None): | |||
| return cde.GTZANNode(self.dataset_dir, self.usage, self.sampler) | |||
| class LJSpeechDataset(MappableDataset, AudioBaseDataset): | |||
| """ | |||
| A source dataset that reads and parses LJSpeech dataset. | |||
| @@ -34,6 +34,36 @@ from . import samplers | |||
| from . import cache_client | |||
| def check_gtzan_dataset(method): | |||
| """A wrapper that wraps a parameter checker around the original GTZANDataset.""" | |||
| @wraps(method) | |||
| def new_method(self, *args, **kwargs): | |||
| _, param_dict = parse_user_args(method, *args, **kwargs) | |||
| nreq_param_int = ['num_samples', 'num_parallel_workers', 'num_shards', 'shard_id'] | |||
| nreq_param_bool = ['shuffle'] | |||
| dataset_dir = param_dict.get('dataset_dir') | |||
| check_dir(dataset_dir) | |||
| usage = param_dict.get('usage') | |||
| if usage is not None: | |||
| check_valid_str(usage, ['train', 'valid', 'test', 'all'], "usage") | |||
| validate_dataset_param_value(nreq_param_int, param_dict, int) | |||
| validate_dataset_param_value(nreq_param_bool, param_dict, bool) | |||
| check_sampler_shuffle_shard_options(param_dict) | |||
| cache = param_dict.get('cache') | |||
| check_cache_option(cache) | |||
| return method(self, *args, **kwargs) | |||
| return new_method | |||
| def check_imagefolderdataset(method): | |||
| """A wrapper that wraps a parameter checker around the original Dataset(ImageFolderDataset).""" | |||
| @@ -32,6 +32,7 @@ SET(DE_UT_SRCS | |||
| c_api_dataset_fake_image_test.cc | |||
| c_api_dataset_fashion_mnist_test.cc | |||
| c_api_dataset_flickr_test.cc | |||
| c_api_dataset_gtzan_test.cc | |||
| c_api_dataset_imdb_test.cc | |||
| c_api_dataset_iterator_test.cc | |||
| c_api_dataset_iwslt_test.cc | |||
| @@ -0,0 +1,271 @@ | |||
| /** | |||
| * Copyright 2022 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "common/common.h" | |||
| #include "include/dataset/datasets.h" | |||
| #include "include/dataset/transforms.h" | |||
| using namespace mindspore::dataset; | |||
| using mindspore::dataset::Tensor; | |||
| class MindDataTestPipeline : public UT::DatasetOpTesting { | |||
| protected: | |||
| }; | |||
| /// Feature: GTZANDataset | |||
| /// Description: test GTZAN | |||
| /// Expectation: get correct GTZAN dataset | |||
| TEST_F(MindDataTestPipeline, TestGTZANBasic) { | |||
| MS_LOG(INFO) << "Doing MindDataTestPipeline-TestGTZANBasic."; | |||
| std::string file_path = datasets_root_path_ + "/testGTZANData"; | |||
| // Create a GTZAN Dataset | |||
| std::shared_ptr<Dataset> ds = GTZAN(file_path); | |||
| EXPECT_NE(ds, nullptr); | |||
| // Create an iterator over the result of the above dataset. | |||
| // This will trigger the creation of the Execution Tree and launch it. | |||
| std::shared_ptr<Iterator> iter = ds->CreateIterator(); | |||
| EXPECT_NE(iter, nullptr); | |||
| // Iterate the dataset and get each row. | |||
| std::unordered_map<std::string, mindspore::MSTensor> row; | |||
| ASSERT_OK(iter->GetNextRow(&row)); | |||
| std::string_view label_idx; | |||
| uint32_t rate = 0; | |||
| uint64_t i = 0; | |||
| while (row.size() != 0) { | |||
| i++; | |||
| auto waveform = row["waveform"]; | |||
| auto label = row["label"]; | |||
| auto sample_rate = row["sample_rate"]; | |||
| MS_LOG(INFO) << "Tensor waveform shape: " << waveform.Shape(); | |||
| std::shared_ptr<Tensor> trate; | |||
| ASSERT_OK(Tensor::CreateFromMSTensor(sample_rate, &trate)); | |||
| ASSERT_OK(trate->GetItemAt<uint32_t>(&rate, {})); | |||
| EXPECT_EQ(rate, 22050); | |||
| MS_LOG(INFO) << "Tensor label rate: " << rate; | |||
| std::shared_ptr<Tensor> de_label; | |||
| ASSERT_OK(Tensor::CreateFromMSTensor(label, &de_label)); | |||
| ASSERT_OK(de_label->GetItemAt(&label_idx, {})); | |||
| std::string s_label(label_idx); | |||
| std::string expected("blues"); | |||
| EXPECT_STREQ(s_label.c_str(), expected.c_str()); | |||
| MS_LOG(INFO) << "Tensor label value: " << label_idx; | |||
| ASSERT_OK(iter->GetNextRow(&row)); | |||
| } | |||
| EXPECT_EQ(i, 3); | |||
| // Manually terminate the pipeline | |||
| iter->Stop(); | |||
| } | |||
| /// Feature: GTZANDataset | |||
| /// Description: test GTZAN with Pipeline | |||
| /// Expectation: get correct GTZAN dataset | |||
| TEST_F(MindDataTestPipeline, TestGTZANBasicWithPipeline) { | |||
| MS_LOG(INFO) << "Doing DataSetOpBatchTest-TestGTZANBasicWithPipeline."; | |||
| // Create a GTZANDataset Dataset. | |||
| std::string folder_path = datasets_root_path_ + "/testGTZANData"; | |||
| std::shared_ptr<Dataset> ds = GTZAN(folder_path, "all", std::make_shared<RandomSampler>(false, 2)); | |||
| EXPECT_NE(ds, nullptr); | |||
| auto op = transforms::PadEnd({1, 50000}); | |||
| std::vector<std::string> input_columns = {"waveform"}; | |||
| std::vector<std::string> output_columns = {"waveform"}; | |||
| std::vector<std::string> project_columns = {"label", "waveform", "sample_rate"}; | |||
| ds = ds->Map({op}, input_columns, output_columns, project_columns); | |||
| EXPECT_NE(ds, nullptr); | |||
| ds = ds->Repeat(10); | |||
| EXPECT_NE(ds, nullptr); | |||
| ds = ds->Batch(5); | |||
| EXPECT_NE(ds, nullptr); | |||
| // Create an iterator over the result of the above dataset. | |||
| // This will trigger the creation of the Execution Tree and launch it. | |||
| std::shared_ptr<Iterator> iter = ds->CreateIterator(); | |||
| EXPECT_NE(iter, nullptr); | |||
| // Iterate the dataset and get each row. | |||
| std::unordered_map<std::string, mindspore::MSTensor> row; | |||
| iter->GetNextRow(&row); | |||
| std::vector<uint32_t> expected_rate = {22050, 22050, 22050, 22050, 22050}; | |||
| std::vector<std::string> expected_label = {"blues", "blues", "blues", "blues", "blues"}; | |||
| uint64_t i = 0; | |||
| while (row.size() != 0) { | |||
| i++; | |||
| auto waveform = row["waveform"]; | |||
| auto label = row["label"]; | |||
| auto sample_rate = row["sample_rate"]; | |||
| std::shared_ptr<Tensor> de_expected_rate; | |||
| ASSERT_OK(Tensor::CreateFromVector(expected_rate, &de_expected_rate)); | |||
| mindspore::MSTensor fix_expected_rate = | |||
| mindspore::MSTensor(std::make_shared<mindspore::dataset::DETensor>(de_expected_rate)); | |||
| EXPECT_MSTENSOR_EQ(sample_rate, fix_expected_rate); | |||
| std::shared_ptr<Tensor> de_expected_label; | |||
| ASSERT_OK(Tensor::CreateFromVector(expected_label, &de_expected_label)); | |||
| mindspore::MSTensor fix_expected_label = | |||
| mindspore::MSTensor(std::make_shared<mindspore::dataset::DETensor>(de_expected_label)); | |||
| EXPECT_MSTENSOR_EQ(label, fix_expected_label); | |||
| ASSERT_OK(iter->GetNextRow(&row)); | |||
| } | |||
| EXPECT_EQ(i, 4); | |||
| // Manually terminate the pipeline. | |||
| iter->Stop(); | |||
| } | |||
| /// Feature: GTZANDataset | |||
| /// Description: test GTZAN with invalid directory | |||
| /// Expectation: get correct GTZAN dataset | |||
| TEST_F(MindDataTestPipeline, TestGTZANError) { | |||
| MS_LOG(INFO) << "Doing MindDataTestPipeline-TestGTZANError."; | |||
| // Create a GTZAN Dataset with non-existing dataset dir. | |||
| std::shared_ptr<Dataset> ds0 = GTZAN("NotExistFile"); | |||
| EXPECT_NE(ds0, nullptr); | |||
| // Create an iterator over the result of the above dataset. | |||
| std::shared_ptr<Iterator> iter0 = ds0->CreateIterator(); | |||
| // Expect failure: invalid GTZAN30k input. | |||
| EXPECT_EQ(iter0, nullptr); | |||
| // Create a GTZAN Dataset with invalid string of dataset dir. | |||
| std::shared_ptr<Dataset> ds1 = GTZAN(":*?\"<>|`&;'"); | |||
| EXPECT_NE(ds1, nullptr); | |||
| // Create an iterator over the result of the above dataset. | |||
| std::shared_ptr<Iterator> iter1 = ds1->CreateIterator(); | |||
| // Expect failure: invalid GTZAN input. | |||
| EXPECT_EQ(iter1, nullptr); | |||
| } | |||
| /// Feature: GTZANDataset | |||
| /// Description: test GTZAN with Getters | |||
| /// Expectation: dataset is null | |||
| TEST_F(MindDataTestPipeline, TestGTZANGetters) { | |||
| MS_LOG(INFO) << "Doing MindDataTestPipeline-TestGTZANGetters."; | |||
| std::string folder_path = datasets_root_path_ + "/testGTZANData"; | |||
| // Create a GTZAN Dataset. | |||
| std::shared_ptr<Dataset> ds1 = GTZAN(folder_path); | |||
| std::shared_ptr<Dataset> ds2 = GTZAN(folder_path, "all"); | |||
| std::shared_ptr<Dataset> ds3 = GTZAN(folder_path, "valid"); | |||
| std::vector<std::string> column_names = {"waveform", "sample_rate", "label"}; | |||
| EXPECT_NE(ds1, nullptr); | |||
| EXPECT_EQ(ds1->GetDatasetSize(), 3); | |||
| EXPECT_EQ(ds1->GetColumnNames(), column_names); | |||
| EXPECT_NE(ds2, nullptr); | |||
| EXPECT_EQ(ds2->GetDatasetSize(), 3); | |||
| EXPECT_EQ(ds2->GetColumnNames(), column_names); | |||
| EXPECT_NE(ds3, nullptr); | |||
| EXPECT_EQ(ds3->GetDatasetSize(), 3); | |||
| EXPECT_EQ(ds3->GetColumnNames(), column_names); | |||
| } | |||
| /// Feature: GTZANDataset | |||
| /// Description: test GTZAN dataset with invalid usage | |||
| /// Expectation: dataset is null | |||
| TEST_F(MindDataTestPipeline, TestGTZANWithInvalidUsageError) { | |||
| MS_LOG(INFO) << "Doing MindDataTestPipeline-TestGTZANWithInvalidUsageError."; | |||
| std::string folder_path = datasets_root_path_ + "/testGTZANData"; | |||
| // Create a GTZAN Dataset. | |||
| std::shared_ptr<Dataset> ds1 = GTZAN(folder_path, "----"); | |||
| EXPECT_NE(ds1, nullptr); | |||
| // Create an iterator over the result of the above dataset. | |||
| std::shared_ptr<Iterator> iter1 = ds1->CreateIterator(); | |||
| EXPECT_EQ(iter1, nullptr); | |||
| std::shared_ptr<Dataset> ds2 = GTZAN(folder_path, "csacs"); | |||
| EXPECT_NE(ds2, nullptr); | |||
| // Create an iterator over the result of the above dataset. | |||
| std::shared_ptr<Iterator> iter2 = ds2->CreateIterator(); | |||
| EXPECT_EQ(iter2, nullptr); | |||
| } | |||
| /// Feature: GTZANDataset | |||
| /// Description: test GTZAN dataset with null sampler | |||
| /// Expectation: dataset is null | |||
| TEST_F(MindDataTestPipeline, TestGTZANWithNullSamplerError) { | |||
| MS_LOG(INFO) << "Doing MindDataTestPipeline-TestGTZANWithNullSamplerError."; | |||
| std::string folder_path = datasets_root_path_ + "/testGTZANData"; | |||
| // Create a GTZAN Dataset. | |||
| std::shared_ptr<Dataset> ds = GTZAN(folder_path, "all ", nullptr); | |||
| EXPECT_NE(ds, nullptr); | |||
| // Create an iterator over the result of the above dataset. | |||
| std::shared_ptr<Iterator> iter = ds->CreateIterator(); | |||
| // Expect failure: invalid GTZAN input, sampler cannot be nullptr. | |||
| EXPECT_EQ(iter, nullptr); | |||
| } | |||
| /// Feature: GTZANDataset | |||
| /// Description: test GTZAN with sequential sampler | |||
| /// Expectation: get correct GTZAN dataset | |||
| TEST_F(MindDataTestPipeline, TestGTZANNumSamplers) { | |||
| MS_LOG(INFO) << "Doing MindDataTestPipeline-TestGTZANWithSequentialSampler."; | |||
| std::string folder_path = datasets_root_path_ + "/testGTZANData"; | |||
| // Create a GTZAN Dataset. | |||
| std::shared_ptr<Dataset> ds = GTZAN(folder_path, "all", std::make_shared<SequentialSampler>(0, 2)); | |||
| EXPECT_NE(ds, nullptr); | |||
| // Create an iterator over the result of the above dataset. | |||
| // This will trigger the creation of the Execution Tree and launch it. | |||
| std::shared_ptr<Iterator> iter = ds->CreateIterator(); | |||
| EXPECT_NE(iter, nullptr); | |||
| // Iterate the dataset and get each row. | |||
| std::unordered_map<std::string, mindspore::MSTensor> row; | |||
| ASSERT_OK(iter->GetNextRow(&row)); | |||
| uint32_t rate = 0; | |||
| uint64_t i = 0; | |||
| while (row.size() != 0) { | |||
| auto waveform = row["waveform"]; | |||
| auto sample_rate = row["sample_rate"]; | |||
| MS_LOG(INFO) << "Tensor waveform shape: " << waveform.Shape(); | |||
| std::shared_ptr<Tensor> t_rate; | |||
| ASSERT_OK(Tensor::CreateFromMSTensor(sample_rate, &t_rate)); | |||
| ASSERT_OK(t_rate->GetItemAt<uint32_t>(&rate, {})); | |||
| EXPECT_EQ(rate, 22050); | |||
| MS_LOG(INFO) << "Tensor sample rate: " << rate; | |||
| i++; | |||
| ASSERT_OK(iter->GetNextRow(&row)); | |||
| } | |||
| EXPECT_EQ(i, 2); | |||
| iter->Stop(); | |||
| } | |||
| @@ -0,0 +1,223 @@ | |||
| # Copyright 2022 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License foNtest_resr the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================== | |||
| """ | |||
| Test Gtzan dataset operators. | |||
| """ | |||
| import numpy as np | |||
| import pytest | |||
| import mindspore.dataset as ds | |||
| from mindspore import log as logger | |||
| DATA_DIR = "../data/dataset/testGTZANData" | |||
| def test_gtzan_basic(): | |||
| """ | |||
| Feature: GTZANDataset | |||
| Description: test basic usage of GTZAN | |||
| Expectation: the dataset is as expected | |||
| """ | |||
| logger.info("Test GTZANDataset Op") | |||
| # case 1: test loading whole dataset. | |||
| data1 = ds.GTZANDataset(DATA_DIR) | |||
| num_iter1 = 0 | |||
| for _ in data1.create_dict_iterator(output_numpy=True, num_epochs=1): | |||
| num_iter1 += 1 | |||
| assert num_iter1 == 3 | |||
| # case 2: test num_samples. | |||
| data2 = ds.GTZANDataset(DATA_DIR, num_samples=2) | |||
| num_iter2 = 0 | |||
| for _ in data2.create_dict_iterator(output_numpy=True, num_epochs=1): | |||
| num_iter2 += 1 | |||
| assert num_iter2 == 2 | |||
| # case 3: test repeat. | |||
| data3 = ds.GTZANDataset(DATA_DIR, num_samples=2) | |||
| data3 = data3.repeat(5) | |||
| num_iter3 = 0 | |||
| for _ in data3.create_dict_iterator(output_numpy=True, num_epochs=1): | |||
| num_iter3 += 1 | |||
| assert num_iter3 == 10 | |||
| # case 4: test batch with drop_remainder=False. | |||
| data4 = ds.GTZANDataset(DATA_DIR, num_samples=3) | |||
| assert data4.get_dataset_size() == 3 | |||
| assert data4.get_batch_size() == 1 | |||
| data4 = data4.batch(batch_size=2) # drop_remainder is default to be False. | |||
| assert data4.get_dataset_size() == 2 | |||
| assert data4.get_batch_size() == 2 | |||
| # case 5: test batch with drop_remainder=True. | |||
| data5 = ds.GTZANDataset(DATA_DIR, num_samples=3) | |||
| assert data5.get_dataset_size() == 3 | |||
| assert data5.get_batch_size() == 1 | |||
| # the rest of incomplete batch will be dropped. | |||
| data5 = data5.batch(batch_size=2, drop_remainder=True) | |||
| assert data5.get_dataset_size() == 1 | |||
| assert data5.get_batch_size() == 2 | |||
| def test_gtzan_distribute_sampler(): | |||
| """ | |||
| Feature: GTZANDataset | |||
| Description: test GTZAN dataset with DistributedSampler | |||
| Expectation: the results are as expected | |||
| """ | |||
| logger.info("Test GTZAN with DistributedSampler") | |||
| label_list1, label_list2 = [], [] | |||
| num_shards = 3 | |||
| shard_id = 0 | |||
| data1 = ds.GTZANDataset(DATA_DIR, usage="all", num_shards=num_shards, shard_id=shard_id) | |||
| count = 0 | |||
| for item1 in data1.create_dict_iterator(output_numpy=True, num_epochs=1): | |||
| label_list1.append(item1["label"]) | |||
| count = count + 1 | |||
| assert count == 1 | |||
| num_shards = 3 | |||
| shard_id = 0 | |||
| sampler = ds.DistributedSampler(num_shards, shard_id) | |||
| data2 = ds.GTZANDataset(DATA_DIR, usage="all", sampler=sampler) | |||
| count = 0 | |||
| for item2 in data2.create_dict_iterator(output_numpy=True, num_epochs=1): | |||
| label_list2.append(item2["label"]) | |||
| count = count + 1 | |||
| np.testing.assert_array_equal(label_list1, label_list2) | |||
| assert count == 1 | |||
| def test_gtzan_exception(): | |||
| """ | |||
| Feature: GTZANDataset | |||
| Description: test error cases for GTZANDataset | |||
| Expectation: the results are as expected | |||
| """ | |||
| logger.info("Test error cases for GTZANDataset") | |||
| error_msg_1 = "sampler and shuffle cannot be specified at the same time" | |||
| with pytest.raises(RuntimeError, match=error_msg_1): | |||
| ds.GTZANDataset(DATA_DIR, shuffle=False, sampler=ds.PKSampler(3)) | |||
| error_msg_2 = "sampler and sharding cannot be specified at the same time" | |||
| with pytest.raises(RuntimeError, match=error_msg_2): | |||
| ds.GTZANDataset(DATA_DIR, sampler=ds.PKSampler(3), | |||
| num_shards=2, shard_id=0) | |||
| error_msg_3 = "num_shards is specified and currently requires shard_id as well" | |||
| with pytest.raises(RuntimeError, match=error_msg_3): | |||
| ds.GTZANDataset(DATA_DIR, num_shards=10) | |||
| error_msg_4 = "shard_id is specified but num_shards is not" | |||
| with pytest.raises(RuntimeError, match=error_msg_4): | |||
| ds.GTZANDataset(DATA_DIR, shard_id=0) | |||
| error_msg_5 = "Input shard_id is not within the required interval" | |||
| with pytest.raises(ValueError, match=error_msg_5): | |||
| ds.GTZANDataset(DATA_DIR, num_shards=5, shard_id=-1) | |||
| with pytest.raises(ValueError, match=error_msg_5): | |||
| ds.GTZANDataset(DATA_DIR, num_shards=5, shard_id=5) | |||
| with pytest.raises(ValueError, match=error_msg_5): | |||
| ds.GTZANDataset(DATA_DIR, num_shards=2, shard_id=5) | |||
| error_msg_6 = "num_parallel_workers exceeds" | |||
| with pytest.raises(ValueError, match=error_msg_6): | |||
| ds.GTZANDataset(DATA_DIR, shuffle=False, num_parallel_workers=0) | |||
| with pytest.raises(ValueError, match=error_msg_6): | |||
| ds.GTZANDataset(DATA_DIR, shuffle=False, num_parallel_workers=256) | |||
| with pytest.raises(ValueError, match=error_msg_6): | |||
| ds.GTZANDataset(DATA_DIR, shuffle=False, num_parallel_workers=-2) | |||
| error_msg_7 = "Argument shard_id" | |||
| with pytest.raises(TypeError, match=error_msg_7): | |||
| ds.GTZANDataset(DATA_DIR, num_shards=2, shard_id="0") | |||
| def exception_func(item): | |||
| raise Exception("Error occur!") | |||
| error_msg_8 = "The corresponding data files" | |||
| with pytest.raises(RuntimeError, match=error_msg_8): | |||
| data = ds.GTZANDataset(DATA_DIR) | |||
| data = data.map(operations=exception_func, input_columns=["waveform"], num_parallel_workers=1) | |||
| for _ in data.create_dict_iterator(output_numpy=True, num_epochs=1): | |||
| pass | |||
| def test_gtzan_sequential_sampler(): | |||
| """ | |||
| Feature: GTZANDataset | |||
| Description: test GTZANDataset with SequentialSampler | |||
| Expectation: the results are as expected | |||
| """ | |||
| logger.info("Test GTZANDataset Op with SequentialSampler") | |||
| num_samples = 2 | |||
| sampler = ds.SequentialSampler(num_samples=num_samples) | |||
| data1 = ds.GTZANDataset(DATA_DIR, sampler=sampler) | |||
| data2 = ds.GTZANDataset(DATA_DIR, shuffle=False, num_samples=num_samples) | |||
| label_list1, label_list2 = [], [] | |||
| num_iter = 0 | |||
| for item1, item2 in zip(data1.create_dict_iterator(output_numpy=True, num_epochs=1), | |||
| data2.create_dict_iterator(output_numpy=True, num_epochs=1)): | |||
| label_list1.append(item1["label"]) | |||
| label_list2.append(item2["label"]) | |||
| num_iter += 1 | |||
| np.testing.assert_array_equal(label_list1, label_list2) | |||
| assert num_iter == num_samples | |||
| def test_gtzan_usage(): | |||
| """ | |||
| Feature: GTZANDataset | |||
| Description: test GTZANDataset usage | |||
| Expectation: the results are as expected | |||
| """ | |||
| logger.info("Test GTZANDataset usage") | |||
| def test_config(usage, gtzan_path=None): | |||
| gtzan_path = DATA_DIR if gtzan_path is None else gtzan_path | |||
| try: | |||
| data = ds.GTZANDataset(gtzan_path, usage=usage, shuffle=False) | |||
| num_rows = 0 | |||
| for _ in data.create_dict_iterator(num_epochs=1, output_numpy=True): | |||
| num_rows += 1 | |||
| except (ValueError, TypeError, RuntimeError) as e: | |||
| return str(e) | |||
| return num_rows | |||
| assert test_config("valid") == 3 | |||
| assert test_config("all") == 3 | |||
| assert "usage is not within the valid set of ['train', 'valid', 'test', 'all']" in test_config("invalid") | |||
| assert "Argument usage with value ['list'] is not of type [<class 'str'>]" in test_config(["list"]) | |||
| # change this directory to the folder that contains all gtzan files. | |||
| all_files_path = None | |||
| # the following tests on the entire datasets. | |||
| if all_files_path is not None: | |||
| assert test_config("train", all_files_path) == 3 | |||
| assert test_config("valid", all_files_path) == 3 | |||
| assert ds.GTZANDataset(all_files_path, usage="train").get_dataset_size() == 3 | |||
| assert ds.GTZANDataset(all_files_path, usage="valid").get_dataset_size() == 3 | |||
| if __name__ == '__main__': | |||
| test_gtzan_basic() | |||
| test_gtzan_distribute_sampler() | |||
| test_gtzan_exception() | |||
| test_gtzan_sequential_sampler() | |||
| test_gtzan_usage() | |||