|
|
|
@@ -137,6 +137,49 @@ STATUS TfliteCustomParser::ExtractFeatures(const std::vector<uint8_t> &custom_at |
|
|
|
return RET_OK; |
|
|
|
} |
|
|
|
|
|
|
|
STATUS TfliteCustomParser::Rfft(const std::vector<uint8_t> &custom_attr, schema::CNodeT *op, |
|
|
|
const std::unique_ptr<tflite::OperatorT> &tflite_op, |
|
|
|
const std::unique_ptr<tflite::ModelT> &tflite_model) { |
|
|
|
std::unique_ptr<schema::RfftT> attr = std::make_unique<schema::RfftT>(); |
|
|
|
if (attr == nullptr) { |
|
|
|
MS_LOG(ERROR) << "new op failed"; |
|
|
|
return RET_NULL_PTR; |
|
|
|
} |
|
|
|
std::vector<int> fft_length; |
|
|
|
if (GetTfliteData(tflite_op->inputs[1], tflite_model->subgraphs[0]->tensors, tflite_model->buffers, fft_length)) { |
|
|
|
MS_LOG(ERROR) << "rfft -> fftLength get failed"; |
|
|
|
return RET_ERROR; |
|
|
|
} |
|
|
|
attr->fftLength = fft_length[0]; |
|
|
|
op->primitive->value.type = schema::PrimitiveType_Rfft; |
|
|
|
op->primitive->value.value = attr.release(); |
|
|
|
return RET_OK; |
|
|
|
} |
|
|
|
|
|
|
|
STATUS TfliteCustomParser::FftReal(const std::vector<uint8_t> &custom_attr, schema::CNodeT *op, |
|
|
|
const std::unique_ptr<tflite::OperatorT> &tflite_op) { |
|
|
|
std::unique_ptr<schema::FftRealT> attr = std::make_unique<schema::FftRealT>(); |
|
|
|
if (attr == nullptr) { |
|
|
|
MS_LOG(ERROR) << "new op failed"; |
|
|
|
return RET_NULL_PTR; |
|
|
|
} |
|
|
|
op->primitive->value.type = schema::PrimitiveType_FftReal; |
|
|
|
op->primitive->value.value = attr.release(); |
|
|
|
return RET_OK; |
|
|
|
} |
|
|
|
|
|
|
|
STATUS TfliteCustomParser::FftImag(const std::vector<uint8_t> &custom_attr, schema::CNodeT *op, |
|
|
|
const std::unique_ptr<tflite::OperatorT> &tflite_op) { |
|
|
|
std::unique_ptr<schema::FftImagT> attr = std::make_unique<schema::FftImagT>(); |
|
|
|
if (attr == nullptr) { |
|
|
|
MS_LOG(ERROR) << "new op failed"; |
|
|
|
return RET_NULL_PTR; |
|
|
|
} |
|
|
|
op->primitive->value.type = schema::PrimitiveType_FftImag; |
|
|
|
op->primitive->value.value = attr.release(); |
|
|
|
return RET_OK; |
|
|
|
} |
|
|
|
|
|
|
|
STATUS TfliteCustomParser::Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op, |
|
|
|
const std::unique_ptr<tflite::ModelT> &tflite_model, schema::CNodeT *op) { |
|
|
|
MS_LOG(DEBUG) << "parse TfliteCustomParser"; |
|
|
|
@@ -163,6 +206,12 @@ STATUS TfliteCustomParser::Parse(TfliteTensorsInfo *tensors_info, const std::uni |
|
|
|
status = ExtractFeatures(custom_attr, op, tflite_op); |
|
|
|
} else if (custom_type == "AudioSpectrogram") { |
|
|
|
status = AudioSpectrogram(custom_attr, op, tflite_op); |
|
|
|
} else if (custom_type == "FlexRFFT") { |
|
|
|
status = Rfft(custom_attr, op, tflite_op, tflite_model); |
|
|
|
} else if (custom_type == "FlexReal") { |
|
|
|
status = FftReal(custom_attr, op, tflite_op); |
|
|
|
} else if (custom_type == "FlexImag") { |
|
|
|
status = FftImag(custom_attr, op, tflite_op); |
|
|
|
} else { |
|
|
|
MS_LOG(ERROR) << "the custom op hasn't been supported now"; |
|
|
|
status = RET_NOT_FIND_OP; |
|
|
|
|