Browse Source

!22421 [assistant][ops] Add new dataset operator GTZANDataset

Merge pull request !22421 from TR-nbu/GTZANDataset
feature/build-system-rewrite
i-robot Gitee 4 years ago
parent
commit
10c8b83068
No known key found for this signature in database GPG Key ID: 173E9B9CA92EEF8F
20 changed files with 1421 additions and 2 deletions
  1. +22
    -0
      mindspore/ccsrc/minddata/dataset/api/datasets.cc
  2. +12
    -0
      mindspore/ccsrc/minddata/dataset/api/python/bindings/dataset/engine/ir/datasetops/source/bindings.cc
  3. +1
    -0
      mindspore/ccsrc/minddata/dataset/engine/datasetops/source/CMakeLists.txt
  4. +336
    -0
      mindspore/ccsrc/minddata/dataset/engine/datasetops/source/gtzan_op.cc
  5. +97
    -0
      mindspore/ccsrc/minddata/dataset/engine/datasetops/source/gtzan_op.h
  6. +1
    -0
      mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/dataset_node.h
  7. +1
    -0
      mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/CMakeLists.txt
  8. +110
    -0
      mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/gtzan_node.cc
  9. +95
    -0
      mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/gtzan_node.h
  10. +87
    -0
      mindspore/ccsrc/minddata/dataset/include/dataset/datasets.h
  11. +1
    -0
      mindspore/ccsrc/minddata/dataset/include/dataset/samplers.h
  12. +1
    -0
      mindspore/python/mindspore/dataset/engine/__init__.py
  13. +132
    -2
      mindspore/python/mindspore/dataset/engine/datasets_audio.py
  14. +30
    -0
      mindspore/python/mindspore/dataset/engine/validators.py
  15. +1
    -0
      tests/ut/cpp/dataset/CMakeLists.txt
  16. +271
    -0
      tests/ut/cpp/dataset/c_api_dataset_gtzan_test.cc
  17. BIN
      tests/ut/data/dataset/testGTZANData/blues/blues.00000.wav
  18. BIN
      tests/ut/data/dataset/testGTZANData/blues/blues.00001.wav
  19. BIN
      tests/ut/data/dataset/testGTZANData/blues/blues.00002.wav
  20. +223
    -0
      tests/ut/python/dataset/test_datasets_gtzan.py

+ 22
- 0
mindspore/ccsrc/minddata/dataset/api/datasets.cc View File

@@ -92,6 +92,7 @@
#include "minddata/dataset/engine/ir/datasetops/source/fake_image_node.h"
#include "minddata/dataset/engine/ir/datasetops/source/fashion_mnist_node.h"
#include "minddata/dataset/engine/ir/datasetops/source/flickr_node.h"
#include "minddata/dataset/engine/ir/datasetops/source/gtzan_node.h"
#include "minddata/dataset/engine/ir/datasetops/source/image_folder_node.h"
#include "minddata/dataset/engine/ir/datasetops/source/imdb_node.h"
#include "minddata/dataset/engine/ir/datasetops/source/iwslt2016_node.h"
@@ -1235,6 +1236,27 @@ FlickrDataset::FlickrDataset(const std::vector<char> &dataset_dir, const std::ve
ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
}

GTZANDataset::GTZANDataset(const std::vector<char> &dataset_dir, const std::vector<char> &usage,
const std::shared_ptr<Sampler> &sampler, const std::shared_ptr<DatasetCache> &cache) {
auto sampler_obj = sampler ? sampler->Parse() : nullptr;
auto ds = std::make_shared<GTZANNode>(CharToString(dataset_dir), CharToString(usage), sampler_obj, cache);
ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
}

GTZANDataset::GTZANDataset(const std::vector<char> &dataset_dir, const std::vector<char> &usage, const Sampler *sampler,
const std::shared_ptr<DatasetCache> &cache) {
auto sampler_obj = sampler ? sampler->Parse() : nullptr;
auto ds = std::make_shared<GTZANNode>(CharToString(dataset_dir), CharToString(usage), sampler_obj, cache);
ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
}

GTZANDataset::GTZANDataset(const std::vector<char> &dataset_dir, const std::vector<char> &usage,
const std::reference_wrapper<Sampler> &sampler, const std::shared_ptr<DatasetCache> &cache) {
auto sampler_obj = sampler.get().Parse();
auto ds = std::make_shared<GTZANNode>(CharToString(dataset_dir), CharToString(usage), sampler_obj, cache);
ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
}

ImageFolderDataset::ImageFolderDataset(const std::vector<char> &dataset_dir, bool decode,
const std::shared_ptr<Sampler> &sampler,
const std::set<std::vector<char>> &extensions,


+ 12
- 0
mindspore/ccsrc/minddata/dataset/api/python/bindings/dataset/engine/ir/datasetops/source/bindings.cc View File

@@ -43,6 +43,7 @@
#include "minddata/dataset/engine/ir/datasetops/source/fashion_mnist_node.h"
#include "minddata/dataset/engine/ir/datasetops/source/flickr_node.h"
#include "minddata/dataset/engine/ir/datasetops/source/generator_node.h"
#include "minddata/dataset/engine/ir/datasetops/source/gtzan_node.h"
#include "minddata/dataset/engine/ir/datasetops/source/image_folder_node.h"
#include "minddata/dataset/engine/ir/datasetops/source/imdb_node.h"
#include "minddata/dataset/engine/ir/datasetops/source/iwslt2016_node.h"
@@ -323,6 +324,17 @@ PYBIND_REGISTER(GeneratorNode, 2, ([](const py::module *m) {
}));
}));

PYBIND_REGISTER(GTZANNode, 2, ([](const py::module *m) {
(void)py::class_<GTZANNode, DatasetNode, std::shared_ptr<GTZANNode>>(*m, "GTZANNode",
"to create a GTZANNode")
.def(
py::init([](const std::string &dataset_dir, const std::string &usage, const py::handle &sampler) {
auto gtzan = std::make_shared<GTZANNode>(dataset_dir, usage, toSamplerObj(sampler), nullptr);
THROW_IF_ERROR(gtzan->ValidateParams());
return gtzan;
}));
}));

