Browse Source

!18598 [assistant][MulLawDecoding]

Merge pull request !18598 from QingfengLi/mulawdecoding
tags/v1.5.0-rc1
i-robot Gitee 4 years ago
parent
commit
b14297bd06
17 changed files with 508 additions and 14 deletions
  1. +13
    -0
      mindspore/ccsrc/minddata/dataset/api/audio.cc
  2. +12
    -0
      mindspore/ccsrc/minddata/dataset/api/python/bindings/dataset/audio/kernels/ir/bindings.cc
  3. +1
    -0
      mindspore/ccsrc/minddata/dataset/audio/ir/kernels/CMakeLists.txt
  4. +52
    -0
      mindspore/ccsrc/minddata/dataset/audio/ir/kernels/mu_law_decoding_ir.cc
  5. +54
    -0
      mindspore/ccsrc/minddata/dataset/audio/ir/kernels/mu_law_decoding_ir.h
  6. +1
    -0
      mindspore/ccsrc/minddata/dataset/audio/kernels/CMakeLists.txt
  7. +43
    -0
      mindspore/ccsrc/minddata/dataset/audio/kernels/audio_utils.cc
  8. +7
    -0
      mindspore/ccsrc/minddata/dataset/audio/kernels/audio_utils.h
  9. +54
    -0
      mindspore/ccsrc/minddata/dataset/audio/kernels/mu_law_decoding_op.cc
  10. +47
    -0
      mindspore/ccsrc/minddata/dataset/audio/kernels/mu_law_decoding_op.h
  11. +21
    -0
      mindspore/ccsrc/minddata/dataset/include/dataset/audio.h
  12. +1
    -0
      mindspore/ccsrc/minddata/dataset/kernels/tensor_op.h
  13. +24
    -1
      mindspore/dataset/audio/transforms.py
  14. +12
    -0
      mindspore/dataset/audio/validators.py
  15. +60
    -2
      tests/ut/cpp/dataset/c_api_audio_a_to_q_test.cc
  16. +24
    -11
      tests/ut/cpp/dataset/execute_test.cc
  17. +82
    -0
      tests/ut/python/dataset/test_mu_law_decoding_op.py

+ 13
- 0
mindspore/ccsrc/minddata/dataset/api/audio.cc View File

