diff --git a/mindspore/core/ir/dtype/number.cc b/mindspore/core/ir/dtype/number.cc index 8bbfcd7e14..c18f3e6f0e 100644 --- a/mindspore/core/ir/dtype/number.cc +++ b/mindspore/core/ir/dtype/number.cc @@ -47,6 +47,12 @@ Float::Float(const int nbits) : Number(FloatBitsToTypeId(nbits), nbits, false) { } } +Complex::Complex(const int nbits) : Number(TypeId::kNumberTypeComplex64, nbits, false) { + if (nbits != 64) { + MS_LOG(EXCEPTION) << "Wrong number of bits."; + } +} + const TypePtr kBool = std::make_shared(); const TypePtr kInt8 = std::make_shared(8); const TypePtr kInt16 = std::make_shared(16); @@ -63,4 +69,5 @@ const TypePtr kInt = std::make_shared(); const TypePtr kUInt = std::make_shared(); const TypePtr kFloat = std::make_shared(); const TypePtr kNumber = std::make_shared(); +const TypePtr kComplex64 = std::make_shared(64); } // namespace mindspore diff --git a/mindspore/core/ir/dtype/number.h b/mindspore/core/ir/dtype/number.h index d753546b2e..1410eb368d 100644 --- a/mindspore/core/ir/dtype/number.h +++ b/mindspore/core/ir/dtype/number.h @@ -150,6 +150,28 @@ class Float : public Number { } }; +// Complex +class Complex : public Number { + public: + Complex() : Number(kNumberTypeComplex64, 0) {} + explicit Complex(const int nbits); + ~Complex() override {} + MS_DECLARE_PARENT(Complex, Number) + + TypeId generic_type_id() const override { return kNumberTypeComplex64; } + TypePtr DeepCopy() const override { + if (nbits() == 0) { + return std::make_shared(); + } + return std::make_shared(nbits()); + } + std::string ToString() const override { return GetTypeName("Complex64"); } + std::string ToReprString() const override { return nbits() == 0 ? "complex64_" : GetTypeName("complex64"); } + std::string DumpText() const override { + return nbits() == 0 ? std::string("Complex64") : std::string("C") + std::to_string(nbits()); + } +}; + extern const TypePtr kBool; extern const TypePtr kInt8; extern const TypePtr kInt16; @@ -166,6 +188,7 @@ extern const TypePtr kInt; extern const TypePtr kUInt; extern const TypePtr kFloat; extern const TypePtr kNumber; +extern const TypePtr kComplex64; } // namespace mindspore #endif // MINDSPORE_CORE_IR_DTYPE_NUMBER_H_ diff --git a/mindspore/core/ir/dtype_extends.cc b/mindspore/core/ir/dtype_extends.cc index 5144aff490..7c69807138 100644 --- a/mindspore/core/ir/dtype_extends.cc +++ b/mindspore/core/ir/dtype_extends.cc @@ -69,6 +69,8 @@ TypePtr TypeIdToType(TypeId id) { return kFloat32; case kNumberTypeFloat64: return kFloat64; + case kNumberTypeComplex64: + return kComplex64; case kNumberTypeInt8: return kInt8; case kNumberTypeInt16: diff --git a/mindspore/lite/src/ops/audio_spectrogram.cc b/mindspore/lite/src/ops/audio_spectrogram.cc new file mode 100644 index 0000000000..e5c565625d --- /dev/null +++ b/mindspore/lite/src/ops/audio_spectrogram.cc @@ -0,0 +1,107 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/ops/audio_spectrogram.h" + +#ifndef PRIMITIVE_WRITEABLE +#include "src/ops/ops_register.h" +#endif + +namespace mindspore { +namespace lite { +#ifdef PRIMITIVE_WRITEABLE +int AudioSpectrogram::GetWindowSize() const { return this->primitive_->value.AsAudioSpectrogram()->windowSize; } +int AudioSpectrogram::GetStride() const { return this->primitive_->value.AsAudioSpectrogram()->stride; } +bool AudioSpectrogram::GetMagSquare() const { return this->primitive_->value.AsAudioSpectrogram()->magSquare; } + +#else +int AudioSpectrogram::GetWindowSize() const { return this->primitive_->value_as_AudioSpectrogram()->windowSize(); } +int AudioSpectrogram::GetStride() const { return this->primitive_->value_as_AudioSpectrogram()->stride(); } +bool AudioSpectrogram::GetMagSquare() const { return this->primitive_->value_as_AudioSpectrogram()->magSquare(); } +int AudioSpectrogram::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) { + MS_ASSERT(nullptr != primitive); + MS_ASSERT(nullptr != fbb); + auto attr = primitive->value_as_AudioSpectrogram(); + if (attr == nullptr) { + MS_LOG(ERROR) << "value_as_Add return nullptr"; + return RET_ERROR; + } + auto val_offset = schema::CreateAudioSpectrogram(*fbb, attr->windowSize(), attr->stride(), attr->magSquare()); + auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_AudioSpectrogram, val_offset.o); + fbb->Finish(prim_offset); + return RET_OK; +} + +PrimitiveC *AudioSpectrogramCreator(const schema::Primitive *primitive) { + return PrimitiveC::NewPrimitiveC(primitive); +} +Registry AudioSpectrogramRegistry(schema::PrimitiveType_AudioSpectrogram, AudioSpectrogramCreator); +#endif +int AudioSpectrogram::Log2Ceil(uint32_t length) { + if (length == 0) { + return -1; + } + int floor = 0; + for (int i = 4; i >= 0; --i) { + int shift = (1 << i); + uint32_t tmp = length >> shift; + if (tmp != 0) { + length = tmp; + floor += shift; + } + } + return length == (length & ~(length - 1)) ? floor : floor + 1; +} +uint32_t AudioSpectrogram::GetFftLength(uint32_t length) { + int shift = Log2Ceil(length); + return 1 << shift; +} +int AudioSpectrogram::InferShape(std::vector inputs_, std::vector outputs_) { + auto input = inputs_.front(); + MS_ASSERT(input != nullptr); + auto output = outputs_.front(); + MS_ASSERT(output != nullptr); + output->set_data_type(input->data_type()); + output->SetFormat(input->GetFormat()); + if (!GetInferFlag()) { + return RET_OK; + } + auto input_shape = input->shape(); + if (input_shape.size() != 2) { + MS_LOG(ERROR) << "input shape is error, which need to be 2 dimensions"; + return RET_ERROR; + } + if (GetWindowSize() < 2) { + MS_LOG(ERROR) << "window size is too short, now is " << GetWindowSize(); + return RET_ERROR; + } + if (GetStride() < 1) { + MS_LOG(ERROR) << "stride must be positive, now is " << GetStride(); + return RET_ERROR; + } + std::vector output_shape(3); + output_shape[0] = input_shape[1]; + // output height + int sample_sub_window = input_shape[0] - GetWindowSize(); + output_shape[1] = sample_sub_window < 0 ? 0 : 1 + sample_sub_window / GetStride(); + // compute fft length + int fft_length = GetFftLength(GetWindowSize()); + output_shape[2] = fft_length / 2 + 1; + outputs_.front()->set_shape(output_shape); + return RET_OK; +} +} // namespace lite +} // namespace mindspore diff --git a/mindspore/lite/src/ops/audio_spectrogram.h b/mindspore/lite/src/ops/audio_spectrogram.h new file mode 100644 index 0000000000..53e679a097 --- /dev/null +++ b/mindspore/lite/src/ops/audio_spectrogram.h @@ -0,0 +1,51 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef LITE_MINDSPORE_LITE_C_OPS_AUDIO_SPECTROGRAM_H_ +#define LITE_MINDSPORE_LITE_C_OPS_AUDIO_SPECTROGRAM_H_ + +#include +#include +#include +#include "src/ops/primitive_c.h" + +namespace mindspore { +namespace lite { +class AudioSpectrogram : public PrimitiveC { + public: +#ifdef PRIMITIVE_WRITEABLE + MS_DECLARE_PARENT(AudioSpectrogram, PrimitiveC); + AudioSpectrogram() = default; + explicit AudioSpectrogram(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {} + void SetWindowSize(int window_size) { this->primitive_->value.AsAudioSpectrogram()->windowSize = window_size; } + void SetStride(int stride) { this->primitive_->value.AsAudioSpectrogram()->stride = stride; } + void SetMagSquare(bool mag_square) { this->primitive_->value.AsAudioSpectrogram()->magSquare = mag_square; } +#else + AudioSpectrogram() = default; + + int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override; +#endif + int GetWindowSize() const; + int GetStride() const; + bool GetMagSquare() const; + int Log2Ceil(uint32_t length); + uint32_t GetFftLength(uint32_t length); + int InferShape(std::vector inputs_, std::vector outputs_) override; +}; +} // namespace lite +} // namespace mindspore + +#endif // LITE_MINDSPORE_LITE_C_OPS_AUDIO_SPECTROGRAM_H_ diff --git a/mindspore/lite/src/ops/fft_imag.cc b/mindspore/lite/src/ops/fft_imag.cc new file mode 100644 index 0000000000..3e2f6c07f9 --- /dev/null +++ b/mindspore/lite/src/ops/fft_imag.cc @@ -0,0 +1,54 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/ops/fft_imag.h" + +#ifndef PRIMITIVE_WRITEABLE +#include "src/ops/ops_register.h" +#endif + +namespace mindspore { +namespace lite { +#ifndef PRIMITIVE_WRITEABLE +int FftImag::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) { + MS_ASSERT(nullptr != primitive); + MS_ASSERT(nullptr != fbb); + auto val_offset = schema::CreateEqual(*fbb); + auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_FftImag, val_offset.o); + fbb->Finish(prim_offset); + return RET_OK; +} + +PrimitiveC *FftImagCreator(const schema::Primitive *primitive) { return PrimitiveC::NewPrimitiveC(primitive); } +Registry FftImagRegistry(schema::PrimitiveType_FftImag, FftImagCreator); +#endif +int FftImag::InferShape(std::vector inputs_, std::vector outputs_) { + auto input = inputs_.front(); + MS_ASSERT(input != nullptr); + auto output = outputs_.front(); + MS_ASSERT(output != nullptr); + output->set_data_type(TypeId::kNumberTypeFloat32); + output->SetFormat(input->GetFormat()); + if (!GetInferFlag()) { + return RET_OK; + } + auto input_shape = input->shape(); + input_shape.pop_back(); + outputs_.front()->set_shape(input_shape); + return RET_OK; +} +} // namespace lite +} // namespace mindspore diff --git a/mindspore/lite/src/ops/fft_imag.h b/mindspore/lite/src/ops/fft_imag.h new file mode 100644 index 0000000000..1b35e77a16 --- /dev/null +++ b/mindspore/lite/src/ops/fft_imag.h @@ -0,0 +1,43 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef LITE_MINDSPORE_LITE_C_OPS_FFT_IMAG_H_ +#define LITE_MINDSPORE_LITE_C_OPS_FFT_IMAG_H_ + +#include +#include +#include +#include "src/ops/primitive_c.h" + +namespace mindspore { +namespace lite { +class FftImag : public PrimitiveC { + public: +#ifdef PRIMITIVE_WRITEABLE + MS_DECLARE_PARENT(FftImag, PrimitiveC); + FftImag() = default; + explicit FftImag(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {} +#else + FftImag() = default; + + int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override; +#endif + int InferShape(std::vector inputs_, std::vector outputs_) override; +}; +} // namespace lite +} // namespace mindspore + +#endif // LITE_MINDSPORE_LITE_C_OPS_FFT_IMAG_H_ diff --git a/mindspore/lite/src/ops/fft_real.cc b/mindspore/lite/src/ops/fft_real.cc new file mode 100644 index 0000000000..de68c73723 --- /dev/null +++ b/mindspore/lite/src/ops/fft_real.cc @@ -0,0 +1,54 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/ops/fft_real.h" + +#ifndef PRIMITIVE_WRITEABLE +#include "src/ops/ops_register.h" +#endif + +namespace mindspore { +namespace lite { +#ifndef PRIMITIVE_WRITEABLE +int FftReal::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) { + MS_ASSERT(nullptr != primitive); + MS_ASSERT(nullptr != fbb); + auto val_offset = schema::CreateEqual(*fbb); + auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_FftReal, val_offset.o); + fbb->Finish(prim_offset); + return RET_OK; +} + +PrimitiveC *FftRealCreator(const schema::Primitive *primitive) { return PrimitiveC::NewPrimitiveC(primitive); } +Registry FftRealRegistry(schema::PrimitiveType_FftReal, FftRealCreator); +#endif +int FftReal::InferShape(std::vector inputs_, std::vector outputs_) { + auto input = inputs_.front(); + MS_ASSERT(input != nullptr); + auto output = outputs_.front(); + MS_ASSERT(output != nullptr); + output->set_data_type(TypeId::kNumberTypeFloat32); + output->SetFormat(input->GetFormat()); + if (!GetInferFlag()) { + return RET_OK; + } + auto input_shape = input->shape(); + input_shape.pop_back(); + outputs_.front()->set_shape(input_shape); + return RET_OK; +} +} // namespace lite +} // namespace mindspore diff --git a/mindspore/lite/src/ops/fft_real.h b/mindspore/lite/src/ops/fft_real.h new file mode 100644 index 0000000000..b35e4af5b1 --- /dev/null +++ b/mindspore/lite/src/ops/fft_real.h @@ -0,0 +1,43 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef LITE_MINDSPORE_LITE_C_OPS_FFT_REAL_H_ +#define LITE_MINDSPORE_LITE_C_OPS_FFT_REAL_H_ + +#include +#include +#include +#include "src/ops/primitive_c.h" + +namespace mindspore { +namespace lite { +class FftReal : public PrimitiveC { + public: +#ifdef PRIMITIVE_WRITEABLE + MS_DECLARE_PARENT(FftReal, PrimitiveC); + FftReal() = default; + explicit FftReal(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {} +#else + FftReal() = default; + + int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override; +#endif + int InferShape(std::vector inputs_, std::vector outputs_) override; +}; +} // namespace lite +} // namespace mindspore + +#endif // LITE_MINDSPORE_LITE_C_OPS_FFT_REAL_H_ diff --git a/mindspore/lite/src/ops/mfcc.cc b/mindspore/lite/src/ops/mfcc.cc new file mode 100644 index 0000000000..f184e33b50 --- /dev/null +++ b/mindspore/lite/src/ops/mfcc.cc @@ -0,0 +1,83 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/ops/mfcc.h" + +#ifndef PRIMITIVE_WRITEABLE +#include "src/ops/ops_register.h" +#endif + +namespace mindspore { +namespace lite { +#ifdef PRIMITIVE_WRITEABLE +float Mfcc::GetFreqUpperLimit() const { return this->primitive_->value.AsMfcc()->freqUpperLimit; } +float Mfcc::GetFreqLowerLimit() const { return this->primitive_->value.AsMfcc()->freqLowerLimit; } +int Mfcc::GetFilterBankChannelNum() const { return this->primitive_->value.AsMfcc()->filterBankChannelNum; } +int Mfcc::GetDctCoeffNum() const { return this->primitive_->value.AsMfcc()->dctCoeffNum; } + +#else +float Mfcc::GetFreqUpperLimit() const { return this->primitive_->value_as_Mfcc()->freqUpperLimit(); } +float Mfcc::GetFreqLowerLimit() const { return this->primitive_->value_as_Mfcc()->freqLowerLimit(); } +int Mfcc::GetFilterBankChannelNum() const { return this->primitive_->value_as_Mfcc()->filterBankChannelNum(); } +int Mfcc::GetDctCoeffNum() const { return this->primitive_->value_as_Mfcc()->dctCoeffNum(); } +int Mfcc::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) { + MS_ASSERT(nullptr != primitive); + MS_ASSERT(nullptr != fbb); + auto attr = primitive->value_as_Mfcc(); + if (attr == nullptr) { + MS_LOG(ERROR) << "value_as_Add return nullptr"; + return RET_ERROR; + } + auto val_offset = schema::CreateMfcc(*fbb, attr->freqUpperLimit(), attr->freqLowerLimit(), + attr->filterBankChannelNum(), attr->dctCoeffNum()); + auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_Mfcc, val_offset.o); + fbb->Finish(prim_offset); + return RET_OK; +} + +PrimitiveC *MfccCreator(const schema::Primitive *primitive) { return PrimitiveC::NewPrimitiveC(primitive); } +Registry MfccRegistry(schema::PrimitiveType_Mfcc, MfccCreator); +#endif +int Mfcc::InferShape(std::vector inputs_, std::vector outputs_) { + auto input = inputs_.front(); + MS_ASSERT(input != nullptr); + auto output = outputs_.front(); + MS_ASSERT(output != nullptr); + output->set_data_type(input->data_type()); + output->SetFormat(input->GetFormat()); + if (!GetInferFlag()) { + return RET_OK; + } + auto input_shape = input->shape(); + if (input_shape.size() != 3) { + MS_LOG(ERROR) << "first input shape is error, which need to be 3 dimensions, but the dimension is " + << input_shape.size(); + return RET_ERROR; + } + if (inputs_[1]->ElementsNum() != 1) { + MS_LOG(ERROR) << "second input element num is error, which need only a value, but the number is " + << inputs_[1]->ElementsNum(); + return RET_ERROR; + } + std::vector output_shape(3); + output_shape[0] = input_shape[0]; + output_shape[1] = input_shape[1]; + output_shape[2] = GetDctCoeffNum(); + outputs_.front()->set_shape(output_shape); + return RET_OK; +} +} // namespace lite +} // namespace mindspore diff --git a/mindspore/lite/src/ops/mfcc.h b/mindspore/lite/src/ops/mfcc.h new file mode 100644 index 0000000000..063a7ddb4b --- /dev/null +++ b/mindspore/lite/src/ops/mfcc.h @@ -0,0 +1,57 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef LITE_MINDSPORE_LITE_C_OPS_MFCC_H_ +#define LITE_MINDSPORE_LITE_C_OPS_MFCC_H_ + +#include +#include +#include +#include "src/ops/primitive_c.h" + +namespace mindspore { +namespace lite { +class Mfcc : public PrimitiveC { + public: +#ifdef PRIMITIVE_WRITEABLE + MS_DECLARE_PARENT(Mfcc, PrimitiveC); + Mfcc() = default; + explicit Mfcc(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {} + void SetFreqUpperLimit(float freq_upper_limit) { + this->primitive_->value.AsMfcc()->freqUpperLimit = freq_upper_limit; + } + void SetFreqLowerLimit(float freq_lower_limit) { + this->primitive_->value.AsMfcc()->freqLowerLimit = freq_lower_limit; + } + void SetFilterBankChannelNum(int filter_bank_channel_num) { + this->primitive_->value.AsMfcc()->filterBankChannelNum = filter_bank_channel_num; + } + void SetDctCoeffNum(int dct_coeff_num) { this->primitive_->value.AsMfcc()->dctCoeffNum = dct_coeff_num; } +#else + Mfcc() = default; + + int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override; +#endif + float GetFreqUpperLimit() const; + float GetFreqLowerLimit() const; + int GetFilterBankChannelNum() const; + int GetDctCoeffNum() const; + int InferShape(std::vector inputs_, std::vector outputs_) override; +}; +} // namespace lite +} // namespace mindspore + +#endif // LITE_MINDSPORE_LITE_C_OPS_MFCC_H_ diff --git a/mindspore/lite/src/ops/primitive_c.cc b/mindspore/lite/src/ops/primitive_c.cc index e187a41ce2..8f04785bd1 100644 --- a/mindspore/lite/src/ops/primitive_c.cc +++ b/mindspore/lite/src/ops/primitive_c.cc @@ -137,6 +137,11 @@ #include "src/ops/upsample.h" #include "src/ops/layer_norm.h" #include "src/ops/non_max_suppression.h" +#include "src/ops/rfft.h" +#include "src/ops/fft_real.h" +#include "src/ops/fft_imag.h" +#include "src/ops/audio_spectrogram.h" +#include "src/ops/mfcc.h" #include "src/ops/identity.h" #ifdef SUPPORT_TRAIN @@ -775,6 +780,16 @@ PrimitiveC *PrimitiveC::Create(mindspore::schema::PrimitiveT *primitive) { return new NonMaxSuppression(primitive); case schema::PrimitiveType_Identity: return new Identity(primitive); + case schema::PrimitiveType_Rfft: + return new Rfft(primitive); + case schema::PrimitiveType_FftReal: + return new FftReal(primitive); + case schema::PrimitiveType_FftImag: + return new FftImag(primitive); + case schema::PrimitiveType_AudioSpectrogram: + return new AudioSpectrogram(primitive); + case schema::PrimitiveType_Mfcc: + return new Mfcc(primitive); #ifdef SUPPORT_TRAIN case schema::PrimitiveType_ActivationGrad: diff --git a/mindspore/lite/src/ops/rfft.cc b/mindspore/lite/src/ops/rfft.cc new file mode 100644 index 0000000000..e83d3210bf --- /dev/null +++ b/mindspore/lite/src/ops/rfft.cc @@ -0,0 +1,66 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/ops/rfft.h" + +#ifndef PRIMITIVE_WRITEABLE +#include "src/ops/ops_register.h" +#endif + +namespace mindspore { +namespace lite { +#ifdef PRIMITIVE_WRITEABLE +int Rfft::GetFftLength() const { return this->primitive_->value.AsRfft()->fftLength; } + +void Rfft::SetFftLength(int fft_length) { this->primitive_->value.AsRfft()->fftLength = fft_length; } + +#else +int Rfft::GetFftLength() const { return this->primitive_->value_as_Rfft()->fftLength(); } +int Rfft::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) { + MS_ASSERT(nullptr != primitive); + MS_ASSERT(nullptr != fbb); + auto attr = primitive->value_as_Rfft(); + if (attr == nullptr) { + MS_LOG(ERROR) << "value_as_Add return nullptr"; + return RET_ERROR; + } + auto val_offset = schema::CreateRfft(*fbb, attr->fftLength()); + auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_Rfft, val_offset.o); + fbb->Finish(prim_offset); + return RET_OK; +} + +PrimitiveC *RfftCreator(const schema::Primitive *primitive) { return PrimitiveC::NewPrimitiveC(primitive); } +Registry RfftRegistry(schema::PrimitiveType_Rfft, RfftCreator); +#endif +int Rfft::InferShape(std::vector inputs_, std::vector outputs_) { + auto input = inputs_.front(); + MS_ASSERT(input != nullptr); + auto output = outputs_.front(); + MS_ASSERT(output != nullptr); + output->set_data_type(TypeId::kNumberTypeComplex64); + output->SetFormat(input->GetFormat()); + if (!GetInferFlag()) { + return RET_OK; + } + auto input_shape = input->shape(); + input_shape[input_shape.size() - 1] = GetFftLength() / 2 + 1; + input_shape.push_back(2); + outputs_.front()->set_shape(input_shape); + return RET_OK; +} +} // namespace lite +} // namespace mindspore diff --git a/mindspore/lite/src/ops/rfft.h b/mindspore/lite/src/ops/rfft.h new file mode 100644 index 0000000000..08c01d03ac --- /dev/null +++ b/mindspore/lite/src/ops/rfft.h @@ -0,0 +1,45 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef LITE_MINDSPORE_LITE_C_OPS_RFFT_H_ +#define LITE_MINDSPORE_LITE_C_OPS_RFFT_H_ + +#include +#include +#include +#include "src/ops/primitive_c.h" + +namespace mindspore { +namespace lite { +class Rfft : public PrimitiveC { + public: +#ifdef PRIMITIVE_WRITEABLE + MS_DECLARE_PARENT(Rfft, PrimitiveC); + Rfft() = default; + explicit Rfft(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {} + void SetFftLength(int fft_length); +#else + Rfft() = default; + + int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override; +#endif + int GetFftLength() const; + int InferShape(std::vector inputs_, std::vector outputs_) override; +}; +} // namespace lite +} // namespace mindspore + +#endif // LITE_MINDSPORE_LITE_C_OPS_RFFT_H_ diff --git a/mindspore/lite/tools/converter/parser/caffe/caffe_elu_parser.cc b/mindspore/lite/tools/converter/parser/caffe/caffe_elu_parser.cc new file mode 100644 index 0000000000..c1eb314291 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/caffe/caffe_elu_parser.cc @@ -0,0 +1,56 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "tools/converter/parser/caffe/caffe_elu_parser.h" +#include + +namespace mindspore { +namespace lite { +STATUS CaffeEluParser::Parse(const caffe::LayerParameter &proto, const caffe::LayerParameter &weight, + schema::CNodeT *op, std::vector *weightVec) { + MS_LOG(DEBUG) << "parse CaffeEluParser"; + if (op == nullptr) { + MS_LOG(ERROR) << "op is null"; + return RET_NULL_PTR; + } + op->primitive = std::make_unique(); + if (op->primitive == nullptr) { + MS_LOG(ERROR) << "op->primitive is null"; + return RET_NULL_PTR; + } + + std::unique_ptr attr = std::make_unique(); + if (attr == nullptr) { + MS_LOG(ERROR) << "new op failed"; + return RET_NULL_PTR; + } + + if (proto.has_elu_param()) { + const caffe::ELUParameter eluParameter = proto.elu_param(); + if (eluParameter.has_alpha()) { + attr->alpha = eluParameter.alpha(); + } + } + + op->name = proto.name(); + op->primitive->value.type = schema::PrimitiveType_Elu; + op->primitive->value.value = attr.release(); + return RET_OK; +} + +CaffeNodeRegistrar g_caffeEluParser("ELU", new CaffeEluParser()); +} // namespace lite +} // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/caffe/caffe_elu_parser.h b/mindspore/lite/tools/converter/parser/caffe/caffe_elu_parser.h new file mode 100644 index 0000000000..75b8f19e4b --- /dev/null +++ b/mindspore/lite/tools/converter/parser/caffe/caffe_elu_parser.h @@ -0,0 +1,36 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_CAFFE_CAFFE_ELU_PARSER_H_ +#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_CAFFE_CAFFE_ELU_PARSER_H_ + +#include +#include "tools/converter/parser/caffe/caffe_node_parser.h" +#include "tools/converter/parser/caffe/caffe_node_parser_registry.h" + +namespace mindspore { +namespace lite { +class CaffeEluParser : public CaffeNodeParser { + public: + CaffeEluParser() : CaffeNodeParser("elu") {} + + STATUS Parse(const caffe::LayerParameter &proto, const caffe::LayerParameter &weight, schema::CNodeT *op, + std::vector *weightVec) override; +}; +} // namespace lite +} // namespace mindspore + +#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_CAFFE_CAFFE_ELU_PARSER_H_ diff --git a/mindspore/lite/tools/converter/parser/caffe/caffe_model_parser.cc b/mindspore/lite/tools/converter/parser/caffe/caffe_model_parser.cc index b3d8e38243..4c4649ce9d 100644 --- a/mindspore/lite/tools/converter/parser/caffe/caffe_model_parser.cc +++ b/mindspore/lite/tools/converter/parser/caffe/caffe_model_parser.cc @@ -243,7 +243,7 @@ STATUS CaffeModelParser::ParseLayer(const caffe::NetParameter &proto, const caff auto status_node = nodeParser->Parse(layer, layerP, op.get(), &weightVec); if (status_node != RET_OK) { interrupt = true; - if (status_node == RET_NOT_SUPPORT) { + if (status_node == RET_NOT_FIND_OP) { NoSupportOp::GetInstance()->InsertOp(layer.type()); } else { MS_LOG(ERROR) << "Parse weight for " << layer.name() << " Failed!"; diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_deconv_parser.cc b/mindspore/lite/tools/converter/parser/onnx/onnx_deconv_parser.cc index 4bb9867ff1..24ad60b9fd 100644 --- a/mindspore/lite/tools/converter/parser/onnx/onnx_deconv_parser.cc +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_deconv_parser.cc @@ -156,8 +156,8 @@ STATUS OnnxDeConvParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::N if (attr->group != 1) { if (!ParseGroupDeConvolution(attr, op)) { - MS_LOG(ERROR) << "Convert DeConvolution to DeDepthwise failed"; - return RET_ERROR; + MS_LOG(ERROR) << "Convert DeConvolution to DeDepthwise failed, generalized group deconv hasn't support"; + return RET_NOT_SUPPORT; } } else { op->primitive->value.type = schema::PrimitiveType_DeConv2D; diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_model_parser.cc b/mindspore/lite/tools/converter/parser/onnx/onnx_model_parser.cc index affba69b3c..101ec38850 100644 --- a/mindspore/lite/tools/converter/parser/onnx/onnx_model_parser.cc +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_model_parser.cc @@ -522,6 +522,7 @@ STATUS OnnxModelParser::ParseSubgraph(schema::CNodeT *dst_op, const onnx::NodePr dst_op->primitive->value.value = attr.release(); return RET_OK; } + schema::MetaGraphT *OnnxModelParser::ParseGraph(const onnx::GraphProto &onnx_graph, const QuantType &quantType) { TensorCache tensor_cache; // dst_graph->name = onnx_graph.name(); // this is not used @@ -593,6 +594,7 @@ schema::MetaGraphT *OnnxModelParser::ParseGraph(const onnx::GraphProto &onnx_gra SetAllTensors(tensor_cache, dst_graph.get()); return dst_graph.release(); } + schema::MetaGraphT *OnnxModelParser::ParseToFb(const std::string &modelFile, const std::string &weightFile, const QuantType &quantType) { int status = ValidateFileStr(modelFile, ".onnx"); diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_conv_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_conv_parser.cc index e05f1cd75a..956b1ee597 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_conv_parser.cc +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_conv_parser.cc @@ -72,11 +72,12 @@ STATUS TfliteConvParser::Parse(TfliteTensorsInfo *tensors_info, const std::uniqu auto data_index = tflite_op->inputs[0]; const auto &data_tensor = tflite_model->subgraphs[0]->tensors[data_index]; std::vector params; - if (getPaddingParam(data_tensor, attr->padMode, attr->strideH, attr->strideW, attr->kernelH, attr->kernelW, - ¶ms) != RET_OK) { + int status = + getPaddingParam(data_tensor, attr->padMode, attr->strideH, attr->strideW, attr->kernelH, attr->kernelW, ¶ms); + if (status != RET_OK && status != RET_NO_CHANGE) { MS_LOG(ERROR) << "get padding params failed"; return RET_ERROR; - } else { + } else if (status == RET_OK) { attr->padUp = params.at(0); attr->padDown = params.at(1); attr->padLeft = params.at(2); diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_deconv_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_deconv_parser.cc index 70e5145253..670a8534d6 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_deconv_parser.cc +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_deconv_parser.cc @@ -73,11 +73,12 @@ STATUS TfliteDeConvParser::Parse(TfliteTensorsInfo *tensors_info, const std::uni auto data_index = tflite_op->inputs[2]; const auto &data_tensor = tflite_model->subgraphs[0]->tensors[data_index]; std::vector params; - if (getPaddingParam(data_tensor, attr->padMode, attr->strideH, attr->strideW, attr->kernelH, attr->kernelW, - ¶ms) != RET_OK) { + int status = + getPaddingParam(data_tensor, attr->padMode, attr->strideH, attr->strideW, attr->kernelH, attr->kernelW, ¶ms); + if (status != RET_OK && status != RET_NO_CHANGE) { MS_LOG(ERROR) << "get padding params failed"; return RET_ERROR; - } else { + } else if (status == RET_OK) { attr->padUp = params.at(0); attr->padDown = params.at(1); attr->padLeft = params.at(2); diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_depthwise_conv_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_depthwise_conv_parser.cc index 0d3ab72fbe..043f9fff85 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_depthwise_conv_parser.cc +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_depthwise_conv_parser.cc @@ -79,11 +79,12 @@ STATUS TfliteDepthwiseConv2DParser::Parse(TfliteTensorsInfo *tensors_info, // calculate pad params std::vector params; - if (getPaddingParam(data_tensor, attr->padMode, attr->strideH, attr->strideW, attr->kernelH, attr->kernelW, - ¶ms) != RET_OK) { + int status = + getPaddingParam(data_tensor, attr->padMode, attr->strideH, attr->strideW, attr->kernelH, attr->kernelW, ¶ms); + if (status != RET_OK && status != RET_NO_CHANGE) { MS_LOG(ERROR) << "get padding params failed"; return RET_ERROR; - } else { + } else if (status == RET_OK) { attr->padUp = params.at(0); attr->padDown = params.at(1); attr->padLeft = params.at(2); diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_model_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_model_parser.cc index 4fecaebc56..4c347e04be 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_model_parser.cc +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_model_parser.cc @@ -18,6 +18,7 @@ #include #include #include +#include #include "tools/common/graph_util.h" #include "tools/common/storage.h" #include "flatbuffers/flatbuffers.h" @@ -102,11 +103,6 @@ STATUS TfliteModelParser::ConvertOp(const std::unique_ptr &tflit for (const auto &tflite_op : tflite_subgraph->operators) { auto tflite_op_type = (tflite_model->operator_codes[tflite_op->opcode_index])->builtin_code; auto op_type = GetMSOpType(tflite_op_type); - if (op_type == "CUSTOM") { - auto custom_type = (tflite_model->operator_codes[tflite_op->opcode_index])->custom_code; - MS_LOG(ERROR) << "CUSTOM op is not supported, the type is " << custom_type; - return RET_ERROR; - } auto op = std::make_unique(); op->name = op_type + "-" + std::to_string(idx++); @@ -122,7 +118,9 @@ STATUS TfliteModelParser::ConvertOp(const std::unique_ptr &tflit if (status == RET_OK) { status = node_parser->Parse(&tensorsInfo, tflite_op, tflite_model, op.get()); if (status != RET_OK) { - if (status == RET_NOT_SUPPORT) { + if (status == RET_NOT_FIND_OP) { + op_type = + (op_type != "Custom" ? op_type : (tflite_model->operator_codes[tflite_op->opcode_index])->custom_code); NoSupportOp::GetInstance()->InsertOp(op_type); } else { MS_LOG(ERROR) << "node " << op_type.c_str() << " parser failed"; @@ -141,6 +139,16 @@ STATUS TfliteModelParser::ConvertOp(const std::unique_ptr &tflit STATUS TfliteModelParser::ConvertTensor(const std::unique_ptr &tflite_subgraph, const std::vector> &tflite_model_buffer, schema::MetaGraphT *sub_graph) { + std::set output_index; + for (const auto &tflite_op : tflite_subgraph->operators) { + for (size_t j = 0; j < tflite_op->outputs.size(); ++j) { + int idx = tflite_op->outputs[j]; + if (idx < 0) { + idx += tflite_subgraph->tensors.size(); + } + output_index.insert(idx); + } + } for (size_t i = 0; i < tensorsInfo.tensorsId.size(); i++) { auto idx = tensorsInfo.tensorsId[i]; if (idx < 0) { @@ -173,11 +181,16 @@ STATUS TfliteModelParser::ConvertTensor(const std::unique_ptr return status; } } + // set tensor attr if (isInput || isConst) { tensor->nodeType = schema::NodeType::NodeType_ValueNode; } else { - tensor->nodeType = schema::NodeType_Parameter; + if (output_index.find(idx) == output_index.end() && tflite_tensor->shape[0] == 0) { + tensor->nodeType = schema::NodeType::NodeType_ValueNode; + } else { + tensor->nodeType = schema::NodeType_Parameter; + } } // quant param @@ -246,7 +259,6 @@ STATUS TfliteModelParser::ConvertGroupDepthwiseOp(schema::MetaGraphT *sub_graph) if (op->primitive->value.type == schema::PrimitiveType_DepthwiseConv2D) { auto attr = op->primitive->value.AsDepthwiseConv2D(); if (attr->channelMultiplier > 1) { - std::unique_ptr conv_attr = std::make_unique(); // get channel attr if (op->inputIndex.empty()) { MS_LOG(ERROR) << "the input of DepthwiseConv2D is null"; @@ -263,7 +275,11 @@ STATUS TfliteModelParser::ConvertGroupDepthwiseOp(schema::MetaGraphT *sub_graph) return RET_NULL_PTR; } auto data_shape = data_tensor->dims; - + if (data_shape.empty()) { + MS_LOG(DEBUG) << "the tensor's shape is dynamic, which obtain only when running"; + return RET_NO_CHANGE; + } + std::unique_ptr conv_attr = std::make_unique(); if (data_shape[3] == 1) { conv_attr->channelIn = data_shape[3]; conv_attr->channelOut = conv_attr->channelIn * attr->channelMultiplier; @@ -372,7 +388,7 @@ schema::MetaGraphT *TfliteModelParser::ParseToFb(const std::string &model_file, // update for depthwiseConv status = ConvertGroupDepthwiseOp(meta_graph.get()); - if (status != RET_OK) { + if (status != RET_OK && status != RET_NO_CHANGE) { MS_LOG(ERROR) << "convert group depthwise conv failed"; ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status); return nullptr; diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_node_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_node_parser.h index 1c8378accc..a426d70a78 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_node_parser.h +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_node_parser.h @@ -86,6 +86,10 @@ class TfliteNodeParser { return RET_NULL_PTR; } auto data_ptr = buf_data->data.data(); + if (data_ptr == nullptr) { + MS_LOG(DEBUG) << "data is not a constant"; + return RET_NO_CHANGE; + } switch (tflite_tensors[tensor_index]->type) { case tflite::TensorType_UINT8: { for (int i = 0; i < count; i++) { diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_pad_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_pad_parser.cc index 0c7ab33cc1..b8dc009ceb 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_pad_parser.cc +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_pad_parser.cc @@ -71,11 +71,11 @@ STATUS TflitePadParser::Parse(TfliteTensorsInfo *tensors_info, const std::unique break; default: MS_LOG(ERROR) << "paddingmode:" << tflite_attr->mode << " don't support"; - return RET_INVALID_OP_ATTR; + return RET_NOT_SUPPORT; } } else { MS_LOG(ERROR) << "this pad:" << node_name << " hasn't been supported"; - return RET_NOT_SUPPORT; + return RET_NOT_FIND_OP; } op->primitive->value.type = schema::PrimitiveType_Pad; diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_pooling_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_pooling_parser.cc index 3d47e9bbed..57832f8ced 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_pooling_parser.cc +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_pooling_parser.cc @@ -71,11 +71,12 @@ STATUS TflitePoolingParser::Parse(TfliteTensorsInfo *tensors_info, const std::un auto data_index = tflite_op->inputs[0]; const auto &data_tensor = tflite_model->subgraphs[0]->tensors[data_index]; std::vector params; - if (getPaddingParam(data_tensor, attr->padMode, attr->strideH, attr->strideW, attr->windowH, attr->windowW, - ¶ms) != RET_OK) { + int status = + getPaddingParam(data_tensor, attr->padMode, attr->strideH, attr->strideW, attr->windowH, attr->windowW, ¶ms); + if (status != RET_OK && status != RET_NO_CHANGE) { MS_LOG(ERROR) << "get padding params failed"; return RET_ERROR; - } else { + } else if (status == RET_OK) { attr->padUp = params.at(0); attr->padDown = params.at(1); attr->padLeft = params.at(2); diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_range_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_range_parser.cc index ff3244c260..72ccac5f1e 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_range_parser.cc +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_range_parser.cc @@ -41,15 +41,31 @@ STATUS TfliteRangeParser::Parse(TfliteTensorsInfo *tensors_info, const std::uniq } attr->dType = 0; - // attr->start - // attr->limit - // attr->delta - + std::vector limit; + std::vector delta; + int status = GetTfliteData(tflite_op->inputs[1], tflite_model->subgraphs[0]->tensors, tflite_model->buffers, limit); + if (status != RET_OK && status != RET_NO_CHANGE) { + MS_LOG(ERROR) << "range -> limit get failed"; + return RET_ERROR; + } else if (status == RET_OK) { + status = GetTfliteData(tflite_op->inputs[2], tflite_model->subgraphs[0]->tensors, tflite_model->buffers, delta); + if (status != RET_OK && status != RET_NO_CHANGE) { + MS_LOG(ERROR) << "stridedSlice -> end get failed"; + return RET_ERROR; + } + } + if (status == RET_OK) { + attr->limit = limit.front(); + attr->delta = delta.front(); + } op->primitive->value.type = schema::PrimitiveType_Range; op->primitive->value.value = attr.release(); - AddOpInput(op, tensors_info, tflite_op->inputs[0], tflite_model->subgraphs[0]->tensors.size(), - schema::Format::Format_NHWC); + int input_num = status == RET_OK ? 1 : 3; + for (int i = 0; i < input_num; ++i) { + AddOpInput(op, tensors_info, tflite_op->inputs[i], tflite_model->subgraphs[0]->tensors.size(), + schema::Format::Format_NHWC); + } AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_model->subgraphs[0]->tensors.size(), schema::Format::Format_NHWC); return RET_OK; diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_reduce_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_reduce_parser.cc index b764a9d3b9..70a214c5b9 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_reduce_parser.cc +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_reduce_parser.cc @@ -69,7 +69,7 @@ STATUS TfliteReduceParser::Parse(TfliteTensorsInfo *tensors_info, const std::uni } else if (std::strcmp(node_name, "ReduceAny") == 0) { // attr->mode; MS_LOG(ERROR) << "ms-lite haven't supported REDUCE_ANY now"; - return RET_NOT_FIND_OP; + return RET_NOT_SUPPORT; } if (GetTfliteData(tflite_op->inputs[1], tflite_model->subgraphs[0]->tensors, tflite_model->buffers, attr->axes)) { diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_split_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_split_parser.cc index 66da0beff5..351a2cf75b 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_split_parser.cc +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_split_parser.cc @@ -67,14 +67,15 @@ STATUS TfliteSplitParser::Parse(TfliteTensorsInfo *tensors_info, const std::uniq return RET_ERROR; } attr->splitDim = axis; - if (tensor_shape[axis] % num_splits != 0) { + if (tensor_shape[axis] % num_splits != 0 && tensor_shape[axis] / num_splits != 0) { MS_LOG(ERROR) << "num_splits can't divide tensor's length at axis " << axis; return RET_ERROR; } attr->numberSplit = num_splits; - - for (int i = 0; i < num_splits; i++) { - attr->sizeSplits.push_back(tensor_shape[axis] / num_splits); + if (tensor_shape[axis] / num_splits != 0) { + for (int i = 0; i < num_splits; i++) { + attr->sizeSplits.push_back(tensor_shape[axis] / num_splits); + } } op->primitive->value.type = schema::PrimitiveType_Split; diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_strided_slice_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_strided_slice_parser.cc index a66f8936d8..6242d05e2b 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_strided_slice_parser.cc +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_strided_slice_parser.cc @@ -52,17 +52,24 @@ STATUS TfliteStridedSliceParser::Parse(TfliteTensorsInfo *tensors_info, attr->newAxisMask = tflite_attr->new_axis_mask; attr->shrinkAxisMask = tflite_attr->shrink_axis_mask; - if (GetTfliteData(tflite_op->inputs[1], tflite_model->subgraphs[0]->tensors, tflite_model->buffers, attr->begin)) { + int status = + GetTfliteData(tflite_op->inputs[1], tflite_model->subgraphs[0]->tensors, tflite_model->buffers, attr->begin); + if (status != RET_OK && status != RET_NO_CHANGE) { MS_LOG(ERROR) << "stridedSlice -> begin get failed"; return RET_ERROR; - } - if (GetTfliteData(tflite_op->inputs[2], tflite_model->subgraphs[0]->tensors, tflite_model->buffers, attr->end)) { - MS_LOG(ERROR) << "stridedSlice -> end get failed"; - return RET_ERROR; - } - if (GetTfliteData(tflite_op->inputs[3], tflite_model->subgraphs[0]->tensors, tflite_model->buffers, attr->stride)) { - MS_LOG(ERROR) << "stridedSlice -> stride get failed"; - return RET_ERROR; + } else if (status == RET_OK) { + status = GetTfliteData(tflite_op->inputs[2], tflite_model->subgraphs[0]->tensors, tflite_model->buffers, attr->end); + if (status != RET_OK && status != RET_NO_CHANGE) { + MS_LOG(ERROR) << "stridedSlice -> end get failed"; + return RET_ERROR; + } else if (status == RET_OK) { + status = + GetTfliteData(tflite_op->inputs[3], tflite_model->subgraphs[0]->tensors, tflite_model->buffers, attr->stride); + if (status != RET_OK && status != RET_NO_CHANGE) { + MS_LOG(ERROR) << "stridedSlice -> stride get failed"; + return RET_ERROR; + } + } } attr->isScale.assign(tflite_model->subgraphs[0]->tensors[tflite_op->inputs[0]]->shape.begin(), tflite_model->subgraphs[0]->tensors[tflite_op->inputs[0]]->shape.end()); @@ -70,8 +77,11 @@ STATUS TfliteStridedSliceParser::Parse(TfliteTensorsInfo *tensors_info, op->primitive->value.type = schema::PrimitiveType_StridedSlice; op->primitive->value.value = attr.release(); - AddOpInput(op, tensors_info, tflite_op->inputs[0], tflite_model->subgraphs[0]->tensors.size(), - schema::Format::Format_NHWC); + int input_num = status == RET_OK ? 1 : 4; + for (int i = 0; i < input_num; ++i) { + AddOpInput(op, tensors_info, tflite_op->inputs[i], tflite_model->subgraphs[0]->tensors.size(), + schema::Format::Format_NHWC); + } AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_model->subgraphs[0]->tensors.size(), schema::Format::Format_NHWC); return RET_OK; diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_util.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_util.cc index da59b6768d..63347bee39 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_util.cc +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_util.cc @@ -198,7 +198,10 @@ STATUS getPaddingParam(const std::unique_ptr &tensor, schema::P MS_LOG(ERROR) << "the input tensor is null"; return RET_ERROR; } - + if (tensor->shape.empty()) { + MS_LOG(DEBUG) << "the tensor's shape is dynamic, which obtain nly when running."; + return RET_NO_CHANGE; + } int padUp = 0; int padDown = 0; int padLeft = 0;