PYBIND_REGISTER(ImageFolderNode, 2, ([](const py::module *m) {
(void)py::class_<ImageFolderNode, DatasetNode, std::shared_ptr<ImageFolderNode>>(
*m, "ImageFolderNode", "to create an ImageFolderNode")


+ 1
- 0
mindspore/ccsrc/minddata/dataset/engine/datasetops/source/CMakeLists.txt View File

@@ -21,6 +21,7 @@ set(DATASET_ENGINE_DATASETOPS_SOURCE_SRC_FILES
fake_image_op.cc
fashion_mnist_op.cc
flickr_op.cc
gtzan_op.cc
image_folder_op.cc
imdb_op.cc
iwslt_op.cc


+ 336
- 0
mindspore/ccsrc/minddata/dataset/engine/datasetops/source/gtzan_op.cc View File

@@ -0,0 +1,336 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "minddata/dataset/engine/datasetops/source/gtzan_op.h"
#include <fstream>
#include <iomanip>
#include <set>
#include "minddata/dataset/audio/kernels/audio_utils.h"
#include "minddata/dataset/core/config_manager.h"
#include "minddata/dataset/core/tensor_shape.h"
#include "minddata/dataset/engine/datasetops/source/sampler/sequential_sampler.h"
#include "minddata/dataset/engine/execution_tree.h"
#include "utils/file_utils.h"
namespace mindspore {
namespace dataset {
const std::vector<std::string> genres = {
"blues", "classical", "country", "disco", "hiphop", "jazz", "metal", "pop", "reggae", "rock",
};
const std::vector<std::string> filtered_test = {
"blues.00012", "blues.00013", "blues.00014", "blues.00015", "blues.00016", "blues.00017",
"blues.00018", "blues.00019", "blues.00020", "blues.00021", "blues.00022", "blues.00023",
"blues.00024", "blues.00025", "blues.00026", "blues.00027", "blues.00028", "blues.00061",
"blues.00062", "blues.00063", "blues.00064", "blues.00065", "blues.00066", "blues.00067",
"blues.00068", "blues.00069", "blues.00070", "blues.00071", "blues.00072", "blues.00098",
"blues.00099", "classical.00011", "classical.00012", "classical.00013", "classical.00014", "classical.00015",
"classical.00016", "classical.00017", "classical.00018", "classical.00019", "classical.00020", "classical.00021",
"classical.00022", "classical.00023", "classical.00024", "classical.00025", "classical.00026", "classical.00027",
"classical.00028", "classical.00029", "classical.00034", "classical.00035", "classical.00036", "classical.00037",
"classical.00038", "classical.00039", "classical.00040", "classical.00041", "classical.00049", "classical.00077",
"classical.00078", "classical.00079", "country.00030", "country.00031", "country.00032", "country.00033",
"country.00034", "country.00035", "country.00036", "country.00037", "country.00038", "country.00039",
"country.00040", "country.00043", "country.00044", "country.00046", "country.00047", "country.00048",
"country.00050", "country.00051", "country.00053", "country.00054", "country.00055", "country.00056",
"country.00057", "country.00058", "country.00059", "country.00060", "country.00061", "country.00062",
"country.00063", "country.00064", "disco.00001", "disco.00021", "disco.00058", "disco.00062",
"disco.00063", "disco.00064", "disco.00065", "disco.00066", "disco.00069", "disco.00076",
"disco.00077", "disco.00078", "disco.00079", "disco.00080", "disco.00081", "disco.00082",
"disco.00083", "disco.00084", "disco.00085", "disco.00086", "disco.00087", "disco.00088",
"disco.00091", "disco.00092", "disco.00093", "disco.00094", "disco.00096", "disco.00097",
"disco.00099", "hiphop.00000", "hiphop.00026", "hiphop.00027", "hiphop.00030", "hiphop.00040",
"hiphop.00043", "hiphop.00044", "hiphop.00045", "hiphop.00051", "hiphop.00052", "hiphop.00053",
"hiphop.00054", "hiphop.00062", "hiphop.00063", "hiphop.00064", "hiphop.00065", "hiphop.00066",
"hiphop.00067", "hiphop.00068", "hiphop.00069", "hiphop.00070", "hiphop.00071", "hiphop.00072",
"hiphop.00073", "hiphop.00074", "hiphop.00075", "hiphop.00099", "jazz.00073", "jazz.00074",
"jazz.00075", "jazz.00076", "jazz.00077", "jazz.00078", "jazz.00079", "jazz.00080",
"jazz.00081", "jazz.00082", "jazz.00083", "jazz.00084", "jazz.00085", "jazz.00086",
"jazz.00087", "jazz.00088", "jazz.00089", "jazz.00090", "jazz.00091", "jazz.00092",
"jazz.00093", "jazz.00094", "jazz.00095", "jazz.00096", "jazz.00097", "jazz.00098",
"jazz.00099", "metal.00012", "metal.00013", "metal.00014", "metal.00015", "metal.00022",
"metal.00023", "metal.00025", "metal.00026", "metal.00027", "metal.00028", "metal.00029",
"metal.00030", "metal.00031", "metal.00032", "metal.00033", "metal.00038", "metal.00039",
"metal.00067", "metal.00070", "metal.00073", "metal.00074", "metal.00075", "metal.00078",
"metal.00083", "metal.00085", "metal.00087", "metal.00088", "pop.00000", "pop.00001",
"pop.00013", "pop.00014", "pop.00043", "pop.00063", "pop.00064", "pop.00065",
"pop.00066", "pop.00069", "pop.00070", "pop.00071", "pop.00072", "pop.00073",
"pop.00074", "pop.00075", "pop.00076", "pop.00077", "pop.00078", "pop.00079",
"pop.00082", "pop.00088", "pop.00089", "pop.00090", "pop.00091", "pop.00092",
"pop.00093", "pop.00094", "pop.00095", "pop.00096", "reggae.00034", "reggae.00035",
"reggae.00036", "reggae.00037", "reggae.00038", "reggae.00039", "reggae.00040", "reggae.00046",
"reggae.00047", "reggae.00048", "reggae.00052", "reggae.00053", "reggae.00064", "reggae.00065",
"reggae.00066", "reggae.00067", "reggae.00068", "reggae.00071", "reggae.00079", "reggae.00082",
"reggae.00083", "reggae.00084", "reggae.00087", "reggae.00088", "reggae.00089", "reggae.00090",
"rock.00010", "rock.00011", "rock.00012", "rock.00013", "rock.00014", "rock.00015",
"rock.00027", "rock.00028", "rock.00029", "rock.00030", "rock.00031", "rock.00032",
"rock.00033", "rock.00034", "rock.00035", "rock.00036", "rock.00037", "rock.00039",
"rock.00040", "rock.00041", "rock.00042", "rock.00043", "rock.00044", "rock.00045",
"rock.00046", "rock.00047", "rock.00048", "rock.00086", "rock.00087", "rock.00088",
"rock.00089", "rock.00090",
};
const std::vector<std::string> filtered_train = {
"blues.00029", "blues.00030", "blues.00031", "blues.00032", "blues.00033", "blues.00034",
"blues.00035", "blues.00036", "blues.00037", "blues.00038", "blues.00039", "blues.00040",
"blues.00041", "blues.00042", "blues.00043", "blues.00044", "blues.00045", "blues.00046",
"blues.00047", "blues.00048", "blues.00049", "blues.00073", "blues.00074", "blues.00075",
"blues.00076", "blues.00077", "blues.00078", "blues.00079", "blues.00080", "blues.00081",
"blues.00082", "blues.00083", "blues.00084", "blues.00085", "blues.00086", "blues.00087",
"blues.00088", "blues.00089", "blues.00090", "blues.00091", "blues.00092", "blues.00093",
"blues.00094", "blues.00095", "blues.00096", "blues.00097", "classical.00030", "classical.00031",
"classical.00032", "classical.00033", "classical.00043", "classical.00044", "classical.00045", "classical.00046",
"classical.00047", "classical.00048", "classical.00050", "classical.00051", "classical.00052", "classical.00053",
"classical.00054", "classical.00055", "classical.00056", "classical.00057", "classical.00058", "classical.00059",
"classical.00060", "classical.00061", "classical.00062", "classical.00063", "classical.00064", "classical.00065",
"classical.00066", "classical.00067", "classical.00080", "classical.00081", "classical.00082", "classical.00083",
"classical.00084", "classical.00085", "classical.00086", "classical.00087", "classical.00088", "classical.00089",
"classical.00090", "classical.00091", "classical.00092", "classical.00093", "classical.00094", "classical.00095",
"classical.00096", "classical.00097", "classical.00098", "classical.00099", "country.00019", "country.00020",
"country.00021", "country.00022", "country.00023", "country.00024", "country.00025", "country.00026",
"country.00028", "country.00029", "country.00065", "country.00066", "country.00067", "country.00068",
"country.00069", "country.00070", "country.00071", "country.00072", "country.00073", "country.00074",
"country.00075", "country.00076", "country.00077", "country.00078", "country.00079", "country.00080",
"country.00081", "country.00082", "country.00083", "country.00084", "country.00085", "country.00086",
"country.00087", "country.00088", "country.00089", "country.00090", "country.00091", "country.00092",
"country.00093", "country.00094", "country.00095", "country.00096", "country.00097", "country.00098",
"country.00099", "disco.00005", "disco.00015", "disco.00016", "disco.00017", "disco.00018",
"disco.00019", "disco.00020", "disco.00022", "disco.00023", "disco.00024", "disco.00025",
"disco.00026", "disco.00027", "disco.00028", "disco.00029", "disco.00030", "disco.00031",
"disco.00032", "disco.00033", "disco.00034", "disco.00035", "disco.00036", "disco.00037",
"disco.00039", "disco.00040", "disco.00041", "disco.00042", "disco.00043", "disco.00044",
"disco.00045", "disco.00047", "disco.00049", "disco.00053", "disco.00054", "disco.00056",
"disco.00057", "disco.00059", "disco.00061", "disco.00070", "disco.00073", "disco.00074",
"disco.00089", "hiphop.00002", "hiphop.00003", "hiphop.00004", "hiphop.00005", "hiphop.00006",
"hiphop.00007", "hiphop.00008", "hiphop.00009", "hiphop.00010", "hiphop.00011", "hiphop.00012",
"hiphop.00013", "hiphop.00014", "hiphop.00015", "hiphop.00016", "hiphop.00017", "hiphop.00018",
"hiphop.00019", "hiphop.00020", "hiphop.00021", "hiphop.00022", "hiphop.00023", "hiphop.00024",
"hiphop.00025", "hiphop.00028", "hiphop.00029", "hiphop.00031", "hiphop.00032", "hiphop.00033",
"hiphop.00034", "hiphop.00035", "hiphop.00036", "hiphop.00037", "hiphop.00038", "hiphop.00041",
"hiphop.00042", "hiphop.00055", "hiphop.00056", "hiphop.00057", "hiphop.00058", "hiphop.00059",
"hiphop.00060", "hiphop.00061", "hiphop.00077", "hiphop.00078", "hiphop.00079", "hiphop.00080",
"jazz.00000", "jazz.00001", "jazz.00011", "jazz.00012", "jazz.00013", "jazz.00014",
"jazz.00015", "jazz.00016", "jazz.00017", "jazz.00018", "jazz.00019", "jazz.00020",
"jazz.00021", "jazz.00022", "jazz.00023", "jazz.00024", "jazz.00041", "jazz.00047",
"jazz.00048", "jazz.00049", "jazz.00050", "jazz.00051", "jazz.00052", "jazz.00053",
"jazz.00054", "jazz.00055", "jazz.00056", "jazz.00057", "jazz.00058", "jazz.00059",
"jazz.00060", "jazz.00061", "jazz.00062", "jazz.00063", "jazz.00064", "jazz.00065",
"jazz.00066", "jazz.00067", "jazz.00068", "jazz.00069", "jazz.00070", "jazz.00071",
"jazz.00072", "metal.00002", "metal.00003", "metal.00005", "metal.00021", "metal.00024",
"metal.00035", "metal.00046", "metal.00047", "metal.00048", "metal.00049", "metal.00050",
"metal.00051", "metal.00052", "metal.00053", "metal.00054", "metal.00055", "metal.00056",
"metal.00057", "metal.00059", "metal.00060", "metal.00061", "metal.00062", "metal.00063",
"metal.00064", "metal.00065", "metal.00066", "metal.00069", "metal.00071", "metal.00072",
"metal.00079", "metal.00080", "metal.00084", "metal.00086", "metal.00089", "metal.00090",
"metal.00091", "metal.00092", "metal.00093", "metal.00094", "metal.00095", "metal.00096",
"metal.00097", "metal.00098", "metal.00099", "pop.00002", "pop.00003", "pop.00004",
"pop.00005", "pop.00006", "pop.00007", "pop.00008", "pop.00009", "pop.00011",
"pop.00012", "pop.00016", "pop.00017", "pop.00018", "pop.00019", "pop.00020",
"pop.00023", "pop.00024", "pop.00025", "pop.00026", "pop.00027", "pop.00028",
"pop.00029", "pop.00031", "pop.00032", "pop.00033", "pop.00034", "pop.00035",
"pop.00036", "pop.00038", "pop.00039", "pop.00040", "pop.00041", "pop.00042",
"pop.00044", "pop.00046", "pop.00049", "pop.00050", "pop.00080", "pop.00097",
"pop.00098", "pop.00099", "reggae.00000", "reggae.00001", "reggae.00002", "reggae.00004",
"reggae.00006", "reggae.00009", "reggae.00011", "reggae.00012", "reggae.00014", "reggae.00015",
"reggae.00016", "reggae.00017", "reggae.00018", "reggae.00019", "reggae.00020", "reggae.00021",
"reggae.00022", "reggae.00023", "reggae.00024", "reggae.00025", "reggae.00026", "reggae.00027",
"reggae.00028", "reggae.00029", "reggae.00030", "reggae.00031", "reggae.00032", "reggae.00042",
"reggae.00043", "reggae.00044", "reggae.00045", "reggae.00049", "reggae.00050", "reggae.00051",
"reggae.00054", "reggae.00055", "reggae.00056", "reggae.00057", "reggae.00058", "reggae.00059",
"reggae.00060", "reggae.00063", "reggae.00069", "rock.00000", "rock.00001", "rock.00002",
"rock.00003", "rock.00004", "rock.00005", "rock.00006", "rock.00007", "rock.00008",
"rock.00009", "rock.00016", "rock.00017", "rock.00018", "rock.00019", "rock.00020",
"rock.00021", "rock.00022", "rock.00023", "rock.00024", "rock.00025", "rock.00026",
"rock.00057", "rock.00058", "rock.00059", "rock.00060", "rock.00061", "rock.00062",
"rock.00063", "rock.00064", "rock.00065", "rock.00066", "rock.00067", "rock.00068",
"rock.00069", "rock.00070", "rock.00091", "rock.00092", "rock.00093", "rock.00094",
"rock.00095", "rock.00096", "rock.00097", "rock.00098", "rock.00099",
};
const std::vector<std::string> filtered_valid = {
"blues.00000", "blues.00001", "blues.00002", "blues.00003", "blues.00004", "blues.00005",
"blues.00006", "blues.00007", "blues.00008", "blues.00009", "blues.00010", "blues.00011",
"blues.00050", "blues.00051", "blues.00052", "blues.00053", "blues.00054", "blues.00055",
"blues.00056", "blues.00057", "blues.00058", "blues.00059", "blues.00060", "classical.00000",
"classical.00001", "classical.00002", "classical.00003", "classical.00004", "classical.00005", "classical.00006",
"classical.00007", "classical.00008", "classical.00009", "classical.00010", "classical.00068", "classical.00069",
"classical.00070", "classical.00071", "classical.00072", "classical.00073", "classical.00074", "classical.00075",
"classical.00076", "country.00000", "country.00001", "country.00002", "country.00003", "country.00004",
"country.00005", "country.00006", "country.00007", "country.00009", "country.00010", "country.00011",
"country.00012", "country.00013", "country.00014", "country.00015", "country.00016", "country.00017",
"country.00018", "country.00027", "country.00041", "country.00042", "country.00045", "country.00049",
"disco.00000", "disco.00002", "disco.00003", "disco.00004", "disco.00006", "disco.00007",
"disco.00008", "disco.00009", "disco.00010", "disco.00011", "disco.00012", "disco.00013",
"disco.00014", "disco.00046", "disco.00048", "disco.00052", "disco.00067", "disco.00068",
"disco.00072", "disco.00075", "disco.00090", "disco.00095", "hiphop.00081", "hiphop.00082",
"hiphop.00083", "hiphop.00084", "hiphop.00085", "hiphop.00086", "hiphop.00087", "hiphop.00088",
"hiphop.00089", "hiphop.00090", "hiphop.00091", "hiphop.00092", "hiphop.00093", "hiphop.00094",
"hiphop.00095", "hiphop.00096", "hiphop.00097", "hiphop.00098", "jazz.00002", "jazz.00003",
"jazz.00004", "jazz.00005", "jazz.00006", "jazz.00007", "jazz.00008", "jazz.00009",
"jazz.00010", "jazz.00025", "jazz.00026", "jazz.00027", "jazz.00028", "jazz.00029",
"jazz.00030", "jazz.00031", "jazz.00032", "metal.00000", "metal.00001", "metal.00006",
"metal.00007", "metal.00008", "metal.00009", "metal.00010", "metal.00011", "metal.00016",
"metal.00017", "metal.00018", "metal.00019", "metal.00020", "metal.00036", "metal.00037",
"metal.00068", "metal.00076", "metal.00077", "metal.00081", "metal.00082", "pop.00010",
"pop.00053", "pop.00055", "pop.00058", "pop.00059", "pop.00060", "pop.00061",
"pop.00062", "pop.00081", "pop.00083", "pop.00084", "pop.00085", "pop.00086",
"reggae.00061", "reggae.00062", "reggae.00070", "reggae.00072", "reggae.00074", "reggae.00076",
"reggae.00077", "reggae.00078", "reggae.00085", "reggae.00092", "reggae.00093", "reggae.00094",
"reggae.00095", "reggae.00096", "reggae.00097", "reggae.00098", "reggae.00099", "rock.00038",
"rock.00049", "rock.00050", "rock.00051", "rock.00052", "rock.00053", "rock.00054",
"rock.00055", "rock.00056", "rock.00071", "rock.00072", "rock.00073", "rock.00074",
"rock.00075", "rock.00076", "rock.00077", "rock.00078", "rock.00079", "rock.00080",
"rock.00081", "rock.00082", "rock.00083", "rock.00084", "rock.00085",
};
GTZANOp::GTZANOp(const std::string &usage, int32_t num_workers, const std::string &folder_path, int32_t queue_size,
std::unique_ptr<DataSchema> data_schema, std::shared_ptr<SamplerRT> sampler)
: MappableLeafOp(num_workers, queue_size, std::move(sampler)),
usage_(usage),
folder_path_(folder_path),
data_schema_(std::move(data_schema)) {}
Status GTZANOp::LoadTensorRow(row_id_type row_id, TensorRow *trow) {
RETURN_UNEXPECTED_IF_NULL(trow);
const uint32_t sample_rate = 22050;
std::shared_ptr<Tensor> waveform, rate, label;
RETURN_IF_NOT_OK(ReadAudio(audio_names_[row_id].first, &waveform));
RETURN_IF_NOT_OK(Tensor::CreateScalar(sample_rate, &rate));
RETURN_IF_NOT_OK(Tensor::CreateScalar(audio_names_[row_id].second, &label));
(*trow) = TensorRow(row_id, {std::move(waveform), std::move(rate), std::move(label)});
trow->setPath({audio_names_[row_id].first, audio_names_[row_id].first, audio_names_[row_id].first});
return Status::OK();
}
void GTZANOp::Print(std::ostream &out, bool show_all) const {
if (!show_all) {
ParallelOp::Print(out, show_all);
out << "\n";
return;
}
ParallelOp::Print(out, show_all);
out << "\nNumber of rows: " << num_rows_ << "\nGTZAN directory: " << folder_path_ << "\n\n";
}
Status GTZANOp::CountTotalRows(const std::string &dir, const std::string &usage, int64_t *count) {
RETURN_UNEXPECTED_IF_NULL(count);
*count = 0;
const int64_t num_samples = 0;
const int64_t start_index = 0;
auto sampler = std::make_shared<SequentialSamplerRT>(start_index, num_samples);
auto schema = std::make_unique<DataSchema>();
RETURN_IF_NOT_OK(schema->AddColumn(ColDescriptor("waveform", DataType(DataType::DE_FLOAT32), TensorImpl::kCv, 1)));
TensorShape scalar_rate = TensorShape::CreateScalar();
RETURN_IF_NOT_OK(schema->AddColumn(
ColDescriptor("sample_rate", DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 0, &scalar_rate)));
TensorShape scalar_label = TensorShape::CreateScalar();
RETURN_IF_NOT_OK(
schema->AddColumn(ColDescriptor("label", DataType(DataType::DE_STRING), TensorImpl::kFlexible, 0, &scalar_label)));
std::shared_ptr<ConfigManager> cfg = GlobalContext::config_manager();
int32_t num_workers = cfg->num_parallel_workers();
int32_t op_connect_size = cfg->op_connector_size();
auto op = std::make_shared<GTZANOp>(usage, num_workers, dir, op_connect_size, std::move(schema), std::move(sampler));
RETURN_IF_NOT_OK(op->PrepareData());
*count = op->audio_names_.size();
return Status::OK();
}
Status GTZANOp::ComputeColMap() {
if (column_name_id_map_.empty()) {
for (int32_t i = 0; i < data_schema_->NumColumns(); ++i) {
column_name_id_map_[data_schema_->Column(i).Name()] = i;
}
} else {
MS_LOG(WARNING) << "Column name map is already set!";
}
return Status::OK();
}
Status GTZANOp::ReadAudio(const std::string &audio_dir, std::shared_ptr<Tensor> *waveform) {
RETURN_UNEXPECTED_IF_NULL(waveform);
const int32_t kWavFileSampleRate = 22050;
int32_t sample_rate = 0;
std::vector<float> waveform_vec;
RETURN_IF_NOT_OK(ReadWaveFile(audio_dir, &waveform_vec, &sample_rate));
CHECK_FAIL_RETURN_UNEXPECTED(sample_rate == kWavFileSampleRate,
"Invalid file, sampling rate of GTZAN wav file must be 22050, file path: " + audio_dir);
RETURN_IF_NOT_OK(Tensor::CreateFromVector(waveform_vec, waveform));
RETURN_IF_NOT_OK((*waveform)->ExpandDim(0));
return Status::OK();
}
Status GTZANOp::PrepareData() {
auto realpath = FileUtils::GetRealPath(folder_path_.data());
if (!realpath.has_value()) {
MS_LOG(ERROR) << "Invalid file path, GTZAN Dataset dir: " << folder_path_ << " does not exist.";
RETURN_STATUS_UNEXPECTED("Invalid file path, GTZAN Dataset dir: " + folder_path_ + " does not exist.");
}
Path dir(folder_path_);
if (usage_ == "all") {
for (std::string sub_directory : genres) {
Path full_dir = dir / sub_directory;
if (!full_dir.Exists() || !full_dir.IsDirectory()) {
continue;
}
auto dir_it = Path::DirIterator::OpenDirectory(&full_dir);
if (dir_it != nullptr) {
while (dir_it->HasNext()) {
Path file = dir_it->Next();
std::string file_name = file.ToString();
auto pos = file_name.find_last_of('.');
std::string name = file_name.substr(0, pos), temp_ext = file_name.substr(pos);
if (temp_ext == ".wav" && name.find('.') != std::string::npos) {
audio_names_.push_back({file.ToString(), sub_directory});
} else {
MS_LOG(WARNING) << "Invalid file, invalid file name or file type: " << file.ToString() << ".";
}
}
} else {
MS_LOG(WARNING) << "Invalid file path, unable to open directory: " << full_dir.ToString() << ".";
}
}
} else {
const std::vector<std::string> *files_point = nullptr;
if (usage_ == "test") {
files_point = &filtered_test;
} else if (usage_ == "train") {
files_point = &filtered_train;
} else {
files_point = &filtered_valid;
}
std::string ext = ".wav";
for (auto sub_file_name : *files_point) {
auto pos = sub_file_name.find_first_of('.');
std::string cls = sub_file_name.substr(0, pos);
Path full_dir = dir / cls / (sub_file_name + ext);
if (full_dir.Exists()) {
audio_names_.push_back({full_dir.ToString(), cls});
} else {
MS_LOG(WARNING) << "The audio file is lost, file name= " << (sub_file_name + ext);
}
}
}
num_rows_ = audio_names_.size();
CHECK_FAIL_RETURN_UNEXPECTED(num_rows_ > 0, "Invalid data, no valid data found in path:" + folder_path_);
return Status::OK();
}
} // namespace dataset
} // namespace mindspore

+ 97
- 0
mindspore/ccsrc/minddata/dataset/engine/datasetops/source/gtzan_op.h View File

@@ -0,0 +1,97 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_SOURCE_GTZAN_OP_H_
#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_SOURCE_GTZAN_OP_H_
#include <algorithm>
#include <map>
#include <memory>
#include <string>
#include <utility>
#include <vector>
#include "minddata/dataset/core/tensor.h"
#include "minddata/dataset/engine/data_schema.h"
#include "minddata/dataset/engine/datasetops/parallel_op.h"
#include "minddata/dataset/engine/datasetops/source/mappable_leaf_op.h"
#include "minddata/dataset/engine/datasetops/source/sampler/sampler.h"
#include "minddata/dataset/util/path.h"
#include "minddata/dataset/util/queue.h"
#include "minddata/dataset/util/status.h"
#include "minddata/dataset/util/wait_post.h"
namespace mindspore {
namespace dataset {
class GTZANOp : public MappableLeafOp {
public:
/// \brief Constructor
/// \param[in] usage Usage of this dataset, can be 'train', 'valid', 'test', or 'all'.
/// \param[in] num_workers Number of workers reading audios in parallel.
/// \param[in] folder_path Dir directory of GTZAN.
/// \param[in] queue_size Connector queue size.
/// \param[in] data_schema The schema of the GTZAN dataset.
/// \param[in] sampler Sampler tells GTZANOp what to read.
GTZANOp(const std::string &usage, int32_t num_workers, const std::string &folder_path, int32_t queue_size,
std::unique_ptr<DataSchema> data_schema, std::shared_ptr<SamplerRT> sampler);
/// \Destructor.
~GTZANOp() = default;
/// \A print method typically used for debugging.
/// \param[out] out Output stream.
/// \param[in] show_all Whether to show all information.
void Print(std::ostream &out, bool show_all) const override;
/// \Function to count the number of samples in the GTZAN dataset.
/// \param[in] dir Path to the GTZAN directory.
/// \param[in] usage Choose the subset of GTZAN dataset.
/// \param[out] count Output arg that will hold the actual dataset size.
/// \return Status The status code returned.
static Status CountTotalRows(const std::string &dir, const std::string &usage, int64_t *count);
/// \Op name getter.
/// \return Name of the current Op.
std::string Name() const override { return "GTZANOp"; }
private:
/// \Load a tensor row according to a pair.
/// \param[in] row_id Id for this tensor row.
/// \param[out] row Audio & label read into this tensor row.
/// \return Status The status code returned.
Status LoadTensorRow(row_id_type row_id, TensorRow *row) override;
/// \Parse a audio file.
/// \param[in] audio_dir Audio file path.
/// \param[out] waveform The output waveform tensor.
/// \return Status The status code returned.
Status ReadAudio(const std::string &audio_dir, std::shared_ptr<Tensor> *waveform);
/// \Prepare data.
/// \return Status The status code returned.
Status PrepareData();
/// \Private function for computing the assignment of the column name map.
/// \return Status The status code returned.
Status ComputeColMap() override;
const std::string usage_;
std::string folder_path_;
std::unique_ptr<DataSchema> data_schema_;
std::vector<std::pair<std::string, std::string>> audio_names_;
};
} // namespace dataset
} // namespace mindspore
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_SOURCE_GTZAN_OP_H_

+ 1
- 0
mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/dataset_node.h View File

@@ -96,6 +96,7 @@ constexpr char kFakeImageNode[] = "FakeImageDataset";
constexpr char kFashionMnistNode[] = "FashionMnistDataset";
constexpr char kFlickrNode[] = "FlickrDataset";
constexpr char kGeneratorNode[] = "GeneratorDataset";
constexpr char kGTZANNode[] = "GTZANDataset";
constexpr char kImageFolderNode[] = "ImageFolderDataset";
constexpr char kIMDBNode[] = "IMDBDataset";
constexpr char kIWSLT2016Node[] = "IWSLT2016Dataset";


+ 1
- 0
mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/CMakeLists.txt View File

@@ -22,6 +22,7 @@ set(DATASET_ENGINE_IR_DATASETOPS_SOURCE_SRC_FILES
fake_image_node.cc
fashion_mnist_node.cc
flickr_node.cc
gtzan_node.cc
image_folder_node.cc
imdb_node.cc
iwslt2016_node.cc


+ 110
- 0
mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/gtzan_node.cc View File

@@ -0,0 +1,110 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "minddata/dataset/engine/ir/datasetops/source/gtzan_node.h"
#include <memory>
#include <string>
#include <utility>
#include <vector>
#include "minddata/dataset/engine/datasetops/source/gtzan_op.h"
#include "minddata/dataset/util/status.h"
namespace mindspore {
namespace dataset {
GTZANNode::GTZANNode(const std::string &dataset_dir, const std::string &usage, std::shared_ptr<SamplerObj> sampler,
std::shared_ptr<DatasetCache> cache)
: MappableSourceNode(std::move(cache)), dataset_dir_(dataset_dir), usage_(usage), sampler_(sampler) {}
void GTZANNode::Print(std::ostream &out) const { out << Name(); }
std::shared_ptr<DatasetNode> GTZANNode::Copy() {
std::shared_ptr<SamplerObj> sampler = (sampler_ == nullptr) ? nullptr : sampler_->SamplerCopy();
auto node = std::make_shared<GTZANNode>(dataset_dir_, usage_, sampler, cache_);
return node;
}
Status GTZANNode::ValidateParams() {
RETURN_IF_NOT_OK(DatasetNode::ValidateParams());
RETURN_IF_NOT_OK(ValidateDatasetDirParam("GTZANDataset", dataset_dir_));
RETURN_IF_NOT_OK(ValidateDatasetSampler("GTZANDataset", sampler_));
RETURN_IF_NOT_OK(ValidateStringValue("GTZANDataset", usage_, {"train", "valid", "test", "all"}));
return Status::OK();
}
Status GTZANNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops) {
// Do internal Schema generation.
auto schema = std::make_unique<DataSchema>();
RETURN_IF_NOT_OK(schema->AddColumn(ColDescriptor("waveform", DataType(DataType::DE_FLOAT64), TensorImpl::kCv, 1)));
TensorShape scalar_rate = TensorShape::CreateScalar();
RETURN_IF_NOT_OK(schema->AddColumn(
ColDescriptor("sample_rate", DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 0, &scalar_rate)));
TensorShape scalar_label = TensorShape::CreateScalar();
RETURN_IF_NOT_OK(
schema->AddColumn(ColDescriptor("label", DataType(DataType::DE_STRING), TensorImpl::kFlexible, 0, &scalar_label)));
std::shared_ptr<SamplerRT> sampler_rt = nullptr;
RETURN_IF_NOT_OK(sampler_->SamplerBuild(&sampler_rt));
auto op = std::make_shared<GTZANOp>(usage_, num_workers_, dataset_dir_, connector_que_size_, std::move(schema),
std::move(sampler_rt));
op->SetTotalRepeats(GetTotalRepeats());
op->SetNumRepeatsPerEpoch(GetNumRepeatsPerEpoch());
node_ops->push_back(op);
return Status::OK();
}
// Get the shard id of node.
Status GTZANNode::GetShardId(int32_t *shard_id) {
*shard_id = sampler_->ShardId();
return Status::OK();
}
// Get Dataset size.
Status GTZANNode::GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size_getter, bool estimate,
int64_t *dataset_size) {
if (dataset_size_ > 0) {
*dataset_size = dataset_size_;
return Status::OK();
}
int64_t num_rows, sample_size;
RETURN_IF_NOT_OK(GTZANOp::CountTotalRows(dataset_dir_, usage_, &num_rows));
std::shared_ptr<SamplerRT> sampler_rt = nullptr;
RETURN_IF_NOT_OK(sampler_->SamplerBuild(&sampler_rt));
sample_size = sampler_rt->CalculateNumSamples(num_rows);
if (sample_size == -1) {
RETURN_IF_NOT_OK(size_getter->DryRun(shared_from_this(), &sample_size));
}
*dataset_size = sample_size;
dataset_size_ = *dataset_size;
return Status::OK();
}
Status GTZANNode::to_json(nlohmann::json *out_json) {
nlohmann::json args, sampler_args;
RETURN_IF_NOT_OK(sampler_->to_json(&sampler_args));
args["sampler"] = sampler_args;
args["num_parallel_workers"] = num_workers_;
args["dataset_dir"] = dataset_dir_;
args["usage"] = usage_;
if (cache_ != nullptr) {
nlohmann::json cache_args;
RETURN_IF_NOT_OK(cache_->to_json(&cache_args));
args["cache"] = cache_args;
}
*out_json = args;
return Status::OK();
}
} // namespace dataset
} // namespace mindspore

+ 95
- 0
mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/gtzan_node.h View File

@@ -0,0 +1,95 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_GTZAN_NODE_H_
#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_GTZAN_NODE_H_
#include <memory>
#include <string>
#include <vector>
#include "minddata/dataset/engine/ir/datasetops/dataset_node.h"
namespace mindspore {
namespace dataset {
class GTZANNode : public MappableSourceNode {
public:
/// \brief Constructor
GTZANNode(const std::string &dataset_dir, const std::string &usage, std::shared_ptr<SamplerObj> sampler,
std::shared_ptr<DatasetCache> cache);
/// \brief Destructor
~GTZANNode() = default;
/// \brief Node name getter.
/// \return Name of the current node.
std::string Name() const override { return "kGTZANNode"; }
/// \brief Print the description.
/// \param out The output stream to write output to.
void Print(std::ostream &out) const override;
/// \brief Copy the node to a new object.
/// \return A shared pointer to the new copy.
std::shared_ptr<DatasetNode> Copy() override;
/// \brief a base class override function to create the required runtime dataset op objects for this class.
/// \param node_ops A vector containing shared pointer to the Dataset Ops that this object will create.
/// \return Status Status::OK() if build successfully.
Status Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops) override;
/// \brief Parameters validation.
/// \return Status Status::OK() if all the parameters are valid.
Status ValidateParams() override;
/// \brief Get the shard id of node.
/// \param[in] shard_id The shard ID within num_shards.
/// \return Status Status::OK() if get shard id successfully.
Status GetShardId(int32_t *shard_id) override;
/// \brief Base-class override for GetDatasetSize.
/// \param[in] size_getter Shared pointer to DatasetSizeGetter.
/// \param[in] estimate This is only supported by some of the ops and it's used to speed up the process of getting
/// dataset size at the expense of accuracy.
/// \param[out] dataset_size the size of the dataset.
/// \return Status of the function.
Status GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size_getter, bool estimate,
int64_t *dataset_size) override;
/// \brief Getter functions.
const std::string &DatasetDir() const { return dataset_dir_; }
const std::string &Usage() const { return usage_; }
/// \brief Get the arguments of node.
/// \param[out] out_json JSON string of all attributes.
/// \return Status of the function.
Status to_json(nlohmann::json *out_json) override;
/// \brief Sampler getter.
/// \return SamplerObj of the current node.
std::shared_ptr<SamplerObj> Sampler() override { return sampler_; }
/// \brief Sampler setter.
/// \param[in] sampler Tells GTZANOp what to read.
void SetSampler(std::shared_ptr<SamplerObj> sampler) override { sampler_ = sampler; }
private:
std::string dataset_dir_;
std::string usage_;
std::shared_ptr<SamplerObj> sampler_;
};
} // namespace dataset
} // namespace mindspore
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_GTZAN_NODE_H_

+ 87
- 0
mindspore/ccsrc/minddata/dataset/include/dataset/datasets.h View File

@@ -2647,6 +2647,93 @@ inline std::shared_ptr<FlickrDataset> MS_API Flickr(const std::string &dataset_d
cache);
}

