| @@ -57,7 +57,7 @@ STATUS TfliteDoubleInputOpParser::Parse(const std::unique_ptr<tflite::OperatorT> | |||||
| } | } | ||||
| if (!x_data->data.empty()) { | if (!x_data->data.empty()) { | ||||
| std::vector<tflite::TensorT *> x_tensors{x_tensor.get()}; | std::vector<tflite::TensorT *> x_tensors{x_tensor.get()}; | ||||
| if (RET_OK != ParseTensor(x_tensors, tfliteModelBuffer, tensor_cache, TF_CONST)) { | |||||
| if (RET_OK != ParseTensor(x_tensors, tfliteModelBuffer, tensor_cache, TF_CONST, false)) { | |||||
| MS_LOG(ERROR) << "parse the first tensor failed"; | MS_LOG(ERROR) << "parse the first tensor failed"; | ||||
| return RET_ERROR; | return RET_ERROR; | ||||
| } | } | ||||
| @@ -76,7 +76,7 @@ STATUS TfliteDoubleInputOpParser::Parse(const std::unique_ptr<tflite::OperatorT> | |||||
| } | } | ||||
| if (!y_data->data.empty()) { | if (!y_data->data.empty()) { | ||||
| std::vector<tflite::TensorT *> y_tensors{y_tensor.get()}; | std::vector<tflite::TensorT *> y_tensors{y_tensor.get()}; | ||||
| if (RET_OK != ParseTensor(y_tensors, tfliteModelBuffer, tensor_cache, TF_CONST)) { | |||||
| if (RET_OK != ParseTensor(y_tensors, tfliteModelBuffer, tensor_cache, TF_CONST, false)) { | |||||
| MS_LOG(ERROR) << "parse the second tensor failed"; | MS_LOG(ERROR) << "parse the second tensor failed"; | ||||
| return RET_ERROR; | return RET_ERROR; | ||||
| } | } | ||||
| @@ -59,7 +59,7 @@ STATUS TfliteConvParser::Parse(const std::unique_ptr<tflite::OperatorT> &tfliteO | |||||
| return RET_NULL_PTR; | return RET_NULL_PTR; | ||||
| } | } | ||||
| std::vector<tflite::TensorT *> weight_tensors{weight_tensor.get()}; | std::vector<tflite::TensorT *> weight_tensors{weight_tensor.get()}; | ||||
| if (RET_OK != ParseTensor(weight_tensors, tfliteModelBuffer, tensor_cache, TF_CONST)) { | |||||
| if (RET_OK != ParseTensor(weight_tensors, tfliteModelBuffer, tensor_cache, TF_CONST, true)) { | |||||
| MS_LOG(ERROR) << "parse weight failed"; | MS_LOG(ERROR) << "parse weight failed"; | ||||
| return RET_ERROR; | return RET_ERROR; | ||||
| } | } | ||||
| @@ -79,7 +79,7 @@ STATUS TfliteConvParser::Parse(const std::unique_ptr<tflite::OperatorT> &tfliteO | |||||
| return RET_NULL_PTR; | return RET_NULL_PTR; | ||||
| } | } | ||||
| std::vector<tflite::TensorT *> bias_tensors{bias_tensor.get()}; | std::vector<tflite::TensorT *> bias_tensors{bias_tensor.get()}; | ||||
| if (RET_OK != ParseTensor(bias_tensors, tfliteModelBuffer, tensor_cache, TF_CONST)) { | |||||
| if (RET_OK != ParseTensor(bias_tensors, tfliteModelBuffer, tensor_cache, TF_CONST, false)) { | |||||
| MS_LOG(ERROR) << "parse bias failed"; | MS_LOG(ERROR) << "parse bias failed"; | ||||
| return RET_ERROR; | return RET_ERROR; | ||||
| } | } | ||||
| @@ -59,7 +59,7 @@ STATUS TfliteDeConvParser::Parse(const std::unique_ptr<tflite::OperatorT> &tflit | |||||
| return RET_NULL_PTR; | return RET_NULL_PTR; | ||||
| } | } | ||||
| std::vector<tflite::TensorT *> weight_tensors{weight_tensor.get()}; | std::vector<tflite::TensorT *> weight_tensors{weight_tensor.get()}; | ||||
| if (RET_OK != ParseTensor(weight_tensors, tfliteModelBuffer, tensor_cache, TF_CONST)) { | |||||
| if (RET_OK != ParseTensor(weight_tensors, tfliteModelBuffer, tensor_cache, TF_CONST, true)) { | |||||
| return RET_ERROR; | return RET_ERROR; | ||||
| } | } | ||||
| auto weight_shape = weight_tensor->shape; | auto weight_shape = weight_tensor->shape; | ||||
| @@ -123,7 +123,7 @@ STATUS TfliteDepthwiseConv2DParser::Parse(const std::unique_ptr<tflite::Operator | |||||
| std::vector<tflite::TensorT *> weight_tensors{weight_tensor.get()}; | std::vector<tflite::TensorT *> weight_tensors{weight_tensor.get()}; | ||||
| if (RET_OK != ParseTensor(weight_tensors, tfliteModelBuffer, tensor_cache, TF_CONST)) { | |||||
| if (RET_OK != ParseTensor(weight_tensors, tfliteModelBuffer, tensor_cache, TF_CONST, true)) { | |||||
| MS_LOG(ERROR) << "parse weight failed"; | MS_LOG(ERROR) << "parse weight failed"; | ||||
| return RET_ERROR; | return RET_ERROR; | ||||
| } | } | ||||
| @@ -133,7 +133,7 @@ STATUS TfliteDepthwiseConv2DParser::Parse(const std::unique_ptr<tflite::Operator | |||||
| auto bias_index = tflite_op->inputs[2]; | auto bias_index = tflite_op->inputs[2]; | ||||
| const auto &bias_tensor = tflite_tensors[bias_index]; | const auto &bias_tensor = tflite_tensors[bias_index]; | ||||
| std::vector<tflite::TensorT *> bias_tensors{bias_tensor.get()}; | std::vector<tflite::TensorT *> bias_tensors{bias_tensor.get()}; | ||||
| if (RET_OK != ParseTensor(bias_tensors, tfliteModelBuffer, tensor_cache, TF_CONST)) { | |||||
| if (RET_OK != ParseTensor(bias_tensors, tfliteModelBuffer, tensor_cache, TF_CONST, false)) { | |||||
| MS_LOG(ERROR) << "parse bias failed"; | MS_LOG(ERROR) << "parse bias failed"; | ||||
| return RET_ERROR; | return RET_ERROR; | ||||
| } | } | ||||
| @@ -44,7 +44,7 @@ STATUS TfliteFakeQuantParser::Parse(const std::unique_ptr<tflite::OperatorT> &tf | |||||
| return RET_NULL_PTR; | return RET_NULL_PTR; | ||||
| } | } | ||||
| std::vector<tflite::TensorT *> weight_tensors{weight_tensor.get()}; | std::vector<tflite::TensorT *> weight_tensors{weight_tensor.get()}; | ||||
| if (RET_OK != ParseTensor(weight_tensors, tfliteModelBuffer, tensor_cache, TF_CONST)) { | |||||
| if (RET_OK != ParseTensor(weight_tensors, tfliteModelBuffer, tensor_cache, TF_CONST, true)) { | |||||
| MS_LOG(ERROR) << "parse weight failed"; | MS_LOG(ERROR) << "parse weight failed"; | ||||
| return RET_ERROR; | return RET_ERROR; | ||||
| } | } | ||||
| @@ -58,7 +58,7 @@ STATUS TfliteFakeQuantParser::Parse(const std::unique_ptr<tflite::OperatorT> &tf | |||||
| return RET_NULL_PTR; | return RET_NULL_PTR; | ||||
| } | } | ||||
| std::vector<tflite::TensorT *> bias_tensors{bias_tensor.get()}; | std::vector<tflite::TensorT *> bias_tensors{bias_tensor.get()}; | ||||
| if (RET_OK != ParseTensor(bias_tensors, tfliteModelBuffer, tensor_cache, TF_CONST)) { | |||||
| if (RET_OK != ParseTensor(bias_tensors, tfliteModelBuffer, tensor_cache, TF_CONST, false)) { | |||||
| MS_LOG(ERROR) << "parse bias failed"; | MS_LOG(ERROR) << "parse bias failed"; | ||||
| return RET_ERROR; | return RET_ERROR; | ||||
| } | } | ||||
| @@ -45,7 +45,7 @@ STATUS TfliteFullyConnectedParser::Parse(const std::unique_ptr<tflite::OperatorT | |||||
| return RET_NULL_PTR; | return RET_NULL_PTR; | ||||
| } | } | ||||
| std::vector<tflite::TensorT *> weight_tensors{weight_tensor.get()}; | std::vector<tflite::TensorT *> weight_tensors{weight_tensor.get()}; | ||||
| if (RET_OK != ParseTensor(weight_tensors, tfliteModelBuffer, tensor_cache, TF_CONST)) { | |||||
| if (RET_OK != ParseTensor(weight_tensors, tfliteModelBuffer, tensor_cache, TF_CONST, true)) { | |||||
| MS_LOG(ERROR) << "parse weight failed"; | MS_LOG(ERROR) << "parse weight failed"; | ||||
| return RET_ERROR; | return RET_ERROR; | ||||
| } | } | ||||
| @@ -59,7 +59,7 @@ STATUS TfliteFullyConnectedParser::Parse(const std::unique_ptr<tflite::OperatorT | |||||
| return RET_NULL_PTR; | return RET_NULL_PTR; | ||||
| } | } | ||||
| std::vector<tflite::TensorT *> bias_tensors{bias_tensor.get()}; | std::vector<tflite::TensorT *> bias_tensors{bias_tensor.get()}; | ||||
| if (RET_OK != ParseTensor(bias_tensors, tfliteModelBuffer, tensor_cache, TF_CONST)) { | |||||
| if (RET_OK != ParseTensor(bias_tensors, tfliteModelBuffer, tensor_cache, TF_CONST, false)) { | |||||
| MS_LOG(ERROR) << "parse bias failed"; | MS_LOG(ERROR) << "parse bias failed"; | ||||
| return RET_ERROR; | return RET_ERROR; | ||||
| } | } | ||||
| @@ -58,7 +58,7 @@ STATUS TfliteGatherNdParser::Parse(const std::unique_ptr<tflite::OperatorT> &tfl | |||||
| } | } | ||||
| if (!y_data->data.empty()) { | if (!y_data->data.empty()) { | ||||
| std::vector<tflite::TensorT *> y_tensors{y_tensor.get()}; | std::vector<tflite::TensorT *> y_tensors{y_tensor.get()}; | ||||
| if (RET_OK != ParseTensor(y_tensors, tfliteModelBuffer, tensor_cache, TF_CONST)) { | |||||
| if (RET_OK != ParseTensor(y_tensors, tfliteModelBuffer, tensor_cache, TF_CONST, false)) { | |||||
| MS_LOG(ERROR) << "parse the second tensor failed"; | MS_LOG(ERROR) << "parse the second tensor failed"; | ||||
| return RET_ERROR; | return RET_ERROR; | ||||
| } | } | ||||
| @@ -49,6 +49,25 @@ STATUS TfliteGatherParser::Parse(const std::unique_ptr<tflite::OperatorT> &tflit | |||||
| attr->batchDims = 0; | attr->batchDims = 0; | ||||
| auto y_index = tfliteOp->inputs[1]; | |||||
| const auto &y_tensor = tfliteTensors[y_index]; | |||||
| if (y_tensor == nullptr) { | |||||
| MS_LOG(ERROR) << "the second input is null"; | |||||
| return RET_NULL_PTR; | |||||
| } | |||||
| auto &y_data = tfliteModelBuffer.at(y_tensor->buffer); | |||||
| if (y_data == nullptr) { | |||||
| MS_LOG(ERROR) << "the data of the second input is null"; | |||||
| return RET_NULL_PTR; | |||||
| } | |||||
| if (!y_data->data.empty()) { | |||||
| std::vector<tflite::TensorT *> y_tensors{y_tensor.get()}; | |||||
| if (RET_OK != ParseTensor(y_tensors, tfliteModelBuffer, tensor_cache, TF_CONST, false)) { | |||||
| MS_LOG(ERROR) << "parse the second tensor failed"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| } | |||||
| op->primitive->value.type = schema::PrimitiveType_Gather; | op->primitive->value.type = schema::PrimitiveType_Gather; | ||||
| op->primitive->value.value = attr.release(); | op->primitive->value.value = attr.release(); | ||||
| return RET_OK; | return RET_OK; | ||||
| @@ -45,7 +45,8 @@ STATUS TfliteNodeParser::CopyTfliteTensorData(const std::vector<std::unique_ptr< | |||||
| STATUS TfliteNodeParser::ParseTensor(const std::vector<tflite::TensorT *> &ts, | STATUS TfliteNodeParser::ParseTensor(const std::vector<tflite::TensorT *> &ts, | ||||
| const std::vector<std::unique_ptr<tflite::BufferT>> &tfliteModelBuffer, | const std::vector<std::unique_ptr<tflite::BufferT>> &tfliteModelBuffer, | ||||
| mindspore::lite::TensorCache *tensor_cache, | mindspore::lite::TensorCache *tensor_cache, | ||||
| int node_type) { | |||||
| int node_type, | |||||
| bool isWeight) { | |||||
| for (const auto &t : ts) { | for (const auto &t : ts) { | ||||
| auto idx = tensor_cache->FindTensor(t->name); | auto idx = tensor_cache->FindTensor(t->name); | ||||
| if (idx < 0) { | if (idx < 0) { | ||||
| @@ -53,6 +54,12 @@ STATUS TfliteNodeParser::ParseTensor(const std::vector<tflite::TensorT *> &ts, | |||||
| tensor->dataType = GetTfliteDataType(t->type); | tensor->dataType = GetTfliteDataType(t->type); | ||||
| tensor->dims = t->shape; | tensor->dims = t->shape; | ||||
| if (isWeight) { | |||||
| tensor->format = schema::Format_KHWC; | |||||
| } else { | |||||
| tensor->format = schema::Format_NHWC; | |||||
| } | |||||
| if (t->buffer > 0) { | if (t->buffer > 0) { | ||||
| CopyTfliteTensorData(tfliteModelBuffer, t, tensor.get()); | CopyTfliteTensorData(tfliteModelBuffer, t, tensor.get()); | ||||
| } | } | ||||
| @@ -47,7 +47,8 @@ class TfliteNodeParser { | |||||
| STATUS ParseTensor(const std::vector<tflite::TensorT *> &ts, | STATUS ParseTensor(const std::vector<tflite::TensorT *> &ts, | ||||
| const std::vector<std::unique_ptr<tflite::BufferT>> &tfliteModelBuffer, | const std::vector<std::unique_ptr<tflite::BufferT>> &tfliteModelBuffer, | ||||
| mindspore::lite::TensorCache *tensor_cache, | mindspore::lite::TensorCache *tensor_cache, | ||||
| int node_type); | |||||
| int node_type, | |||||
| bool isWeight); | |||||
| STATUS CopyTfliteTensorData(const std::vector<std::unique_ptr<tflite::BufferT>> &tfliteModelBuffer, | STATUS CopyTfliteTensorData(const std::vector<std::unique_ptr<tflite::BufferT>> &tfliteModelBuffer, | ||||
| const tflite::TensorT *tflite_tensor, | const tflite::TensorT *tflite_tensor, | ||||
| @@ -50,7 +50,7 @@ STATUS TfliteTransposeParser::Parse(const std::unique_ptr<tflite::OperatorT> &tf | |||||
| return RET_ERROR; | return RET_ERROR; | ||||
| } | } | ||||
| std::vector<tflite::TensorT *> weight_tensors{weight_tensor.get()}; | std::vector<tflite::TensorT *> weight_tensors{weight_tensor.get()}; | ||||
| if (RET_OK != ParseTensor(weight_tensors, tfliteModelBuffer, tensor_cache, TF_CONST)) { | |||||
| if (RET_OK != ParseTensor(weight_tensors, tfliteModelBuffer, tensor_cache, TF_CONST, false)) { | |||||
| MS_LOG(ERROR) << "parse weight failed"; | MS_LOG(ERROR) << "parse weight failed"; | ||||
| return RET_ERROR; | return RET_ERROR; | ||||
| } | } | ||||