@@ -29,6 +29,7 @@
#include "minddata/dataset/audio/ir/kernels/frequency_masking_ir.h"
#include "minddata/dataset/audio/ir/kernels/highpass_biquad_ir.h"
#include "minddata/dataset/audio/ir/kernels/lowpass_biquad_ir.h"
#include "minddata/dataset/audio/ir/kernels/mu_law_decoding_ir.h"
#include "minddata/dataset/audio/ir/kernels/time_masking_ir.h"
#include "minddata/dataset/audio/ir/kernels/time_stretch_ir.h"
@@ -226,6 +227,18 @@ std::shared_ptr<TensorOperation> LowpassBiquad::Parse() {
return std::make_shared<LowpassBiquadOperation>(data_->sample_rate_, data_->cutoff_freq_, data_->Q_);
}
// MuLawDecoding Transform Operation.
struct MuLawDecoding::Data {
explicit Data(int quantization_channels) : quantization_channels_(quantization_channels) {}
int quantization_channels_;
};
MuLawDecoding::MuLawDecoding(int quantization_channels) : data_(std::make_shared<Data>(quantization_channels)) {}
std::shared_ptr<TensorOperation> MuLawDecoding::Parse() {
return std::make_shared<MuLawDecodingOperation>(data_->quantization_channels_);
}
// TimeMasking Transform Operation.
struct TimeMasking::Data {
Data(bool iid_masks, int32_t time_mask_param, int32_t mask_start, float mask_value)


+ 12
- 0
mindspore/ccsrc/minddata/dataset/api/python/bindings/dataset/audio/kernels/ir/bindings.cc View File

@@ -33,6 +33,7 @@
#include "minddata/dataset/audio/ir/kernels/frequency_masking_ir.h"
#include "minddata/dataset/audio/ir/kernels/highpass_biquad_ir.h"
#include "minddata/dataset/audio/ir/kernels/lowpass_biquad_ir.h"
#include "minddata/dataset/audio/ir/kernels/mu_law_decoding_ir.h"
#include "minddata/dataset/audio/ir/kernels/time_masking_ir.h"
#include "minddata/dataset/audio/ir/kernels/time_stretch_ir.h"
@@ -191,6 +192,17 @@ PYBIND_REGISTER(
}));
}));
PYBIND_REGISTER(
MuLawDecodingOperation, 1, ([](const py::module *m) {
(void)py::class_<audio::MuLawDecodingOperation, TensorOperation, std::shared_ptr<audio::MuLawDecodingOperation>>(
*m, "MuLawDecodingOperation")
.def(py::init([](int quantization_channels) {
auto mu_law_decoding = std::make_shared<audio::MuLawDecodingOperation>(quantization_channels);
THROW_IF_ERROR(mu_law_decoding->ValidateParams());
return mu_law_decoding;
}));
}));
PYBIND_REGISTER(
TimeMaskingOperation, 1, ([](const py::module *m) {
(void)py::class_<audio::TimeMaskingOperation, TensorOperation, std::shared_ptr<audio::TimeMaskingOperation>>(


+ 1
- 0
mindspore/ccsrc/minddata/dataset/audio/ir/kernels/CMakeLists.txt View File

@@ -15,6 +15,7 @@ add_library(audio-ir-kernels OBJECT
frequency_masking_ir.cc
highpass_biquad_ir.cc
lowpass_biquad_ir.cc
mu_law_decoding_ir.cc
time_masking_ir.cc
time_stretch_ir.cc
)


+ 52
- 0
mindspore/ccsrc/minddata/dataset/audio/ir/kernels/mu_law_decoding_ir.cc View File

@@ -0,0 +1,52 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "minddata/dataset/audio/ir/kernels/mu_law_decoding_ir.h"

#include "minddata/dataset/audio/ir/validators.h"
#include "minddata/dataset/audio/kernels/mu_law_decoding_op.h"

namespace mindspore {
namespace dataset {

namespace audio {

MuLawDecodingOperation::MuLawDecodingOperation(int quantization_channels)
: quantization_channels_(quantization_channels) {}

MuLawDecodingOperation::~MuLawDecodingOperation() = default;

Status MuLawDecodingOperation::ValidateParams() {
RETURN_IF_NOT_OK(ValidateIntScalarPositive("MuLawEncoding", "quantization_channels", quantization_channels_));
return Status::OK();
}

Status MuLawDecodingOperation::to_json(nlohmann::json *out_json) {
nlohmann::json args;
args["quantization_channels"] = quantization_channels_;
*out_json = args;
return Status::OK();
}

std::shared_ptr<TensorOp> MuLawDecodingOperation::Build() {
std::shared_ptr<MuLawDecodingOp> tensor_op = std::make_shared<MuLawDecodingOp>(quantization_channels_);
return tensor_op;
}

std::string MuLawDecodingOperation::Name() const { return kMuLawDecodingOperation; }

} // namespace audio
} // namespace dataset
} // namespace mindspore

+ 54
- 0
mindspore/ccsrc/minddata/dataset/audio/ir/kernels/mu_law_decoding_ir.h View File

@@ -0,0 +1,54 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_AUDIO_IR_KERNELS_MU_LAW_DECODING_IR_H_
#define MINDSPORE_CCSRC_MINDDATA_DATASET_AUDIO_IR_KERNELS_MU_LAW_DECODING_IR_H_

#include <memory>
#include <string>
#include <vector>

#include "include/api/status.h"
#include "minddata/dataset/kernels/ir/tensor_operation.h"

namespace mindspore {
namespace dataset {
namespace audio {

constexpr char kMuLawDecodingOperation[] = "MuLawDecoding";

class MuLawDecodingOperation : public TensorOperation {
public:
explicit MuLawDecodingOperation(int quantization_channels);

~MuLawDecodingOperation();

std::shared_ptr<TensorOp> Build() override;

Status ValidateParams() override;

std::string Name() const override;

Status to_json(nlohmann::json *out_json) override;

private:
int quantization_channels_;
}; // class MuLawDecodingOperation

} // namespace audio
} // namespace dataset
} // namespace mindspore

#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_AUDIO_IR_KERNELS_MU_LAW_DECODING_IR_H_

+ 1
- 0
mindspore/ccsrc/minddata/dataset/audio/kernels/CMakeLists.txt View File

@@ -16,6 +16,7 @@ add_library(audio-kernels OBJECT
frequency_masking_op.cc
highpass_biquad_op.cc
lowpass_biquad_op.cc
mu_law_decoding_op.cc
time_masking_op.cc
time_stretch_op.cc
)

+ 43
- 0
mindspore/ccsrc/minddata/dataset/audio/kernels/audio_utils.cc View File

@@ -466,5 +466,48 @@ Status ComplexNorm(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor>
RETURN_STATUS_UNEXPECTED("ComplexNorm: " + std::string(e.what()));
}
}

template <typename T>
float sgn(T val) {
return (static_cast<T>(0) < val) - (val < static_cast<T>(0));
}

template <typename T>
Status Decoding(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output, T mu) {
RETURN_IF_NOT_OK(Tensor::CreateEmpty(input->shape(), input->type(), output));
auto itr_out = (*output)->begin<T>();
auto itr = input->begin<T>();
auto end = input->end<T>();

while (itr != end) {
auto x_mu = *itr;
x_mu = ((x_mu) / mu) * 2 - 1.0;
x_mu = sgn(x_mu) * expm1(fabs(x_mu) * log1p(mu)) / mu;
*itr_out = x_mu;
++itr_out;
++itr;
}
return Status::OK();
}

Status MuLawDecoding(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output, int quantization_channels) {
if (input->type().value() >= DataType::DE_INT8 && input->type().value() <= DataType::DE_FLOAT32) {
float f_mu = static_cast<float>(quantization_channels) - 1;

// convert the data type to float
std::shared_ptr<Tensor> input_tensor;
RETURN_IF_NOT_OK(TypeCast(input, &input_tensor, DataType(DataType::DE_FLOAT32)));

RETURN_IF_NOT_OK(Decoding<float>(input_tensor, output, f_mu));
} else if (input->type().value() == DataType::DE_FLOAT64) {
double f_mu = static_cast<double>(quantization_channels) - 1;

RETURN_IF_NOT_OK(Decoding<double>(input, output, f_mu));
} else {
RETURN_STATUS_UNEXPECTED("MuLawDecoding: input tensor type should be int, float or double, but got: " +
input->type().ToString());
}
return Status::OK();
}
} // namespace dataset
} // namespace mindspore

+ 7
- 0
mindspore/ccsrc/minddata/dataset/audio/kernels/audio_utils.h View File

@@ -276,6 +276,13 @@ Status MaskAlongAxis(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tenso
/// \return Status code.
Status ComplexNorm(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output, float power);

/// \brief Decode mu-law encoded signal.
/// \param input Tensor of shape <..., time>.
/// \param output Tensor of shape <..., time>.
/// \param quantization_channels Number of channels.
/// \return Status code.
Status MuLawDecoding(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output, int quantization_channels);

} // namespace dataset
} // namespace mindspore
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_AUDIO_KERNELS_AUDIO_UTILS_H_

