Browse Source

[fix][assistant][I3CEGF] fix can not compute bool tensor

r1.7
chenx2ovo 4 years ago
parent
commit
c50a4611ec
2 changed files with 16 additions and 1 deletions
  1. +1
    -1
      mindspore/ccsrc/minddata/dataset/audio/kernels/audio_utils.cc
  2. +15
    -0
      tests/ut/cpp/dataset/execute_test.cc

+ 1
- 1
mindspore/ccsrc/minddata/dataset/audio/kernels/audio_utils.cc View File

@@ -851,7 +851,7 @@ Status Fade(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *outpu

Status Fade(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output, int32_t fade_in_len,
int32_t fade_out_len, FadeShape fade_shape) {
if (DataType::DE_INT8 <= input->type().value() && input->type().value() <= DataType::DE_FLOAT32) {
if (DataType::DE_BOOL <= input->type().value() && input->type().value() <= DataType::DE_FLOAT32) {
std::shared_ptr<Tensor> waveform;
RETURN_IF_NOT_OK(TypeCast(input, &waveform, DataType(DataType::DE_FLOAT32)));
RETURN_IF_NOT_OK(Fade<float>(waveform, output, fade_in_len, fade_out_len, fade_shape));


+ 15
- 0
tests/ut/cpp/dataset/execute_test.cc View File

@@ -1420,6 +1420,21 @@ TEST_F(MindDataTestExecute, TestFadeWithInvalidArg) {
EXPECT_FALSE(s04.IsOk());
}

/// Feature: Fade
/// Description: test Fade with bool type
/// Expectation: success.
TEST_F(MindDataTestExecute, TestFadeWithBool) {
MS_LOG(INFO) << "Doing MindDataTestExecute-TestFadeWithBool.";
std::vector<bool> waveform = {1, 0, 1, 1, 1, 1, 1, 1};
std::shared_ptr<Tensor> input;
ASSERT_OK(Tensor::CreateFromVector(waveform, TensorShape({1, 8}), &input));
auto input_01 = mindspore::MSTensor(std::make_shared<mindspore::dataset::DETensor>(input));
std::shared_ptr<TensorTransform> fade1 = std::make_shared<audio::Fade>(5, 6, FadeShape::kLinear);
mindspore::dataset::Execute Transform01({fade1});
Status s01 = Transform01(input_01, &input_01);
EXPECT_TRUE(s01.IsOk());
}

TEST_F(MindDataTestExecute, TestVolDefalutValue) {
MS_LOG(INFO) << "Doing MindDataTestExecute-TestVolDefalutValue.";
std::shared_ptr<Tensor> input_tensor_;


Loading…
Cancel
Save