/// \class GTZANDataset
/// \brief A source dataset for reading and parsing GTZAN dataset.
class MS_API GTZANDataset : public Dataset {
public:
/// \brief Constructor of GTZANDataset.
/// \param[in] dataset_dir Path to the root directory that contains the dataset.
/// \param[in] usage Part of dataset of GTZAN, can be "train", "valid", "test", or "all" (default = "all").
/// \param[in] sampler Shared pointer to a sampler object used to choose samples from the dataset.
/// \param[in] cache Tensor cache to use.
GTZANDataset(const std::vector<char> &dataset_dir, const std::vector<char> &usage,
const std::shared_ptr<Sampler> &sampler, const std::shared_ptr<DatasetCache> &cache);

/// \brief Constructor of GTZANDataset.
/// \param[in] dataset_dir Path to the root directory that contains the dataset.
/// \param[in] usage Part of dataset of GTZAN, can be "train", "valid", "test", or "all".
/// \param[in] sampler Raw pointer to a sampler object used to choose samples from the dataset.
/// \param[in] cache Tensor cache to use.
GTZANDataset(const std::vector<char> &dataset_dir, const std::vector<char> &usage, const Sampler *sampler,
const std::shared_ptr<DatasetCache> &cache);

/// \brief Constructor of GTZANDataset.
/// \param[in] dataset_dir Path to the root directory that contains the dataset.
/// \param[in] usage Part of dataset of GTZAN, can be "train", "valid", "test", or "all".
/// \param[in] sampler Sampler object used to choose samples from the dataset.
/// \param[in] cache Tensor cache to use.
GTZANDataset(const std::vector<char> &dataset_dir, const std::vector<char> &usage,
const std::reference_wrapper<Sampler> &sampler, const std::shared_ptr<DatasetCache> &cache);

/// \brief Destructor of GTZANDataset.
~GTZANDataset() = default;
};

