Merge pull request !7716 from 徐安越/mastertags/v1.1.0
| @@ -47,6 +47,12 @@ Float::Float(const int nbits) : Number(FloatBitsToTypeId(nbits), nbits, false) { | |||
| } | |||
| } | |||
| Complex::Complex(const int nbits) : Number(TypeId::kNumberTypeComplex64, nbits, false) { | |||
| if (nbits != 64) { | |||
| MS_LOG(EXCEPTION) << "Wrong number of bits."; | |||
| } | |||
| } | |||
| const TypePtr kBool = std::make_shared<Bool>(); | |||
| const TypePtr kInt8 = std::make_shared<Int>(8); | |||
| const TypePtr kInt16 = std::make_shared<Int>(16); | |||
| @@ -63,4 +69,5 @@ const TypePtr kInt = std::make_shared<Int>(); | |||
| const TypePtr kUInt = std::make_shared<UInt>(); | |||
| const TypePtr kFloat = std::make_shared<Float>(); | |||
| const TypePtr kNumber = std::make_shared<Number>(); | |||
| const TypePtr kComplex64 = std::make_shared<Complex>(64); | |||
| } // namespace mindspore | |||
| @@ -150,6 +150,28 @@ class Float : public Number { | |||
| } | |||
| }; | |||
| // Complex | |||
| class Complex : public Number { | |||
| public: | |||
| Complex() : Number(kNumberTypeComplex64, 0) {} | |||
| explicit Complex(const int nbits); | |||
| ~Complex() override {} | |||
| MS_DECLARE_PARENT(Complex, Number) | |||
| TypeId generic_type_id() const override { return kNumberTypeComplex64; } | |||
| TypePtr DeepCopy() const override { | |||
| if (nbits() == 0) { | |||
| return std::make_shared<Complex>(); | |||
| } | |||
| return std::make_shared<Complex>(nbits()); | |||
| } | |||
| std::string ToString() const override { return GetTypeName("Complex64"); } | |||
| std::string ToReprString() const override { return nbits() == 0 ? "complex64_" : GetTypeName("complex64"); } | |||
| std::string DumpText() const override { | |||
| return nbits() == 0 ? std::string("Complex64") : std::string("C") + std::to_string(nbits()); | |||
| } | |||
| }; | |||
| extern const TypePtr kBool; | |||
| extern const TypePtr kInt8; | |||
| extern const TypePtr kInt16; | |||
| @@ -166,6 +188,7 @@ extern const TypePtr kInt; | |||
| extern const TypePtr kUInt; | |||
| extern const TypePtr kFloat; | |||
| extern const TypePtr kNumber; | |||
| extern const TypePtr kComplex64; | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CORE_IR_DTYPE_NUMBER_H_ | |||
| @@ -69,6 +69,8 @@ TypePtr TypeIdToType(TypeId id) { | |||
| return kFloat32; | |||
| case kNumberTypeFloat64: | |||
| return kFloat64; | |||
| case kNumberTypeComplex64: | |||
| return kComplex64; | |||
| case kNumberTypeInt8: | |||
| return kInt8; | |||
| case kNumberTypeInt16: | |||
| @@ -0,0 +1,107 @@ | |||
| /** | |||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "src/ops/audio_spectrogram.h" | |||
| #ifndef PRIMITIVE_WRITEABLE | |||
| #include "src/ops/ops_register.h" | |||
| #endif | |||
| namespace mindspore { | |||
| namespace lite { | |||
| #ifdef PRIMITIVE_WRITEABLE | |||
| int AudioSpectrogram::GetWindowSize() const { return this->primitive_->value.AsAudioSpectrogram()->windowSize; } | |||
| int AudioSpectrogram::GetStride() const { return this->primitive_->value.AsAudioSpectrogram()->stride; } | |||
| bool AudioSpectrogram::GetMagSquare() const { return this->primitive_->value.AsAudioSpectrogram()->magSquare; } | |||
| #else | |||
| int AudioSpectrogram::GetWindowSize() const { return this->primitive_->value_as_AudioSpectrogram()->windowSize(); } | |||
| int AudioSpectrogram::GetStride() const { return this->primitive_->value_as_AudioSpectrogram()->stride(); } | |||
| bool AudioSpectrogram::GetMagSquare() const { return this->primitive_->value_as_AudioSpectrogram()->magSquare(); } | |||
| int AudioSpectrogram::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) { | |||
| MS_ASSERT(nullptr != primitive); | |||
| MS_ASSERT(nullptr != fbb); | |||
| auto attr = primitive->value_as_AudioSpectrogram(); | |||
| if (attr == nullptr) { | |||
| MS_LOG(ERROR) << "value_as_Add return nullptr"; | |||
| return RET_ERROR; | |||
| } | |||
| auto val_offset = schema::CreateAudioSpectrogram(*fbb, attr->windowSize(), attr->stride(), attr->magSquare()); | |||
| auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_AudioSpectrogram, val_offset.o); | |||
| fbb->Finish(prim_offset); | |||
| return RET_OK; | |||
| } | |||
| PrimitiveC *AudioSpectrogramCreator(const schema::Primitive *primitive) { | |||
| return PrimitiveC::NewPrimitiveC<AudioSpectrogram>(primitive); | |||
| } | |||
| Registry AudioSpectrogramRegistry(schema::PrimitiveType_AudioSpectrogram, AudioSpectrogramCreator); | |||
| #endif | |||
| int AudioSpectrogram::Log2Ceil(uint32_t length) { | |||
| if (length == 0) { | |||
| return -1; | |||
| } | |||
| int floor = 0; | |||
| for (int i = 4; i >= 0; --i) { | |||
| int shift = (1 << i); | |||
| uint32_t tmp = length >> shift; | |||
| if (tmp != 0) { | |||
| length = tmp; | |||
| floor += shift; | |||
| } | |||
| } | |||
| return length == (length & ~(length - 1)) ? floor : floor + 1; | |||
| } | |||
| uint32_t AudioSpectrogram::GetFftLength(uint32_t length) { | |||
| int shift = Log2Ceil(length); | |||
| return 1 << shift; | |||
| } | |||
| int AudioSpectrogram::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> outputs_) { | |||
| auto input = inputs_.front(); | |||
| MS_ASSERT(input != nullptr); | |||
| auto output = outputs_.front(); | |||
| MS_ASSERT(output != nullptr); | |||
| output->set_data_type(input->data_type()); | |||
| output->SetFormat(input->GetFormat()); | |||
| if (!GetInferFlag()) { | |||
| return RET_OK; | |||
| } | |||
| auto input_shape = input->shape(); | |||
| if (input_shape.size() != 2) { | |||
| MS_LOG(ERROR) << "input shape is error, which need to be 2 dimensions"; | |||
| return RET_ERROR; | |||
| } | |||
| if (GetWindowSize() < 2) { | |||
| MS_LOG(ERROR) << "window size is too short, now is " << GetWindowSize(); | |||
| return RET_ERROR; | |||
| } | |||
| if (GetStride() < 1) { | |||
| MS_LOG(ERROR) << "stride must be positive, now is " << GetStride(); | |||
| return RET_ERROR; | |||
| } | |||
| std::vector<int> output_shape(3); | |||
| output_shape[0] = input_shape[1]; | |||
| // output height | |||
| int sample_sub_window = input_shape[0] - GetWindowSize(); | |||
| output_shape[1] = sample_sub_window < 0 ? 0 : 1 + sample_sub_window / GetStride(); | |||
| // compute fft length | |||
| int fft_length = GetFftLength(GetWindowSize()); | |||
| output_shape[2] = fft_length / 2 + 1; | |||
| outputs_.front()->set_shape(output_shape); | |||
| return RET_OK; | |||
| } | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,51 @@ | |||
| /** | |||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef LITE_MINDSPORE_LITE_C_OPS_AUDIO_SPECTROGRAM_H_ | |||
| #define LITE_MINDSPORE_LITE_C_OPS_AUDIO_SPECTROGRAM_H_ | |||
| #include <vector> | |||
| #include <set> | |||
| #include <cmath> | |||
| #include "src/ops/primitive_c.h" | |||
| namespace mindspore { | |||
| namespace lite { | |||
| class AudioSpectrogram : public PrimitiveC { | |||
| public: | |||
| #ifdef PRIMITIVE_WRITEABLE | |||
| MS_DECLARE_PARENT(AudioSpectrogram, PrimitiveC); | |||
| AudioSpectrogram() = default; | |||
| explicit AudioSpectrogram(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {} | |||
| void SetWindowSize(int window_size) { this->primitive_->value.AsAudioSpectrogram()->windowSize = window_size; } | |||
| void SetStride(int stride) { this->primitive_->value.AsAudioSpectrogram()->stride = stride; } | |||
| void SetMagSquare(bool mag_square) { this->primitive_->value.AsAudioSpectrogram()->magSquare = mag_square; } | |||
| #else | |||
| AudioSpectrogram() = default; | |||
| int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override; | |||
| #endif | |||
| int GetWindowSize() const; | |||
| int GetStride() const; | |||
| bool GetMagSquare() const; | |||
| int Log2Ceil(uint32_t length); | |||
| uint32_t GetFftLength(uint32_t length); | |||
| int InferShape(std::vector<lite::Tensor *> inputs_, std::vector<lite::Tensor *> outputs_) override; | |||
| }; | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| #endif // LITE_MINDSPORE_LITE_C_OPS_AUDIO_SPECTROGRAM_H_ | |||
| @@ -0,0 +1,54 @@ | |||
| /** | |||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "src/ops/fft_imag.h" | |||
| #ifndef PRIMITIVE_WRITEABLE | |||
| #include "src/ops/ops_register.h" | |||
| #endif | |||
| namespace mindspore { | |||
| namespace lite { | |||
| #ifndef PRIMITIVE_WRITEABLE | |||
| int FftImag::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) { | |||
| MS_ASSERT(nullptr != primitive); | |||
| MS_ASSERT(nullptr != fbb); | |||
| auto val_offset = schema::CreateEqual(*fbb); | |||
| auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_FftImag, val_offset.o); | |||
| fbb->Finish(prim_offset); | |||
| return RET_OK; | |||
| } | |||
| PrimitiveC *FftImagCreator(const schema::Primitive *primitive) { return PrimitiveC::NewPrimitiveC<FftImag>(primitive); } | |||
| Registry FftImagRegistry(schema::PrimitiveType_FftImag, FftImagCreator); | |||
| #endif | |||
| int FftImag::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> outputs_) { | |||
| auto input = inputs_.front(); | |||
| MS_ASSERT(input != nullptr); | |||
| auto output = outputs_.front(); | |||
| MS_ASSERT(output != nullptr); | |||
| output->set_data_type(TypeId::kNumberTypeFloat32); | |||
| output->SetFormat(input->GetFormat()); | |||
| if (!GetInferFlag()) { | |||
| return RET_OK; | |||
| } | |||
| auto input_shape = input->shape(); | |||
| input_shape.pop_back(); | |||
| outputs_.front()->set_shape(input_shape); | |||
| return RET_OK; | |||
| } | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,43 @@ | |||
| /** | |||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef LITE_MINDSPORE_LITE_C_OPS_FFT_IMAG_H_ | |||
| #define LITE_MINDSPORE_LITE_C_OPS_FFT_IMAG_H_ | |||
| #include <vector> | |||
| #include <set> | |||
| #include <cmath> | |||
| #include "src/ops/primitive_c.h" | |||
| namespace mindspore { | |||
| namespace lite { | |||
| class FftImag : public PrimitiveC { | |||
| public: | |||
| #ifdef PRIMITIVE_WRITEABLE | |||
| MS_DECLARE_PARENT(FftImag, PrimitiveC); | |||
| FftImag() = default; | |||
| explicit FftImag(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {} | |||
| #else | |||
| FftImag() = default; | |||
| int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override; | |||
| #endif | |||
| int InferShape(std::vector<lite::Tensor *> inputs_, std::vector<lite::Tensor *> outputs_) override; | |||
| }; | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| #endif // LITE_MINDSPORE_LITE_C_OPS_FFT_IMAG_H_ | |||
| @@ -0,0 +1,54 @@ | |||
| /** | |||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "src/ops/fft_real.h" | |||
| #ifndef PRIMITIVE_WRITEABLE | |||
| #include "src/ops/ops_register.h" | |||
| #endif | |||
| namespace mindspore { | |||
| namespace lite { | |||
| #ifndef PRIMITIVE_WRITEABLE | |||
| int FftReal::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) { | |||
| MS_ASSERT(nullptr != primitive); | |||
| MS_ASSERT(nullptr != fbb); | |||
| auto val_offset = schema::CreateEqual(*fbb); | |||
| auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_FftReal, val_offset.o); | |||
| fbb->Finish(prim_offset); | |||
| return RET_OK; | |||
| } | |||
| PrimitiveC *FftRealCreator(const schema::Primitive *primitive) { return PrimitiveC::NewPrimitiveC<FftReal>(primitive); } | |||
| Registry FftRealRegistry(schema::PrimitiveType_FftReal, FftRealCreator); | |||
| #endif | |||
| int FftReal::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> outputs_) { | |||
| auto input = inputs_.front(); | |||
| MS_ASSERT(input != nullptr); | |||
| auto output = outputs_.front(); | |||
| MS_ASSERT(output != nullptr); | |||
| output->set_data_type(TypeId::kNumberTypeFloat32); | |||
| output->SetFormat(input->GetFormat()); | |||
| if (!GetInferFlag()) { | |||
| return RET_OK; | |||
| } | |||
| auto input_shape = input->shape(); | |||
| input_shape.pop_back(); | |||
| outputs_.front()->set_shape(input_shape); | |||
| return RET_OK; | |||
| } | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,43 @@ | |||
| /** | |||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef LITE_MINDSPORE_LITE_C_OPS_FFT_REAL_H_ | |||
| #define LITE_MINDSPORE_LITE_C_OPS_FFT_REAL_H_ | |||
| #include <vector> | |||
| #include <set> | |||
| #include <cmath> | |||
| #include "src/ops/primitive_c.h" | |||
| namespace mindspore { | |||
| namespace lite { | |||
| class FftReal : public PrimitiveC { | |||
| public: | |||
| #ifdef PRIMITIVE_WRITEABLE | |||
| MS_DECLARE_PARENT(FftReal, PrimitiveC); | |||
| FftReal() = default; | |||
| explicit FftReal(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {} | |||
| #else | |||
| FftReal() = default; | |||
| int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override; | |||
| #endif | |||
| int InferShape(std::vector<lite::Tensor *> inputs_, std::vector<lite::Tensor *> outputs_) override; | |||
| }; | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| #endif // LITE_MINDSPORE_LITE_C_OPS_FFT_REAL_H_ | |||
| @@ -0,0 +1,83 @@ | |||
| /** | |||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "src/ops/mfcc.h" | |||
| #ifndef PRIMITIVE_WRITEABLE | |||
| #include "src/ops/ops_register.h" | |||
| #endif | |||
| namespace mindspore { | |||
| namespace lite { | |||
| #ifdef PRIMITIVE_WRITEABLE | |||
| float Mfcc::GetFreqUpperLimit() const { return this->primitive_->value.AsMfcc()->freqUpperLimit; } | |||
| float Mfcc::GetFreqLowerLimit() const { return this->primitive_->value.AsMfcc()->freqLowerLimit; } | |||
| int Mfcc::GetFilterBankChannelNum() const { return this->primitive_->value.AsMfcc()->filterBankChannelNum; } | |||
| int Mfcc::GetDctCoeffNum() const { return this->primitive_->value.AsMfcc()->dctCoeffNum; } | |||
| #else | |||
| float Mfcc::GetFreqUpperLimit() const { return this->primitive_->value_as_Mfcc()->freqUpperLimit(); } | |||
| float Mfcc::GetFreqLowerLimit() const { return this->primitive_->value_as_Mfcc()->freqLowerLimit(); } | |||
| int Mfcc::GetFilterBankChannelNum() const { return this->primitive_->value_as_Mfcc()->filterBankChannelNum(); } | |||
| int Mfcc::GetDctCoeffNum() const { return this->primitive_->value_as_Mfcc()->dctCoeffNum(); } | |||
| int Mfcc::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) { | |||
| MS_ASSERT(nullptr != primitive); | |||
| MS_ASSERT(nullptr != fbb); | |||
| auto attr = primitive->value_as_Mfcc(); | |||
| if (attr == nullptr) { | |||
| MS_LOG(ERROR) << "value_as_Add return nullptr"; | |||
| return RET_ERROR; | |||
| } | |||
| auto val_offset = schema::CreateMfcc(*fbb, attr->freqUpperLimit(), attr->freqLowerLimit(), | |||
| attr->filterBankChannelNum(), attr->dctCoeffNum()); | |||
| auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_Mfcc, val_offset.o); | |||
| fbb->Finish(prim_offset); | |||
| return RET_OK; | |||
| } | |||
| PrimitiveC *MfccCreator(const schema::Primitive *primitive) { return PrimitiveC::NewPrimitiveC<Mfcc>(primitive); } | |||
| Registry MfccRegistry(schema::PrimitiveType_Mfcc, MfccCreator); | |||
| #endif | |||
| int Mfcc::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> outputs_) { | |||
| auto input = inputs_.front(); | |||
| MS_ASSERT(input != nullptr); | |||
| auto output = outputs_.front(); | |||
| MS_ASSERT(output != nullptr); | |||
| output->set_data_type(input->data_type()); | |||
| output->SetFormat(input->GetFormat()); | |||
| if (!GetInferFlag()) { | |||
| return RET_OK; | |||
| } | |||
| auto input_shape = input->shape(); | |||
| if (input_shape.size() != 3) { | |||
| MS_LOG(ERROR) << "first input shape is error, which need to be 3 dimensions, but the dimension is " | |||
| << input_shape.size(); | |||
| return RET_ERROR; | |||
| } | |||
| if (inputs_[1]->ElementsNum() != 1) { | |||
| MS_LOG(ERROR) << "second input element num is error, which need only a value, but the number is " | |||
| << inputs_[1]->ElementsNum(); | |||
| return RET_ERROR; | |||
| } | |||
| std::vector<int> output_shape(3); | |||
| output_shape[0] = input_shape[0]; | |||
| output_shape[1] = input_shape[1]; | |||
| output_shape[2] = GetDctCoeffNum(); | |||
| outputs_.front()->set_shape(output_shape); | |||
| return RET_OK; | |||
| } | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,57 @@ | |||
| /** | |||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef LITE_MINDSPORE_LITE_C_OPS_MFCC_H_ | |||
| #define LITE_MINDSPORE_LITE_C_OPS_MFCC_H_ | |||
| #include <vector> | |||
| #include <set> | |||
| #include <cmath> | |||
| #include "src/ops/primitive_c.h" | |||
| namespace mindspore { | |||
| namespace lite { | |||
| class Mfcc : public PrimitiveC { | |||
| public: | |||
| #ifdef PRIMITIVE_WRITEABLE | |||
| MS_DECLARE_PARENT(Mfcc, PrimitiveC); | |||
| Mfcc() = default; | |||
| explicit Mfcc(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {} | |||
| void SetFreqUpperLimit(float freq_upper_limit) { | |||
| this->primitive_->value.AsMfcc()->freqUpperLimit = freq_upper_limit; | |||
| } | |||
| void SetFreqLowerLimit(float freq_lower_limit) { | |||
| this->primitive_->value.AsMfcc()->freqLowerLimit = freq_lower_limit; | |||
| } | |||
| void SetFilterBankChannelNum(int filter_bank_channel_num) { | |||
| this->primitive_->value.AsMfcc()->filterBankChannelNum = filter_bank_channel_num; | |||
| } | |||
| void SetDctCoeffNum(int dct_coeff_num) { this->primitive_->value.AsMfcc()->dctCoeffNum = dct_coeff_num; } | |||
| #else | |||
| Mfcc() = default; | |||
| int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override; | |||
| #endif | |||
| float GetFreqUpperLimit() const; | |||
| float GetFreqLowerLimit() const; | |||
| int GetFilterBankChannelNum() const; | |||
| int GetDctCoeffNum() const; | |||
| int InferShape(std::vector<lite::Tensor *> inputs_, std::vector<lite::Tensor *> outputs_) override; | |||
| }; | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| #endif // LITE_MINDSPORE_LITE_C_OPS_MFCC_H_ | |||
| @@ -137,6 +137,11 @@ | |||
| #include "src/ops/upsample.h" | |||
| #include "src/ops/layer_norm.h" | |||
| #include "src/ops/non_max_suppression.h" | |||
| #include "src/ops/rfft.h" | |||
| #include "src/ops/fft_real.h" | |||
| #include "src/ops/fft_imag.h" | |||
| #include "src/ops/audio_spectrogram.h" | |||
| #include "src/ops/mfcc.h" | |||
| #include "src/ops/identity.h" | |||
| #ifdef SUPPORT_TRAIN | |||
| @@ -775,6 +780,16 @@ PrimitiveC *PrimitiveC::Create(mindspore::schema::PrimitiveT *primitive) { | |||
| return new NonMaxSuppression(primitive); | |||
| case schema::PrimitiveType_Identity: | |||
| return new Identity(primitive); | |||
| case schema::PrimitiveType_Rfft: | |||
| return new Rfft(primitive); | |||
| case schema::PrimitiveType_FftReal: | |||
| return new FftReal(primitive); | |||
| case schema::PrimitiveType_FftImag: | |||
| return new FftImag(primitive); | |||
| case schema::PrimitiveType_AudioSpectrogram: | |||
| return new AudioSpectrogram(primitive); | |||
| case schema::PrimitiveType_Mfcc: | |||
| return new Mfcc(primitive); | |||
| #ifdef SUPPORT_TRAIN | |||
| case schema::PrimitiveType_ActivationGrad: | |||
| @@ -0,0 +1,66 @@ | |||
| /** | |||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "src/ops/rfft.h" | |||
| #ifndef PRIMITIVE_WRITEABLE | |||
| #include "src/ops/ops_register.h" | |||
| #endif | |||
| namespace mindspore { | |||
| namespace lite { | |||
| #ifdef PRIMITIVE_WRITEABLE | |||
| int Rfft::GetFftLength() const { return this->primitive_->value.AsRfft()->fftLength; } | |||
| void Rfft::SetFftLength(int fft_length) { this->primitive_->value.AsRfft()->fftLength = fft_length; } | |||
| #else | |||
| int Rfft::GetFftLength() const { return this->primitive_->value_as_Rfft()->fftLength(); } | |||
| int Rfft::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) { | |||
| MS_ASSERT(nullptr != primitive); | |||
| MS_ASSERT(nullptr != fbb); | |||
| auto attr = primitive->value_as_Rfft(); | |||
| if (attr == nullptr) { | |||
| MS_LOG(ERROR) << "value_as_Add return nullptr"; | |||
| return RET_ERROR; | |||
| } | |||
| auto val_offset = schema::CreateRfft(*fbb, attr->fftLength()); | |||
| auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_Rfft, val_offset.o); | |||
| fbb->Finish(prim_offset); | |||
| return RET_OK; | |||
| } | |||
| PrimitiveC *RfftCreator(const schema::Primitive *primitive) { return PrimitiveC::NewPrimitiveC<Rfft>(primitive); } | |||
| Registry RfftRegistry(schema::PrimitiveType_Rfft, RfftCreator); | |||
| #endif | |||
| int Rfft::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> outputs_) { | |||
| auto input = inputs_.front(); | |||
| MS_ASSERT(input != nullptr); | |||
| auto output = outputs_.front(); | |||
| MS_ASSERT(output != nullptr); | |||
| output->set_data_type(TypeId::kNumberTypeComplex64); | |||
| output->SetFormat(input->GetFormat()); | |||
| if (!GetInferFlag()) { | |||
| return RET_OK; | |||
| } | |||
| auto input_shape = input->shape(); | |||
| input_shape[input_shape.size() - 1] = GetFftLength() / 2 + 1; | |||
| input_shape.push_back(2); | |||
| outputs_.front()->set_shape(input_shape); | |||
| return RET_OK; | |||
| } | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,45 @@ | |||
| /** | |||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef LITE_MINDSPORE_LITE_C_OPS_RFFT_H_ | |||
| #define LITE_MINDSPORE_LITE_C_OPS_RFFT_H_ | |||
| #include <vector> | |||
| #include <set> | |||
| #include <cmath> | |||
| #include "src/ops/primitive_c.h" | |||
| namespace mindspore { | |||
| namespace lite { | |||
| class Rfft : public PrimitiveC { | |||
| public: | |||
| #ifdef PRIMITIVE_WRITEABLE | |||
| MS_DECLARE_PARENT(Rfft, PrimitiveC); | |||
| Rfft() = default; | |||
| explicit Rfft(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {} | |||
| void SetFftLength(int fft_length); | |||
| #else | |||
| Rfft() = default; | |||
| int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override; | |||
| #endif | |||
| int GetFftLength() const; | |||
| int InferShape(std::vector<lite::Tensor *> inputs_, std::vector<lite::Tensor *> outputs_) override; | |||
| }; | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| #endif // LITE_MINDSPORE_LITE_C_OPS_RFFT_H_ | |||
| @@ -0,0 +1,56 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "tools/converter/parser/caffe/caffe_elu_parser.h" | |||
| #include <memory> | |||
| namespace mindspore { | |||
| namespace lite { | |||
| STATUS CaffeEluParser::Parse(const caffe::LayerParameter &proto, const caffe::LayerParameter &weight, | |||
| schema::CNodeT *op, std::vector<schema::TensorT *> *weightVec) { | |||
| MS_LOG(DEBUG) << "parse CaffeEluParser"; | |||
| if (op == nullptr) { | |||
| MS_LOG(ERROR) << "op is null"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| op->primitive = std::make_unique<schema::PrimitiveT>(); | |||
| if (op->primitive == nullptr) { | |||
| MS_LOG(ERROR) << "op->primitive is null"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| std::unique_ptr<schema::EluT> attr = std::make_unique<schema::EluT>(); | |||
| if (attr == nullptr) { | |||
| MS_LOG(ERROR) << "new op failed"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| if (proto.has_elu_param()) { | |||
| const caffe::ELUParameter eluParameter = proto.elu_param(); | |||
| if (eluParameter.has_alpha()) { | |||
| attr->alpha = eluParameter.alpha(); | |||
| } | |||
| } | |||
| op->name = proto.name(); | |||
| op->primitive->value.type = schema::PrimitiveType_Elu; | |||
| op->primitive->value.value = attr.release(); | |||
| return RET_OK; | |||
| } | |||
| CaffeNodeRegistrar g_caffeEluParser("ELU", new CaffeEluParser()); | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,36 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_CAFFE_CAFFE_ELU_PARSER_H_ | |||
| #define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_CAFFE_CAFFE_ELU_PARSER_H_ | |||
| #include <vector> | |||
| #include "tools/converter/parser/caffe/caffe_node_parser.h" | |||
| #include "tools/converter/parser/caffe/caffe_node_parser_registry.h" | |||
| namespace mindspore { | |||
| namespace lite { | |||
| class CaffeEluParser : public CaffeNodeParser { | |||
| public: | |||
| CaffeEluParser() : CaffeNodeParser("elu") {} | |||
| STATUS Parse(const caffe::LayerParameter &proto, const caffe::LayerParameter &weight, schema::CNodeT *op, | |||
| std::vector<schema::TensorT *> *weightVec) override; | |||
| }; | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_CAFFE_CAFFE_ELU_PARSER_H_ | |||
| @@ -243,7 +243,7 @@ STATUS CaffeModelParser::ParseLayer(const caffe::NetParameter &proto, const caff | |||
| auto status_node = nodeParser->Parse(layer, layerP, op.get(), &weightVec); | |||
| if (status_node != RET_OK) { | |||
| interrupt = true; | |||
| if (status_node == RET_NOT_SUPPORT) { | |||
| if (status_node == RET_NOT_FIND_OP) { | |||
| NoSupportOp::GetInstance()->InsertOp(layer.type()); | |||
| } else { | |||
| MS_LOG(ERROR) << "Parse weight for " << layer.name() << " Failed!"; | |||
| @@ -156,8 +156,8 @@ STATUS OnnxDeConvParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::N | |||
| if (attr->group != 1) { | |||
| if (!ParseGroupDeConvolution(attr, op)) { | |||
| MS_LOG(ERROR) << "Convert DeConvolution to DeDepthwise failed"; | |||
| return RET_ERROR; | |||
| MS_LOG(ERROR) << "Convert DeConvolution to DeDepthwise failed, generalized group deconv hasn't support"; | |||
| return RET_NOT_SUPPORT; | |||
| } | |||
| } else { | |||
| op->primitive->value.type = schema::PrimitiveType_DeConv2D; | |||
| @@ -522,6 +522,7 @@ STATUS OnnxModelParser::ParseSubgraph(schema::CNodeT *dst_op, const onnx::NodePr | |||
| dst_op->primitive->value.value = attr.release(); | |||
| return RET_OK; | |||
| } | |||
| schema::MetaGraphT *OnnxModelParser::ParseGraph(const onnx::GraphProto &onnx_graph, const QuantType &quantType) { | |||
| TensorCache tensor_cache; | |||
| // dst_graph->name = onnx_graph.name(); // this is not used | |||
| @@ -593,6 +594,7 @@ schema::MetaGraphT *OnnxModelParser::ParseGraph(const onnx::GraphProto &onnx_gra | |||
| SetAllTensors(tensor_cache, dst_graph.get()); | |||
| return dst_graph.release(); | |||
| } | |||
| schema::MetaGraphT *OnnxModelParser::ParseToFb(const std::string &modelFile, const std::string &weightFile, | |||
| const QuantType &quantType) { | |||
| int status = ValidateFileStr(modelFile, ".onnx"); | |||
| @@ -72,11 +72,12 @@ STATUS TfliteConvParser::Parse(TfliteTensorsInfo *tensors_info, const std::uniqu | |||
| auto data_index = tflite_op->inputs[0]; | |||
| const auto &data_tensor = tflite_model->subgraphs[0]->tensors[data_index]; | |||
| std::vector<int> params; | |||
| if (getPaddingParam(data_tensor, attr->padMode, attr->strideH, attr->strideW, attr->kernelH, attr->kernelW, | |||
| ¶ms) != RET_OK) { | |||
| int status = | |||
| getPaddingParam(data_tensor, attr->padMode, attr->strideH, attr->strideW, attr->kernelH, attr->kernelW, ¶ms); | |||
| if (status != RET_OK && status != RET_NO_CHANGE) { | |||
| MS_LOG(ERROR) << "get padding params failed"; | |||
| return RET_ERROR; | |||
| } else { | |||
| } else if (status == RET_OK) { | |||
| attr->padUp = params.at(0); | |||
| attr->padDown = params.at(1); | |||
| attr->padLeft = params.at(2); | |||
| @@ -73,11 +73,12 @@ STATUS TfliteDeConvParser::Parse(TfliteTensorsInfo *tensors_info, const std::uni | |||
| auto data_index = tflite_op->inputs[2]; | |||
| const auto &data_tensor = tflite_model->subgraphs[0]->tensors[data_index]; | |||
| std::vector<int> params; | |||
| if (getPaddingParam(data_tensor, attr->padMode, attr->strideH, attr->strideW, attr->kernelH, attr->kernelW, | |||
| ¶ms) != RET_OK) { | |||
| int status = | |||
| getPaddingParam(data_tensor, attr->padMode, attr->strideH, attr->strideW, attr->kernelH, attr->kernelW, ¶ms); | |||
| if (status != RET_OK && status != RET_NO_CHANGE) { | |||
| MS_LOG(ERROR) << "get padding params failed"; | |||
| return RET_ERROR; | |||
| } else { | |||
| } else if (status == RET_OK) { | |||
| attr->padUp = params.at(0); | |||
| attr->padDown = params.at(1); | |||
| attr->padLeft = params.at(2); | |||
| @@ -79,11 +79,12 @@ STATUS TfliteDepthwiseConv2DParser::Parse(TfliteTensorsInfo *tensors_info, | |||
| // calculate pad params | |||
| std::vector<int> params; | |||
| if (getPaddingParam(data_tensor, attr->padMode, attr->strideH, attr->strideW, attr->kernelH, attr->kernelW, | |||
| ¶ms) != RET_OK) { | |||
| int status = | |||
| getPaddingParam(data_tensor, attr->padMode, attr->strideH, attr->strideW, attr->kernelH, attr->kernelW, ¶ms); | |||
| if (status != RET_OK && status != RET_NO_CHANGE) { | |||
| MS_LOG(ERROR) << "get padding params failed"; | |||
| return RET_ERROR; | |||
| } else { | |||
| } else if (status == RET_OK) { | |||
| attr->padUp = params.at(0); | |||
| attr->padDown = params.at(1); | |||
| attr->padLeft = params.at(2); | |||
| @@ -18,6 +18,7 @@ | |||
| #include <utility> | |||
| #include <memory> | |||
| #include <vector> | |||
| #include <set> | |||
| #include "tools/common/graph_util.h" | |||
| #include "tools/common/storage.h" | |||
| #include "flatbuffers/flatbuffers.h" | |||
| @@ -102,11 +103,6 @@ STATUS TfliteModelParser::ConvertOp(const std::unique_ptr<tflite::ModelT> &tflit | |||
| for (const auto &tflite_op : tflite_subgraph->operators) { | |||
| auto tflite_op_type = (tflite_model->operator_codes[tflite_op->opcode_index])->builtin_code; | |||
| auto op_type = GetMSOpType(tflite_op_type); | |||
| if (op_type == "CUSTOM") { | |||
| auto custom_type = (tflite_model->operator_codes[tflite_op->opcode_index])->custom_code; | |||
| MS_LOG(ERROR) << "CUSTOM op is not supported, the type is " << custom_type; | |||
| return RET_ERROR; | |||
| } | |||
| auto op = std::make_unique<schema::CNodeT>(); | |||
| op->name = op_type + "-" + std::to_string(idx++); | |||
| @@ -122,7 +118,9 @@ STATUS TfliteModelParser::ConvertOp(const std::unique_ptr<tflite::ModelT> &tflit | |||
| if (status == RET_OK) { | |||
| status = node_parser->Parse(&tensorsInfo, tflite_op, tflite_model, op.get()); | |||
| if (status != RET_OK) { | |||
| if (status == RET_NOT_SUPPORT) { | |||
| if (status == RET_NOT_FIND_OP) { | |||
| op_type = | |||
| (op_type != "Custom" ? op_type : (tflite_model->operator_codes[tflite_op->opcode_index])->custom_code); | |||
| NoSupportOp::GetInstance()->InsertOp(op_type); | |||
| } else { | |||
| MS_LOG(ERROR) << "node " << op_type.c_str() << " parser failed"; | |||
| @@ -141,6 +139,16 @@ STATUS TfliteModelParser::ConvertOp(const std::unique_ptr<tflite::ModelT> &tflit | |||
| STATUS TfliteModelParser::ConvertTensor(const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, | |||
| const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer, | |||
| schema::MetaGraphT *sub_graph) { | |||
| std::set<int> output_index; | |||
| for (const auto &tflite_op : tflite_subgraph->operators) { | |||
| for (size_t j = 0; j < tflite_op->outputs.size(); ++j) { | |||
| int idx = tflite_op->outputs[j]; | |||
| if (idx < 0) { | |||
| idx += tflite_subgraph->tensors.size(); | |||
| } | |||
| output_index.insert(idx); | |||
| } | |||
| } | |||
| for (size_t i = 0; i < tensorsInfo.tensorsId.size(); i++) { | |||
| auto idx = tensorsInfo.tensorsId[i]; | |||
| if (idx < 0) { | |||
| @@ -173,11 +181,16 @@ STATUS TfliteModelParser::ConvertTensor(const std::unique_ptr<tflite::SubGraphT> | |||
| return status; | |||
| } | |||
| } | |||
| // set tensor attr | |||
| if (isInput || isConst) { | |||
| tensor->nodeType = schema::NodeType::NodeType_ValueNode; | |||
| } else { | |||
| tensor->nodeType = schema::NodeType_Parameter; | |||
| if (output_index.find(idx) == output_index.end() && tflite_tensor->shape[0] == 0) { | |||
| tensor->nodeType = schema::NodeType::NodeType_ValueNode; | |||
| } else { | |||
| tensor->nodeType = schema::NodeType_Parameter; | |||
| } | |||
| } | |||
| // quant param | |||
| @@ -246,7 +259,6 @@ STATUS TfliteModelParser::ConvertGroupDepthwiseOp(schema::MetaGraphT *sub_graph) | |||
| if (op->primitive->value.type == schema::PrimitiveType_DepthwiseConv2D) { | |||
| auto attr = op->primitive->value.AsDepthwiseConv2D(); | |||
| if (attr->channelMultiplier > 1) { | |||
| std::unique_ptr<schema::Conv2DT> conv_attr = std::make_unique<schema::Conv2DT>(); | |||
| // get channel attr | |||
| if (op->inputIndex.empty()) { | |||
| MS_LOG(ERROR) << "the input of DepthwiseConv2D is null"; | |||
| @@ -263,7 +275,11 @@ STATUS TfliteModelParser::ConvertGroupDepthwiseOp(schema::MetaGraphT *sub_graph) | |||
| return RET_NULL_PTR; | |||
| } | |||
| auto data_shape = data_tensor->dims; | |||
| if (data_shape.empty()) { | |||
| MS_LOG(DEBUG) << "the tensor's shape is dynamic, which obtain only when running"; | |||
| return RET_NO_CHANGE; | |||
| } | |||
| std::unique_ptr<schema::Conv2DT> conv_attr = std::make_unique<schema::Conv2DT>(); | |||
| if (data_shape[3] == 1) { | |||
| conv_attr->channelIn = data_shape[3]; | |||
| conv_attr->channelOut = conv_attr->channelIn * attr->channelMultiplier; | |||
| @@ -372,7 +388,7 @@ schema::MetaGraphT *TfliteModelParser::ParseToFb(const std::string &model_file, | |||
| // update for depthwiseConv | |||
| status = ConvertGroupDepthwiseOp(meta_graph.get()); | |||
| if (status != RET_OK) { | |||
| if (status != RET_OK && status != RET_NO_CHANGE) { | |||
| MS_LOG(ERROR) << "convert group depthwise conv failed"; | |||
| ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status); | |||
| return nullptr; | |||
| @@ -86,6 +86,10 @@ class TfliteNodeParser { | |||
| return RET_NULL_PTR; | |||
| } | |||
| auto data_ptr = buf_data->data.data(); | |||
| if (data_ptr == nullptr) { | |||
| MS_LOG(DEBUG) << "data is not a constant"; | |||
| return RET_NO_CHANGE; | |||
| } | |||
| switch (tflite_tensors[tensor_index]->type) { | |||
| case tflite::TensorType_UINT8: { | |||
| for (int i = 0; i < count; i++) { | |||
| @@ -71,11 +71,11 @@ STATUS TflitePadParser::Parse(TfliteTensorsInfo *tensors_info, const std::unique | |||
| break; | |||
| default: | |||
| MS_LOG(ERROR) << "paddingmode:" << tflite_attr->mode << " don't support"; | |||
| return RET_INVALID_OP_ATTR; | |||
| return RET_NOT_SUPPORT; | |||
| } | |||
| } else { | |||
| MS_LOG(ERROR) << "this pad:" << node_name << " hasn't been supported"; | |||
| return RET_NOT_SUPPORT; | |||
| return RET_NOT_FIND_OP; | |||
| } | |||
| op->primitive->value.type = schema::PrimitiveType_Pad; | |||
| @@ -71,11 +71,12 @@ STATUS TflitePoolingParser::Parse(TfliteTensorsInfo *tensors_info, const std::un | |||
| auto data_index = tflite_op->inputs[0]; | |||
| const auto &data_tensor = tflite_model->subgraphs[0]->tensors[data_index]; | |||
| std::vector<int> params; | |||
| if (getPaddingParam(data_tensor, attr->padMode, attr->strideH, attr->strideW, attr->windowH, attr->windowW, | |||
| ¶ms) != RET_OK) { | |||
| int status = | |||
| getPaddingParam(data_tensor, attr->padMode, attr->strideH, attr->strideW, attr->windowH, attr->windowW, ¶ms); | |||
| if (status != RET_OK && status != RET_NO_CHANGE) { | |||
| MS_LOG(ERROR) << "get padding params failed"; | |||
| return RET_ERROR; | |||
| } else { | |||
| } else if (status == RET_OK) { | |||
| attr->padUp = params.at(0); | |||
| attr->padDown = params.at(1); | |||
| attr->padLeft = params.at(2); | |||
| @@ -41,15 +41,31 @@ STATUS TfliteRangeParser::Parse(TfliteTensorsInfo *tensors_info, const std::uniq | |||
| } | |||
| attr->dType = 0; | |||
| // attr->start | |||
| // attr->limit | |||
| // attr->delta | |||
| std::vector<int> limit; | |||
| std::vector<int> delta; | |||
| int status = GetTfliteData(tflite_op->inputs[1], tflite_model->subgraphs[0]->tensors, tflite_model->buffers, limit); | |||
| if (status != RET_OK && status != RET_NO_CHANGE) { | |||
| MS_LOG(ERROR) << "range -> limit get failed"; | |||
| return RET_ERROR; | |||
| } else if (status == RET_OK) { | |||
| status = GetTfliteData(tflite_op->inputs[2], tflite_model->subgraphs[0]->tensors, tflite_model->buffers, delta); | |||
| if (status != RET_OK && status != RET_NO_CHANGE) { | |||
| MS_LOG(ERROR) << "stridedSlice -> end get failed"; | |||
| return RET_ERROR; | |||
| } | |||
| } | |||
| if (status == RET_OK) { | |||
| attr->limit = limit.front(); | |||
| attr->delta = delta.front(); | |||
| } | |||
| op->primitive->value.type = schema::PrimitiveType_Range; | |||
| op->primitive->value.value = attr.release(); | |||
| AddOpInput(op, tensors_info, tflite_op->inputs[0], tflite_model->subgraphs[0]->tensors.size(), | |||
| schema::Format::Format_NHWC); | |||
| int input_num = status == RET_OK ? 1 : 3; | |||
| for (int i = 0; i < input_num; ++i) { | |||
| AddOpInput(op, tensors_info, tflite_op->inputs[i], tflite_model->subgraphs[0]->tensors.size(), | |||
| schema::Format::Format_NHWC); | |||
| } | |||
| AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_model->subgraphs[0]->tensors.size(), | |||
| schema::Format::Format_NHWC); | |||
| return RET_OK; | |||
| @@ -69,7 +69,7 @@ STATUS TfliteReduceParser::Parse(TfliteTensorsInfo *tensors_info, const std::uni | |||
| } else if (std::strcmp(node_name, "ReduceAny") == 0) { | |||
| // attr->mode; | |||
| MS_LOG(ERROR) << "ms-lite haven't supported REDUCE_ANY now"; | |||
| return RET_NOT_FIND_OP; | |||
| return RET_NOT_SUPPORT; | |||
| } | |||
| if (GetTfliteData(tflite_op->inputs[1], tflite_model->subgraphs[0]->tensors, tflite_model->buffers, attr->axes)) { | |||
| @@ -67,14 +67,15 @@ STATUS TfliteSplitParser::Parse(TfliteTensorsInfo *tensors_info, const std::uniq | |||
| return RET_ERROR; | |||
| } | |||
| attr->splitDim = axis; | |||
| if (tensor_shape[axis] % num_splits != 0) { | |||
| if (tensor_shape[axis] % num_splits != 0 && tensor_shape[axis] / num_splits != 0) { | |||
| MS_LOG(ERROR) << "num_splits can't divide tensor's length at axis " << axis; | |||
| return RET_ERROR; | |||
| } | |||
| attr->numberSplit = num_splits; | |||
| for (int i = 0; i < num_splits; i++) { | |||
| attr->sizeSplits.push_back(tensor_shape[axis] / num_splits); | |||
| if (tensor_shape[axis] / num_splits != 0) { | |||
| for (int i = 0; i < num_splits; i++) { | |||
| attr->sizeSplits.push_back(tensor_shape[axis] / num_splits); | |||
| } | |||
| } | |||
| op->primitive->value.type = schema::PrimitiveType_Split; | |||
| @@ -52,17 +52,24 @@ STATUS TfliteStridedSliceParser::Parse(TfliteTensorsInfo *tensors_info, | |||
| attr->newAxisMask = tflite_attr->new_axis_mask; | |||
| attr->shrinkAxisMask = tflite_attr->shrink_axis_mask; | |||
| if (GetTfliteData(tflite_op->inputs[1], tflite_model->subgraphs[0]->tensors, tflite_model->buffers, attr->begin)) { | |||
| int status = | |||
| GetTfliteData(tflite_op->inputs[1], tflite_model->subgraphs[0]->tensors, tflite_model->buffers, attr->begin); | |||
| if (status != RET_OK && status != RET_NO_CHANGE) { | |||
| MS_LOG(ERROR) << "stridedSlice -> begin get failed"; | |||
| return RET_ERROR; | |||
| } | |||
| if (GetTfliteData(tflite_op->inputs[2], tflite_model->subgraphs[0]->tensors, tflite_model->buffers, attr->end)) { | |||
| MS_LOG(ERROR) << "stridedSlice -> end get failed"; | |||
| return RET_ERROR; | |||
| } | |||
| if (GetTfliteData(tflite_op->inputs[3], tflite_model->subgraphs[0]->tensors, tflite_model->buffers, attr->stride)) { | |||
| MS_LOG(ERROR) << "stridedSlice -> stride get failed"; | |||
| return RET_ERROR; | |||
| } else if (status == RET_OK) { | |||
| status = GetTfliteData(tflite_op->inputs[2], tflite_model->subgraphs[0]->tensors, tflite_model->buffers, attr->end); | |||
| if (status != RET_OK && status != RET_NO_CHANGE) { | |||
| MS_LOG(ERROR) << "stridedSlice -> end get failed"; | |||
| return RET_ERROR; | |||
| } else if (status == RET_OK) { | |||
| status = | |||
| GetTfliteData(tflite_op->inputs[3], tflite_model->subgraphs[0]->tensors, tflite_model->buffers, attr->stride); | |||
| if (status != RET_OK && status != RET_NO_CHANGE) { | |||
| MS_LOG(ERROR) << "stridedSlice -> stride get failed"; | |||
| return RET_ERROR; | |||
| } | |||
| } | |||
| } | |||
| attr->isScale.assign(tflite_model->subgraphs[0]->tensors[tflite_op->inputs[0]]->shape.begin(), | |||
| tflite_model->subgraphs[0]->tensors[tflite_op->inputs[0]]->shape.end()); | |||
| @@ -70,8 +77,11 @@ STATUS TfliteStridedSliceParser::Parse(TfliteTensorsInfo *tensors_info, | |||
| op->primitive->value.type = schema::PrimitiveType_StridedSlice; | |||
| op->primitive->value.value = attr.release(); | |||
| AddOpInput(op, tensors_info, tflite_op->inputs[0], tflite_model->subgraphs[0]->tensors.size(), | |||
| schema::Format::Format_NHWC); | |||
| int input_num = status == RET_OK ? 1 : 4; | |||
| for (int i = 0; i < input_num; ++i) { | |||
| AddOpInput(op, tensors_info, tflite_op->inputs[i], tflite_model->subgraphs[0]->tensors.size(), | |||
| schema::Format::Format_NHWC); | |||
| } | |||
| AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_model->subgraphs[0]->tensors.size(), | |||
| schema::Format::Format_NHWC); | |||
| return RET_OK; | |||
| @@ -198,7 +198,10 @@ STATUS getPaddingParam(const std::unique_ptr<tflite::TensorT> &tensor, schema::P | |||
| MS_LOG(ERROR) << "the input tensor is null"; | |||
| return RET_ERROR; | |||
| } | |||
| if (tensor->shape.empty()) { | |||
| MS_LOG(DEBUG) << "the tensor's shape is dynamic, which obtain nly when running."; | |||
| return RET_NO_CHANGE; | |||
| } | |||
| int padUp = 0; | |||
| int padDown = 0; | |||
| int padLeft = 0; | |||