| @@ -78,6 +78,7 @@ enum TypeId : int { | |||||
| kNumberTypeFloat16, | kNumberTypeFloat16, | ||||
| kNumberTypeFloat32, | kNumberTypeFloat32, | ||||
| kNumberTypeFloat64, | kNumberTypeFloat64, | ||||
| kNumberTypeComplex64, | |||||
| kNumberTypeEnd | kNumberTypeEnd | ||||
| }; | }; | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -212,6 +212,9 @@ union PrimitiveType { | |||||
| CustomExtractFeatures, | CustomExtractFeatures, | ||||
| AudioSpectrogram, | AudioSpectrogram, | ||||
| Mfcc, | Mfcc, | ||||
| Rfft, | |||||
| FftReal, | |||||
| FftImag, | |||||
| } | } | ||||
| enum QuantType: int { | enum QuantType: int { | ||||
| @@ -987,4 +987,14 @@ table Mfcc { | |||||
| freqLowerLimit : float; | freqLowerLimit : float; | ||||
| filterBankChannelNum : int; | filterBankChannelNum : int; | ||||
| dctCoeffNum : int; | dctCoeffNum : int; | ||||
| } | |||||
| } | |||||
| table Rfft { | |||||
| fftLength : int; | |||||
| } | |||||
| table FftReal { | |||||
| } | |||||
| table FftImag { | |||||
| } | |||||
| @@ -110,7 +110,8 @@ std::vector<lite::Tensor *> LiteKernelUtil::SubgraphInputTensors(const std::vect | |||||
| for (const auto &kernel : input_kernels) { | for (const auto &kernel : input_kernels) { | ||||
| for (const auto &tensor : kernel->in_tensors()) { | for (const auto &tensor : kernel->in_tensors()) { | ||||
| auto iter = std::find(all_output_tensors.begin(), all_output_tensors.end(), tensor); | auto iter = std::find(all_output_tensors.begin(), all_output_tensors.end(), tensor); | ||||
| if (iter == all_output_tensors.end() && tensor->data_c() == nullptr) { | |||||
| if (iter == all_output_tensors.end() && | |||||
| !(tensor->category() == mindspore::lite::Tensor::CONST && tensor->data_c() != nullptr)) { | |||||
| input_tensors.emplace_back(tensor); | input_tensors.emplace_back(tensor); | ||||
| } | } | ||||
| } | } | ||||
| @@ -171,16 +171,16 @@ kernel::LiteKernel *CpuConvFp32KernelCreator(const std::vector<lite::Tensor *> & | |||||
| auto conv_param = reinterpret_cast<ConvParameter *>(op_parameter); | auto conv_param = reinterpret_cast<ConvParameter *>(op_parameter); | ||||
| int kernel_h = conv_param->kernel_h_; | int kernel_h = conv_param->kernel_h_; | ||||
| int kernel_w = conv_param->kernel_w_; | int kernel_w = conv_param->kernel_w_; | ||||
| conv_param->input_h_ = inputs.front()->Height(); | |||||
| conv_param->input_w_ = inputs.front()->Width(); | |||||
| conv_param->input_channel_ = inputs.front()->Channel(); | |||||
| conv_param->output_h_ = outputs.front()->Height(); | |||||
| conv_param->output_w_ = outputs.front()->Width(); | |||||
| conv_param->output_channel_ = outputs.front()->Channel(); | |||||
| conv_param->op_parameter_.thread_num_ = ctx->thread_num_; | |||||
| bool use_winograd = false; | bool use_winograd = false; | ||||
| int out_unit; | int out_unit; | ||||
| if (primitive != nullptr && primitive->GetInferFlag()) { | if (primitive != nullptr && primitive->GetInferFlag()) { | ||||
| conv_param->input_h_ = inputs.front()->Height(); | |||||
| conv_param->input_w_ = inputs.front()->Width(); | |||||
| conv_param->input_channel_ = inputs.front()->Channel(); | |||||
| conv_param->output_h_ = outputs.front()->Height(); | |||||
| conv_param->output_w_ = outputs.front()->Width(); | |||||
| conv_param->output_channel_ = outputs.front()->Channel(); | |||||
| conv_param->op_parameter_.thread_num_ = ctx->thread_num_; | |||||
| CheckIfUseWinograd(&use_winograd, &out_unit, conv_param); | CheckIfUseWinograd(&use_winograd, &out_unit, conv_param); | ||||
| } | } | ||||
| @@ -137,6 +137,49 @@ STATUS TfliteCustomParser::ExtractFeatures(const std::vector<uint8_t> &custom_at | |||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| STATUS TfliteCustomParser::Rfft(const std::vector<uint8_t> &custom_attr, schema::CNodeT *op, | |||||
| const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||||
| const std::unique_ptr<tflite::ModelT> &tflite_model) { | |||||
| std::unique_ptr<schema::RfftT> attr = std::make_unique<schema::RfftT>(); | |||||
| if (attr == nullptr) { | |||||
| MS_LOG(ERROR) << "new op failed"; | |||||
| return RET_NULL_PTR; | |||||
| } | |||||
| std::vector<int> fft_length; | |||||
| if (GetTfliteData(tflite_op->inputs[1], tflite_model->subgraphs[0]->tensors, tflite_model->buffers, fft_length)) { | |||||
| MS_LOG(ERROR) << "rfft -> fftLength get failed"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| attr->fftLength = fft_length[0]; | |||||
| op->primitive->value.type = schema::PrimitiveType_Rfft; | |||||
| op->primitive->value.value = attr.release(); | |||||
| return RET_OK; | |||||
| } | |||||
| STATUS TfliteCustomParser::FftReal(const std::vector<uint8_t> &custom_attr, schema::CNodeT *op, | |||||
| const std::unique_ptr<tflite::OperatorT> &tflite_op) { | |||||
| std::unique_ptr<schema::FftRealT> attr = std::make_unique<schema::FftRealT>(); | |||||
| if (attr == nullptr) { | |||||
| MS_LOG(ERROR) << "new op failed"; | |||||
| return RET_NULL_PTR; | |||||
| } | |||||
| op->primitive->value.type = schema::PrimitiveType_FftReal; | |||||
| op->primitive->value.value = attr.release(); | |||||
| return RET_OK; | |||||
| } | |||||
| STATUS TfliteCustomParser::FftImag(const std::vector<uint8_t> &custom_attr, schema::CNodeT *op, | |||||
| const std::unique_ptr<tflite::OperatorT> &tflite_op) { | |||||
| std::unique_ptr<schema::FftImagT> attr = std::make_unique<schema::FftImagT>(); | |||||
| if (attr == nullptr) { | |||||
| MS_LOG(ERROR) << "new op failed"; | |||||
| return RET_NULL_PTR; | |||||
| } | |||||
| op->primitive->value.type = schema::PrimitiveType_FftImag; | |||||
| op->primitive->value.value = attr.release(); | |||||
| return RET_OK; | |||||
| } | |||||
| STATUS TfliteCustomParser::Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op, | STATUS TfliteCustomParser::Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op, | ||||
| const std::unique_ptr<tflite::ModelT> &tflite_model, schema::CNodeT *op) { | const std::unique_ptr<tflite::ModelT> &tflite_model, schema::CNodeT *op) { | ||||
| MS_LOG(DEBUG) << "parse TfliteCustomParser"; | MS_LOG(DEBUG) << "parse TfliteCustomParser"; | ||||
| @@ -163,6 +206,12 @@ STATUS TfliteCustomParser::Parse(TfliteTensorsInfo *tensors_info, const std::uni | |||||
| status = ExtractFeatures(custom_attr, op, tflite_op); | status = ExtractFeatures(custom_attr, op, tflite_op); | ||||
| } else if (custom_type == "AudioSpectrogram") { | } else if (custom_type == "AudioSpectrogram") { | ||||
| status = AudioSpectrogram(custom_attr, op, tflite_op); | status = AudioSpectrogram(custom_attr, op, tflite_op); | ||||
| } else if (custom_type == "FlexRFFT") { | |||||
| status = Rfft(custom_attr, op, tflite_op, tflite_model); | |||||
| } else if (custom_type == "FlexReal") { | |||||
| status = FftReal(custom_attr, op, tflite_op); | |||||
| } else if (custom_type == "FlexImag") { | |||||
| status = FftImag(custom_attr, op, tflite_op); | |||||
| } else { | } else { | ||||
| MS_LOG(ERROR) << "the custom op hasn't been supported now"; | MS_LOG(ERROR) << "the custom op hasn't been supported now"; | ||||
| status = RET_NOT_FIND_OP; | status = RET_NOT_FIND_OP; | ||||
| @@ -49,6 +49,15 @@ class TfliteCustomParser : public TfliteNodeParser { | |||||
| STATUS ExtractFeatures(const std::vector<uint8_t> &custom_attr, schema::CNodeT *op, | STATUS ExtractFeatures(const std::vector<uint8_t> &custom_attr, schema::CNodeT *op, | ||||
| const std::unique_ptr<tflite::OperatorT> &tflite_op); | const std::unique_ptr<tflite::OperatorT> &tflite_op); | ||||
| STATUS Rfft(const std::vector<uint8_t> &custom_attr, schema::CNodeT *op, | |||||
| const std::unique_ptr<tflite::OperatorT> &tflite_op, const std::unique_ptr<tflite::ModelT> &tflite_model); | |||||
| STATUS FftReal(const std::vector<uint8_t> &custom_attr, schema::CNodeT *op, | |||||
| const std::unique_ptr<tflite::OperatorT> &tflite_op); | |||||
| STATUS FftImag(const std::vector<uint8_t> &custom_attr, schema::CNodeT *op, | |||||
| const std::unique_ptr<tflite::OperatorT> &tflite_op); | |||||
| }; | }; | ||||
| } // namespace lite | } // namespace lite | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -133,12 +133,12 @@ std::map<tflite::ActivationFunctionType, schema::ActivationType> tfMsActivationF | |||||
| }; | }; | ||||
| std::map<int, TypeId> type_map = { | std::map<int, TypeId> type_map = { | ||||
| {tflite::TensorType_FLOAT64, TypeId::kNumberTypeFloat64}, {tflite::TensorType_FLOAT32, TypeId::kNumberTypeFloat32}, | |||||
| {tflite::TensorType_FLOAT16, TypeId::kNumberTypeFloat16}, {tflite::TensorType_INT32, TypeId::kNumberTypeInt32}, | |||||
| {tflite::TensorType_INT16, TypeId::kNumberTypeInt16}, {tflite::TensorType_INT8, TypeId::kNumberTypeInt8}, | |||||
| {tflite::TensorType_INT64, TypeId::kNumberTypeInt64}, {tflite::TensorType_UINT8, TypeId::kNumberTypeUInt8}, | |||||
| {tflite::TensorType_BOOL, TypeId::kNumberTypeBool}, {tflite::TensorType_STRING, TypeId::kObjectTypeString}, | |||||
| }; | |||||
| {tflite::TensorType_FLOAT64, TypeId::kNumberTypeFloat64}, {tflite::TensorType_FLOAT32, TypeId::kNumberTypeFloat32}, | |||||
| {tflite::TensorType_FLOAT16, TypeId::kNumberTypeFloat16}, {tflite::TensorType_INT32, TypeId::kNumberTypeInt32}, | |||||
| {tflite::TensorType_INT16, TypeId::kNumberTypeInt16}, {tflite::TensorType_INT8, TypeId::kNumberTypeInt8}, | |||||
| {tflite::TensorType_INT64, TypeId::kNumberTypeInt64}, {tflite::TensorType_UINT8, TypeId::kNumberTypeUInt8}, | |||||
| {tflite::TensorType_BOOL, TypeId::kNumberTypeBool}, {tflite::TensorType_STRING, TypeId::kObjectTypeString}, | |||||
| {tflite::TensorType_COMPLEX64, TypeId::kNumberTypeComplex64}}; | |||||
| schema::ActivationType GetActivationFunctionType(tflite::ActivationFunctionType tfliteAFType) { | schema::ActivationType GetActivationFunctionType(tflite::ActivationFunctionType tfliteAFType) { | ||||
| return tfMsActivationFunctionMap.at(tfliteAFType); | return tfMsActivationFunctionMap.at(tfliteAFType); | ||||