/// \brief Function to create a GTZANDataset.
/// \note The generated dataset has three columns ["waveform", "sample_rate", "label"].
/// \param[in] dataset_dir Path to the root directory that contains the dataset.
/// \param[in] usage Part of dataset of GTZAN, can be "train", "valid", "test", or "all" (default = "all").
/// \param[in] sampler Shared pointer to a sampler object used to choose samples from the dataset. If sampler is not
/// given, a `RandomSampler` will be used to randomly iterate the entire dataset (default = RandomSampler()).
/// \param[in] cache Tensor cache to use (default=nullptr, which means no cache is used).
/// \return Shared pointer to the GTZANDataset.
/// \code
/// /* Define dataset path and MindData object */
/// std::string folder_path = "/path/to/gtzan_dataset_directory";
/// std::shared_ptr<Dataset> ds =
/// GTZANDataset(folder_path, usage = "all", std::make_shared<RandomSampler>(false, 10));
///
/// /* Create iterator to read dataset */
/// std::shared_ptr<Iterator> iter = ds->CreateIterator();
/// std::unordered_map<std::string, mindspore::MSTensor> row;
/// iter->GetNextRow(&row);
///
/// /* Note: In GTZAN dataset, each data dictionary has keys "waveform", "sample_rate" and "label" */
/// auto waveform = row["waveform"];
/// \endcode
inline std::shared_ptr<GTZANDataset> MS_API
GTZAN(const std::string &dataset_dir, const std::string &usage = "all",
const std::shared_ptr<Sampler> &sampler = std::make_shared<RandomSampler>(),
const std::shared_ptr<DatasetCache> &cache = nullptr) {
return std::make_shared<GTZANDataset>(StringToChar(dataset_dir), StringToChar(usage), sampler, cache);
}