+ 54
- 0
mindspore/ccsrc/minddata/dataset/audio/kernels/mu_law_decoding_op.cc View File

@@ -0,0 +1,54 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "minddata/dataset/audio/kernels/mu_law_decoding_op.h"

#include "minddata/dataset/audio/kernels/audio_utils.h"

namespace mindspore {
namespace dataset {

// constructor
MuLawDecodingOp::MuLawDecodingOp(int quantization_channels) : quantization_channels_(quantization_channels) {}

// main function
Status MuLawDecodingOp::Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) {
IO_CHECK(input, output);

CHECK_FAIL_RETURN_UNEXPECTED(input->Rank() >= 1, "MuLawDecoding: input tensor is not in shape of <..., time>.");

if (input->type().value() >= DataType::DE_INT8 && input->type().value() <= DataType::DE_FLOAT64) {
return MuLawDecoding(input, output, quantization_channels_);
} else {
RETURN_STATUS_UNEXPECTED("MuLawDecoding: input tensor type should be int, float or double, but got: " +
input->type().ToString());
}
}

Status MuLawDecodingOp::OutputType(const std::vector<DataType> &inputs, std::vector<DataType> &outputs) {
RETURN_IF_NOT_OK(TensorOp::OutputType(inputs, outputs));
if (inputs[0] == DataType(DataType::DE_FLOAT64)) {
outputs[0] = DataType(DataType::DE_FLOAT64);
} else if (inputs[0] >= DataType(DataType::DE_INT8) || inputs[0] <= DataType(DataType::DE_FLOAT32)) {
outputs[0] = DataType(DataType::DE_FLOAT32);
} else {
RETURN_STATUS_UNEXPECTED("MuLawDecoding: input tensor type should be int, float or double, but got: " +
inputs[0].ToString());
}
return Status::OK();
}

} // namespace dataset
} // namespace mindspore

