| @@ -207,6 +207,11 @@ union PrimitiveType { | |||||
| LshProjection, | LshProjection, | ||||
| HashtableLookup, | HashtableLookup, | ||||
| SkipGram, | SkipGram, | ||||
| CustomPredict, | |||||
| CustomNormalize, | |||||
| CustomExtractFeatures, | |||||
| AudioSpectrogram, | |||||
| Mfcc, | |||||
| } | } | ||||
| enum QuantType: int { | enum QuantType: int { | ||||
| @@ -963,3 +963,27 @@ table SkipGram { | |||||
| maxSkipSize : int; | maxSkipSize : int; | ||||
| ngramSize : int; | ngramSize : int; | ||||
| } | } | ||||
| table CustomPredict { | |||||
| outputNum : int; | |||||
| weightThreshold : float; | |||||
| } | |||||
| table CustomNormalize { | |||||
| } | |||||
| table CustomExtractFeatures { | |||||
| } | |||||
| table AudioSpectrogram { | |||||
| windowSize : int; | |||||
| stride : int; | |||||
| magSquare : bool; | |||||
| } | |||||
| table Mfcc { | |||||
| freqUpperLimit : float; | |||||
| freqLowerLimit : float; | |||||
| filterBankChannelNum : int; | |||||
| dctCoeffNum : int; | |||||
| } | |||||
| @@ -172,9 +172,11 @@ STATUS OnnxModelParser::SetGraphOutputTensor(const onnx::GraphProto &onnx_graph, | |||||
| } | } | ||||
| void OnnxModelParser::ParseOnnxGemmNode(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, | void OnnxModelParser::ParseOnnxGemmNode(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, | ||||
| schema::MetaGraphT *graph, TensorCache *tensor_cache) { | |||||
| schema::MetaGraphT *graph, TensorCache *tensor_cache, | |||||
| const QuantType &quant_type) { | |||||
| std::unique_ptr<schema::CNodeT> dst_op_1 = std::make_unique<schema::CNodeT>(); | std::unique_ptr<schema::CNodeT> dst_op_1 = std::make_unique<schema::CNodeT>(); | ||||
| dst_op_1->name = "Gemm_MatMul_" + onnx_node.output(0); | dst_op_1->name = "Gemm_MatMul_" + onnx_node.output(0); | ||||
| dst_op_1->quantType = quant_type; | |||||
| ParseOnnxNodeAttr(onnx_graph, onnx_node, "MatMul", dst_op_1.get()); | ParseOnnxNodeAttr(onnx_graph, onnx_node, "MatMul", dst_op_1.get()); | ||||
| auto matmul_output_id = "Gemm_MatMul_" + onnx_node.output(0); | auto matmul_output_id = "Gemm_MatMul_" + onnx_node.output(0); | ||||
| std::vector<string> matmul_inputs{onnx_node.input(0), onnx_node.input(1)}; | std::vector<string> matmul_inputs{onnx_node.input(0), onnx_node.input(1)}; | ||||
| @@ -185,6 +187,7 @@ void OnnxModelParser::ParseOnnxGemmNode(const onnx::GraphProto &onnx_graph, cons | |||||
| std::unique_ptr<schema::CNodeT> dst_op_2 = std::make_unique<schema::CNodeT>(); | std::unique_ptr<schema::CNodeT> dst_op_2 = std::make_unique<schema::CNodeT>(); | ||||
| dst_op_2->name = "Gemm_BiasAdd_" + onnx_node.output(0); | dst_op_2->name = "Gemm_BiasAdd_" + onnx_node.output(0); | ||||
| dst_op_2->quantType = quant_type; | |||||
| ParseOnnxNodeAttr(onnx_graph, onnx_node, "BiasAdd", dst_op_2.get()); | ParseOnnxNodeAttr(onnx_graph, onnx_node, "BiasAdd", dst_op_2.get()); | ||||
| std::vector<string> biasadd_inputs{matmul_output_id, onnx_node.input(2)}; | std::vector<string> biasadd_inputs{matmul_output_id, onnx_node.input(2)}; | ||||
| std::vector<string> biasadd_outputs{onnx_node.output(0)}; | std::vector<string> biasadd_outputs{onnx_node.output(0)}; | ||||
| @@ -343,8 +346,6 @@ void OnnxModelParser::SetOpQuantParams(const onnx::GraphProto &onnx_graph, const | |||||
| } | } | ||||
| if (findQuantParams == needQuantParams) { | if (findQuantParams == needQuantParams) { | ||||
| dst_op->quantType = schema::QuantType_AwareTraining; | dst_op->quantType = schema::QuantType_AwareTraining; | ||||
| } else { | |||||
| dst_op->quantType = schema::QuantType_QUANT_NONE; | |||||
| } | } | ||||
| } | } | ||||
| @@ -520,7 +521,7 @@ schema::MetaGraphT *OnnxModelParser::ParseToFb(const std::string &modelFile, con | |||||
| } | } | ||||
| if (onnx_node.op_type() == "Gemm") { | if (onnx_node.op_type() == "Gemm") { | ||||
| if (status == RET_OK) { | if (status == RET_OK) { | ||||
| ParseOnnxGemmNode(onnx_graph, onnx_node, dst_graph.get(), &tensor_cache); | |||||
| ParseOnnxGemmNode(onnx_graph, onnx_node, dst_graph.get(), &tensor_cache, quantType); | |||||
| } | } | ||||
| continue; | continue; | ||||
| } else if (onnx_node.op_type() == "Int8GivenIntTensorFill" || onnx_node.op_type() == "Int8GivenTensorFill") { | } else if (onnx_node.op_type() == "Int8GivenIntTensorFill" || onnx_node.op_type() == "Int8GivenTensorFill") { | ||||
| @@ -65,7 +65,7 @@ class OnnxModelParser : public ModelParser { | |||||
| const QuantType &quantType); | const QuantType &quantType); | ||||
| void ParseOnnxGemmNode(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, | void ParseOnnxGemmNode(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, | ||||
| schema::MetaGraphT *graph, TensorCache *tensor_cache); | |||||
| schema::MetaGraphT *graph, TensorCache *tensor_cache, const QuantType &quant_type); | |||||
| STATUS ParseOnnxGivenFillNode(const onnx::NodeProto &onnx_node, TensorCache *tensor_cache); | STATUS ParseOnnxGivenFillNode(const onnx::NodeProto &onnx_node, TensorCache *tensor_cache); | ||||
| @@ -23,26 +23,14 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | namespace lite { | ||||
| STATUS TfliteCustomParser::Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||||
| const std::unique_ptr<tflite::ModelT> &tflite_model, schema::CNodeT *op) { | |||||
| MS_LOG(DEBUG) << "parse TfliteCustomParser"; | |||||
| if (op == nullptr) { | |||||
| MS_LOG(ERROR) << "op is null"; | |||||
| return RET_NULL_PTR; | |||||
| } | |||||
| op->primitive = std::make_unique<schema::PrimitiveT>(); | |||||
| if (op->primitive == nullptr) { | |||||
| MS_LOG(ERROR) << "op->primitive is null"; | |||||
| return RET_NULL_PTR; | |||||
| } | |||||
| STATUS TfliteCustomParser::DetectPostProcess(const std::vector<uint8_t> &custom_attr, schema::CNodeT *op, | |||||
| const std::unique_ptr<tflite::OperatorT> &tflite_op) { | |||||
| std::unique_ptr<schema::DetectionPostProcessT> attr = std::make_unique<schema::DetectionPostProcessT>(); | std::unique_ptr<schema::DetectionPostProcessT> attr = std::make_unique<schema::DetectionPostProcessT>(); | ||||
| if (attr == nullptr) { | if (attr == nullptr) { | ||||
| MS_LOG(ERROR) << "new op failed"; | MS_LOG(ERROR) << "new op failed"; | ||||
| return RET_NULL_PTR; | return RET_NULL_PTR; | ||||
| } | } | ||||
| const auto &custom_attr = tflite_op->custom_options; | |||||
| auto attr_map = flexbuffers::GetRoot(custom_attr).AsMap(); | auto attr_map = flexbuffers::GetRoot(custom_attr).AsMap(); | ||||
| attr->format = schema::Format::Format_NHWC; | attr->format = schema::Format::Format_NHWC; | ||||
| attr->inputSize = tflite_op->inputs.size(); | attr->inputSize = tflite_op->inputs.size(); | ||||
| @@ -73,7 +61,115 @@ STATUS TfliteCustomParser::Parse(TfliteTensorsInfo *tensors_info, const std::uni | |||||
| op->primitive->value.type = schema::PrimitiveType_DetectionPostProcess; | op->primitive->value.type = schema::PrimitiveType_DetectionPostProcess; | ||||
| op->primitive->value.value = attr.release(); | op->primitive->value.value = attr.release(); | ||||
| return RET_OK; | |||||
| } | |||||
| STATUS TfliteCustomParser::AudioSpectrogram(const std::vector<uint8_t> &custom_attr, schema::CNodeT *op, | |||||
| const std::unique_ptr<tflite::OperatorT> &tflite_op) { | |||||
| std::unique_ptr<schema::AudioSpectrogramT> attr = std::make_unique<schema::AudioSpectrogramT>(); | |||||
| if (attr == nullptr) { | |||||
| MS_LOG(ERROR) << "new op failed"; | |||||
| return RET_NULL_PTR; | |||||
| } | |||||
| auto attr_map = flexbuffers::GetRoot(custom_attr).AsMap(); | |||||
| attr->windowSize = attr_map["window_size"].AsInt64(); | |||||
| attr->stride = attr_map["stride"].AsInt64(); | |||||
| attr->magSquare = attr_map["magnitude_squared"].AsBool(); | |||||
| op->primitive->value.type = schema::PrimitiveType_AudioSpectrogram; | |||||
| op->primitive->value.value = attr.release(); | |||||
| return RET_OK; | |||||
| } | |||||
| STATUS TfliteCustomParser::Mfcc(const std::vector<uint8_t> &custom_attr, schema::CNodeT *op, | |||||
| const std::unique_ptr<tflite::OperatorT> &tflite_op) { | |||||
| std::unique_ptr<schema::MfccT> attr = std::make_unique<schema::MfccT>(); | |||||
| if (attr == nullptr) { | |||||
| MS_LOG(ERROR) << "new op failed"; | |||||
| return RET_NULL_PTR; | |||||
| } | |||||
| auto attr_map = flexbuffers::GetRoot(custom_attr).AsMap(); | |||||
| attr->freqUpperLimit = attr_map["upper_frequency_limit"].AsInt64(); | |||||
| attr->freqLowerLimit = attr_map["lower_frequency_limit"].AsInt64(); | |||||
| attr->filterBankChannelNum = attr_map["filterbank_channel_count"].AsInt64(); | |||||
| attr->dctCoeffNum = attr_map["dct_coefficient_count"].AsInt64(); | |||||
| op->primitive->value.type = schema::PrimitiveType_Mfcc; | |||||
| op->primitive->value.value = attr.release(); | |||||
| return RET_OK; | |||||
| } | |||||
| STATUS TfliteCustomParser::Predict(const std::vector<uint8_t> &custom_attr, schema::CNodeT *op, | |||||
| const std::unique_ptr<tflite::OperatorT> &tflite_op) { | |||||
| std::unique_ptr<schema::CustomPredictT> attr = std::make_unique<schema::CustomPredictT>(); | |||||
| if (attr == nullptr) { | |||||
| MS_LOG(ERROR) << "new op failed"; | |||||
| return RET_NULL_PTR; | |||||
| } | |||||
| attr->outputNum = reinterpret_cast<const int *>(custom_attr.data())[0]; | |||||
| attr->weightThreshold = reinterpret_cast<const float *>(custom_attr.data())[1]; | |||||
| op->primitive->value.type = schema::PrimitiveType_CustomPredict; | |||||
| op->primitive->value.value = attr.release(); | |||||
| return RET_OK; | |||||
| } | |||||
| STATUS TfliteCustomParser::Normalize(const std::vector<uint8_t> &custom_attr, schema::CNodeT *op, | |||||
| const std::unique_ptr<tflite::OperatorT> &tflite_op) { | |||||
| std::unique_ptr<schema::CustomNormalizeT> attr = std::make_unique<schema::CustomNormalizeT>(); | |||||
| if (attr == nullptr) { | |||||
| MS_LOG(ERROR) << "new op failed"; | |||||
| return RET_NULL_PTR; | |||||
| } | |||||
| op->primitive->value.type = schema::PrimitiveType_CustomNormalize; | |||||
| op->primitive->value.value = attr.release(); | |||||
| return RET_OK; | |||||
| } | |||||
| STATUS TfliteCustomParser::ExtractFeatures(const std::vector<uint8_t> &custom_attr, schema::CNodeT *op, | |||||
| const std::unique_ptr<tflite::OperatorT> &tflite_op) { | |||||
| std::unique_ptr<schema::CustomExtractFeaturesT> attr = std::make_unique<schema::CustomExtractFeaturesT>(); | |||||
| if (attr == nullptr) { | |||||
| MS_LOG(ERROR) << "new op failed"; | |||||
| return RET_NULL_PTR; | |||||
| } | |||||
| op->primitive->value.type = schema::PrimitiveType_CustomExtractFeatures; | |||||
| op->primitive->value.value = attr.release(); | |||||
| return RET_OK; | |||||
| } | |||||
| STATUS TfliteCustomParser::Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||||
| const std::unique_ptr<tflite::ModelT> &tflite_model, schema::CNodeT *op) { | |||||
| MS_LOG(DEBUG) << "parse TfliteCustomParser"; | |||||
| if (op == nullptr) { | |||||
| MS_LOG(ERROR) << "op is null"; | |||||
| return RET_NULL_PTR; | |||||
| } | |||||
| op->primitive = std::make_unique<schema::PrimitiveT>(); | |||||
| if (op->primitive == nullptr) { | |||||
| MS_LOG(ERROR) << "op->primitive is null"; | |||||
| return RET_NULL_PTR; | |||||
| } | |||||
| const auto &custom_attr = tflite_op->custom_options; | |||||
| const auto &opcode_index = tflite_op->opcode_index; | |||||
| const auto &custom_type = tflite_model->operator_codes[opcode_index]->custom_code; | |||||
| int status = RET_OK; | |||||
| if (custom_type == "TFLite_Detection_PostProcess") { | |||||
| status = DetectPostProcess(custom_attr, op, tflite_op); | |||||
| } else if (custom_type == "Predict") { | |||||
| status = Predict(custom_attr, op, tflite_op); | |||||
| } else if (custom_type == "Normalize") { | |||||
| status = Normalize(custom_attr, op, tflite_op); | |||||
| } else if (custom_type == "ExtractFeatures") { | |||||
| status = ExtractFeatures(custom_attr, op, tflite_op); | |||||
| } else if (custom_type == "AudioSpectrogram") { | |||||
| status = AudioSpectrogram(custom_attr, op, tflite_op); | |||||
| } else { | |||||
| MS_LOG(ERROR) << "the custom op hasn't been supported now"; | |||||
| status = RET_NOT_FIND_OP; | |||||
| } | |||||
| if (status != RET_OK) { | |||||
| return status; | |||||
| } | |||||
| for (size_t i = 0; i < tflite_op->inputs.size(); ++i) { | for (size_t i = 0; i < tflite_op->inputs.size(); ++i) { | ||||
| AddOpInput(op, tensors_info, tflite_op->inputs[i], tflite_model->subgraphs[0]->tensors.size(), | AddOpInput(op, tensors_info, tflite_op->inputs[i], tflite_model->subgraphs[0]->tensors.size(), | ||||
| schema::Format::Format_NHWC); | schema::Format::Format_NHWC); | ||||
| @@ -82,7 +178,7 @@ STATUS TfliteCustomParser::Parse(TfliteTensorsInfo *tensors_info, const std::uni | |||||
| AddOpOutput(op, tensors_info, tflite_op->outputs[i], tflite_model->subgraphs[0]->tensors.size(), | AddOpOutput(op, tensors_info, tflite_op->outputs[i], tflite_model->subgraphs[0]->tensors.size(), | ||||
| schema::Format::Format_NHWC); | schema::Format::Format_NHWC); | ||||
| } | } | ||||
| return RET_OK; | |||||
| return status; | |||||
| } | } | ||||
| TfliteNodeRegister g_tfliteCustomParser("Custom", new TfliteCustomParser()); | TfliteNodeRegister g_tfliteCustomParser("Custom", new TfliteCustomParser()); | ||||
| @@ -31,6 +31,24 @@ class TfliteCustomParser : public TfliteNodeParser { | |||||
| STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op, | STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op, | ||||
| const std::unique_ptr<tflite::ModelT> &tflite_model, schema::CNodeT *op) override; | const std::unique_ptr<tflite::ModelT> &tflite_model, schema::CNodeT *op) override; | ||||
| STATUS DetectPostProcess(const std::vector<uint8_t> &custom_attr, schema::CNodeT *op, | |||||
| const std::unique_ptr<tflite::OperatorT> &tflite_op); | |||||
| STATUS AudioSpectrogram(const std::vector<uint8_t> &custom_attr, schema::CNodeT *op, | |||||
| const std::unique_ptr<tflite::OperatorT> &tflite_op); | |||||
| STATUS Mfcc(const std::vector<uint8_t> &custom_attr, schema::CNodeT *op, | |||||
| const std::unique_ptr<tflite::OperatorT> &tflite_op); | |||||
| STATUS Predict(const std::vector<uint8_t> &custom_attr, schema::CNodeT *op, | |||||
| const std::unique_ptr<tflite::OperatorT> &tflite_op); | |||||
| STATUS Normalize(const std::vector<uint8_t> &custom_attr, schema::CNodeT *op, | |||||
| const std::unique_ptr<tflite::OperatorT> &tflite_op); | |||||
| STATUS ExtractFeatures(const std::vector<uint8_t> &custom_attr, schema::CNodeT *op, | |||||
| const std::unique_ptr<tflite::OperatorT> &tflite_op); | |||||
| }; | }; | ||||
| } // namespace lite | } // namespace lite | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -47,14 +47,11 @@ std::unique_ptr<tflite::ModelT> TfliteModelParser::ReadTfliteModel(const char *m | |||||
| STATUS TfliteModelParser::CopyConstTensorData(const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer, | STATUS TfliteModelParser::CopyConstTensorData(const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer, | ||||
| const tflite::TensorT *tflite_tensor, schema::TensorT *tensor) { | const tflite::TensorT *tflite_tensor, schema::TensorT *tensor) { | ||||
| auto count = 1; | |||||
| std::for_each(tflite_tensor->shape.begin(), tflite_tensor->shape.end(), [&](int32_t sha) { count *= sha; }); | |||||
| auto data_size = count * GetDataTypeSize(TypeId(tensor->dataType)); | |||||
| auto buffer_idx = tflite_tensor->buffer; | auto buffer_idx = tflite_tensor->buffer; | ||||
| if (!tflite_model_buffer[buffer_idx]->data.empty()) { | if (!tflite_model_buffer[buffer_idx]->data.empty()) { | ||||
| auto data_size = tflite_model_buffer[buffer_idx]->data.size(); | |||||
| tensor->data.resize(data_size); | tensor->data.resize(data_size); | ||||
| if (memcpy_s(tensor->data.data(), tensor->data.size(), tflite_model_buffer[buffer_idx]->data.data(), | |||||
| tflite_model_buffer[buffer_idx]->data.size())) { | |||||
| if (memcpy_s(tensor->data.data(), data_size, tflite_model_buffer[buffer_idx]->data.data(), data_size) != EOK) { | |||||
| MS_LOG(ERROR) << "memcpy tensor data failed"; | MS_LOG(ERROR) << "memcpy tensor data failed"; | ||||
| return RET_MEMORY_FAILED; | return RET_MEMORY_FAILED; | ||||
| } | } | ||||
| @@ -120,6 +120,9 @@ std::map<tflite::BuiltinOperator, std::string> tfMsOpTypeMap{ | |||||
| {tflite::BuiltinOperator_MIRROR_PAD, "MirrorPad"}, | {tflite::BuiltinOperator_MIRROR_PAD, "MirrorPad"}, | ||||
| {tflite::BuiltinOperator_NEG, "Neg"}, | {tflite::BuiltinOperator_NEG, "Neg"}, | ||||
| {tflite::BuiltinOperator_PRELU, "PRELU"}, | {tflite::BuiltinOperator_PRELU, "PRELU"}, | ||||
| {tflite::BuiltinOperator_HASHTABLE_LOOKUP, "HashtableLookup"}, | |||||
| {tflite::BuiltinOperator_LSH_PROJECTION, "LshProjection"}, | |||||
| {tflite::BuiltinOperator_SKIP_GRAM, "SKipGram"}, | |||||
| }; | }; | ||||
| std::map<tflite::ActivationFunctionType, schema::ActivationType> tfMsActivationFunctionMap{ | std::map<tflite::ActivationFunctionType, schema::ActivationType> tfMsActivationFunctionMap{ | ||||
| @@ -134,7 +137,7 @@ std::map<int, TypeId> type_map = { | |||||
| {tflite::TensorType_FLOAT16, TypeId::kNumberTypeFloat16}, {tflite::TensorType_INT32, TypeId::kNumberTypeInt32}, | {tflite::TensorType_FLOAT16, TypeId::kNumberTypeFloat16}, {tflite::TensorType_INT32, TypeId::kNumberTypeInt32}, | ||||
| {tflite::TensorType_INT16, TypeId::kNumberTypeInt16}, {tflite::TensorType_INT8, TypeId::kNumberTypeInt8}, | {tflite::TensorType_INT16, TypeId::kNumberTypeInt16}, {tflite::TensorType_INT8, TypeId::kNumberTypeInt8}, | ||||
| {tflite::TensorType_INT64, TypeId::kNumberTypeInt64}, {tflite::TensorType_UINT8, TypeId::kNumberTypeUInt8}, | {tflite::TensorType_INT64, TypeId::kNumberTypeInt64}, {tflite::TensorType_UINT8, TypeId::kNumberTypeUInt8}, | ||||
| {tflite::TensorType_BOOL, TypeId::kNumberTypeBool}, | |||||
| {tflite::TensorType_BOOL, TypeId::kNumberTypeBool}, {tflite::TensorType_STRING, TypeId::kObjectTypeString}, | |||||
| }; | }; | ||||
| schema::ActivationType GetActivationFunctionType(tflite::ActivationFunctionType tfliteAFType) { | schema::ActivationType GetActivationFunctionType(tflite::ActivationFunctionType tfliteAFType) { | ||||
| @@ -117,7 +117,7 @@ int GenConvNewBias(const FuncGraphPtr &func_graph, const CNodePtr &conv_node, co | |||||
| } | } | ||||
| } else { | } else { | ||||
| if (EOK != memcpy_s(add_bias_data, kernel_nums * sizeof(float), add_weight_data, kernel_nums * sizeof(float))) { | if (EOK != memcpy_s(add_bias_data, kernel_nums * sizeof(float), add_weight_data, kernel_nums * sizeof(float))) { | ||||
| MS_LOG(ERROR) << "memset_s conv_bias_data failed"; | |||||
| MS_LOG(ERROR) << "memcpy_s conv_bias_data failed"; | |||||
| delete[] add_bias_data; | delete[] add_bias_data; | ||||
| return lite::RET_MEMORY_FAILED; | return lite::RET_MEMORY_FAILED; | ||||
| } | } | ||||