Merge pull request !7716 from 徐安越/mastertags/v1.1.0
| @@ -47,6 +47,12 @@ Float::Float(const int nbits) : Number(FloatBitsToTypeId(nbits), nbits, false) { | |||||
| } | } | ||||
| } | } | ||||
| Complex::Complex(const int nbits) : Number(TypeId::kNumberTypeComplex64, nbits, false) { | |||||
| if (nbits != 64) { | |||||
| MS_LOG(EXCEPTION) << "Wrong number of bits."; | |||||
| } | |||||
| } | |||||
| const TypePtr kBool = std::make_shared<Bool>(); | const TypePtr kBool = std::make_shared<Bool>(); | ||||
| const TypePtr kInt8 = std::make_shared<Int>(8); | const TypePtr kInt8 = std::make_shared<Int>(8); | ||||
| const TypePtr kInt16 = std::make_shared<Int>(16); | const TypePtr kInt16 = std::make_shared<Int>(16); | ||||
| @@ -63,4 +69,5 @@ const TypePtr kInt = std::make_shared<Int>(); | |||||
| const TypePtr kUInt = std::make_shared<UInt>(); | const TypePtr kUInt = std::make_shared<UInt>(); | ||||
| const TypePtr kFloat = std::make_shared<Float>(); | const TypePtr kFloat = std::make_shared<Float>(); | ||||
| const TypePtr kNumber = std::make_shared<Number>(); | const TypePtr kNumber = std::make_shared<Number>(); | ||||
| const TypePtr kComplex64 = std::make_shared<Complex>(64); | |||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -150,6 +150,28 @@ class Float : public Number { | |||||
| } | } | ||||
| }; | }; | ||||
| // Complex | |||||
| class Complex : public Number { | |||||
| public: | |||||
| Complex() : Number(kNumberTypeComplex64, 0) {} | |||||
| explicit Complex(const int nbits); | |||||
| ~Complex() override {} | |||||
| MS_DECLARE_PARENT(Complex, Number) | |||||
| TypeId generic_type_id() const override { return kNumberTypeComplex64; } | |||||
| TypePtr DeepCopy() const override { | |||||
| if (nbits() == 0) { | |||||
| return std::make_shared<Complex>(); | |||||
| } | |||||
| return std::make_shared<Complex>(nbits()); | |||||
| } | |||||
| std::string ToString() const override { return GetTypeName("Complex64"); } | |||||
| std::string ToReprString() const override { return nbits() == 0 ? "complex64_" : GetTypeName("complex64"); } | |||||
| std::string DumpText() const override { | |||||
| return nbits() == 0 ? std::string("Complex64") : std::string("C") + std::to_string(nbits()); | |||||
| } | |||||
| }; | |||||
| extern const TypePtr kBool; | extern const TypePtr kBool; | ||||
| extern const TypePtr kInt8; | extern const TypePtr kInt8; | ||||
| extern const TypePtr kInt16; | extern const TypePtr kInt16; | ||||
| @@ -166,6 +188,7 @@ extern const TypePtr kInt; | |||||
| extern const TypePtr kUInt; | extern const TypePtr kUInt; | ||||
| extern const TypePtr kFloat; | extern const TypePtr kFloat; | ||||
| extern const TypePtr kNumber; | extern const TypePtr kNumber; | ||||
| extern const TypePtr kComplex64; | |||||
| } // namespace mindspore | } // namespace mindspore | ||||
| #endif // MINDSPORE_CORE_IR_DTYPE_NUMBER_H_ | #endif // MINDSPORE_CORE_IR_DTYPE_NUMBER_H_ | ||||
| @@ -69,6 +69,8 @@ TypePtr TypeIdToType(TypeId id) { | |||||
| return kFloat32; | return kFloat32; | ||||
| case kNumberTypeFloat64: | case kNumberTypeFloat64: | ||||
| return kFloat64; | return kFloat64; | ||||
| case kNumberTypeComplex64: | |||||
| return kComplex64; | |||||
| case kNumberTypeInt8: | case kNumberTypeInt8: | ||||
| return kInt8; | return kInt8; | ||||
| case kNumberTypeInt16: | case kNumberTypeInt16: | ||||
| @@ -0,0 +1,107 @@ | |||||
| /** | |||||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "src/ops/audio_spectrogram.h" | |||||
| #ifndef PRIMITIVE_WRITEABLE | |||||
| #include "src/ops/ops_register.h" | |||||
| #endif | |||||
| namespace mindspore { | |||||
| namespace lite { | |||||
| #ifdef PRIMITIVE_WRITEABLE | |||||
| int AudioSpectrogram::GetWindowSize() const { return this->primitive_->value.AsAudioSpectrogram()->windowSize; } | |||||
| int AudioSpectrogram::GetStride() const { return this->primitive_->value.AsAudioSpectrogram()->stride; } | |||||
| bool AudioSpectrogram::GetMagSquare() const { return this->primitive_->value.AsAudioSpectrogram()->magSquare; } | |||||
| #else | |||||
| int AudioSpectrogram::GetWindowSize() const { return this->primitive_->value_as_AudioSpectrogram()->windowSize(); } | |||||
| int AudioSpectrogram::GetStride() const { return this->primitive_->value_as_AudioSpectrogram()->stride(); } | |||||
| bool AudioSpectrogram::GetMagSquare() const { return this->primitive_->value_as_AudioSpectrogram()->magSquare(); } | |||||
| int AudioSpectrogram::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) { | |||||
| MS_ASSERT(nullptr != primitive); | |||||
| MS_ASSERT(nullptr != fbb); | |||||
| auto attr = primitive->value_as_AudioSpectrogram(); | |||||
| if (attr == nullptr) { | |||||
| MS_LOG(ERROR) << "value_as_Add return nullptr"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| auto val_offset = schema::CreateAudioSpectrogram(*fbb, attr->windowSize(), attr->stride(), attr->magSquare()); | |||||
| auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_AudioSpectrogram, val_offset.o); | |||||
| fbb->Finish(prim_offset); | |||||
| return RET_OK; | |||||
| } | |||||
| PrimitiveC *AudioSpectrogramCreator(const schema::Primitive *primitive) { | |||||
| return PrimitiveC::NewPrimitiveC<AudioSpectrogram>(primitive); | |||||
| } | |||||
| Registry AudioSpectrogramRegistry(schema::PrimitiveType_AudioSpectrogram, AudioSpectrogramCreator); | |||||
| #endif | |||||
| int AudioSpectrogram::Log2Ceil(uint32_t length) { | |||||
| if (length == 0) { | |||||
| return -1; | |||||
| } | |||||
| int floor = 0; | |||||
| for (int i = 4; i >= 0; --i) { | |||||
| int shift = (1 << i); | |||||
| uint32_t tmp = length >> shift; | |||||
| if (tmp != 0) { | |||||
| length = tmp; | |||||
| floor += shift; | |||||
| } | |||||
| } | |||||
| return length == (length & ~(length - 1)) ? floor : floor + 1; | |||||
| } | |||||
| uint32_t AudioSpectrogram::GetFftLength(uint32_t length) { | |||||
| int shift = Log2Ceil(length); | |||||
| return 1 << shift; | |||||
| } | |||||
| int AudioSpectrogram::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> outputs_) { | |||||
| auto input = inputs_.front(); | |||||
| MS_ASSERT(input != nullptr); | |||||
| auto output = outputs_.front(); | |||||
| MS_ASSERT(output != nullptr); | |||||
| output->set_data_type(input->data_type()); | |||||
| output->SetFormat(input->GetFormat()); | |||||
| if (!GetInferFlag()) { | |||||
| return RET_OK; | |||||
| } | |||||
| auto input_shape = input->shape(); | |||||
| if (input_shape.size() != 2) { | |||||
| MS_LOG(ERROR) << "input shape is error, which need to be 2 dimensions"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| if (GetWindowSize() < 2) { | |||||
| MS_LOG(ERROR) << "window size is too short, now is " << GetWindowSize(); | |||||
| return RET_ERROR; | |||||
| } | |||||
| if (GetStride() < 1) { | |||||
| MS_LOG(ERROR) << "stride must be positive, now is " << GetStride(); | |||||
| return RET_ERROR; | |||||
| } | |||||
| std::vector<int> output_shape(3); | |||||
| output_shape[0] = input_shape[1]; | |||||
| // output height | |||||
| int sample_sub_window = input_shape[0] - GetWindowSize(); | |||||
| output_shape[1] = sample_sub_window < 0 ? 0 : 1 + sample_sub_window / GetStride(); | |||||
| // compute fft length | |||||
| int fft_length = GetFftLength(GetWindowSize()); | |||||
| output_shape[2] = fft_length / 2 + 1; | |||||
| outputs_.front()->set_shape(output_shape); | |||||
| return RET_OK; | |||||
| } | |||||
| } // namespace lite | |||||
| } // namespace mindspore | |||||
| @@ -0,0 +1,51 @@ | |||||
| /** | |||||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef LITE_MINDSPORE_LITE_C_OPS_AUDIO_SPECTROGRAM_H_ | |||||
| #define LITE_MINDSPORE_LITE_C_OPS_AUDIO_SPECTROGRAM_H_ | |||||
| #include <vector> | |||||
| #include <set> | |||||
| #include <cmath> | |||||
| #include "src/ops/primitive_c.h" | |||||
| namespace mindspore { | |||||
| namespace lite { | |||||
| class AudioSpectrogram : public PrimitiveC { | |||||
| public: | |||||
| #ifdef PRIMITIVE_WRITEABLE | |||||
| MS_DECLARE_PARENT(AudioSpectrogram, PrimitiveC); | |||||
| AudioSpectrogram() = default; | |||||
| explicit AudioSpectrogram(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {} | |||||
| void SetWindowSize(int window_size) { this->primitive_->value.AsAudioSpectrogram()->windowSize = window_size; } | |||||
| void SetStride(int stride) { this->primitive_->value.AsAudioSpectrogram()->stride = stride; } | |||||
| void SetMagSquare(bool mag_square) { this->primitive_->value.AsAudioSpectrogram()->magSquare = mag_square; } | |||||
| #else | |||||
| AudioSpectrogram() = default; | |||||
| int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override; | |||||
| #endif | |||||
| int GetWindowSize() const; | |||||
| int GetStride() const; | |||||
| bool GetMagSquare() const; | |||||
| int Log2Ceil(uint32_t length); | |||||
| uint32_t GetFftLength(uint32_t length); | |||||
| int InferShape(std::vector<lite::Tensor *> inputs_, std::vector<lite::Tensor *> outputs_) override; | |||||
| }; | |||||
| } // namespace lite | |||||
| } // namespace mindspore | |||||
| #endif // LITE_MINDSPORE_LITE_C_OPS_AUDIO_SPECTROGRAM_H_ | |||||
| @@ -0,0 +1,54 @@ | |||||
| /** | |||||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "src/ops/fft_imag.h" | |||||
| #ifndef PRIMITIVE_WRITEABLE | |||||
| #include "src/ops/ops_register.h" | |||||
| #endif | |||||
| namespace mindspore { | |||||
| namespace lite { | |||||
| #ifndef PRIMITIVE_WRITEABLE | |||||
| int FftImag::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) { | |||||
| MS_ASSERT(nullptr != primitive); | |||||
| MS_ASSERT(nullptr != fbb); | |||||
| auto val_offset = schema::CreateEqual(*fbb); | |||||
| auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_FftImag, val_offset.o); | |||||
| fbb->Finish(prim_offset); | |||||
| return RET_OK; | |||||
| } | |||||
| PrimitiveC *FftImagCreator(const schema::Primitive *primitive) { return PrimitiveC::NewPrimitiveC<FftImag>(primitive); } | |||||
| Registry FftImagRegistry(schema::PrimitiveType_FftImag, FftImagCreator); | |||||
| #endif | |||||
| int FftImag::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> outputs_) { | |||||
| auto input = inputs_.front(); | |||||
| MS_ASSERT(input != nullptr); | |||||
| auto output = outputs_.front(); | |||||
| MS_ASSERT(output != nullptr); | |||||
| output->set_data_type(TypeId::kNumberTypeFloat32); | |||||
| output->SetFormat(input->GetFormat()); | |||||
| if (!GetInferFlag()) { | |||||
| return RET_OK; | |||||
| } | |||||
| auto input_shape = input->shape(); | |||||
| input_shape.pop_back(); | |||||
| outputs_.front()->set_shape(input_shape); | |||||
| return RET_OK; | |||||
| } | |||||
| } // namespace lite | |||||
| } // namespace mindspore | |||||
| @@ -0,0 +1,43 @@ | |||||
| /** | |||||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef LITE_MINDSPORE_LITE_C_OPS_FFT_IMAG_H_ | |||||
| #define LITE_MINDSPORE_LITE_C_OPS_FFT_IMAG_H_ | |||||
| #include <vector> | |||||
| #include <set> | |||||
| #include <cmath> | |||||
| #include "src/ops/primitive_c.h" | |||||
| namespace mindspore { | |||||
| namespace lite { | |||||
| class FftImag : public PrimitiveC { | |||||
| public: | |||||
| #ifdef PRIMITIVE_WRITEABLE | |||||
| MS_DECLARE_PARENT(FftImag, PrimitiveC); | |||||
| FftImag() = default; | |||||
| explicit FftImag(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {} | |||||
| #else | |||||
| FftImag() = default; | |||||
| int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override; | |||||
| #endif | |||||
| int InferShape(std::vector<lite::Tensor *> inputs_, std::vector<lite::Tensor *> outputs_) override; | |||||
| }; | |||||
| } // namespace lite | |||||
| } // namespace mindspore | |||||
| #endif // LITE_MINDSPORE_LITE_C_OPS_FFT_IMAG_H_ | |||||
| @@ -0,0 +1,54 @@ | |||||
| /** | |||||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "src/ops/fft_real.h" | |||||
| #ifndef PRIMITIVE_WRITEABLE | |||||
| #include "src/ops/ops_register.h" | |||||
| #endif | |||||
| namespace mindspore { | |||||
| namespace lite { | |||||
| #ifndef PRIMITIVE_WRITEABLE | |||||
| int FftReal::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) { | |||||
| MS_ASSERT(nullptr != primitive); | |||||
| MS_ASSERT(nullptr != fbb); | |||||
| auto val_offset = schema::CreateEqual(*fbb); | |||||
| auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_FftReal, val_offset.o); | |||||
| fbb->Finish(prim_offset); | |||||
| return RET_OK; | |||||
| } | |||||
| PrimitiveC *FftRealCreator(const schema::Primitive *primitive) { return PrimitiveC::NewPrimitiveC<FftReal>(primitive); } | |||||
| Registry FftRealRegistry(schema::PrimitiveType_FftReal, FftRealCreator); | |||||
| #endif | |||||
| int FftReal::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> outputs_) { | |||||
| auto input = inputs_.front(); | |||||
| MS_ASSERT(input != nullptr); | |||||
| auto output = outputs_.front(); | |||||
| MS_ASSERT(output != nullptr); | |||||
| output->set_data_type(TypeId::kNumberTypeFloat32); | |||||
| output->SetFormat(input->GetFormat()); | |||||
| if (!GetInferFlag()) { | |||||
| return RET_OK; | |||||
| } | |||||
| auto input_shape = input->shape(); | |||||
| input_shape.pop_back(); | |||||
| outputs_.front()->set_shape(input_shape); | |||||
| return RET_OK; | |||||
| } | |||||
| } // namespace lite | |||||
| } // namespace mindspore | |||||
| @@ -0,0 +1,43 @@ | |||||
| /** | |||||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef LITE_MINDSPORE_LITE_C_OPS_FFT_REAL_H_ | |||||
| #define LITE_MINDSPORE_LITE_C_OPS_FFT_REAL_H_ | |||||
| #include <vector> | |||||
| #include <set> | |||||
| #include <cmath> | |||||
| #include "src/ops/primitive_c.h" | |||||
| namespace mindspore { | |||||
| namespace lite { | |||||
| class FftReal : public PrimitiveC { | |||||
| public: | |||||
| #ifdef PRIMITIVE_WRITEABLE | |||||
| MS_DECLARE_PARENT(FftReal, PrimitiveC); | |||||
| FftReal() = default; | |||||
| explicit FftReal(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {} | |||||
| #else | |||||
| FftReal() = default; | |||||
| int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override; | |||||
| #endif | |||||
| int InferShape(std::vector<lite::Tensor *> inputs_, std::vector<lite::Tensor *> outputs_) override; | |||||
| }; | |||||
| } // namespace lite | |||||
| } // namespace mindspore | |||||
| #endif // LITE_MINDSPORE_LITE_C_OPS_FFT_REAL_H_ | |||||
| @@ -0,0 +1,83 @@ | |||||
| /** | |||||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "src/ops/mfcc.h" | |||||
| #ifndef PRIMITIVE_WRITEABLE | |||||
| #include "src/ops/ops_register.h" | |||||
| #endif | |||||
| namespace mindspore { | |||||
| namespace lite { | |||||
| #ifdef PRIMITIVE_WRITEABLE | |||||
| float Mfcc::GetFreqUpperLimit() const { return this->primitive_->value.AsMfcc()->freqUpperLimit; } | |||||
| float Mfcc::GetFreqLowerLimit() const { return this->primitive_->value.AsMfcc()->freqLowerLimit; } | |||||
| int Mfcc::GetFilterBankChannelNum() const { return this->primitive_->value.AsMfcc()->filterBankChannelNum; } | |||||
| int Mfcc::GetDctCoeffNum() const { return this->primitive_->value.AsMfcc()->dctCoeffNum; } | |||||
| #else | |||||
| float Mfcc::GetFreqUpperLimit() const { return this->primitive_->value_as_Mfcc()->freqUpperLimit(); } | |||||
| float Mfcc::GetFreqLowerLimit() const { return this->primitive_->value_as_Mfcc()->freqLowerLimit(); } | |||||
| int Mfcc::GetFilterBankChannelNum() const { return this->primitive_->value_as_Mfcc()->filterBankChannelNum(); } | |||||
| int Mfcc::GetDctCoeffNum() const { return this->primitive_->value_as_Mfcc()->dctCoeffNum(); } | |||||
| int Mfcc::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) { | |||||
| MS_ASSERT(nullptr != primitive); | |||||
| MS_ASSERT(nullptr != fbb); | |||||
| auto attr = primitive->value_as_Mfcc(); | |||||
| if (attr == nullptr) { | |||||
| MS_LOG(ERROR) << "value_as_Add return nullptr"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| auto val_offset = schema::CreateMfcc(*fbb, attr->freqUpperLimit(), attr->freqLowerLimit(), | |||||
| attr->filterBankChannelNum(), attr->dctCoeffNum()); | |||||
| auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_Mfcc, val_offset.o); | |||||
| fbb->Finish(prim_offset); | |||||
| return RET_OK; | |||||
| } | |||||
| PrimitiveC *MfccCreator(const schema::Primitive *primitive) { return PrimitiveC::NewPrimitiveC<Mfcc>(primitive); } | |||||
| Registry MfccRegistry(schema::PrimitiveType_Mfcc, MfccCreator); | |||||
| #endif | |||||
| int Mfcc::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> outputs_) { | |||||
| auto input = inputs_.front(); | |||||
| MS_ASSERT(input != nullptr); | |||||
| auto output = outputs_.front(); | |||||
| MS_ASSERT(output != nullptr); | |||||
| output->set_data_type(input->data_type()); | |||||
| output->SetFormat(input->GetFormat()); | |||||
| if (!GetInferFlag()) { | |||||
| return RET_OK; | |||||
| } | |||||
| auto input_shape = input->shape(); | |||||
| if (input_shape.size() != 3) { | |||||
| MS_LOG(ERROR) << "first input shape is error, which need to be 3 dimensions, but the dimension is " | |||||
| << input_shape.size(); | |||||
| return RET_ERROR; | |||||
| } | |||||
| if (inputs_[1]->ElementsNum() != 1) { | |||||
| MS_LOG(ERROR) << "second input element num is error, which need only a value, but the number is " | |||||
| << inputs_[1]->ElementsNum(); | |||||
| return RET_ERROR; | |||||
| } | |||||
| std::vector<int> output_shape(3); | |||||
| output_shape[0] = input_shape[0]; | |||||
| output_shape[1] = input_shape[1]; | |||||
| output_shape[2] = GetDctCoeffNum(); | |||||
| outputs_.front()->set_shape(output_shape); | |||||
| return RET_OK; | |||||
| } | |||||
| } // namespace lite | |||||
| } // namespace mindspore | |||||
| @@ -0,0 +1,57 @@ | |||||
| /** | |||||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef LITE_MINDSPORE_LITE_C_OPS_MFCC_H_ | |||||
| #define LITE_MINDSPORE_LITE_C_OPS_MFCC_H_ | |||||
| #include <vector> | |||||
| #include <set> | |||||
| #include <cmath> | |||||
| #include "src/ops/primitive_c.h" | |||||
| namespace mindspore { | |||||
| namespace lite { | |||||
| class Mfcc : public PrimitiveC { | |||||
| public: | |||||
| #ifdef PRIMITIVE_WRITEABLE | |||||
| MS_DECLARE_PARENT(Mfcc, PrimitiveC); | |||||
| Mfcc() = default; | |||||
| explicit Mfcc(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {} | |||||
| void SetFreqUpperLimit(float freq_upper_limit) { | |||||
| this->primitive_->value.AsMfcc()->freqUpperLimit = freq_upper_limit; | |||||
| } | |||||
| void SetFreqLowerLimit(float freq_lower_limit) { | |||||
| this->primitive_->value.AsMfcc()->freqLowerLimit = freq_lower_limit; | |||||
| } | |||||
| void SetFilterBankChannelNum(int filter_bank_channel_num) { | |||||
| this->primitive_->value.AsMfcc()->filterBankChannelNum = filter_bank_channel_num; | |||||
| } | |||||
| void SetDctCoeffNum(int dct_coeff_num) { this->primitive_->value.AsMfcc()->dctCoeffNum = dct_coeff_num; } | |||||
| #else | |||||
| Mfcc() = default; | |||||
| int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override; | |||||
| #endif | |||||
| float GetFreqUpperLimit() const; | |||||
| float GetFreqLowerLimit() const; | |||||
| int GetFilterBankChannelNum() const; | |||||
| int GetDctCoeffNum() const; | |||||
| int InferShape(std::vector<lite::Tensor *> inputs_, std::vector<lite::Tensor *> outputs_) override; | |||||
| }; | |||||
| } // namespace lite | |||||
| } // namespace mindspore | |||||
| #endif // LITE_MINDSPORE_LITE_C_OPS_MFCC_H_ | |||||
| @@ -137,6 +137,11 @@ | |||||
| #include "src/ops/upsample.h" | #include "src/ops/upsample.h" | ||||
| #include "src/ops/layer_norm.h" | #include "src/ops/layer_norm.h" | ||||
| #include "src/ops/non_max_suppression.h" | #include "src/ops/non_max_suppression.h" | ||||
| #include "src/ops/rfft.h" | |||||
| #include "src/ops/fft_real.h" | |||||
| #include "src/ops/fft_imag.h" | |||||
| #include "src/ops/audio_spectrogram.h" | |||||
| #include "src/ops/mfcc.h" | |||||
| #include "src/ops/identity.h" | #include "src/ops/identity.h" | ||||
| #ifdef SUPPORT_TRAIN | #ifdef SUPPORT_TRAIN | ||||
| @@ -775,6 +780,16 @@ PrimitiveC *PrimitiveC::Create(mindspore::schema::PrimitiveT *primitive) { | |||||
| return new NonMaxSuppression(primitive); | return new NonMaxSuppression(primitive); | ||||
| case schema::PrimitiveType_Identity: | case schema::PrimitiveType_Identity: | ||||
| return new Identity(primitive); | return new Identity(primitive); | ||||
| case schema::PrimitiveType_Rfft: | |||||
| return new Rfft(primitive); | |||||
| case schema::PrimitiveType_FftReal: | |||||
| return new FftReal(primitive); | |||||
| case schema::PrimitiveType_FftImag: | |||||
| return new FftImag(primitive); | |||||
| case schema::PrimitiveType_AudioSpectrogram: | |||||
| return new AudioSpectrogram(primitive); | |||||
| case schema::PrimitiveType_Mfcc: | |||||
| return new Mfcc(primitive); | |||||
| #ifdef SUPPORT_TRAIN | #ifdef SUPPORT_TRAIN | ||||
| case schema::PrimitiveType_ActivationGrad: | case schema::PrimitiveType_ActivationGrad: | ||||
| @@ -0,0 +1,66 @@ | |||||
| /** | |||||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "src/ops/rfft.h" | |||||
| #ifndef PRIMITIVE_WRITEABLE | |||||
| #include "src/ops/ops_register.h" | |||||
| #endif | |||||
| namespace mindspore { | |||||
| namespace lite { | |||||
| #ifdef PRIMITIVE_WRITEABLE | |||||
| int Rfft::GetFftLength() const { return this->primitive_->value.AsRfft()->fftLength; } | |||||
| void Rfft::SetFftLength(int fft_length) { this->primitive_->value.AsRfft()->fftLength = fft_length; } | |||||
| #else | |||||
| int Rfft::GetFftLength() const { return this->primitive_->value_as_Rfft()->fftLength(); } | |||||
| int Rfft::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) { | |||||
| MS_ASSERT(nullptr != primitive); | |||||
| MS_ASSERT(nullptr != fbb); | |||||
| auto attr = primitive->value_as_Rfft(); | |||||
| if (attr == nullptr) { | |||||
| MS_LOG(ERROR) << "value_as_Add return nullptr"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| auto val_offset = schema::CreateRfft(*fbb, attr->fftLength()); | |||||
| auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_Rfft, val_offset.o); | |||||
| fbb->Finish(prim_offset); | |||||
| return RET_OK; | |||||
| } | |||||
| PrimitiveC *RfftCreator(const schema::Primitive *primitive) { return PrimitiveC::NewPrimitiveC<Rfft>(primitive); } | |||||
| Registry RfftRegistry(schema::PrimitiveType_Rfft, RfftCreator); | |||||
| #endif | |||||
| int Rfft::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> outputs_) { | |||||
| auto input = inputs_.front(); | |||||
| MS_ASSERT(input != nullptr); | |||||
| auto output = outputs_.front(); | |||||
| MS_ASSERT(output != nullptr); | |||||
| output->set_data_type(TypeId::kNumberTypeComplex64); | |||||
| output->SetFormat(input->GetFormat()); | |||||
| if (!GetInferFlag()) { | |||||
| return RET_OK; | |||||
| } | |||||
| auto input_shape = input->shape(); | |||||
| input_shape[input_shape.size() - 1] = GetFftLength() / 2 + 1; | |||||
| input_shape.push_back(2); | |||||
| outputs_.front()->set_shape(input_shape); | |||||
| return RET_OK; | |||||
| } | |||||
| } // namespace lite | |||||
| } // namespace mindspore | |||||
| @@ -0,0 +1,45 @@ | |||||
| /** | |||||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef LITE_MINDSPORE_LITE_C_OPS_RFFT_H_ | |||||
| #define LITE_MINDSPORE_LITE_C_OPS_RFFT_H_ | |||||
| #include <vector> | |||||
| #include <set> | |||||
| #include <cmath> | |||||
| #include "src/ops/primitive_c.h" | |||||
| namespace mindspore { | |||||
| namespace lite { | |||||
| class Rfft : public PrimitiveC { | |||||
| public: | |||||
| #ifdef PRIMITIVE_WRITEABLE | |||||
| MS_DECLARE_PARENT(Rfft, PrimitiveC); | |||||
| Rfft() = default; | |||||
| explicit Rfft(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {} | |||||
| void SetFftLength(int fft_length); | |||||
| #else | |||||
| Rfft() = default; | |||||
| int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override; | |||||
| #endif | |||||
| int GetFftLength() const; | |||||
| int InferShape(std::vector<lite::Tensor *> inputs_, std::vector<lite::Tensor *> outputs_) override; | |||||
| }; | |||||
| } // namespace lite | |||||
| } // namespace mindspore | |||||
| #endif // LITE_MINDSPORE_LITE_C_OPS_RFFT_H_ | |||||
| @@ -0,0 +1,56 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "tools/converter/parser/caffe/caffe_elu_parser.h" | |||||
| #include <memory> | |||||
| namespace mindspore { | |||||
| namespace lite { | |||||
| STATUS CaffeEluParser::Parse(const caffe::LayerParameter &proto, const caffe::LayerParameter &weight, | |||||
| schema::CNodeT *op, std::vector<schema::TensorT *> *weightVec) { | |||||
| MS_LOG(DEBUG) << "parse CaffeEluParser"; | |||||
| if (op == nullptr) { | |||||
| MS_LOG(ERROR) << "op is null"; | |||||
| return RET_NULL_PTR; | |||||
| } | |||||
| op->primitive = std::make_unique<schema::PrimitiveT>(); | |||||
| if (op->primitive == nullptr) { | |||||
| MS_LOG(ERROR) << "op->primitive is null"; | |||||
| return RET_NULL_PTR; | |||||
| } | |||||
| std::unique_ptr<schema::EluT> attr = std::make_unique<schema::EluT>(); | |||||
| if (attr == nullptr) { | |||||
| MS_LOG(ERROR) << "new op failed"; | |||||
| return RET_NULL_PTR; | |||||
| } | |||||
| if (proto.has_elu_param()) { | |||||
| const caffe::ELUParameter eluParameter = proto.elu_param(); | |||||
| if (eluParameter.has_alpha()) { | |||||
| attr->alpha = eluParameter.alpha(); | |||||
| } | |||||
| } | |||||
| op->name = proto.name(); | |||||
| op->primitive->value.type = schema::PrimitiveType_Elu; | |||||
| op->primitive->value.value = attr.release(); | |||||
| return RET_OK; | |||||
| } | |||||
| CaffeNodeRegistrar g_caffeEluParser("ELU", new CaffeEluParser()); | |||||
| } // namespace lite | |||||
| } // namespace mindspore | |||||
| @@ -0,0 +1,36 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_CAFFE_CAFFE_ELU_PARSER_H_ | |||||
| #define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_CAFFE_CAFFE_ELU_PARSER_H_ | |||||
| #include <vector> | |||||
| #include "tools/converter/parser/caffe/caffe_node_parser.h" | |||||
| #include "tools/converter/parser/caffe/caffe_node_parser_registry.h" | |||||
| namespace mindspore { | |||||
| namespace lite { | |||||
| class CaffeEluParser : public CaffeNodeParser { | |||||
| public: | |||||
| CaffeEluParser() : CaffeNodeParser("elu") {} | |||||
| STATUS Parse(const caffe::LayerParameter &proto, const caffe::LayerParameter &weight, schema::CNodeT *op, | |||||
| std::vector<schema::TensorT *> *weightVec) override; | |||||
| }; | |||||
| } // namespace lite | |||||
| } // namespace mindspore | |||||
| #endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_CAFFE_CAFFE_ELU_PARSER_H_ | |||||
| @@ -243,7 +243,7 @@ STATUS CaffeModelParser::ParseLayer(const caffe::NetParameter &proto, const caff | |||||
| auto status_node = nodeParser->Parse(layer, layerP, op.get(), &weightVec); | auto status_node = nodeParser->Parse(layer, layerP, op.get(), &weightVec); | ||||
| if (status_node != RET_OK) { | if (status_node != RET_OK) { | ||||
| interrupt = true; | interrupt = true; | ||||
| if (status_node == RET_NOT_SUPPORT) { | |||||
| if (status_node == RET_NOT_FIND_OP) { | |||||
| NoSupportOp::GetInstance()->InsertOp(layer.type()); | NoSupportOp::GetInstance()->InsertOp(layer.type()); | ||||
| } else { | } else { | ||||
| MS_LOG(ERROR) << "Parse weight for " << layer.name() << " Failed!"; | MS_LOG(ERROR) << "Parse weight for " << layer.name() << " Failed!"; | ||||
| @@ -156,8 +156,8 @@ STATUS OnnxDeConvParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::N | |||||
| if (attr->group != 1) { | if (attr->group != 1) { | ||||
| if (!ParseGroupDeConvolution(attr, op)) { | if (!ParseGroupDeConvolution(attr, op)) { | ||||
| MS_LOG(ERROR) << "Convert DeConvolution to DeDepthwise failed"; | |||||
| return RET_ERROR; | |||||
| MS_LOG(ERROR) << "Convert DeConvolution to DeDepthwise failed, generalized group deconv hasn't support"; | |||||
| return RET_NOT_SUPPORT; | |||||
| } | } | ||||
| } else { | } else { | ||||
| op->primitive->value.type = schema::PrimitiveType_DeConv2D; | op->primitive->value.type = schema::PrimitiveType_DeConv2D; | ||||
| @@ -522,6 +522,7 @@ STATUS OnnxModelParser::ParseSubgraph(schema::CNodeT *dst_op, const onnx::NodePr | |||||
| dst_op->primitive->value.value = attr.release(); | dst_op->primitive->value.value = attr.release(); | ||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| schema::MetaGraphT *OnnxModelParser::ParseGraph(const onnx::GraphProto &onnx_graph, const QuantType &quantType) { | schema::MetaGraphT *OnnxModelParser::ParseGraph(const onnx::GraphProto &onnx_graph, const QuantType &quantType) { | ||||
| TensorCache tensor_cache; | TensorCache tensor_cache; | ||||
| // dst_graph->name = onnx_graph.name(); // this is not used | // dst_graph->name = onnx_graph.name(); // this is not used | ||||
| @@ -593,6 +594,7 @@ schema::MetaGraphT *OnnxModelParser::ParseGraph(const onnx::GraphProto &onnx_gra | |||||
| SetAllTensors(tensor_cache, dst_graph.get()); | SetAllTensors(tensor_cache, dst_graph.get()); | ||||
| return dst_graph.release(); | return dst_graph.release(); | ||||
| } | } | ||||
| schema::MetaGraphT *OnnxModelParser::ParseToFb(const std::string &modelFile, const std::string &weightFile, | schema::MetaGraphT *OnnxModelParser::ParseToFb(const std::string &modelFile, const std::string &weightFile, | ||||
| const QuantType &quantType) { | const QuantType &quantType) { | ||||
| int status = ValidateFileStr(modelFile, ".onnx"); | int status = ValidateFileStr(modelFile, ".onnx"); | ||||
| @@ -72,11 +72,12 @@ STATUS TfliteConvParser::Parse(TfliteTensorsInfo *tensors_info, const std::uniqu | |||||
| auto data_index = tflite_op->inputs[0]; | auto data_index = tflite_op->inputs[0]; | ||||
| const auto &data_tensor = tflite_model->subgraphs[0]->tensors[data_index]; | const auto &data_tensor = tflite_model->subgraphs[0]->tensors[data_index]; | ||||
| std::vector<int> params; | std::vector<int> params; | ||||
| if (getPaddingParam(data_tensor, attr->padMode, attr->strideH, attr->strideW, attr->kernelH, attr->kernelW, | |||||
| ¶ms) != RET_OK) { | |||||
| int status = | |||||
| getPaddingParam(data_tensor, attr->padMode, attr->strideH, attr->strideW, attr->kernelH, attr->kernelW, ¶ms); | |||||
| if (status != RET_OK && status != RET_NO_CHANGE) { | |||||
| MS_LOG(ERROR) << "get padding params failed"; | MS_LOG(ERROR) << "get padding params failed"; | ||||
| return RET_ERROR; | return RET_ERROR; | ||||
| } else { | |||||
| } else if (status == RET_OK) { | |||||
| attr->padUp = params.at(0); | attr->padUp = params.at(0); | ||||
| attr->padDown = params.at(1); | attr->padDown = params.at(1); | ||||
| attr->padLeft = params.at(2); | attr->padLeft = params.at(2); | ||||
| @@ -73,11 +73,12 @@ STATUS TfliteDeConvParser::Parse(TfliteTensorsInfo *tensors_info, const std::uni | |||||
| auto data_index = tflite_op->inputs[2]; | auto data_index = tflite_op->inputs[2]; | ||||
| const auto &data_tensor = tflite_model->subgraphs[0]->tensors[data_index]; | const auto &data_tensor = tflite_model->subgraphs[0]->tensors[data_index]; | ||||
| std::vector<int> params; | std::vector<int> params; | ||||
| if (getPaddingParam(data_tensor, attr->padMode, attr->strideH, attr->strideW, attr->kernelH, attr->kernelW, | |||||
| ¶ms) != RET_OK) { | |||||
| int status = | |||||
| getPaddingParam(data_tensor, attr->padMode, attr->strideH, attr->strideW, attr->kernelH, attr->kernelW, ¶ms); | |||||
| if (status != RET_OK && status != RET_NO_CHANGE) { | |||||
| MS_LOG(ERROR) << "get padding params failed"; | MS_LOG(ERROR) << "get padding params failed"; | ||||
| return RET_ERROR; | return RET_ERROR; | ||||
| } else { | |||||
| } else if (status == RET_OK) { | |||||
| attr->padUp = params.at(0); | attr->padUp = params.at(0); | ||||
| attr->padDown = params.at(1); | attr->padDown = params.at(1); | ||||
| attr->padLeft = params.at(2); | attr->padLeft = params.at(2); | ||||
| @@ -79,11 +79,12 @@ STATUS TfliteDepthwiseConv2DParser::Parse(TfliteTensorsInfo *tensors_info, | |||||
| // calculate pad params | // calculate pad params | ||||
| std::vector<int> params; | std::vector<int> params; | ||||
| if (getPaddingParam(data_tensor, attr->padMode, attr->strideH, attr->strideW, attr->kernelH, attr->kernelW, | |||||
| ¶ms) != RET_OK) { | |||||
| int status = | |||||
| getPaddingParam(data_tensor, attr->padMode, attr->strideH, attr->strideW, attr->kernelH, attr->kernelW, ¶ms); | |||||
| if (status != RET_OK && status != RET_NO_CHANGE) { | |||||
| MS_LOG(ERROR) << "get padding params failed"; | MS_LOG(ERROR) << "get padding params failed"; | ||||
| return RET_ERROR; | return RET_ERROR; | ||||
| } else { | |||||
| } else if (status == RET_OK) { | |||||
| attr->padUp = params.at(0); | attr->padUp = params.at(0); | ||||
| attr->padDown = params.at(1); | attr->padDown = params.at(1); | ||||
| attr->padLeft = params.at(2); | attr->padLeft = params.at(2); | ||||
| @@ -18,6 +18,7 @@ | |||||
| #include <utility> | #include <utility> | ||||
| #include <memory> | #include <memory> | ||||
| #include <vector> | #include <vector> | ||||
| #include <set> | |||||
| #include "tools/common/graph_util.h" | #include "tools/common/graph_util.h" | ||||
| #include "tools/common/storage.h" | #include "tools/common/storage.h" | ||||
| #include "flatbuffers/flatbuffers.h" | #include "flatbuffers/flatbuffers.h" | ||||
| @@ -102,11 +103,6 @@ STATUS TfliteModelParser::ConvertOp(const std::unique_ptr<tflite::ModelT> &tflit | |||||
| for (const auto &tflite_op : tflite_subgraph->operators) { | for (const auto &tflite_op : tflite_subgraph->operators) { | ||||
| auto tflite_op_type = (tflite_model->operator_codes[tflite_op->opcode_index])->builtin_code; | auto tflite_op_type = (tflite_model->operator_codes[tflite_op->opcode_index])->builtin_code; | ||||
| auto op_type = GetMSOpType(tflite_op_type); | auto op_type = GetMSOpType(tflite_op_type); | ||||
| if (op_type == "CUSTOM") { | |||||
| auto custom_type = (tflite_model->operator_codes[tflite_op->opcode_index])->custom_code; | |||||
| MS_LOG(ERROR) << "CUSTOM op is not supported, the type is " << custom_type; | |||||
| return RET_ERROR; | |||||
| } | |||||
| auto op = std::make_unique<schema::CNodeT>(); | auto op = std::make_unique<schema::CNodeT>(); | ||||
| op->name = op_type + "-" + std::to_string(idx++); | op->name = op_type + "-" + std::to_string(idx++); | ||||
| @@ -122,7 +118,9 @@ STATUS TfliteModelParser::ConvertOp(const std::unique_ptr<tflite::ModelT> &tflit | |||||
| if (status == RET_OK) { | if (status == RET_OK) { | ||||
| status = node_parser->Parse(&tensorsInfo, tflite_op, tflite_model, op.get()); | status = node_parser->Parse(&tensorsInfo, tflite_op, tflite_model, op.get()); | ||||
| if (status != RET_OK) { | if (status != RET_OK) { | ||||
| if (status == RET_NOT_SUPPORT) { | |||||
| if (status == RET_NOT_FIND_OP) { | |||||
| op_type = | |||||
| (op_type != "Custom" ? op_type : (tflite_model->operator_codes[tflite_op->opcode_index])->custom_code); | |||||
| NoSupportOp::GetInstance()->InsertOp(op_type); | NoSupportOp::GetInstance()->InsertOp(op_type); | ||||
| } else { | } else { | ||||
| MS_LOG(ERROR) << "node " << op_type.c_str() << " parser failed"; | MS_LOG(ERROR) << "node " << op_type.c_str() << " parser failed"; | ||||
| @@ -141,6 +139,16 @@ STATUS TfliteModelParser::ConvertOp(const std::unique_ptr<tflite::ModelT> &tflit | |||||
| STATUS TfliteModelParser::ConvertTensor(const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, | STATUS TfliteModelParser::ConvertTensor(const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, | ||||
| const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer, | const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer, | ||||
| schema::MetaGraphT *sub_graph) { | schema::MetaGraphT *sub_graph) { | ||||
| std::set<int> output_index; | |||||
| for (const auto &tflite_op : tflite_subgraph->operators) { | |||||
| for (size_t j = 0; j < tflite_op->outputs.size(); ++j) { | |||||
| int idx = tflite_op->outputs[j]; | |||||
| if (idx < 0) { | |||||
| idx += tflite_subgraph->tensors.size(); | |||||
| } | |||||
| output_index.insert(idx); | |||||
| } | |||||
| } | |||||
| for (size_t i = 0; i < tensorsInfo.tensorsId.size(); i++) { | for (size_t i = 0; i < tensorsInfo.tensorsId.size(); i++) { | ||||
| auto idx = tensorsInfo.tensorsId[i]; | auto idx = tensorsInfo.tensorsId[i]; | ||||
| if (idx < 0) { | if (idx < 0) { | ||||
| @@ -173,11 +181,16 @@ STATUS TfliteModelParser::ConvertTensor(const std::unique_ptr<tflite::SubGraphT> | |||||
| return status; | return status; | ||||
| } | } | ||||
| } | } | ||||
| // set tensor attr | // set tensor attr | ||||
| if (isInput || isConst) { | if (isInput || isConst) { | ||||
| tensor->nodeType = schema::NodeType::NodeType_ValueNode; | tensor->nodeType = schema::NodeType::NodeType_ValueNode; | ||||
| } else { | } else { | ||||
| tensor->nodeType = schema::NodeType_Parameter; | |||||
| if (output_index.find(idx) == output_index.end() && tflite_tensor->shape[0] == 0) { | |||||
| tensor->nodeType = schema::NodeType::NodeType_ValueNode; | |||||
| } else { | |||||
| tensor->nodeType = schema::NodeType_Parameter; | |||||
| } | |||||
| } | } | ||||
| // quant param | // quant param | ||||
| @@ -246,7 +259,6 @@ STATUS TfliteModelParser::ConvertGroupDepthwiseOp(schema::MetaGraphT *sub_graph) | |||||
| if (op->primitive->value.type == schema::PrimitiveType_DepthwiseConv2D) { | if (op->primitive->value.type == schema::PrimitiveType_DepthwiseConv2D) { | ||||
| auto attr = op->primitive->value.AsDepthwiseConv2D(); | auto attr = op->primitive->value.AsDepthwiseConv2D(); | ||||
| if (attr->channelMultiplier > 1) { | if (attr->channelMultiplier > 1) { | ||||
| std::unique_ptr<schema::Conv2DT> conv_attr = std::make_unique<schema::Conv2DT>(); | |||||
| // get channel attr | // get channel attr | ||||
| if (op->inputIndex.empty()) { | if (op->inputIndex.empty()) { | ||||
| MS_LOG(ERROR) << "the input of DepthwiseConv2D is null"; | MS_LOG(ERROR) << "the input of DepthwiseConv2D is null"; | ||||
| @@ -263,7 +275,11 @@ STATUS TfliteModelParser::ConvertGroupDepthwiseOp(schema::MetaGraphT *sub_graph) | |||||
| return RET_NULL_PTR; | return RET_NULL_PTR; | ||||
| } | } | ||||
| auto data_shape = data_tensor->dims; | auto data_shape = data_tensor->dims; | ||||
| if (data_shape.empty()) { | |||||
| MS_LOG(DEBUG) << "the tensor's shape is dynamic, which obtain only when running"; | |||||
| return RET_NO_CHANGE; | |||||
| } | |||||
| std::unique_ptr<schema::Conv2DT> conv_attr = std::make_unique<schema::Conv2DT>(); | |||||
| if (data_shape[3] == 1) { | if (data_shape[3] == 1) { | ||||
| conv_attr->channelIn = data_shape[3]; | conv_attr->channelIn = data_shape[3]; | ||||
| conv_attr->channelOut = conv_attr->channelIn * attr->channelMultiplier; | conv_attr->channelOut = conv_attr->channelIn * attr->channelMultiplier; | ||||
| @@ -372,7 +388,7 @@ schema::MetaGraphT *TfliteModelParser::ParseToFb(const std::string &model_file, | |||||
| // update for depthwiseConv | // update for depthwiseConv | ||||
| status = ConvertGroupDepthwiseOp(meta_graph.get()); | status = ConvertGroupDepthwiseOp(meta_graph.get()); | ||||
| if (status != RET_OK) { | |||||
| if (status != RET_OK && status != RET_NO_CHANGE) { | |||||
| MS_LOG(ERROR) << "convert group depthwise conv failed"; | MS_LOG(ERROR) << "convert group depthwise conv failed"; | ||||
| ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status); | ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status); | ||||
| return nullptr; | return nullptr; | ||||
| @@ -86,6 +86,10 @@ class TfliteNodeParser { | |||||
| return RET_NULL_PTR; | return RET_NULL_PTR; | ||||
| } | } | ||||
| auto data_ptr = buf_data->data.data(); | auto data_ptr = buf_data->data.data(); | ||||
| if (data_ptr == nullptr) { | |||||
| MS_LOG(DEBUG) << "data is not a constant"; | |||||
| return RET_NO_CHANGE; | |||||
| } | |||||
| switch (tflite_tensors[tensor_index]->type) { | switch (tflite_tensors[tensor_index]->type) { | ||||
| case tflite::TensorType_UINT8: { | case tflite::TensorType_UINT8: { | ||||
| for (int i = 0; i < count; i++) { | for (int i = 0; i < count; i++) { | ||||
| @@ -71,11 +71,11 @@ STATUS TflitePadParser::Parse(TfliteTensorsInfo *tensors_info, const std::unique | |||||
| break; | break; | ||||
| default: | default: | ||||
| MS_LOG(ERROR) << "paddingmode:" << tflite_attr->mode << " don't support"; | MS_LOG(ERROR) << "paddingmode:" << tflite_attr->mode << " don't support"; | ||||
| return RET_INVALID_OP_ATTR; | |||||
| return RET_NOT_SUPPORT; | |||||
| } | } | ||||
| } else { | } else { | ||||
| MS_LOG(ERROR) << "this pad:" << node_name << " hasn't been supported"; | MS_LOG(ERROR) << "this pad:" << node_name << " hasn't been supported"; | ||||
| return RET_NOT_SUPPORT; | |||||
| return RET_NOT_FIND_OP; | |||||
| } | } | ||||
| op->primitive->value.type = schema::PrimitiveType_Pad; | op->primitive->value.type = schema::PrimitiveType_Pad; | ||||
| @@ -71,11 +71,12 @@ STATUS TflitePoolingParser::Parse(TfliteTensorsInfo *tensors_info, const std::un | |||||
| auto data_index = tflite_op->inputs[0]; | auto data_index = tflite_op->inputs[0]; | ||||
| const auto &data_tensor = tflite_model->subgraphs[0]->tensors[data_index]; | const auto &data_tensor = tflite_model->subgraphs[0]->tensors[data_index]; | ||||
| std::vector<int> params; | std::vector<int> params; | ||||
| if (getPaddingParam(data_tensor, attr->padMode, attr->strideH, attr->strideW, attr->windowH, attr->windowW, | |||||
| ¶ms) != RET_OK) { | |||||
| int status = | |||||
| getPaddingParam(data_tensor, attr->padMode, attr->strideH, attr->strideW, attr->windowH, attr->windowW, ¶ms); | |||||
| if (status != RET_OK && status != RET_NO_CHANGE) { | |||||
| MS_LOG(ERROR) << "get padding params failed"; | MS_LOG(ERROR) << "get padding params failed"; | ||||
| return RET_ERROR; | return RET_ERROR; | ||||
| } else { | |||||
| } else if (status == RET_OK) { | |||||
| attr->padUp = params.at(0); | attr->padUp = params.at(0); | ||||
| attr->padDown = params.at(1); | attr->padDown = params.at(1); | ||||
| attr->padLeft = params.at(2); | attr->padLeft = params.at(2); | ||||
| @@ -41,15 +41,31 @@ STATUS TfliteRangeParser::Parse(TfliteTensorsInfo *tensors_info, const std::uniq | |||||
| } | } | ||||
| attr->dType = 0; | attr->dType = 0; | ||||
| // attr->start | |||||
| // attr->limit | |||||
| // attr->delta | |||||
| std::vector<int> limit; | |||||
| std::vector<int> delta; | |||||
| int status = GetTfliteData(tflite_op->inputs[1], tflite_model->subgraphs[0]->tensors, tflite_model->buffers, limit); | |||||
| if (status != RET_OK && status != RET_NO_CHANGE) { | |||||
| MS_LOG(ERROR) << "range -> limit get failed"; | |||||
| return RET_ERROR; | |||||
| } else if (status == RET_OK) { | |||||
| status = GetTfliteData(tflite_op->inputs[2], tflite_model->subgraphs[0]->tensors, tflite_model->buffers, delta); | |||||
| if (status != RET_OK && status != RET_NO_CHANGE) { | |||||
| MS_LOG(ERROR) << "stridedSlice -> end get failed"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| } | |||||
| if (status == RET_OK) { | |||||
| attr->limit = limit.front(); | |||||
| attr->delta = delta.front(); | |||||
| } | |||||
| op->primitive->value.type = schema::PrimitiveType_Range; | op->primitive->value.type = schema::PrimitiveType_Range; | ||||
| op->primitive->value.value = attr.release(); | op->primitive->value.value = attr.release(); | ||||
| AddOpInput(op, tensors_info, tflite_op->inputs[0], tflite_model->subgraphs[0]->tensors.size(), | |||||
| schema::Format::Format_NHWC); | |||||
| int input_num = status == RET_OK ? 1 : 3; | |||||
| for (int i = 0; i < input_num; ++i) { | |||||
| AddOpInput(op, tensors_info, tflite_op->inputs[i], tflite_model->subgraphs[0]->tensors.size(), | |||||
| schema::Format::Format_NHWC); | |||||
| } | |||||
| AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_model->subgraphs[0]->tensors.size(), | AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_model->subgraphs[0]->tensors.size(), | ||||
| schema::Format::Format_NHWC); | schema::Format::Format_NHWC); | ||||
| return RET_OK; | return RET_OK; | ||||
| @@ -69,7 +69,7 @@ STATUS TfliteReduceParser::Parse(TfliteTensorsInfo *tensors_info, const std::uni | |||||
| } else if (std::strcmp(node_name, "ReduceAny") == 0) { | } else if (std::strcmp(node_name, "ReduceAny") == 0) { | ||||
| // attr->mode; | // attr->mode; | ||||
| MS_LOG(ERROR) << "ms-lite haven't supported REDUCE_ANY now"; | MS_LOG(ERROR) << "ms-lite haven't supported REDUCE_ANY now"; | ||||
| return RET_NOT_FIND_OP; | |||||
| return RET_NOT_SUPPORT; | |||||
| } | } | ||||
| if (GetTfliteData(tflite_op->inputs[1], tflite_model->subgraphs[0]->tensors, tflite_model->buffers, attr->axes)) { | if (GetTfliteData(tflite_op->inputs[1], tflite_model->subgraphs[0]->tensors, tflite_model->buffers, attr->axes)) { | ||||
| @@ -67,14 +67,15 @@ STATUS TfliteSplitParser::Parse(TfliteTensorsInfo *tensors_info, const std::uniq | |||||
| return RET_ERROR; | return RET_ERROR; | ||||
| } | } | ||||
| attr->splitDim = axis; | attr->splitDim = axis; | ||||
| if (tensor_shape[axis] % num_splits != 0) { | |||||
| if (tensor_shape[axis] % num_splits != 0 && tensor_shape[axis] / num_splits != 0) { | |||||
| MS_LOG(ERROR) << "num_splits can't divide tensor's length at axis " << axis; | MS_LOG(ERROR) << "num_splits can't divide tensor's length at axis " << axis; | ||||
| return RET_ERROR; | return RET_ERROR; | ||||
| } | } | ||||
| attr->numberSplit = num_splits; | attr->numberSplit = num_splits; | ||||
| for (int i = 0; i < num_splits; i++) { | |||||
| attr->sizeSplits.push_back(tensor_shape[axis] / num_splits); | |||||
| if (tensor_shape[axis] / num_splits != 0) { | |||||
| for (int i = 0; i < num_splits; i++) { | |||||
| attr->sizeSplits.push_back(tensor_shape[axis] / num_splits); | |||||
| } | |||||
| } | } | ||||
| op->primitive->value.type = schema::PrimitiveType_Split; | op->primitive->value.type = schema::PrimitiveType_Split; | ||||
| @@ -52,17 +52,24 @@ STATUS TfliteStridedSliceParser::Parse(TfliteTensorsInfo *tensors_info, | |||||
| attr->newAxisMask = tflite_attr->new_axis_mask; | attr->newAxisMask = tflite_attr->new_axis_mask; | ||||
| attr->shrinkAxisMask = tflite_attr->shrink_axis_mask; | attr->shrinkAxisMask = tflite_attr->shrink_axis_mask; | ||||
| if (GetTfliteData(tflite_op->inputs[1], tflite_model->subgraphs[0]->tensors, tflite_model->buffers, attr->begin)) { | |||||
| int status = | |||||
| GetTfliteData(tflite_op->inputs[1], tflite_model->subgraphs[0]->tensors, tflite_model->buffers, attr->begin); | |||||
| if (status != RET_OK && status != RET_NO_CHANGE) { | |||||
| MS_LOG(ERROR) << "stridedSlice -> begin get failed"; | MS_LOG(ERROR) << "stridedSlice -> begin get failed"; | ||||
| return RET_ERROR; | return RET_ERROR; | ||||
| } | |||||
| if (GetTfliteData(tflite_op->inputs[2], tflite_model->subgraphs[0]->tensors, tflite_model->buffers, attr->end)) { | |||||
| MS_LOG(ERROR) << "stridedSlice -> end get failed"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| if (GetTfliteData(tflite_op->inputs[3], tflite_model->subgraphs[0]->tensors, tflite_model->buffers, attr->stride)) { | |||||
| MS_LOG(ERROR) << "stridedSlice -> stride get failed"; | |||||
| return RET_ERROR; | |||||
| } else if (status == RET_OK) { | |||||
| status = GetTfliteData(tflite_op->inputs[2], tflite_model->subgraphs[0]->tensors, tflite_model->buffers, attr->end); | |||||
| if (status != RET_OK && status != RET_NO_CHANGE) { | |||||
| MS_LOG(ERROR) << "stridedSlice -> end get failed"; | |||||
| return RET_ERROR; | |||||
| } else if (status == RET_OK) { | |||||
| status = | |||||
| GetTfliteData(tflite_op->inputs[3], tflite_model->subgraphs[0]->tensors, tflite_model->buffers, attr->stride); | |||||
| if (status != RET_OK && status != RET_NO_CHANGE) { | |||||
| MS_LOG(ERROR) << "stridedSlice -> stride get failed"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| } | |||||
| } | } | ||||
| attr->isScale.assign(tflite_model->subgraphs[0]->tensors[tflite_op->inputs[0]]->shape.begin(), | attr->isScale.assign(tflite_model->subgraphs[0]->tensors[tflite_op->inputs[0]]->shape.begin(), | ||||
| tflite_model->subgraphs[0]->tensors[tflite_op->inputs[0]]->shape.end()); | tflite_model->subgraphs[0]->tensors[tflite_op->inputs[0]]->shape.end()); | ||||
| @@ -70,8 +77,11 @@ STATUS TfliteStridedSliceParser::Parse(TfliteTensorsInfo *tensors_info, | |||||
| op->primitive->value.type = schema::PrimitiveType_StridedSlice; | op->primitive->value.type = schema::PrimitiveType_StridedSlice; | ||||
| op->primitive->value.value = attr.release(); | op->primitive->value.value = attr.release(); | ||||
| AddOpInput(op, tensors_info, tflite_op->inputs[0], tflite_model->subgraphs[0]->tensors.size(), | |||||
| schema::Format::Format_NHWC); | |||||
| int input_num = status == RET_OK ? 1 : 4; | |||||
| for (int i = 0; i < input_num; ++i) { | |||||
| AddOpInput(op, tensors_info, tflite_op->inputs[i], tflite_model->subgraphs[0]->tensors.size(), | |||||
| schema::Format::Format_NHWC); | |||||
| } | |||||
| AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_model->subgraphs[0]->tensors.size(), | AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_model->subgraphs[0]->tensors.size(), | ||||
| schema::Format::Format_NHWC); | schema::Format::Format_NHWC); | ||||
| return RET_OK; | return RET_OK; | ||||
| @@ -198,7 +198,10 @@ STATUS getPaddingParam(const std::unique_ptr<tflite::TensorT> &tensor, schema::P | |||||
| MS_LOG(ERROR) << "the input tensor is null"; | MS_LOG(ERROR) << "the input tensor is null"; | ||||
| return RET_ERROR; | return RET_ERROR; | ||||
| } | } | ||||
| if (tensor->shape.empty()) { | |||||
| MS_LOG(DEBUG) << "the tensor's shape is dynamic, which obtain nly when running."; | |||||
| return RET_NO_CHANGE; | |||||
| } | |||||
| int padUp = 0; | int padUp = 0; | ||||
| int padDown = 0; | int padDown = 0; | ||||
| int padLeft = 0; | int padLeft = 0; | ||||