/// \brief Function to create a GTZANDataset.
/// \note The generated dataset has three columns ["waveform", "sample_rate", "label"].
/// \param[in] dataset_dir Path to the root directory that contains the dataset.
/// \param[in] usage Part of dataset of GTZAN, can be "train", "valid", "test", or "all".
/// \param[in] sampler Raw pointer to a sampler object used to choose samples from the dataset.
/// \param[in] cache Tensor cache to use (default=nullptr, which means no cache is used).
/// \return Shared pointer to the GTZANDataset.
inline std::shared_ptr<GTZANDataset> MS_API GTZAN(const std::string &dataset_dir, const std::string &usage,
const Sampler *sampler,
const std::shared_ptr<DatasetCache> &cache = nullptr) {
return std::make_shared<GTZANDataset>(StringToChar(dataset_dir), StringToChar(usage), sampler, cache);
}

/// \brief Function to create a GTZANDataset.
/// \note The generated dataset has three columns ["waveform", "sample_rate", "label"].
/// \param[in] dataset_dir Path to the root directory that contains the dataset.
/// \param[in] usage Part of dataset of GTZAN, can be "train", "valid", "test", or "all".
/// \param[in] sampler Raw pointer to a sampler object used to choose samples from the dataset.
/// \param[in] cache Tensor cache to use (default=nullptr, which means no cache is used).
/// \return Shared pointer to the GTZANDataset.
inline std::shared_ptr<GTZANDataset> MS_API GTZAN(const std::string &dataset_dir, const std::string &usage,
const std::reference_wrapper<Sampler> sampler,
const std::shared_ptr<DatasetCache> &cache = nullptr) {
return std::make_shared<GTZANDataset>(StringToChar(dataset_dir), StringToChar(usage), sampler, cache);
}