+ 47
- 0
mindspore/ccsrc/minddata/dataset/audio/kernels/mu_law_decoding_op.h View File

@@ -0,0 +1,47 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_AUDIO_KERNELS_MU_LAW_DECODING_OP_H_
#define MINDSPORE_CCSRC_MINDDATA_DATASET_AUDIO_KERNELS_MU_LAW_DECODING_OP_H_

#include <memory>
#include <string>
#include <vector>

#include "minddata/dataset/core/tensor.h"
#include "minddata/dataset/kernels/tensor_op.h"

namespace mindspore {
namespace dataset {

class MuLawDecodingOp : public TensorOp {
public:
explicit MuLawDecodingOp(int quantization_channels = 256);

~MuLawDecodingOp() override = default;

Status Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) override;

Status OutputType(const std::vector<DataType> &inputs, std::vector<DataType> &outputs) override;

std::string Name() const override { return kMuLawDecodingOp; }

private:
int quantization_channels_;
};
} // namespace dataset
} // namespace mindspore

#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_AUDIO_KERNELS_MU_LAW_DECODING_OP_H_

+ 21
- 0
mindspore/ccsrc/minddata/dataset/include/dataset/audio.h View File

@@ -320,6 +320,27 @@ class LowpassBiquad final : public TensorTransform {
std::shared_ptr<Data> data_;
};
/// \brief MuLawDecoding TensorTransform.
/// \note Decode mu-law encoded signal.
class MuLawDecoding final : public TensorTransform {
public:
/// \brief Constructor.
/// \param[in] quantization_channels Number of channels, which must be positive (Default: 256).
explicit MuLawDecoding(int quantization_channels = 256);
/// \brief Destructor.
~MuLawDecoding() = default;
protected:
/// \brief Function to convert TensorTransform object into a TensorOperation object.
/// \return Shared pointer to TensorOperation object.
std::shared_ptr<TensorOperation> Parse() override;
private:
struct Data;
std::shared_ptr<Data> data_;
};
/// \brief TimeMasking TensorTransform.
/// \notes Apply masking to a spectrogram in the time domain.
class TimeMasking final : public TensorTransform {


+ 1
- 0
mindspore/ccsrc/minddata/dataset/kernels/tensor_op.h View File

@@ -152,6 +152,7 @@ constexpr char kDeemphBiquadOp[] = "DeemphBiquadOp";
constexpr char kFrequencyMaskingOp[] = "FrequencyMaskingOp";
constexpr char kHighpassBiquadOp[] = "HighpassBiquadOp";
constexpr char kLowpassBiquadOp[] = "LowpassBiquadOp";
constexpr char kMuLawDecodingOp[] = "MuLawDecodingOp";
constexpr char kTimeMaskingOp[] = "TimeMaskingOp";
constexpr char kTimeStretchOp[] = "TimeStretchOp";



+ 24
- 1
mindspore/dataset/audio/transforms.py View File

@@ -26,7 +26,7 @@ from ..transforms.c_transforms import TensorOperation
from .utils import ScaleType
from .validators import check_allpass_biquad, check_amplitude_to_db, check_band_biquad, check_bandpass_biquad, \
check_bandreject_biquad, check_bass_biquad, check_complex_norm, check_contrast, check_deemph_biquad, \
check_highpass_biquad, check_lowpass_biquad, check_masking, check_time_stretch
check_highpass_biquad, check_lowpass_biquad, check_masking, check_mu_law_decoding, check_time_stretch
class AudioTensorOperation(TensorOperation):
@@ -406,6 +406,29 @@ class LowpassBiquad(AudioTensorOperation):
return cde.LowpassBiquadOperation(self.sample_rate, self.cutoff_freq, self.Q)
class MuLawDecoding(AudioTensorOperation):
"""
Decode mu-law encoded signal.
Args:
quantization_channels (int): Number of channels, which must be positive (Default: 256).
Examples:
>>> import numpy as np
>>>
>>> waveform = np.random.random([1, 3, 4])
>>> numpy_slices_dataset = ds.NumpySlicesDataset(data=waveform, column_names=["audio"])
>>> transforms = [audio.MuLawDecoding()]
>>> numpy_slices_dataset = numpy_slices_dataset.map(operations=transforms, input_columns=["audio"])
"""
@check_mu_law_decoding
def __init__(self, quantization_channels=256):
self.quantization_channels = quantization_channels
def parse(self):
return cde.MuLawDecodingOperation(self.quantization_channels)
class TimeMasking(AudioTensorOperation):
"""
Apply masking to a spectrogram in the time domain.


+ 12
- 0
mindspore/dataset/audio/validators.py View File

@@ -229,6 +229,18 @@ def check_lowpass_biquad(method):
return new_method


def check_mu_law_decoding(method):
"""Wrapper method to check the parameters of MuLawDecoding"""

@wraps(method)
def new_method(self, *args, **kwargs):
[quantization_channels], _ = parse_user_args(method, *args, **kwargs)
check_pos_int32(quantization_channels, "quantization_channels")
return method(self, *args, **kwargs)

return new_method


def check_time_stretch(method):
"""Wrapper method to check the parameters of TimeStretch."""



+ 60
- 2
tests/ut/cpp/dataset/c_api_audio_a_to_q_test.cc View File

@@ -825,7 +825,7 @@ TEST_F(MindDataTestPipeline, TestHighpassBiquadWrongArgs) {

// Check sample_rate
MS_LOG(INFO) << "sample_rate is zero.";
auto highpass_biquad_op_01 = audio::HighpassBiquad(0,200.0,0.7);
auto highpass_biquad_op_01 = audio::HighpassBiquad(0, 200.0, 0.7);
ds01 = ds->Map({highpass_biquad_op_01});
EXPECT_NE(ds01, nullptr);

@@ -834,10 +834,68 @@ TEST_F(MindDataTestPipeline, TestHighpassBiquadWrongArgs) {

// Check Q
MS_LOG(INFO) << "Q is zero.";
auto highpass_biquad_op_02 = audio::HighpassBiquad(44100,2000.0,0);
auto highpass_biquad_op_02 = audio::HighpassBiquad(44100, 2000.0, 0);
ds02 = ds->Map({highpass_biquad_op_02});
EXPECT_NE(ds02, nullptr);

std::shared_ptr<Iterator> iter02 = ds02->CreateIterator();
EXPECT_EQ(iter02, nullptr);
}

TEST_F(MindDataTestPipeline, TestMuLawDecodingBasic) {
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestMuLawDecodingBasic.";

// Original waveform
std::shared_ptr<SchemaObj> schema = Schema();
ASSERT_OK(schema->add_column("inputData", mindspore::DataType::kNumberTypeInt64, {1, 100}));
std::shared_ptr<Dataset> ds = RandomData(50, schema);
EXPECT_NE(ds, nullptr);

ds = ds->SetNumWorkers(4);
EXPECT_NE(ds, nullptr);

auto MuLawDecodingOp = audio::MuLawDecoding();

ds = ds->Map({MuLawDecodingOp});
EXPECT_NE(ds, nullptr);

// Filtered waveform by MuLawDecoding
std::shared_ptr<Iterator> iter = ds->CreateIterator();
EXPECT_NE(ds, nullptr);

std::unordered_map<std::string, mindspore::MSTensor> row;
ASSERT_OK(iter->GetNextRow(&row));

std::vector<int64_t> expected = {1, 100};

int i = 0;
while (row.size() != 0) {
auto col = row["inputData"];
ASSERT_EQ(col.Shape(), expected);
ASSERT_EQ(col.DataType(), mindspore::DataType::kNumberTypeFloat32);
ASSERT_OK(iter->GetNextRow(&row));
i++;
}
EXPECT_EQ(i, 50);

iter->Stop();
}

TEST_F(MindDataTestPipeline, TestMuLawDecodingWrongArgs) {
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestMuLawDecodingWrongArgs.";

// Original waveform
std::shared_ptr<SchemaObj> schema = Schema();
ASSERT_OK(schema->add_column("inputData", mindspore::DataType::kNumberTypeInt64, {1, 100}));
std::shared_ptr<Dataset> ds = RandomData(50, schema);
EXPECT_NE(ds, nullptr);

ds = ds->SetNumWorkers(4);
EXPECT_NE(ds, nullptr);

auto MuLawDecodingOp = audio::MuLawDecoding(-10);

ds = ds->Map({MuLawDecodingOp});
std::shared_ptr<Iterator> iter1 = ds->CreateIterator();
EXPECT_EQ(iter1, nullptr);
}

+ 24
- 11
tests/ut/cpp/dataset/execute_test.cc View File

@@ -773,9 +773,9 @@ TEST_F(MindDataTestExecute, TestHighpassBiquadEager) {
float Q = 0.707;
std::vector<mindspore::MSTensor> output;
std::shared_ptr<Tensor> test;
std::vector<double> test_vector = {0.8236, 0.2049, 0.3335, 0.5933, 0.9911, 0.2482,
0.3007, 0.9054, 0.7598, 0.5394, 0.2842, 0.5634, 0.6363, 0.2226, 0.2288};
Tensor::CreateFromVector(test_vector, TensorShape({5,3}), &test);
std::vector<double> test_vector = {0.8236, 0.2049, 0.3335, 0.5933, 0.9911, 0.2482, 0.3007, 0.9054,
0.7598, 0.5394, 0.2842, 0.5634, 0.6363, 0.2226, 0.2288};
Tensor::CreateFromVector(test_vector, TensorShape({5, 3}), &test);
auto input = mindspore::MSTensor(std::make_shared<mindspore::dataset::DETensor>(test));
std::shared_ptr<TensorTransform> highpass_biquad(new audio::HighpassBiquad({sample_rate, cutoff_freq, Q}));
auto transform = Execute({highpass_biquad});
@@ -787,11 +787,10 @@ TEST_F(MindDataTestExecute, TestHighpassBiquadParamCheckQ) {
MS_LOG(INFO) << "Doing MindDataTestExecute-TestHighpassBiquadParamCheckQ.";
std::vector<mindspore::MSTensor> output;
std::shared_ptr<Tensor> test;
std::vector<float> test_vector = {0.6013, 0.8081, 0.6600, 0.4278, 0.4049, 0.0541, 0.8800, 0.7143, 0.0926,
0.3502, 0.6148, 0.8738, 0.1869, 0.9023, 0.4293, 0.2175, 0.5132, 0.2622,
0.6490, 0.0741, 0.7903, 0.3428, 0.1598, 0.4841, 0.8128, 0.7409, 0.7226,
0.4951, 0.5589, 0.9210};
Tensor::CreateFromVector(test_vector, TensorShape({5,3,2}), &test);
std::vector<float> test_vector = {0.6013, 0.8081, 0.6600, 0.4278, 0.4049, 0.0541, 0.8800, 0.7143, 0.0926, 0.3502,
0.6148, 0.8738, 0.1869, 0.9023, 0.4293, 0.2175, 0.5132, 0.2622, 0.6490, 0.0741,
0.7903, 0.3428, 0.1598, 0.4841, 0.8128, 0.7409, 0.7226, 0.4951, 0.5589, 0.9210};
Tensor::CreateFromVector(test_vector, TensorShape({5, 3, 2}), &test);
auto input = mindspore::MSTensor(std::make_shared<mindspore::dataset::DETensor>(test));
// Check Q
std::shared_ptr<TensorTransform> highpass_biquad_op = std::make_shared<audio::HighpassBiquad>(44100, 3000.5, 0);
@@ -804,9 +803,8 @@ TEST_F(MindDataTestExecute, TestHighpassBiquadParamCheckSampleRate) {
MS_LOG(INFO) << "Doing MindDataTestExecute-TestHighpassBiquadParamCheckSampleRate.";
std::vector<mindspore::MSTensor> output;
std::shared_ptr<Tensor> test;
std::vector<double> test_vector = {0.0237, 0.6026, 0.3801, 0.1978, 0.8672,
0.0095, 0.5166, 0.2641, 0.5485, 0.5144};
Tensor::CreateFromVector(test_vector, TensorShape({1,10}), &test);
std::vector<double> test_vector = {0.0237, 0.6026, 0.3801, 0.1978, 0.8672, 0.0095, 0.5166, 0.2641, 0.5485, 0.5144};
Tensor::CreateFromVector(test_vector, TensorShape({1, 10}), &test);
auto input = mindspore::MSTensor(std::make_shared<mindspore::dataset::DETensor>(test));
// Check sample_rate
std::shared_ptr<TensorTransform> highpass_biquad_op = std::make_shared<audio::HighpassBiquad>(0, 3000.5, 0.7);
@@ -814,3 +812,18 @@ TEST_F(MindDataTestExecute, TestHighpassBiquadParamCheckSampleRate) {
Status rc = transform({input}, &output);
ASSERT_FALSE(rc.IsOk());
}

TEST_F(MindDataTestExecute, TestMuLawDecodingEager) {
MS_LOG(INFO) << "Doing MindDataTestExecute-TestMuLawDecodingEager.";
// testing
std::shared_ptr<Tensor> input_tensor_;
Tensor::CreateFromVector(std::vector<float>({1, 254, 231, 155, 101, 77}), TensorShape({1, 6}), &input_tensor_);

auto input_02 = mindspore::MSTensor(std::make_shared<mindspore::dataset::DETensor>(input_tensor_));
std::shared_ptr<TensorTransform> mu_law_encoding_01 = std::make_shared<audio::MuLawDecoding>(255);

// Filtered waveform by mulawencoding
mindspore::dataset::Execute Transform01({mu_law_encoding_01});
Status s01 = Transform01(input_02, &input_02);
EXPECT_TRUE(s01.IsOk());
}

+ 82
- 0
tests/ut/python/dataset/test_mu_law_decoding_op.py View File

@@ -0,0 +1,82 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""
Testing MuLawDecoding op in DE.
"""

import numpy as np

import mindspore.dataset as ds
import mindspore.dataset.audio.transforms as audio
from mindspore import log as logger


def test_mu_law_decoding():
"""
Test mu_law_decoding_op (pipeline).
"""
logger.info("Test MuLawDecoding.")

def gen():
data = np.array([[10, 100, 70, 200]])
yield (np.array(data, dtype=np.float32),)

dataset = ds.GeneratorDataset(source=gen, column_names=["multi_dim_data"])

dataset = dataset.map(operations=audio.MuLawDecoding(), input_columns=["multi_dim_data"])

for i in dataset.create_dict_iterator(num_epochs=1, output_numpy=True):
assert i["multi_dim_data"].shape == (1, 4)
expected = np.array([[-0.6459359526634216, -0.009046762250363827, -0.04388953000307083, 0.08788024634122849]])
assert np.array_equal(i["multi_dim_data"], expected)

logger.info("Finish testing MuLawDecoding.")


def test_mu_law_decoding_eager():
"""
Test mu_law_decoding_op callable (eager).
"""
logger.info("Test MuLawDecoding callable.")

input_t = np.array([70, 170])
output_t = audio.MuLawDecoding()(input_t)
assert output_t.shape == (2,)
excepted = np.array([-0.04388953000307083, 0.02097884565591812])
assert np.array_equal(output_t, excepted)

logger.info("Finish testing MuLawDecoding.")


def test_mu_law_decoding_uncallable():
"""
Test mu_law_decoding_op not callable.
"""
logger.info("Test MuLawDecoding not callable.")

try:
input_t = np.random.rand(2, 4)
output_t = audio.MuLawDecoding(-3)(input_t)
assert output_t.shape == (2, 4)
except ValueError as e:
assert 'Input quantization_channels is not within the required interval of [1, 2147483647].' in str(e)

logger.info("Finish testing MuLawDecoding.")


if __name__ == "__main__":
test_mu_law_decoding()
test_mu_law_decoding_eager()
test_mu_law_decoding_uncallable()

Loading…
Cancel
Save