| @@ -25,6 +25,7 @@ | |||
| #include "minddata/dataset/audio/ir/kernels/bass_biquad_ir.h" | |||
| #include "minddata/dataset/audio/ir/kernels/complex_norm_ir.h" | |||
| #include "minddata/dataset/audio/ir/kernels/contrast_ir.h" | |||
| #include "minddata/dataset/audio/ir/kernels/dc_shift_ir.h" | |||
| #include "minddata/dataset/audio/ir/kernels/deemph_biquad_ir.h" | |||
| #include "minddata/dataset/audio/ir/kernels/equalizer_biquad_ir.h" | |||
| #include "minddata/dataset/audio/ir/kernels/frequency_masking_ir.h" | |||
| @@ -166,6 +167,21 @@ std::shared_ptr<TensorOperation> Contrast::Parse() { | |||
| return std::make_shared<ContrastOperation>(data_->enhancement_amount_); | |||
| } | |||
| // DCShift Transform Operation. | |||
| struct DCShift::Data { | |||
| Data(float shift, float limiter_gain) : shift_(shift), limiter_gain_(limiter_gain) {} | |||
| float limiter_gain_; | |||
| float shift_; | |||
| }; | |||
| DCShift::DCShift(float shift) : data_(std::make_shared<Data>(shift, shift)) {} | |||
| DCShift::DCShift(float shift, float limiter_gain) : data_(std::make_shared<Data>(shift, limiter_gain)) {} | |||
| std::shared_ptr<TensorOperation> DCShift::Parse() { | |||
| return std::make_shared<DCShiftOperation>(data_->shift_, data_->limiter_gain_); | |||
| } | |||
| // DeemphBiquad Transform Operation. | |||
| struct DeemphBiquad::Data { | |||
| explicit Data(int32_t sample_rate) : sample_rate_(sample_rate) {} | |||
| @@ -29,6 +29,7 @@ | |||
| #include "minddata/dataset/audio/ir/kernels/bass_biquad_ir.h" | |||
| #include "minddata/dataset/audio/ir/kernels/complex_norm_ir.h" | |||
| #include "minddata/dataset/audio/ir/kernels/contrast_ir.h" | |||
| #include "minddata/dataset/audio/ir/kernels/dc_shift_ir.h" | |||
| #include "minddata/dataset/audio/ir/kernels/deemph_biquad_ir.h" | |||
| #include "minddata/dataset/audio/ir/kernels/equalizer_biquad_ir.h" | |||
| #include "minddata/dataset/audio/ir/kernels/frequency_masking_ir.h" | |||
| @@ -148,6 +149,16 @@ PYBIND_REGISTER(ContrastOperation, 1, ([](const py::module *m) { | |||
| })); | |||
| })); | |||
| PYBIND_REGISTER(DCShiftOperation, 1, ([](const py::module *m) { | |||
| (void)py::class_<audio::DCShiftOperation, TensorOperation, std::shared_ptr<audio::DCShiftOperation>>( | |||
| *m, "DCShiftOperation") | |||
| .def(py::init([](float shift, float limiter_gain) { | |||
| auto dc_shift = std::make_shared<audio::DCShiftOperation>(shift, limiter_gain); | |||
| THROW_IF_ERROR(dc_shift->ValidateParams()); | |||
| return dc_shift; | |||
| })); | |||
| })); | |||
| PYBIND_REGISTER( | |||
| DeemphBiquadOperation, 1, ([](const py::module *m) { | |||
| (void)py::class_<audio::DeemphBiquadOperation, TensorOperation, std::shared_ptr<audio::DeemphBiquadOperation>>( | |||
| @@ -11,6 +11,7 @@ add_library(audio-ir-kernels OBJECT | |||
| bass_biquad_ir.cc | |||
| complex_norm_ir.cc | |||
| contrast_ir.cc | |||
| dc_shift_ir.cc | |||
| deemph_biquad_ir.cc | |||
| equalizer_biquad_ir.cc | |||
| frequency_masking_ir.cc | |||
| @@ -0,0 +1,49 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "minddata/dataset/audio/ir/kernels/dc_shift_ir.h" | |||
| #include "minddata/dataset/audio/ir/validators.h" | |||
| #include "minddata/dataset/audio/kernels/dc_shift_op.h" | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| namespace audio { | |||
| // DCShiftOperation | |||
| DCShiftOperation::DCShiftOperation(float shift, float limiter_gain) : shift_(shift), limiter_gain_(limiter_gain) {} | |||
| Status DCShiftOperation::ValidateParams() { | |||
| RETURN_IF_NOT_OK(ValidateScalar("DCShift", "shift", shift_, {-2.0, 2.0}, false, false)); | |||
| return Status::OK(); | |||
| } | |||
| std::shared_ptr<TensorOp> DCShiftOperation::Build() { | |||
| std::shared_ptr<DCShiftOp> tensor_op = std::make_shared<DCShiftOp>(shift_, limiter_gain_); | |||
| return tensor_op; | |||
| } | |||
| Status DCShiftOperation::to_json(nlohmann::json *out_json) { | |||
| nlohmann::json args; | |||
| args["shift"] = shift_; | |||
| args["limiter_gain"] = limiter_gain_; | |||
| *out_json = args; | |||
| return Status::OK(); | |||
| } | |||
| } // namespace audio | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,55 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_AUDIO_IR_KERNELS_DC_SHIFT_IR_H_ | |||
| #define MINDSPORE_CCSRC_MINDDATA_DATASET_AUDIO_IR_KERNELS_DC_SHIFT_IR_H_ | |||
| #include <memory> | |||
| #include <string> | |||
| #include "include/api/status.h" | |||
| #include "minddata/dataset/include/dataset/constants.h" | |||
| #include "minddata/dataset/include/dataset/transforms.h" | |||
| #include "minddata/dataset/kernels/ir/tensor_operation.h" | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| namespace audio { | |||
| constexpr char kDCShiftOperation[] = "DCShift"; | |||
| class DCShiftOperation : public TensorOperation { | |||
| public: | |||
| DCShiftOperation(float shift, float limiter_gain); | |||
| ~DCShiftOperation() = default; | |||
| std::shared_ptr<TensorOp> Build() override; | |||
| Status ValidateParams() override; | |||
| std::string Name() const override { return kDCShiftOperation; } | |||
| Status to_json(nlohmann::json *out_json) override; | |||
| private: | |||
| float shift_; | |||
| float limiter_gain_; | |||
| }; | |||
| } // namespace audio | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_AUDIO_IR_KERNELS_DC_SHIFT_IR_H_ | |||
| @@ -12,6 +12,7 @@ add_library(audio-kernels OBJECT | |||
| bass_biquad_op.cc | |||
| complex_norm_op.cc | |||
| contrast_op.cc | |||
| dc_shift_op.cc | |||
| deemph_biquad_op.cc | |||
| equalizer_biquad_op.cc | |||
| frequency_masking_op.cc | |||
| @@ -150,6 +150,40 @@ Status Contrast(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *o | |||
| return Status::OK(); | |||
| } | |||
| /// \brief Apply a DC shift to the audio. | |||
| /// \param input/output: Tensor of shape <...,time>. | |||
| /// \param shift: the amount to shift the audio. | |||
| /// \param limiter_gain: used only on peaks to prevent clipping. | |||
| /// \return Status code. | |||
| template <typename T> | |||
| Status DCShift(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output, float shift, float limiter_gain) { | |||
| float limiter_threshold = 0.0; | |||
| if (shift != limiter_gain && shift != 0) { | |||
| limiter_threshold = 1.0 - (std::abs(shift) - limiter_gain); | |||
| for (auto itr = input->begin<T>(); itr != input->end<T>(); itr++) { | |||
| if (*itr > limiter_threshold && shift > 0) { | |||
| T peak = (*itr - limiter_threshold) * limiter_gain / (1 - limiter_threshold); | |||
| T sample = (peak + limiter_threshold + shift); | |||
| *itr = sample > limiter_threshold ? limiter_threshold : sample; | |||
| } else if (*itr < -limiter_threshold && shift < 0) { | |||
| T peak = (*itr + limiter_threshold) * limiter_gain / (1 - limiter_threshold); | |||
| T sample = (peak + limiter_threshold + shift); | |||
| *itr = sample < -limiter_threshold ? -limiter_threshold : sample; | |||
| } else { | |||
| T sample = (*itr + shift); | |||
| *itr = (sample > 1 || sample < -1) ? (sample > 1 ? 1 : -1) : sample; | |||
| } | |||
| } | |||
| } else { | |||
| for (auto itr = input->begin<T>(); itr != input->end<T>(); itr++) { | |||
| T sample = (*itr + shift); | |||
| *itr = sample > 1 || sample < -1 ? (sample > 1 ? 1 : -1) : sample; | |||
| } | |||
| } | |||
| *output = input; | |||
| return Status::OK(); | |||
| } | |||
| /// \brief Perform an IIR filter by evaluating difference equation. | |||
| /// \param input/output: Tensor of shape <..., time> | |||
| /// \param a_coeffs: denominator coefficients of difference equation of dimension of (n_order + 1). | |||
| @@ -0,0 +1,55 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "minddata/dataset/audio/kernels/dc_shift_op.h" | |||
| #include "minddata/dataset/audio/kernels/audio_utils.h" | |||
| #include "minddata/dataset/kernels/data/data_utils.h" | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| Status DCShiftOp::Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) { | |||
| IO_CHECK(input, output); | |||
| // input <..., time>. | |||
| CHECK_FAIL_RETURN_UNEXPECTED(input->Rank() > 0, "ComplexNorm: input tensor is not in shape of <..., time>."); | |||
| // If datatype is not a numeric type, then we cannot deal with the data. | |||
| CHECK_FAIL_RETURN_UNEXPECTED( | |||
| input->type().IsNumeric(), | |||
| "DCShift: input tensor type should be int, float or double, but got: " + input->type().ToString()); | |||
| if (input->type() == DataType(DataType::DE_FLOAT64)) { | |||
| return DCShift<double>(input, output, shift_, limiter_gain_); | |||
| } else { | |||
| std::shared_ptr<Tensor> tmp; | |||
| TypeCast(input, &tmp, DataType(DataType::DE_FLOAT32)); | |||
| return DCShift<float>(tmp, output, shift_, limiter_gain_); | |||
| } | |||
| } | |||
| Status DCShiftOp::OutputType(const std::vector<DataType> &inputs, std::vector<DataType> &outputs) { | |||
| RETURN_IF_NOT_OK(TensorOp::OutputType(inputs, outputs)); | |||
| CHECK_FAIL_RETURN_UNEXPECTED( | |||
| inputs[0].IsNumeric(), | |||
| "DCShift: input tensor type should be int, float or double, but got: " + inputs[0].ToString()); | |||
| if (inputs[0] == DataType(DataType::DE_FLOAT64)) { | |||
| outputs[0] = DataType(DataType::DE_FLOAT64); | |||
| } else { | |||
| outputs[0] = DataType(DataType::DE_FLOAT32); | |||
| } | |||
| return Status::OK(); | |||
| } | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,53 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_AUDIO_KERNELS_DC_SHIFT_OP_H_ | |||
| #define MINDSPORE_CCSRC_MINDDATA_DATASET_AUDIO_KERNELS_DC_SHIFT_OP_H_ | |||
| #include <memory> | |||
| #include <string> | |||
| #include <vector> | |||
| #include "minddata/dataset/core/tensor.h" | |||
| #include "minddata/dataset/kernels/tensor_op.h" | |||
| #include "minddata/dataset/util/status.h" | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| class DCShiftOp : public TensorOp { | |||
| public: | |||
| DCShiftOp(float shift, float limiter_gain) : shift_(shift), limiter_gain_(limiter_gain) {} | |||
| ~DCShiftOp() override = default; | |||
| void Print(std::ostream &out) const override { | |||
| out << Name() << ":: shift: " << shift_ << ", limiter_gain: " << limiter_gain_; | |||
| } | |||
| Status Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) override; | |||
| std::string Name() const override { return kDCShiftOp; } | |||
| Status OutputType(const std::vector<DataType> &inputs, std::vector<DataType> &outputs) override; | |||
| protected: | |||
| float shift_; | |||
| float limiter_gain_; | |||
| }; | |||
| } // namespace dataset | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_AUDIO_KERNELS_DC_SHIFT_OP_H_ | |||
| @@ -230,6 +230,32 @@ class Contrast final : public TensorTransform { | |||
| std::shared_ptr<Data> data_; | |||
| }; | |||
| /// \brief Apply a DC shift to the audio. | |||
| class DCShift : public TensorTransform { | |||
| public: | |||
| /// \brief Constructor | |||
| /// \param[in] shift Indicates the amount to shift the audio, the value must be in the range [-2.0, 2.0]. | |||
| /// \param[in] limiter_gain Used only on peaks to prevent clipping. | |||
| DCShift(float shift, float limiter_gain); | |||
| /// \brief Constructor | |||
| /// \param[in] shift Indicates the amount to shift the audio. | |||
| /// \note This constructor will use `shift` as `limiter_gain`. | |||
| explicit DCShift(float shift); | |||
| /// \brief Destructor. | |||
| ~DCShift() = default; | |||
| protected: | |||
| /// \brief Function to convert TensorTransform object into a TensorOperation object. | |||
| /// \return Shared pointer to TensorOperation object. | |||
| std::shared_ptr<TensorOperation> Parse() override; | |||
| private: | |||
| struct Data; | |||
| std::shared_ptr<Data> data_; | |||
| }; | |||
| /// \brief Design two-pole deemph filter. Similar to SoX implementation. | |||
| class DeemphBiquad final : public TensorTransform { | |||
| public: | |||
| @@ -148,6 +148,7 @@ constexpr char kBandrejectBiquadOp[] = "BandrejectBiquadOp"; | |||
| constexpr char kBassBiquadOp[] = "BassBiquadOp"; | |||
| constexpr char kComplexNormOp[] = "ComplexNormOp"; | |||
| constexpr char kContrastOp[] = "ContrastOp"; | |||
| constexpr char kDCShiftOp[] = "DCShiftOp"; | |||
| constexpr char kDeemphBiquadOp[] = "DeemphBiquadOp"; | |||
| constexpr char kEqualizerBiquadOp[] = "EqualizerBiquadOp"; | |||
| constexpr char kFrequencyMaskingOp[] = "FrequencyMaskingOp"; | |||
| @@ -25,9 +25,9 @@ import mindspore._c_dataengine as cde | |||
| from ..transforms.c_transforms import TensorOperation | |||
| from .utils import ScaleType | |||
| from .validators import check_allpass_biquad, check_amplitude_to_db, check_band_biquad, check_bandpass_biquad, \ | |||
| check_bandreject_biquad, check_bass_biquad, check_complex_norm, check_contrast, check_deemph_biquad, \ | |||
| check_equalizer_biquad, check_highpass_biquad, check_lfilter, check_lowpass_biquad, check_masking, \ | |||
| check_mu_law_decoding, check_time_stretch | |||
| check_bandreject_biquad, check_bass_biquad, check_complex_norm, check_contrast, check_dc_shift, \ | |||
| check_deemph_biquad, check_equalizer_biquad, check_highpass_biquad, check_lfilter, check_lowpass_biquad, \ | |||
| check_masking, check_mu_law_decoding, check_time_stretch | |||
| class AudioTensorOperation(TensorOperation): | |||
| @@ -295,6 +295,33 @@ class Contrast(AudioTensorOperation): | |||
| return cde.ContrastOperation(self.enhancement_amount) | |||
| class DCShift(AudioTensorOperation): | |||
| """ | |||
| Apply a DC shift to the audio. | |||
| Args: | |||
| shift (float): The amount to shift the audio, the value must be in the range [-2.0, 2.0]. | |||
| limiter_gain (float, optional): Used only on peaks to prevent clipping, | |||
| the value should be much less than 1, such as 0.05 or 0.02. | |||
| Examples: | |||
| >>> import numpy as np | |||
| >>> | |||
| >>> waveform = np.array([0.60, 0.97, -1.04, -1.26, 0.97, 0.91, 0.48, 0.93]) | |||
| >>> numpy_slices_dataset = ds.NumpySlicesDataset(data=waveform, column_names=["audio"]) | |||
| >>> transforms = [audio.DCShift(0.5, 0.02)] | |||
| >>> numpy_slices_dataset = numpy_slices_dataset.map(operation=transforms, input_columns=["audio"]) | |||
| """ | |||
| @check_dc_shift | |||
| def __init__(self, shift, limiter_gain=None): | |||
| self.shift = shift | |||
| self.limiter_gain = limiter_gain if limiter_gain else shift | |||
| def parse(self): | |||
| return cde.DCShiftOperation(self.shift, self.limiter_gain) | |||
| class DeemphBiquad(AudioTensorOperation): | |||
| """ | |||
| Design two-pole deemph filter for audio waveform of dimension of (..., time). | |||
| @@ -201,6 +201,20 @@ def check_contrast(method): | |||
| return new_method | |||
| def check_dc_shift(method): | |||
| """Wrapper method to check the parameters of DCShift.""" | |||
| @wraps(method) | |||
| def new_method(self, *args, **kwargs): | |||
| [shift, limiter_gain], _ = parse_user_args(method, *args, **kwargs) | |||
| type_check(shift, (float, int), "shift") | |||
| check_value(shift, [-2.0, 2.0], "shift") | |||
| if limiter_gain is not None: | |||
| type_check(limiter_gain, (float, int), "limiter_gain") | |||
| return method(self, *args, **kwargs) | |||
| return new_method | |||
| def check_deemph_biquad(method): | |||
| """Wrapper method to check the parameters of CutMixBatch.""" | |||
| @@ -1016,3 +1016,54 @@ TEST_F(MindDataTestPipeline, TestLfilterWrongArgs) { | |||
| std::shared_ptr<Iterator> iter01 = ds01->CreateIterator(); | |||
| EXPECT_EQ(iter01, nullptr); | |||
| } | |||
| TEST_F(MindDataTestPipeline, TestDCShiftPipeline) { | |||
| MS_LOG(INFO) << "Doing MindDataTestPipeline-TestDCShiftPipeline."; | |||
| std::shared_ptr<SchemaObj> schema = Schema(); | |||
| ASSERT_OK(schema->add_column("waveform", mindspore::DataType::kNumberTypeFloat32, {1, 2, 100})); | |||
| std::shared_ptr<Dataset> ds = RandomData(50, schema); | |||
| EXPECT_NE(ds, nullptr); | |||
| auto dc_shift_op = audio::DCShift(0.8, 0.02); | |||
| ds = ds->Map({dc_shift_op}); | |||
| EXPECT_NE(ds, nullptr); | |||
| std::shared_ptr<Iterator> iter = ds->CreateIterator(); | |||
| EXPECT_NE(ds, nullptr); | |||
| std::unordered_map<std::string, mindspore::MSTensor> row; | |||
| ASSERT_OK(iter->GetNextRow(&row)); | |||
| std::vector<int64_t> expected = {1, 2, 100}; | |||
| int i = 0; | |||
| while (row.size() != 0) { | |||
| auto col = row["waveform"]; | |||
| ASSERT_EQ(col.Shape(), expected); | |||
| ASSERT_EQ(col.Shape().size(), 3); | |||
| ASSERT_EQ(col.DataType(), mindspore::DataType::kNumberTypeFloat32); | |||
| ASSERT_OK(iter->GetNextRow(&row)); | |||
| i++; | |||
| } | |||
| EXPECT_EQ(i, 50); | |||
| iter->Stop(); | |||
| } | |||
| TEST_F(MindDataTestPipeline, TestDCShiftPipelineError) { | |||
| MS_LOG(INFO) << "Doing MindDataTestPipeline-TestDCShiftPipelineError."; | |||
| std::shared_ptr<SchemaObj> schema = Schema(); | |||
| ASSERT_OK(schema->add_column("waveform", mindspore::DataType::kNumberTypeFloat32, {100})); | |||
| std::shared_ptr<Dataset> ds = RandomData(4, schema); | |||
| EXPECT_NE(ds, nullptr); | |||
| auto dc_shift_op = audio::DCShift(3, 0.02); | |||
| ds = ds->Map({dc_shift_op}); | |||
| EXPECT_NE(ds, nullptr); | |||
| std::shared_ptr<Iterator> iter = ds->CreateIterator(); | |||
| EXPECT_EQ(iter, nullptr); | |||
| } | |||
| @@ -913,3 +913,17 @@ TEST_F(MindDataTestExecute, TestLFilterWithWrongArg) { | |||
| Status s01 = Transform01(input_02, &input_02); | |||
| EXPECT_FALSE(s01.IsOk()); | |||
| } | |||
| TEST_F(MindDataTestExecute, TestDCShiftEager) { | |||
| MS_LOG(INFO) << "Doing MindDataTestExecute-TestDCShiftEager."; | |||
| std::vector<float> origin = {0.67443, 1.87523, 0.73465, -0.74553, -1.54346, 1.54093, -1.23453}; | |||
| std::shared_ptr<Tensor> de_tensor; | |||
| Tensor::CreateFromVector(origin, &de_tensor); | |||
| std::shared_ptr<TensorTransform> dc_shift = std::make_shared<audio::DCShift>(0.5, 0.02); | |||
| auto input = mindspore::MSTensor(std::make_shared<mindspore::dataset::DETensor>(de_tensor)); | |||
| mindspore::dataset::Execute Transform({dc_shift}); | |||
| Status s = Transform(input, &input); | |||
| ASSERT_TRUE(s.IsOk()); | |||
| } | |||
| @@ -0,0 +1,76 @@ | |||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================== | |||
| import numpy as np | |||
| import pytest | |||
| import mindspore.dataset as ds | |||
| import mindspore.dataset.audio.transforms as a_c_trans | |||
| def count_unequal_element(data_expected, data_me, rtol, atol): | |||
| assert data_expected.shape == data_me.shape | |||
| total_count = len(data_expected.flatten()) | |||
| error = np.abs(data_expected - data_me) | |||
| greater = np.greater(error, atol + np.abs(data_expected) * rtol) | |||
| loss_count = np.count_nonzero(greater) | |||
| assert (loss_count / total_count) < rtol, "\ndata_expected_std:{0}\ndata_me_error:{1}\nloss:{2}".format( | |||
| data_expected[greater], data_me[greater], error[greater]) | |||
| def test_func_dc_shift_eager(): | |||
| """ | |||
| Eager Test | |||
| """ | |||
| arr = np.array([0.60, 0.97, -1.04, -1.26, 0.97, 0.91, 0.48, 0.93, 0.71, 0.61], dtype=np.double) | |||
| expected = np.array([0.0400, 0.0400, -0.0400, -0.2600, 0.0400, 0.0400, 0.0400, 0.0400, 0.0400, 0.0400], | |||
| dtype=np.double) | |||
| dcshift_op = a_c_trans.DCShift(1.0, 0.04) | |||
| output = dcshift_op(arr) | |||
| count_unequal_element(expected, output, 0.0001, 0.0001) | |||
| def test_func_dc_shift_pipeline(): | |||
| """ | |||
| Pipeline Test | |||
| """ | |||
| arr = np.array([[1.14, -1.06, 0.94, 0.90], [-1.11, 1.40, -0.33, 1.43]], dtype=np.double) | |||
| expected = np.array([[0.2300, -0.2600, 0.2300, 0.2300], [-0.3100, 0.2300, 0.4700, 0.2300]], dtype=np.double) | |||
| dataset = ds.NumpySlicesDataset(arr, column_names=["col1"], shuffle=False) | |||
| dcshift_op = a_c_trans.DCShift(0.8, 0.03) | |||
| dataset = dataset.map(operations=dcshift_op, input_columns=["col1"]) | |||
| for item1, item2 in zip(dataset.create_dict_iterator(output_numpy=True), expected): | |||
| count_unequal_element(item2, item1['col1'], 0.0001, 0.0001) | |||
| def test_func_dc_shift_pipeline_error(): | |||
| """ | |||
| Pipeline Error Test | |||
| """ | |||
| arr = np.random.uniform(-2, 2, size=(1000)).astype(np.float) | |||
| label = np.random.sample((1000, 1)) | |||
| data = (arr, label) | |||
| dataset = ds.NumpySlicesDataset(data, column_names=["col1", "col2"], shuffle=False) | |||
| num_itr = 0 | |||
| with pytest.raises(ValueError, match=r"Input shift is not within the required interval of \[-2.0, 2.0\]."): | |||
| dcshift_op = a_c_trans.DCShift(2.5, 0.03) | |||
| dataset = dataset.map(operations=dcshift_op, input_columns=["col1"]) | |||
| for _ in dataset.create_dict_iterator(output_numpy=True): | |||
| num_itr += 1 | |||
| if __name__ == "__main__": | |||
| test_func_dc_shift_eager() | |||
| test_func_dc_shift_pipeline() | |||
| test_func_dc_shift_pipeline_error() | |||