/// \class ImageFolderDataset
/// \brief A source dataset that reads images from a tree of directories.
class MS_API ImageFolderDataset : public Dataset {


+ 1
- 0
mindspore/ccsrc/minddata/dataset/include/dataset/samplers.h View File

@@ -45,6 +45,7 @@ class MS_API Sampler : std::enable_shared_from_this<Sampler> {
friend class FakeImageDataset;
friend class FashionMnistDataset;
friend class FlickrDataset;
friend class GTZANDataset;
friend class ImageFolderDataset;
friend class IMDBDataset;
friend class KMnistDataset;


+ 1
- 0
mindspore/python/mindspore/dataset/engine/__init__.py View File

@@ -81,6 +81,7 @@ __all__ = ["Caltech101Dataset", # Vision
"WikiTextDataset", # Text
"YahooAnswersDataset", # Text
"YelpReviewDataset", # Text
"GTZANDataset", # Audio
"LJSpeechDataset", # Audio
"SpeechCommandsDataset", # Audio
"TedliumDataset", # Audio


+ 132
- 2
mindspore/python/mindspore/dataset/engine/datasets_audio.py View File

@@ -26,12 +26,142 @@ After declaring the dataset object, you can further apply dataset operations
import mindspore._c_dataengine as cde

from .datasets import AudioBaseDataset, MappableDataset
from .validators import check_lj_speech_dataset, check_yes_no_dataset, check_speech_commands_dataset, \
check_tedlium_dataset
from .validators import check_gtzan_dataset, check_lj_speech_dataset, check_speech_commands_dataset, check_tedlium_dataset, \
check_yes_no_dataset

from ..core.validator_helpers import replace_none


class GTZANDataset(MappableDataset, AudioBaseDataset):
"""
A source dataset that reads and parses GTZAN dataset.

The generated dataset has three columns: :py:obj:`["waveform", "sample_rate", "label"]`.
The tensor of column :py:obj:`waveform` is of the float32 type.
The tensor of column :py:obj:`sample_rate` is of a scalar of uint32 type.
The tensor of column :py:obj:`label` is of a scalar of string type.

Args:
dataset_dir (str): Path to the root directory that contains the dataset.
usage (str, optional): Usage of this dataset, can be "train", "valid", "test" or "all"
(default=None, all samples).
num_samples (int, optional): The number of audio to be included in the dataset
(default=None, will read all audio).
num_parallel_workers (int, optional): Number of workers to read the data
(default=None, will use value set in the config).
shuffle (bool, optional): Whether or not to perform shuffle on the dataset
(default=None, expected order behavior shown in the table).
sampler (Sampler, optional): Object used to choose samples from the
dataset (default=None, expected order behavior shown in the table).
num_shards (int, optional): Number of shards that the dataset will be divided into (default=None).
When this argument is specified, `num_samples` reflects the max sample number of per shard.
shard_id (int, optional): The shard ID within `num_shards` (default=None). This
argument can only be specified when `num_shards` is also specified.
cache (DatasetCache, optional): Use tensor caching service to speed up dataset processing
(default=None, which means no cache is used).

Raises:
RuntimeError: If source raises an exception during execution.
RuntimeError: If dataset_dir does not contain data files.
RuntimeError: If num_parallel_workers exceeds the max thread numbers.
RuntimeError: If sampler and shuffle are specified at the same time.
RuntimeError: If sampler and sharding are specified at the same time.
RuntimeError: If num_shards is specified but shard_id is None.
RuntimeError: If shard_id is specified but num_shards is None.
ValueError: If shard_id is invalid (< 0 or >= num_shards).

Note:
- GTZAN doesn't support PKSampler.
- This dataset can take in a `sampler`. `sampler` and `shuffle` are mutually exclusive.
The table below shows what input arguments are allowed and their expected behavior.

.. list-table:: Expected Order Behavior of Using 'sampler' and 'shuffle'
:widths: 25 25 50
:header-rows: 1

* - Parameter `sampler`
- Parameter `shuffle`
- Expected Order Behavior
* - None
- None
- random order
* - None
- True
- random order
* - None
- False
- sequential order
* - Sampler object
- None
- order defined by sampler
* - Sampler object
- True
- not allowed
* - Sampler object
- False
- not allowed

Examples:
>>> gtzan_dataset_directory = "/path/to/gtzan_dataset_directory"
>>>
>>> # 1) Read 500 samples (audio files) in gtzan_dataset_directory
>>> dataset = ds.GTZANDataset(gtzan_dataset_directory, usage="all", num_samples=500)
>>>
>>> # 2) Read all samples (audio files) in gtzan_dataset_directory
>>> dataset = ds.GTZANDataset(gtzan_dataset_directory)

About GTZAN dataset:

The GTZAN dataset appears in at least 100 published works and is the most commonly used
public dataset for evaluation in machine listening research for music genre recognition.
It consists of 1000 audio tracks, each of which is 30 seconds long. It contains 10 genres (blues,
classical, country, disco, hiphop, jazz, metal, pop, reggae and reggae), each of which is
represented by 100 tracks. The tracks are all 22050Hz Mono 16-bit audio files in .wav format.

You can construct the following directory structure from GTZAN dataset and read by MindSpore's API.

.. code-block::

.
└── gtzan_dataset_directory
├── blues
│ ├──blues.00000.wav
│ ├──blues.00001.wav
│ ├──blues.00002.wav
│ ├──...
├── disco
│ ├──disco.00000.wav
│ ├──disco.00001.wav
│ ├──disco.00002.wav
│ └──...
└──...

Citation:

.. code-block::

@misc{tzanetakis_essl_cook_2001,
author = "Tzanetakis, George and Essl, Georg and Cook, Perry",
title = "Automatic Musical Genre Classification Of Audio Signals",
url = "http://ismir2001.ismir.net/pdf/tzanetakis.pdf",
publisher = "The International Society for Music Information Retrieval",
year = "2001"
}
"""

@check_gtzan_dataset
def __init__(self, dataset_dir, usage=None, num_samples=None, num_parallel_workers=None, shuffle=None,
sampler=None, num_shards=None, shard_id=None, cache=None):
super().__init__(num_parallel_workers=num_parallel_workers, sampler=sampler, num_samples=num_samples,
shuffle=shuffle, num_shards=num_shards, shard_id=shard_id, cache=cache)

self.dataset_dir = dataset_dir
self.usage = replace_none(usage, "all")

def parse(self, children=None):
return cde.GTZANNode(self.dataset_dir, self.usage, self.sampler)


class LJSpeechDataset(MappableDataset, AudioBaseDataset):
"""
A source dataset that reads and parses LJSpeech dataset.


+ 30
- 0
mindspore/python/mindspore/dataset/engine/validators.py View File

@@ -34,6 +34,36 @@ from . import samplers
from . import cache_client


def check_gtzan_dataset(method):
"""A wrapper that wraps a parameter checker around the original GTZANDataset."""

@wraps(method)
def new_method(self, *args, **kwargs):
_, param_dict = parse_user_args(method, *args, **kwargs)

nreq_param_int = ['num_samples', 'num_parallel_workers', 'num_shards', 'shard_id']
nreq_param_bool = ['shuffle']

dataset_dir = param_dict.get('dataset_dir')
check_dir(dataset_dir)

usage = param_dict.get('usage')
if usage is not None:
check_valid_str(usage, ['train', 'valid', 'test', 'all'], "usage")

validate_dataset_param_value(nreq_param_int, param_dict, int)
validate_dataset_param_value(nreq_param_bool, param_dict, bool)

check_sampler_shuffle_shard_options(param_dict)

cache = param_dict.get('cache')
check_cache_option(cache)

return method(self, *args, **kwargs)

return new_method


def check_imagefolderdataset(method):
"""A wrapper that wraps a parameter checker around the original Dataset(ImageFolderDataset)."""



+ 1
- 0
tests/ut/cpp/dataset/CMakeLists.txt View File

@@ -32,6 +32,7 @@ SET(DE_UT_SRCS
c_api_dataset_fake_image_test.cc
c_api_dataset_fashion_mnist_test.cc
c_api_dataset_flickr_test.cc
c_api_dataset_gtzan_test.cc
c_api_dataset_imdb_test.cc
c_api_dataset_iterator_test.cc
c_api_dataset_iwslt_test.cc


+ 271
- 0
tests/ut/cpp/dataset/c_api_dataset_gtzan_test.cc View File

@@ -0,0 +1,271 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "common/common.h"
#include "include/dataset/datasets.h"
#include "include/dataset/transforms.h"
using namespace mindspore::dataset;
using mindspore::dataset::Tensor;
class MindDataTestPipeline : public UT::DatasetOpTesting {
protected:
};
/// Feature: GTZANDataset
/// Description: test GTZAN
/// Expectation: get correct GTZAN dataset
TEST_F(MindDataTestPipeline, TestGTZANBasic) {
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestGTZANBasic.";
std::string file_path = datasets_root_path_ + "/testGTZANData";
// Create a GTZAN Dataset
std::shared_ptr<Dataset> ds = GTZAN(file_path);
EXPECT_NE(ds, nullptr);
// Create an iterator over the result of the above dataset.
// This will trigger the creation of the Execution Tree and launch it.
std::shared_ptr<Iterator> iter = ds->CreateIterator();
EXPECT_NE(iter, nullptr);
// Iterate the dataset and get each row.
std::unordered_map<std::string, mindspore::MSTensor> row;
ASSERT_OK(iter->GetNextRow(&row));
std::string_view label_idx;
uint32_t rate = 0;
uint64_t i = 0;
while (row.size() != 0) {
i++;
auto waveform = row["waveform"];
auto label = row["label"];
auto sample_rate = row["sample_rate"];
MS_LOG(INFO) << "Tensor waveform shape: " << waveform.Shape();
std::shared_ptr<Tensor> trate;
ASSERT_OK(Tensor::CreateFromMSTensor(sample_rate, &trate));
ASSERT_OK(trate->GetItemAt<uint32_t>(&rate, {}));
EXPECT_EQ(rate, 22050);
MS_LOG(INFO) << "Tensor label rate: " << rate;
std::shared_ptr<Tensor> de_label;
ASSERT_OK(Tensor::CreateFromMSTensor(label, &de_label));
ASSERT_OK(de_label->GetItemAt(&label_idx, {}));
std::string s_label(label_idx);
std::string expected("blues");
EXPECT_STREQ(s_label.c_str(), expected.c_str());
MS_LOG(INFO) << "Tensor label value: " << label_idx;
ASSERT_OK(iter->GetNextRow(&row));
}
EXPECT_EQ(i, 3);
// Manually terminate the pipeline
iter->Stop();
}
/// Feature: GTZANDataset
/// Description: test GTZAN with Pipeline
/// Expectation: get correct GTZAN dataset
TEST_F(MindDataTestPipeline, TestGTZANBasicWithPipeline) {
MS_LOG(INFO) << "Doing DataSetOpBatchTest-TestGTZANBasicWithPipeline.";
// Create a GTZANDataset Dataset.
std::string folder_path = datasets_root_path_ + "/testGTZANData";
std::shared_ptr<Dataset> ds = GTZAN(folder_path, "all", std::make_shared<RandomSampler>(false, 2));
EXPECT_NE(ds, nullptr);
auto op = transforms::PadEnd({1, 50000});
std::vector<std::string> input_columns = {"waveform"};
std::vector<std::string> output_columns = {"waveform"};
std::vector<std::string> project_columns = {"label", "waveform", "sample_rate"};
ds = ds->Map({op}, input_columns, output_columns, project_columns);
EXPECT_NE(ds, nullptr);
ds = ds->Repeat(10);
EXPECT_NE(ds, nullptr);
ds = ds->Batch(5);
EXPECT_NE(ds, nullptr);
// Create an iterator over the result of the above dataset.
// This will trigger the creation of the Execution Tree and launch it.
std::shared_ptr<Iterator> iter = ds->CreateIterator();
EXPECT_NE(iter, nullptr);
// Iterate the dataset and get each row.
std::unordered_map<std::string, mindspore::MSTensor> row;
iter->GetNextRow(&row);
std::vector<uint32_t> expected_rate = {22050, 22050, 22050, 22050, 22050};
std::vector<std::string> expected_label = {"blues", "blues", "blues", "blues", "blues"};
uint64_t i = 0;
while (row.size() != 0) {
i++;
auto waveform = row["waveform"];
auto label = row["label"];
auto sample_rate = row["sample_rate"];
std::shared_ptr<Tensor> de_expected_rate;
ASSERT_OK(Tensor::CreateFromVector(expected_rate, &de_expected_rate));
mindspore::MSTensor fix_expected_rate =
mindspore::MSTensor(std::make_shared<mindspore::dataset::DETensor>(de_expected_rate));
EXPECT_MSTENSOR_EQ(sample_rate, fix_expected_rate);
std::shared_ptr<Tensor> de_expected_label;
ASSERT_OK(Tensor::CreateFromVector(expected_label, &de_expected_label));
mindspore::MSTensor fix_expected_label =
mindspore::MSTensor(std::make_shared<mindspore::dataset::DETensor>(de_expected_label));
EXPECT_MSTENSOR_EQ(label, fix_expected_label);
ASSERT_OK(iter->GetNextRow(&row));
}
EXPECT_EQ(i, 4);
// Manually terminate the pipeline.
iter->Stop();
}
/// Feature: GTZANDataset
/// Description: test GTZAN with invalid directory
/// Expectation: get correct GTZAN dataset
TEST_F(MindDataTestPipeline, TestGTZANError) {
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestGTZANError.";
// Create a GTZAN Dataset with non-existing dataset dir.
std::shared_ptr<Dataset> ds0 = GTZAN("NotExistFile");
EXPECT_NE(ds0, nullptr);
// Create an iterator over the result of the above dataset.
std::shared_ptr<Iterator> iter0 = ds0->CreateIterator();
// Expect failure: invalid GTZAN30k input.
EXPECT_EQ(iter0, nullptr);
// Create a GTZAN Dataset with invalid string of dataset dir.
std::shared_ptr<Dataset> ds1 = GTZAN(":*?\"<>|`&;'");
EXPECT_NE(ds1, nullptr);
// Create an iterator over the result of the above dataset.
std::shared_ptr<Iterator> iter1 = ds1->CreateIterator();
// Expect failure: invalid GTZAN input.
EXPECT_EQ(iter1, nullptr);
}
/// Feature: GTZANDataset
/// Description: test GTZAN with Getters
/// Expectation: dataset is null
TEST_F(MindDataTestPipeline, TestGTZANGetters) {
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestGTZANGetters.";
std::string folder_path = datasets_root_path_ + "/testGTZANData";
// Create a GTZAN Dataset.
std::shared_ptr<Dataset> ds1 = GTZAN(folder_path);
std::shared_ptr<Dataset> ds2 = GTZAN(folder_path, "all");
std::shared_ptr<Dataset> ds3 = GTZAN(folder_path, "valid");
std::vector<std::string> column_names = {"waveform", "sample_rate", "label"};
EXPECT_NE(ds1, nullptr);
EXPECT_EQ(ds1->GetDatasetSize(), 3);
EXPECT_EQ(ds1->GetColumnNames(), column_names);
EXPECT_NE(ds2, nullptr);
EXPECT_EQ(ds2->GetDatasetSize(), 3);
EXPECT_EQ(ds2->GetColumnNames(), column_names);
EXPECT_NE(ds3, nullptr);
EXPECT_EQ(ds3->GetDatasetSize(), 3);
EXPECT_EQ(ds3->GetColumnNames(), column_names);
}
/// Feature: GTZANDataset
/// Description: test GTZAN dataset with invalid usage
/// Expectation: dataset is null
TEST_F(MindDataTestPipeline, TestGTZANWithInvalidUsageError) {
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestGTZANWithInvalidUsageError.";
std::string folder_path = datasets_root_path_ + "/testGTZANData";
// Create a GTZAN Dataset.
std::shared_ptr<Dataset> ds1 = GTZAN(folder_path, "----");
EXPECT_NE(ds1, nullptr);
// Create an iterator over the result of the above dataset.
std::shared_ptr<Iterator> iter1 = ds1->CreateIterator();
EXPECT_EQ(iter1, nullptr);
std::shared_ptr<Dataset> ds2 = GTZAN(folder_path, "csacs");
EXPECT_NE(ds2, nullptr);
// Create an iterator over the result of the above dataset.
std::shared_ptr<Iterator> iter2 = ds2->CreateIterator();
EXPECT_EQ(iter2, nullptr);
}
/// Feature: GTZANDataset
/// Description: test GTZAN dataset with null sampler
/// Expectation: dataset is null
TEST_F(MindDataTestPipeline, TestGTZANWithNullSamplerError) {
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestGTZANWithNullSamplerError.";
std::string folder_path = datasets_root_path_ + "/testGTZANData";
// Create a GTZAN Dataset.
std::shared_ptr<Dataset> ds = GTZAN(folder_path, "all ", nullptr);
EXPECT_NE(ds, nullptr);
// Create an iterator over the result of the above dataset.
std::shared_ptr<Iterator> iter = ds->CreateIterator();
// Expect failure: invalid GTZAN input, sampler cannot be nullptr.
EXPECT_EQ(iter, nullptr);
}
/// Feature: GTZANDataset
/// Description: test GTZAN with sequential sampler
/// Expectation: get correct GTZAN dataset
TEST_F(MindDataTestPipeline, TestGTZANNumSamplers) {
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestGTZANWithSequentialSampler.";
std::string folder_path = datasets_root_path_ + "/testGTZANData";
// Create a GTZAN Dataset.
std::shared_ptr<Dataset> ds = GTZAN(folder_path, "all", std::make_shared<SequentialSampler>(0, 2));
EXPECT_NE(ds, nullptr);
// Create an iterator over the result of the above dataset.
// This will trigger the creation of the Execution Tree and launch it.
std::shared_ptr<Iterator> iter = ds->CreateIterator();
EXPECT_NE(iter, nullptr);
// Iterate the dataset and get each row.
std::unordered_map<std::string, mindspore::MSTensor> row;
ASSERT_OK(iter->GetNextRow(&row));
uint32_t rate = 0;
uint64_t i = 0;
while (row.size() != 0) {
auto waveform = row["waveform"];
auto sample_rate = row["sample_rate"];
MS_LOG(INFO) << "Tensor waveform shape: " << waveform.Shape();
std::shared_ptr<Tensor> t_rate;
ASSERT_OK(Tensor::CreateFromMSTensor(sample_rate, &t_rate));
ASSERT_OK(t_rate->GetItemAt<uint32_t>(&rate, {}));
EXPECT_EQ(rate, 22050);
MS_LOG(INFO) << "Tensor sample rate: " << rate;
i++;
ASSERT_OK(iter->GetNextRow(&row));
}
EXPECT_EQ(i, 2);
iter->Stop();
}

BIN
tests/ut/data/dataset/testGTZANData/blues/blues.00000.wav View File


BIN
tests/ut/data/dataset/testGTZANData/blues/blues.00001.wav View File


BIN
tests/ut/data/dataset/testGTZANData/blues/blues.00002.wav View File


+ 223
- 0
tests/ut/python/dataset/test_datasets_gtzan.py View File

@@ -0,0 +1,223 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License foNtest_resr the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""
Test Gtzan dataset operators.
"""
import numpy as np
import pytest
import mindspore.dataset as ds
from mindspore import log as logger
DATA_DIR = "../data/dataset/testGTZANData"
def test_gtzan_basic():
"""
Feature: GTZANDataset
Description: test basic usage of GTZAN
Expectation: the dataset is as expected
"""
logger.info("Test GTZANDataset Op")
# case 1: test loading whole dataset.
data1 = ds.GTZANDataset(DATA_DIR)
num_iter1 = 0
for _ in data1.create_dict_iterator(output_numpy=True, num_epochs=1):
num_iter1 += 1
assert num_iter1 == 3
# case 2: test num_samples.
data2 = ds.GTZANDataset(DATA_DIR, num_samples=2)
num_iter2 = 0
for _ in data2.create_dict_iterator(output_numpy=True, num_epochs=1):
num_iter2 += 1
assert num_iter2 == 2
# case 3: test repeat.
data3 = ds.GTZANDataset(DATA_DIR, num_samples=2)
data3 = data3.repeat(5)
num_iter3 = 0
for _ in data3.create_dict_iterator(output_numpy=True, num_epochs=1):
num_iter3 += 1
assert num_iter3 == 10
# case 4: test batch with drop_remainder=False.
data4 = ds.GTZANDataset(DATA_DIR, num_samples=3)
assert data4.get_dataset_size() == 3
assert data4.get_batch_size() == 1
data4 = data4.batch(batch_size=2) # drop_remainder is default to be False.
assert data4.get_dataset_size() == 2
assert data4.get_batch_size() == 2
# case 5: test batch with drop_remainder=True.
data5 = ds.GTZANDataset(DATA_DIR, num_samples=3)
assert data5.get_dataset_size() == 3
assert data5.get_batch_size() == 1
# the rest of incomplete batch will be dropped.
data5 = data5.batch(batch_size=2, drop_remainder=True)
assert data5.get_dataset_size() == 1
assert data5.get_batch_size() == 2
def test_gtzan_distribute_sampler():
"""
Feature: GTZANDataset
Description: test GTZAN dataset with DistributedSampler
Expectation: the results are as expected
"""
logger.info("Test GTZAN with DistributedSampler")
label_list1, label_list2 = [], []
num_shards = 3
shard_id = 0
data1 = ds.GTZANDataset(DATA_DIR, usage="all", num_shards=num_shards, shard_id=shard_id)
count = 0
for item1 in data1.create_dict_iterator(output_numpy=True, num_epochs=1):
label_list1.append(item1["label"])
count = count + 1
assert count == 1
num_shards = 3
shard_id = 0
sampler = ds.DistributedSampler(num_shards, shard_id)
data2 = ds.GTZANDataset(DATA_DIR, usage="all", sampler=sampler)
count = 0
for item2 in data2.create_dict_iterator(output_numpy=True, num_epochs=1):
label_list2.append(item2["label"])
count = count + 1
np.testing.assert_array_equal(label_list1, label_list2)
assert count == 1
def test_gtzan_exception():
"""
Feature: GTZANDataset
Description: test error cases for GTZANDataset
Expectation: the results are as expected
"""
logger.info("Test error cases for GTZANDataset")
error_msg_1 = "sampler and shuffle cannot be specified at the same time"
with pytest.raises(RuntimeError, match=error_msg_1):
ds.GTZANDataset(DATA_DIR, shuffle=False, sampler=ds.PKSampler(3))
error_msg_2 = "sampler and sharding cannot be specified at the same time"
with pytest.raises(RuntimeError, match=error_msg_2):
ds.GTZANDataset(DATA_DIR, sampler=ds.PKSampler(3),
num_shards=2, shard_id=0)
error_msg_3 = "num_shards is specified and currently requires shard_id as well"
with pytest.raises(RuntimeError, match=error_msg_3):
ds.GTZANDataset(DATA_DIR, num_shards=10)
error_msg_4 = "shard_id is specified but num_shards is not"
with pytest.raises(RuntimeError, match=error_msg_4):
ds.GTZANDataset(DATA_DIR, shard_id=0)
error_msg_5 = "Input shard_id is not within the required interval"
with pytest.raises(ValueError, match=error_msg_5):
ds.GTZANDataset(DATA_DIR, num_shards=5, shard_id=-1)
with pytest.raises(ValueError, match=error_msg_5):
ds.GTZANDataset(DATA_DIR, num_shards=5, shard_id=5)
with pytest.raises(ValueError, match=error_msg_5):
ds.GTZANDataset(DATA_DIR, num_shards=2, shard_id=5)
error_msg_6 = "num_parallel_workers exceeds"
with pytest.raises(ValueError, match=error_msg_6):
ds.GTZANDataset(DATA_DIR, shuffle=False, num_parallel_workers=0)
with pytest.raises(ValueError, match=error_msg_6):
ds.GTZANDataset(DATA_DIR, shuffle=False, num_parallel_workers=256)
with pytest.raises(ValueError, match=error_msg_6):
ds.GTZANDataset(DATA_DIR, shuffle=False, num_parallel_workers=-2)
error_msg_7 = "Argument shard_id"
with pytest.raises(TypeError, match=error_msg_7):
ds.GTZANDataset(DATA_DIR, num_shards=2, shard_id="0")
def exception_func(item):
raise Exception("Error occur!")
error_msg_8 = "The corresponding data files"
with pytest.raises(RuntimeError, match=error_msg_8):
data = ds.GTZANDataset(DATA_DIR)
data = data.map(operations=exception_func, input_columns=["waveform"], num_parallel_workers=1)
for _ in data.create_dict_iterator(output_numpy=True, num_epochs=1):
pass
def test_gtzan_sequential_sampler():
"""
Feature: GTZANDataset
Description: test GTZANDataset with SequentialSampler
Expectation: the results are as expected
"""
logger.info("Test GTZANDataset Op with SequentialSampler")
num_samples = 2
sampler = ds.SequentialSampler(num_samples=num_samples)
data1 = ds.GTZANDataset(DATA_DIR, sampler=sampler)
data2 = ds.GTZANDataset(DATA_DIR, shuffle=False, num_samples=num_samples)
label_list1, label_list2 = [], []
num_iter = 0
for item1, item2 in zip(data1.create_dict_iterator(output_numpy=True, num_epochs=1),
data2.create_dict_iterator(output_numpy=True, num_epochs=1)):
label_list1.append(item1["label"])
label_list2.append(item2["label"])
num_iter += 1
np.testing.assert_array_equal(label_list1, label_list2)
assert num_iter == num_samples
def test_gtzan_usage():
"""
Feature: GTZANDataset
Description: test GTZANDataset usage
Expectation: the results are as expected
"""
logger.info("Test GTZANDataset usage")
def test_config(usage, gtzan_path=None):
gtzan_path = DATA_DIR if gtzan_path is None else gtzan_path
try:
data = ds.GTZANDataset(gtzan_path, usage=usage, shuffle=False)
num_rows = 0
for _ in data.create_dict_iterator(num_epochs=1, output_numpy=True):
num_rows += 1
except (ValueError, TypeError, RuntimeError) as e:
return str(e)
return num_rows
assert test_config("valid") == 3
assert test_config("all") == 3
assert "usage is not within the valid set of ['train', 'valid', 'test', 'all']" in test_config("invalid")
assert "Argument usage with value ['list'] is not of type [<class 'str'>]" in test_config(["list"])
# change this directory to the folder that contains all gtzan files.
all_files_path = None
# the following tests on the entire datasets.
if all_files_path is not None:
assert test_config("train", all_files_path) == 3
assert test_config("valid", all_files_path) == 3
assert ds.GTZANDataset(all_files_path, usage="train").get_dataset_size() == 3
assert ds.GTZANDataset(all_files_path, usage="valid").get_dataset_size() == 3
if __name__ == '__main__':
test_gtzan_basic()
test_gtzan_distribute_sampler()
test_gtzan_exception()
test_gtzan_sequential_sampler()
test_gtzan_usage()

Loading…
Cancel
Save