| @@ -20,7 +20,6 @@ | |||||
| #include <stdlib.h> | #include <stdlib.h> | ||||
| #include <string.h> | #include <string.h> | ||||
| #include <stddef.h> | #include <stddef.h> | ||||
| #include <initializer_list> | |||||
| #define DEFAULT_CAPACITY 4 | #define DEFAULT_CAPACITY 4 | ||||
| struct MSTensor; | struct MSTensor; | ||||
| @@ -31,7 +31,7 @@ template <typename T> | |||||
| Vector<T>::Vector(size_t size) { | Vector<T>::Vector(size_t size) { | ||||
| size_ = size; | size_ = size; | ||||
| elem_size_ = sizeof(T); | elem_size_ = sizeof(T); | ||||
| capacity_ = size; | |||||
| capacity_ = (size == 0 ? DEFAULT_CAPACITY : size); | |||||
| data_ = reinterpret_cast<T *>(malloc(capacity_ * elem_size_)); | data_ = reinterpret_cast<T *>(malloc(capacity_ * elem_size_)); | ||||
| if (data_ == nullptr) { | if (data_ == nullptr) { | ||||
| MS_C_EXCEPTION("malloc data failed"); | MS_C_EXCEPTION("malloc data failed"); | ||||
| @@ -43,7 +43,7 @@ template <typename T> | |||||
| Vector<T>::Vector(size_t size, const T &value) { | Vector<T>::Vector(size_t size, const T &value) { | ||||
| size_ = size; | size_ = size; | ||||
| elem_size_ = sizeof(T); | elem_size_ = sizeof(T); | ||||
| capacity_ = size; | |||||
| capacity_ = (size == 0 ? DEFAULT_CAPACITY : size); | |||||
| data_ = reinterpret_cast<T *>(malloc(capacity_ * elem_size_)); | data_ = reinterpret_cast<T *>(malloc(capacity_ * elem_size_)); | ||||
| if (data_ == nullptr) { | if (data_ == nullptr) { | ||||
| MS_C_EXCEPTION("malloc data failed"); | MS_C_EXCEPTION("malloc data failed"); | ||||
| @@ -115,7 +115,7 @@ void Vector<T>::push_back(const T &elem) { | |||||
| template <typename T> | template <typename T> | ||||
| void Vector<T>::push_back(T &&elem) { | void Vector<T>::push_back(T &&elem) { | ||||
| if (data_ == nullptr) { | if (data_ == nullptr) { | ||||
| data_ = reinterpret_cast<T *>(malloc(elem_size_)); | |||||
| data_ = reinterpret_cast<T *>(malloc(capacity_ * elem_size_)); | |||||
| if (data_ == nullptr) { | if (data_ == nullptr) { | ||||
| MS_C_EXCEPTION("malloc data failed"); | MS_C_EXCEPTION("malloc data failed"); | ||||
| } | } | ||||
| @@ -102,9 +102,13 @@ int WriteToBin(const std::string &file_path, void *data, size_t size) { | |||||
| return 0; | return 0; | ||||
| } | } | ||||
| int CompareOutputData(float *output_data, float *correct_data, int data_size) { | |||||
| int CompareOutputData(float *output_data, size_t output_size, float *correct_data, size_t data_size) { | |||||
| if (output_size != data_size) { | |||||
| printf("compare failed, output_size %zu isn't equal to data_size %zu.\n", output_size, data_size); | |||||
| return 0; | |||||
| } | |||||
| float error = 0; | float error = 0; | ||||
| for (int i = 0; i < data_size; i++) { | |||||
| for (size_t i = 0; i < data_size; i++) { | |||||
| float abs = fabs(output_data[i] - correct_data[i]); | float abs = fabs(output_data[i] - correct_data[i]); | ||||
| if (abs > 0.00001) { | if (abs > 0.00001) { | ||||
| error += abs; | error += abs; | ||||
| @@ -120,12 +124,12 @@ int CompareOutputData(float *output_data, float *correct_data, int data_size) { | |||||
| return 0; | return 0; | ||||
| } | } | ||||
| int CompareOutput(float *output_data, std::string file_path) { | |||||
| size_t output_size; | |||||
| auto ground_truth = reinterpret_cast<float *>(mindspore::lite::ReadFile(file_path.c_str(), &output_size)); | |||||
| size_t output_num = output_size / sizeof(float); | |||||
| printf("output num : %zu\n", output_num); | |||||
| int res = CompareOutputData(output_data, ground_truth, output_num); | |||||
| int CompareOutput(float *output_data, size_t output_num, std::string file_path) { | |||||
| size_t ground_truth_size; | |||||
| auto ground_truth = reinterpret_cast<float *>(mindspore::lite::ReadFile(file_path.c_str(), &ground_truth_size)); | |||||
| size_t ground_truth_num = ground_truth_size / sizeof(float); | |||||
| printf("ground truth num : %zu\n", ground_truth_num); | |||||
| int res = CompareOutputData(output_data, output_num, ground_truth, ground_truth_num); | |||||
| delete[] ground_truth; | delete[] ground_truth; | ||||
| return res; | return res; | ||||
| } | } | ||||
| @@ -50,8 +50,8 @@ void WriteToTxt(const std::string &file_path, void *data, size_t element_size) { | |||||
| int WriteToBin(const std::string &file_path, void *data, size_t size); | int WriteToBin(const std::string &file_path, void *data, size_t size); | ||||
| int CompareOutputData(float *output_data, float *correct_data, int data_size); | |||||
| int CompareOutput(float *output_data, std::string file_path); | |||||
| int CompareOutputData(float *output_data, size_t output_num, float *correct_data, size_t data_size); | |||||
| int CompareOutput(float *output_data, size_t output_num, std::string file_path); | |||||
| std::string GetAndroidPackageName(); | std::string GetAndroidPackageName(); | ||||
| std::string GetAndroidPackagePath(); | std::string GetAndroidPackagePath(); | ||||
| @@ -169,6 +169,7 @@ class PrimitiveC { | |||||
| } | } | ||||
| auto ret = primc->UnPackSchemaPrimitive(primitive); | auto ret = primc->UnPackSchemaPrimitive(primitive); | ||||
| if (ret != RET_OK) { | if (ret != RET_OK) { | ||||
| delete primc; | |||||
| MS_LOG(ERROR) << "UnPackSchemaPrimitive failed"; | MS_LOG(ERROR) << "UnPackSchemaPrimitive failed"; | ||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| @@ -144,6 +144,8 @@ void CalShape(const T *data, const std::vector<Tensor *> &inputs, std::vector<in | |||||
| for (int i = 0; i < shape_size; i++) { | for (int i = 0; i < shape_size; i++) { | ||||
| if (static_cast<int>(data[i]) == -1) { | if (static_cast<int>(data[i]) == -1) { | ||||
| index = i; | index = i; | ||||
| } else if (static_cast<int>(data[i]) == 0) { | |||||
| size *= inputs[0]->shape()[i]; | |||||
| } else { | } else { | ||||
| size *= data[i]; | size *= data[i]; | ||||
| } | } | ||||
| @@ -64,6 +64,10 @@ int StridedSlice::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr | |||||
| } | } | ||||
| if (this->primitive_->value.value == nullptr) { | if (this->primitive_->value.value == nullptr) { | ||||
| auto attr = new (std::nothrow) schema::StridedSliceT(); | auto attr = new (std::nothrow) schema::StridedSliceT(); | ||||
| if (attr == nullptr) { | |||||
| MS_LOG(ERROR) << "new StridedSlice failed"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| attr->beginMask = GetValue<int>(prim.GetAttr("begin_mask")); | attr->beginMask = GetValue<int>(prim.GetAttr("begin_mask")); | ||||
| attr->endMask = GetValue<int>(prim.GetAttr("end_mask")); | attr->endMask = GetValue<int>(prim.GetAttr("end_mask")); | ||||
| attr->ellipsisMask = GetValue<int>(prim.GetAttr("ellipsis_mask")); | attr->ellipsisMask = GetValue<int>(prim.GetAttr("ellipsis_mask")); | ||||
| @@ -43,6 +43,10 @@ int Transpose::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> & | |||||
| } | } | ||||
| if (this->primitive_->value.value == nullptr) { | if (this->primitive_->value.value == nullptr) { | ||||
| auto attr = new (std::nothrow) schema::TransposeT(); | auto attr = new (std::nothrow) schema::TransposeT(); | ||||
| if (attr == nullptr) { | |||||
| MS_LOG(ERROR) << "new TransposeT failed"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| MS_ASSERT(inputs.size() == kAnfPopulaterTwo); | MS_ASSERT(inputs.size() == kAnfPopulaterTwo); | ||||
| auto inputNode = inputs[kAnfPopulaterOne]; | auto inputNode = inputs[kAnfPopulaterOne]; | ||||
| if (inputNode->isa<ValueNode>()) { | if (inputNode->isa<ValueNode>()) { | ||||
| @@ -54,7 +54,7 @@ TEST_F(TestConv1x1Fp32, Input1x1PrePack1) { | |||||
| float out[20] = {0}; | float out[20] = {0}; | ||||
| Conv1x1InputPack(in, out, conv_param, sizeof(float)); | Conv1x1InputPack(in, out, conv_param, sizeof(float)); | ||||
| EXPECT_EQ(0, lite::CompareOutputData(out, correct, 20)); | |||||
| EXPECT_EQ(0, lite::CompareOutputData(out, 20, correct, 20)); | |||||
| delete conv_param; | delete conv_param; | ||||
| } | } | ||||
| @@ -114,7 +114,7 @@ TEST_F(TestConv1x1Fp32, Input1x1PrePack3) { | |||||
| -5.052577, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0}; | -5.052577, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0}; | ||||
| Conv1x1InputPack(in, out, conv_param, sizeof(float)); | Conv1x1InputPack(in, out, conv_param, sizeof(float)); | ||||
| EXPECT_EQ(0, lite::CompareOutputData(out, correct, 18)); | |||||
| EXPECT_EQ(0, lite::CompareOutputData(out, 18, correct, 18)); | |||||
| delete conv_param; | delete conv_param; | ||||
| } | } | ||||
| @@ -136,7 +136,7 @@ TEST_F(TestConv1x1Fp32, Input1x1PrePack4) { | |||||
| 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0}; | 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0}; | ||||
| float out[54] = {0}; | float out[54] = {0}; | ||||
| Conv1x1InputPack(in, out, conv_param, sizeof(float)); | Conv1x1InputPack(in, out, conv_param, sizeof(float)); | ||||
| EXPECT_EQ(0, lite::CompareOutputData(out, correct, 54)); | |||||
| EXPECT_EQ(0, lite::CompareOutputData(out, 54, correct, 54)); | |||||
| delete conv_param; | delete conv_param; | ||||
| } | } | ||||
| @@ -166,7 +166,7 @@ TEST_F(TestConv1x1Fp32, Conv1x1WeightTest1) { | |||||
| conv_param->output_channel_ = 7; | conv_param->output_channel_ = 7; | ||||
| float out[96] = {0}; | float out[96] = {0}; | ||||
| Pack1x1WeightFp32(in, out, conv_param); | Pack1x1WeightFp32(in, out, conv_param); | ||||
| EXPECT_EQ(0, lite::CompareOutputData(out, co, 96)); | |||||
| EXPECT_EQ(0, lite::CompareOutputData(out, 96, co, 96)); | |||||
| delete conv_param; | delete conv_param; | ||||
| } | } | ||||
| @@ -75,7 +75,7 @@ TEST_F(TestDeConvolutionFp32, DeConvWeightC4x4Pack1) { | |||||
| 0.000, 0.000, 0.000, 0.00}; | 0.000, 0.000, 0.000, 0.00}; | ||||
| float dst[256] = {0}; | float dst[256] = {0}; | ||||
| PackDeConvWeightFp32(in, dst, 5, 6, 2 * 2); | PackDeConvWeightFp32(in, dst, 5, 6, 2 * 2); | ||||
| EXPECT_EQ(0, lite::CompareOutputData(dst, co, 256)); | |||||
| EXPECT_EQ(0, lite::CompareOutputData(dst, 256, co, 256)); | |||||
| } | } | ||||
| TEST_F(TestDeConvolutionFp32, DeConvWeightC4x4Pack2) { | TEST_F(TestDeConvolutionFp32, DeConvWeightC4x4Pack2) { | ||||
| @@ -90,7 +90,7 @@ TEST_F(TestDeConvolutionFp32, DeConvWeightC4x4Pack2) { | |||||
| -0.293, 18.686, 0.0873, 0, 0, 0, 0, 0, 0, 0, 0, 0}; | -0.293, 18.686, 0.0873, 0, 0, 0, 0, 0, 0, 0, 0, 0}; | ||||
| float dst[64] = {0}; | float dst[64] = {0}; | ||||
| PackDeConvWeightFp32(in, dst, 6, 3, 2 * 1); | PackDeConvWeightFp32(in, dst, 6, 3, 2 * 1); | ||||
| EXPECT_EQ(0, lite::CompareOutputData(dst, co, 64)); | |||||
| EXPECT_EQ(0, lite::CompareOutputData(dst, 64, co, 64)); | |||||
| } | } | ||||
| TEST_F(TestDeConvolutionFp32, PostConvFuncC8Test1) { | TEST_F(TestDeConvolutionFp32, PostConvFuncC8Test1) { | ||||
| @@ -212,7 +212,7 @@ TEST_F(TestActGradFp32, SigmoidGradFp32) { | |||||
| int res = lite::CompareRelativeOutput(output_data, output_path); | int res = lite::CompareRelativeOutput(output_data, output_path); | ||||
| EXPECT_EQ(res, 0); | EXPECT_EQ(res, 0); | ||||
| // lite::CompareOutput(output_data, output_path); | |||||
| // lite::CompareOutput(output_data, output_data_size, output_path); | |||||
| delete[] input_data; | delete[] input_data; | ||||
| delete[] output_data; | delete[] output_data; | ||||
| @@ -58,7 +58,7 @@ TEST_F(TestBiasGradFp32, BiasGradFp32) { | |||||
| } | } | ||||
| std::cout << std::endl; | std::cout << std::endl; | ||||
| std::string output_path = "./test_data/operators/biasgradfp32_1_db_7.bin"; | std::string output_path = "./test_data/operators/biasgradfp32_1_db_7.bin"; | ||||
| lite::CompareOutput(output_data, output_path); | |||||
| lite::CompareOutput(output_data, 7, output_path); | |||||
| delete[] input_data; | delete[] input_data; | ||||
| delete[] output_data; | delete[] output_data; | ||||
| @@ -96,7 +96,7 @@ TEST_F(TestPoolingGradFp32, AvgPoolingGradFp32) { | |||||
| } | } | ||||
| std::cout << std::endl; | std::cout << std::endl; | ||||
| std::string output_path = "./test_data/pooling/avgpoolgradfp32_1_dx_1_28_28_3.bin"; | std::string output_path = "./test_data/pooling/avgpoolgradfp32_1_dx_1_28_28_3.bin"; | ||||
| auto res = lite::CompareOutput(output_data, output_path); | |||||
| auto res = lite::CompareOutput(output_data, output_data_size, output_path); | |||||
| EXPECT_EQ(res, 0); | EXPECT_EQ(res, 0); | ||||
| delete[] input_data; | delete[] input_data; | ||||
| @@ -152,7 +152,7 @@ TEST_F(TestPoolingGradFp32, AvgPoolingKernelGradFp32) { | |||||
| } | } | ||||
| std::cout << std::endl; | std::cout << std::endl; | ||||
| std::string output_path = "./test_data/pooling/avgpoolgradfp32_1_dx_1_28_28_3.bin"; | std::string output_path = "./test_data/pooling/avgpoolgradfp32_1_dx_1_28_28_3.bin"; | ||||
| auto res = lite::CompareOutput(output_data, output_path); | |||||
| auto res = lite::CompareOutput(output_data, output_data_size, output_path); | |||||
| EXPECT_EQ(res, 0); | EXPECT_EQ(res, 0); | ||||
| delete[] input_data; | delete[] input_data; | ||||
| @@ -213,7 +213,8 @@ TEST_F(TestPoolingGradFp32, AvgPoolingBatchGradFp32) { | |||||
| } | } | ||||
| std::cout << std::endl; | std::cout << std::endl; | ||||
| std::string output_path = "./test_data/pooling/avgpoolgradfp32_1_dx_3_28_28_3.bin"; | std::string output_path = "./test_data/pooling/avgpoolgradfp32_1_dx_3_28_28_3.bin"; | ||||
| auto res = lite::CompareOutput(output_data, output_path); | |||||
| size_t output_data_size = dx_tensor.ElementsNum(); | |||||
| auto res = lite::CompareOutput(output_data, output_data_size, output_path); | |||||
| EXPECT_EQ(res, 0); | EXPECT_EQ(res, 0); | ||||
| delete[] input_data; | delete[] input_data; | ||||
| @@ -388,7 +389,7 @@ TEST_F(TestPoolingGradFp32, MaxPoolingGradFp32) { | |||||
| } | } | ||||
| std::cout << std::endl; | std::cout << std::endl; | ||||
| std::string output_path = "./test_data/pooling/maxpoolgradfp32_1_xgrad_1_28_28_3.bin"; | std::string output_path = "./test_data/pooling/maxpoolgradfp32_1_xgrad_1_28_28_3.bin"; | ||||
| auto res = lite::CompareOutput(output_data, output_path); | |||||
| auto res = lite::CompareOutput(output_data, output_data_size, output_path); | |||||
| EXPECT_EQ(res, 0); | EXPECT_EQ(res, 0); | ||||
| free(pooling_param); | free(pooling_param); | ||||
| @@ -70,7 +70,7 @@ TEST_F(TestSoftmaxCrossEntropyFp32, SoftmaxCrossEntropyFp32) { | |||||
| printf("==================Testing Grad===============\n"); | printf("==================Testing Grad===============\n"); | ||||
| std::string output_path = "./test_data/operators/sce_fp32_1_loss_1.bin"; | std::string output_path = "./test_data/operators/sce_fp32_1_loss_1.bin"; | ||||
| lite::CompareOutput(loss, output_path); | |||||
| lite::CompareOutput(loss, 1, output_path); | |||||
| ((mindspore::kernel::SparseSoftmaxCrossEntropyWithLogitsCPUKernel *)kernel_obj)->train(); | ((mindspore::kernel::SparseSoftmaxCrossEntropyWithLogitsCPUKernel *)kernel_obj)->train(); | ||||
| kernel_obj->Run(); | kernel_obj->Run(); | ||||
| @@ -81,7 +81,7 @@ TEST_F(TestSoftmaxCrossEntropyFp32, SoftmaxCrossEntropyFp32) { | |||||
| } | } | ||||
| std::cout << std::endl; | std::cout << std::endl; | ||||
| std::string grad_path = "./test_data/operators/sce_fp32_1_dy_6_4.bin"; | std::string grad_path = "./test_data/operators/sce_fp32_1_dy_6_4.bin"; | ||||
| lite::CompareOutput(grad, grad_path); | |||||
| lite::CompareOutput(grad, 24, grad_path); | |||||
| delete[] ll_labels; | delete[] ll_labels; | ||||
| delete[] labels; | delete[] labels; | ||||
| @@ -20,10 +20,8 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | namespace lite { | ||||
| STATUS CaffeReduceParser::Parse(const caffe::LayerParameter &proto, | |||||
| const caffe::LayerParameter &weight, | |||||
| schema::CNodeT *op, | |||||
| std::vector<schema::TensorT *> *weightVec) { | |||||
| STATUS CaffeReduceParser::Parse(const caffe::LayerParameter &proto, const caffe::LayerParameter &weight, | |||||
| schema::CNodeT *op, std::vector<schema::TensorT *> *weightVec) { | |||||
| MS_LOG(DEBUG) << "parse CaffeReduceParser"; | MS_LOG(DEBUG) << "parse CaffeReduceParser"; | ||||
| if (op == nullptr) { | if (op == nullptr) { | ||||
| MS_LOG(ERROR) << "op is null"; | MS_LOG(ERROR) << "op is null"; | ||||
| @@ -67,6 +65,11 @@ STATUS CaffeReduceParser::Parse(const caffe::LayerParameter &proto, | |||||
| } else { | } else { | ||||
| attr->axes = std::vector(1, 0); | attr->axes = std::vector(1, 0); | ||||
| } | } | ||||
| if (reduce_param.has_coeff()) { | |||||
| attr->coeff = reduce_param.coeff(); | |||||
| } else { | |||||
| attr->coeff = 1.0; | |||||
| } | |||||
| attr->reduceToEnd = true; | attr->reduceToEnd = true; | ||||
| attr->keepDims = false; | attr->keepDims = false; | ||||
| op->name = proto.name(); | op->name = proto.name(); | ||||
| @@ -78,4 +81,3 @@ STATUS CaffeReduceParser::Parse(const caffe::LayerParameter &proto, | |||||
| CaffeNodeRegistrar g_caffeReduceParser("Reduction", new CaffeReduceParser()); | CaffeNodeRegistrar g_caffeReduceParser("Reduction", new CaffeReduceParser()); | ||||
| } // namespace lite | } // namespace lite | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -22,11 +22,9 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | namespace lite { | ||||
| STATUS TfliteActivationParser::Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||||
| const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors, | |||||
| const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer, | |||||
| schema::CNodeT *op, std::vector<int32_t> *tensors_id, | |||||
| std::vector<schema::Format> *tensors_format, std::map<int, int> *tensors_id_map) { | |||||
| STATUS TfliteActivationParser::Parse(TfliteTensorsInfo *tensors_info, | |||||
| const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||||
| const std::unique_ptr<tflite::ModelT> &tflite_model, schema::CNodeT *op) { | |||||
| if (op == nullptr) { | if (op == nullptr) { | ||||
| MS_LOG(ERROR) << "op is null"; | MS_LOG(ERROR) << "op is null"; | ||||
| return RET_NULL_PTR; | return RET_NULL_PTR; | ||||
| @@ -74,10 +72,10 @@ STATUS TfliteActivationParser::Parse(const std::unique_ptr<tflite::OperatorT> &t | |||||
| op->primitive->value.type = schema::PrimitiveType_Activation; | op->primitive->value.type = schema::PrimitiveType_Activation; | ||||
| op->primitive->value.value = attr.release(); | op->primitive->value.value = attr.release(); | ||||
| AddOpInput(op, tensors_id, tensors_format, tensors_id_map, tflite_op->inputs[0], tensors_id->size(), | |||||
| tflite_tensors.size(), schema::Format::Format_NHWC); | |||||
| AddOpOutput(op, tensors_id, tensors_format, tensors_id_map, tflite_op->outputs[0], tensors_id->size(), | |||||
| tflite_tensors.size(), schema::Format::Format_NHWC); | |||||
| AddOpInput(op, tensors_info, tflite_op->inputs[0], tflite_model->subgraphs[0]->tensors.size(), | |||||
| schema::Format::Format_NHWC); | |||||
| AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_model->subgraphs[0]->tensors.size(), | |||||
| schema::Format::Format_NHWC); | |||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| @@ -29,11 +29,8 @@ class TfliteActivationParser : public TfliteNodeParser { | |||||
| public: | public: | ||||
| TfliteActivationParser() : TfliteNodeParser("node_name") {} | TfliteActivationParser() : TfliteNodeParser("node_name") {} | ||||
| STATUS Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||||
| const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors, | |||||
| const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer, schema::CNodeT *op, | |||||
| std::vector<int32_t> *tensors_id, std::vector<schema::Format> *tensors_format, | |||||
| std::map<int, int> *tensors_id_map) override; | |||||
| STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||||
| const std::unique_ptr<tflite::ModelT> &tflite_model, schema::CNodeT *op) override; | |||||
| }; | }; | ||||
| class TfliteReluParser : public TfliteActivationParser { | class TfliteReluParser : public TfliteActivationParser { | ||||
| @@ -22,11 +22,8 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | namespace lite { | ||||
| STATUS TfliteAddNParser::Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||||
| const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors, | |||||
| const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer, | |||||
| schema::CNodeT *op, std::vector<int32_t> *tensors_id, | |||||
| std::vector<schema::Format> *tensors_format, std::map<int, int> *tensors_id_map) { | |||||
| STATUS TfliteAddNParser::Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||||
| const std::unique_ptr<tflite::ModelT> &tflite_model, schema::CNodeT *op) { | |||||
| MS_LOG(DEBUG) << "parse TfliteAddNParser"; | MS_LOG(DEBUG) << "parse TfliteAddNParser"; | ||||
| if (op == nullptr) { | if (op == nullptr) { | ||||
| MS_LOG(ERROR) << "op is null"; | MS_LOG(ERROR) << "op is null"; | ||||
| @@ -44,16 +41,16 @@ STATUS TfliteAddNParser::Parse(const std::unique_ptr<tflite::OperatorT> &tflite_ | |||||
| return RET_NULL_PTR; | return RET_NULL_PTR; | ||||
| } | } | ||||
| attr->N = tflite_tensors.size() - 1; | |||||
| attr->N = tflite_model->subgraphs[0]->tensors.size() - 1; | |||||
| op->primitive->value.type = schema::PrimitiveType_AddN; | op->primitive->value.type = schema::PrimitiveType_AddN; | ||||
| op->primitive->value.value = attr.release(); | op->primitive->value.value = attr.release(); | ||||
| for (size_t i = 0; i < tflite_op->inputs.size(); i++) { | for (size_t i = 0; i < tflite_op->inputs.size(); i++) { | ||||
| AddOpInput(op, tensors_id, tensors_format, tensors_id_map, tflite_op->inputs[i], tensors_id->size(), | |||||
| tflite_tensors.size(), schema::Format::Format_NHWC); | |||||
| AddOpInput(op, tensors_info, tflite_op->inputs[i], tflite_model->subgraphs[0]->tensors.size(), | |||||
| schema::Format::Format_NHWC); | |||||
| } | } | ||||
| AddOpOutput(op, tensors_id, tensors_format, tensors_id_map, tflite_op->outputs[0], tensors_id->size(), | |||||
| tflite_tensors.size(), schema::Format::Format_NHWC); | |||||
| AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_model->subgraphs[0]->tensors.size(), | |||||
| schema::Format::Format_NHWC); | |||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| @@ -29,11 +29,8 @@ class TfliteAddNParser : public TfliteNodeParser { | |||||
| public: | public: | ||||
| TfliteAddNParser() : TfliteNodeParser("AddN") {} | TfliteAddNParser() : TfliteNodeParser("AddN") {} | ||||
| STATUS Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||||
| const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors, | |||||
| const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer, schema::CNodeT *op, | |||||
| std::vector<int32_t> *tensors_id, std::vector<schema::Format> *tensors_format, | |||||
| std::map<int, int> *tensors_id_map) override; | |||||
| STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||||
| const std::unique_ptr<tflite::ModelT> &tflite_model, schema::CNodeT *op) override; | |||||
| }; | }; | ||||
| } // namespace lite | } // namespace lite | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -21,11 +21,8 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | namespace lite { | ||||
| STATUS TfliteArgmaxParser::Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||||
| const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors, | |||||
| const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer, | |||||
| schema::CNodeT *op, std::vector<int32_t> *tensors_id, | |||||
| std::vector<schema::Format> *tensors_format, std::map<int, int> *tensors_id_map) { | |||||
| STATUS TfliteArgmaxParser::Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||||
| const std::unique_ptr<tflite::ModelT> &tflite_model, schema::CNodeT *op) { | |||||
| MS_LOG(DEBUG) << "parse TfliteArgmaxParser"; | MS_LOG(DEBUG) << "parse TfliteArgmaxParser"; | ||||
| if (op == nullptr) { | if (op == nullptr) { | ||||
| MS_LOG(ERROR) << "op is null"; | MS_LOG(ERROR) << "op is null"; | ||||
| @@ -50,8 +47,8 @@ STATUS TfliteArgmaxParser::Parse(const std::unique_ptr<tflite::OperatorT> &tflit | |||||
| // get axis attr | // get axis attr | ||||
| auto axis_idx = tflite_op->inputs[1]; | auto axis_idx = tflite_op->inputs[1]; | ||||
| std::for_each(tflite_tensors[axis_idx]->shape.begin(), tflite_tensors[axis_idx]->shape.end(), [&](int32_t sha) {}); | |||||
| auto &buf_data = tflite_model_buffer[tflite_tensors[axis_idx]->buffer]; | |||||
| auto buffer_idx = tflite_model->subgraphs[0]->tensors[axis_idx]->buffer; | |||||
| auto &buf_data = tflite_model->buffers[buffer_idx]; | |||||
| if (buf_data == nullptr) { | if (buf_data == nullptr) { | ||||
| MS_LOG(ERROR) << "the buf data is null"; | MS_LOG(ERROR) << "the buf data is null"; | ||||
| return RET_NULL_PTR; | return RET_NULL_PTR; | ||||
| @@ -66,10 +63,10 @@ STATUS TfliteArgmaxParser::Parse(const std::unique_ptr<tflite::OperatorT> &tflit | |||||
| op->primitive->value.type = schema::PrimitiveType_ArgMax; | op->primitive->value.type = schema::PrimitiveType_ArgMax; | ||||
| op->primitive->value.value = attr.release(); | op->primitive->value.value = attr.release(); | ||||
| AddOpInput(op, tensors_id, tensors_format, tensors_id_map, tflite_op->inputs[0], tensors_id->size(), | |||||
| tflite_tensors.size(), schema::Format::Format_NHWC); | |||||
| AddOpOutput(op, tensors_id, tensors_format, tensors_id_map, tflite_op->outputs[0], tensors_id->size(), | |||||
| tflite_tensors.size(), schema::Format::Format_NHWC); | |||||
| AddOpInput(op, tensors_info, tflite_op->inputs[0], tflite_model->subgraphs[0]->tensors.size(), | |||||
| schema::Format::Format_NHWC); | |||||
| AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_model->subgraphs[0]->tensors.size(), | |||||
| schema::Format::Format_NHWC); | |||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| @@ -29,11 +29,8 @@ class TfliteArgmaxParser : public TfliteNodeParser { | |||||
| public: | public: | ||||
| TfliteArgmaxParser() : TfliteNodeParser("Argmax") {} | TfliteArgmaxParser() : TfliteNodeParser("Argmax") {} | ||||
| STATUS Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||||
| const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors, | |||||
| const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer, schema::CNodeT *op, | |||||
| std::vector<int32_t> *tensors_id, std::vector<schema::Format> *tensors_format, | |||||
| std::map<int, int> *tensors_id_map) override; | |||||
| STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||||
| const std::unique_ptr<tflite::ModelT> &tflite_model, schema::CNodeT *op) override; | |||||
| }; | }; | ||||
| } // namespace lite | } // namespace lite | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -21,11 +21,8 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | namespace lite { | ||||
| STATUS TfliteArgminParser::Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||||
| const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors, | |||||
| const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer, | |||||
| schema::CNodeT *op, std::vector<int32_t> *tensors_id, | |||||
| std::vector<schema::Format> *tensors_format, std::map<int, int> *tensors_id_map) { | |||||
| STATUS TfliteArgminParser::Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||||
| const std::unique_ptr<tflite::ModelT> &tflite_model, schema::CNodeT *op) { | |||||
| MS_LOG(DEBUG) << "parse TfliteArgminParser"; | MS_LOG(DEBUG) << "parse TfliteArgminParser"; | ||||
| if (op == nullptr) { | if (op == nullptr) { | ||||
| MS_LOG(ERROR) << "op is null"; | MS_LOG(ERROR) << "op is null"; | ||||
| @@ -50,8 +47,8 @@ STATUS TfliteArgminParser::Parse(const std::unique_ptr<tflite::OperatorT> &tflit | |||||
| // get axis attr | // get axis attr | ||||
| auto axis_idx = tflite_op->inputs[1]; | auto axis_idx = tflite_op->inputs[1]; | ||||
| std::for_each(tflite_tensors[axis_idx]->shape.begin(), tflite_tensors[axis_idx]->shape.end(), [&](int32_t sha) {}); | |||||
| auto &buf_data = tflite_model_buffer[tflite_tensors[axis_idx]->buffer]; | |||||
| auto buffer_idx = tflite_model->subgraphs[0]->tensors[axis_idx]->buffer; | |||||
| auto &buf_data = tflite_model->buffers[buffer_idx]; | |||||
| if (buf_data == nullptr) { | if (buf_data == nullptr) { | ||||
| MS_LOG(ERROR) << "the buf data is null"; | MS_LOG(ERROR) << "the buf data is null"; | ||||
| return RET_NULL_PTR; | return RET_NULL_PTR; | ||||
| @@ -66,10 +63,10 @@ STATUS TfliteArgminParser::Parse(const std::unique_ptr<tflite::OperatorT> &tflit | |||||
| op->primitive->value.type = schema::PrimitiveType_ArgMin; | op->primitive->value.type = schema::PrimitiveType_ArgMin; | ||||
| op->primitive->value.value = attr.release(); | op->primitive->value.value = attr.release(); | ||||
| AddOpInput(op, tensors_id, tensors_format, tensors_id_map, tflite_op->inputs[0], tensors_id->size(), | |||||
| tflite_tensors.size(), schema::Format::Format_NHWC); | |||||
| AddOpOutput(op, tensors_id, tensors_format, tensors_id_map, tflite_op->outputs[0], tensors_id->size(), | |||||
| tflite_tensors.size(), schema::Format::Format_NHWC); | |||||
| AddOpInput(op, tensors_info, tflite_op->inputs[0], tflite_model->subgraphs[0]->tensors.size(), | |||||
| schema::Format::Format_NHWC); | |||||
| AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_model->subgraphs[0]->tensors.size(), | |||||
| schema::Format::Format_NHWC); | |||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| @@ -29,11 +29,8 @@ class TfliteArgminParser : public TfliteNodeParser { | |||||
| public: | public: | ||||
| TfliteArgminParser() : TfliteNodeParser("Argmin") {} | TfliteArgminParser() : TfliteNodeParser("Argmin") {} | ||||
| STATUS Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||||
| const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors, | |||||
| const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer, schema::CNodeT *op, | |||||
| std::vector<int32_t> *tensors_id, std::vector<schema::Format> *tensors_format, | |||||
| std::map<int, int> *tensors_id_map) override; | |||||
| STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||||
| const std::unique_ptr<tflite::ModelT> &tflite_model, schema::CNodeT *op) override; | |||||
| }; | }; | ||||
| } // namespace lite | } // namespace lite | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -22,12 +22,9 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | namespace lite { | ||||
| STATUS TfliteDoubleInputOpParser::Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||||
| const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors, | |||||
| const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer, | |||||
| schema::CNodeT *op, std::vector<int32_t> *tensors_id, | |||||
| std::vector<schema::Format> *tensors_format, | |||||
| std::map<int, int> *tensors_id_map) { | |||||
| STATUS TfliteDoubleInputOpParser::Parse(TfliteTensorsInfo *tensors_info, | |||||
| const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||||
| const std::unique_ptr<tflite::ModelT> &tflite_model, schema::CNodeT *op) { | |||||
| if (op == nullptr) { | if (op == nullptr) { | ||||
| MS_LOG(ERROR) << "op is null"; | MS_LOG(ERROR) << "op is null"; | ||||
| return RET_NULL_PTR; | return RET_NULL_PTR; | ||||
| @@ -171,20 +168,17 @@ STATUS TfliteDoubleInputOpParser::Parse(const std::unique_ptr<tflite::OperatorT> | |||||
| // set input | // set input | ||||
| for (size_t i = 0; i < tflite_op->inputs.size(); i++) { | for (size_t i = 0; i < tflite_op->inputs.size(); i++) { | ||||
| AddOpInput(op, tensors_id, tensors_format, tensors_id_map, tflite_op->inputs[i], tensors_id->size(), | |||||
| tflite_tensors.size(), schema::Format::Format_NHWC); | |||||
| AddOpInput(op, tensors_info, tflite_op->inputs[i], tflite_model->subgraphs[0]->tensors.size(), | |||||
| schema::Format::Format_NHWC); | |||||
| } | } | ||||
| AddOpOutput(op, tensors_id, tensors_format, tensors_id_map, tflite_op->outputs[0], tensors_id->size(), | |||||
| tflite_tensors.size(), schema::Format::Format_NHWC); | |||||
| AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_model->subgraphs[0]->tensors.size(), | |||||
| schema::Format::Format_NHWC); | |||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| STATUS TfliteSingleInputOpParser::Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||||
| const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors, | |||||
| const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer, | |||||
| schema::CNodeT *op, std::vector<int32_t> *tensors_id, | |||||
| std::vector<schema::Format> *tensors_format, | |||||
| std::map<int, int> *tensors_id_map) { | |||||
| STATUS TfliteSingleInputOpParser::Parse(TfliteTensorsInfo *tensors_info, | |||||
| const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||||
| const std::unique_ptr<tflite::ModelT> &tflite_model, schema::CNodeT *op) { | |||||
| if (op == nullptr) { | if (op == nullptr) { | ||||
| MS_LOG(ERROR) << "op is null"; | MS_LOG(ERROR) << "op is null"; | ||||
| return RET_NULL_PTR; | return RET_NULL_PTR; | ||||
| @@ -210,13 +204,13 @@ STATUS TfliteSingleInputOpParser::Parse(const std::unique_ptr<tflite::OperatorT> | |||||
| } else if (std::strcmp(node_name, "Exp") == 0) { | } else if (std::strcmp(node_name, "Exp") == 0) { | ||||
| MS_LOG(DEBUG) << "parse TfliteExpParser"; | MS_LOG(DEBUG) << "parse TfliteExpParser"; | ||||
| auto attr = std::make_unique<schema::ExpT>(); | auto attr = std::make_unique<schema::ExpT>(); | ||||
| attr->base = -1; // -1 represent base = e | |||||
| attr->scale = 1; | |||||
| attr->shift = 0; | |||||
| if (attr == nullptr) { | if (attr == nullptr) { | ||||
| MS_LOG(ERROR) << "new op failed"; | MS_LOG(ERROR) << "new op failed"; | ||||
| return RET_NULL_PTR; | return RET_NULL_PTR; | ||||
| } | } | ||||
| attr->base = -1; // -1 represent base = e | |||||
| attr->scale = 1; | |||||
| attr->shift = 0; | |||||
| op->primitive->value.type = schema::PrimitiveType_Exp; | op->primitive->value.type = schema::PrimitiveType_Exp; | ||||
| op->primitive->value.value = attr.release(); | op->primitive->value.value = attr.release(); | ||||
| } else if (std::strcmp(node_name, "Sqrt") == 0) { | } else if (std::strcmp(node_name, "Sqrt") == 0) { | ||||
| @@ -300,7 +294,7 @@ STATUS TfliteSingleInputOpParser::Parse(const std::unique_ptr<tflite::OperatorT> | |||||
| } | } | ||||
| op->primitive->value.type = schema::PrimitiveType_Floor; | op->primitive->value.type = schema::PrimitiveType_Floor; | ||||
| op->primitive->value.value = attr.release(); | op->primitive->value.value = attr.release(); | ||||
| } else if (std::strcmp(node_name, "NEG") == 0) { | |||||
| } else if (std::strcmp(node_name, "Neg") == 0) { | |||||
| MS_LOG(DEBUG) << "parse TfliteNegParser"; | MS_LOG(DEBUG) << "parse TfliteNegParser"; | ||||
| auto attr = std::make_unique<schema::NegT>(); | auto attr = std::make_unique<schema::NegT>(); | ||||
| if (attr == nullptr) { | if (attr == nullptr) { | ||||
| @@ -311,18 +305,16 @@ STATUS TfliteSingleInputOpParser::Parse(const std::unique_ptr<tflite::OperatorT> | |||||
| op->primitive->value.value = attr.release(); | op->primitive->value.value = attr.release(); | ||||
| } | } | ||||
| AddOpInput(op, tensors_id, tensors_format, tensors_id_map, tflite_op->inputs[0], tensors_id->size(), | |||||
| tflite_tensors.size(), schema::Format::Format_NHWC); | |||||
| AddOpOutput(op, tensors_id, tensors_format, tensors_id_map, tflite_op->outputs[0], tensors_id->size(), | |||||
| tflite_tensors.size(), schema::Format::Format_NHWC); | |||||
| AddOpInput(op, tensors_info, tflite_op->inputs[0], tflite_model->subgraphs[0]->tensors.size(), | |||||
| schema::Format::Format_NHWC); | |||||
| AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_model->subgraphs[0]->tensors.size(), | |||||
| schema::Format::Format_NHWC); | |||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| STATUS TfliteCompareOpParser::Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||||
| const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors, | |||||
| const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer, | |||||
| schema::CNodeT *op, std::vector<int32_t> *tensors_id, | |||||
| std::vector<schema::Format> *tensors_format, std::map<int, int> *tensors_id_map) { | |||||
| STATUS TfliteCompareOpParser::Parse(TfliteTensorsInfo *tensors_info, | |||||
| const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||||
| const std::unique_ptr<tflite::ModelT> &tflite_model, schema::CNodeT *op) { | |||||
| if (op == nullptr) { | if (op == nullptr) { | ||||
| MS_LOG(ERROR) << "op is null"; | MS_LOG(ERROR) << "op is null"; | ||||
| return RET_NULL_PTR; | return RET_NULL_PTR; | ||||
| @@ -393,11 +385,11 @@ STATUS TfliteCompareOpParser::Parse(const std::unique_ptr<tflite::OperatorT> &tf | |||||
| } | } | ||||
| for (size_t i = 0; i < tflite_op->inputs.size(); i++) { | for (size_t i = 0; i < tflite_op->inputs.size(); i++) { | ||||
| AddOpInput(op, tensors_id, tensors_format, tensors_id_map, tflite_op->inputs[i], tensors_id->size(), | |||||
| tflite_tensors.size(), schema::Format::Format_NHWC); | |||||
| AddOpInput(op, tensors_info, tflite_op->inputs[i], tflite_model->subgraphs[0]->tensors.size(), | |||||
| schema::Format::Format_NHWC); | |||||
| } | } | ||||
| AddOpOutput(op, tensors_id, tensors_format, tensors_id_map, tflite_op->outputs[0], tensors_id->size(), | |||||
| tflite_tensors.size(), schema::Format::Format_NHWC); | |||||
| AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_model->subgraphs[0]->tensors.size(), | |||||
| schema::Format::Format_NHWC); | |||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| @@ -424,7 +416,7 @@ TfliteNodeRegister g_TfliteLogParser("Log", new TfliteLogParser()); | |||||
| TfliteNodeRegister g_tfliteRoundParser("Round", new TfliteRoundParser()); | TfliteNodeRegister g_tfliteRoundParser("Round", new TfliteRoundParser()); | ||||
| TfliteNodeRegister g_TfliteCeilParser("Ceil", new TfliteCeilParser()); | TfliteNodeRegister g_TfliteCeilParser("Ceil", new TfliteCeilParser()); | ||||
| TfliteNodeRegister g_tfliteFloorParser("flOOR", new TfliteFloorParser()); | TfliteNodeRegister g_tfliteFloorParser("flOOR", new TfliteFloorParser()); | ||||
| TfliteNodeRegister g_tfliteNegParser("NEG", new TfliteNegParser()); | |||||
| TfliteNodeRegister g_tfliteNegParser("Neg", new TfliteNegParser()); | |||||
| TfliteNodeRegister g_tfliteEqualParser("Equal", new TfliteEqualParser()); | TfliteNodeRegister g_tfliteEqualParser("Equal", new TfliteEqualParser()); | ||||
| TfliteNodeRegister g_tfliteNotEqualParser("NotEqual", new TfliteNotEqualParser()); | TfliteNodeRegister g_tfliteNotEqualParser("NotEqual", new TfliteNotEqualParser()); | ||||
| @@ -29,11 +29,8 @@ class TfliteDoubleInputOpParser : public TfliteNodeParser { | |||||
| public: | public: | ||||
| TfliteDoubleInputOpParser() : TfliteNodeParser("node_name") {} | TfliteDoubleInputOpParser() : TfliteNodeParser("node_name") {} | ||||
| STATUS Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||||
| const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors, | |||||
| const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer, schema::CNodeT *op, | |||||
| std::vector<int32_t> *tensors_id, std::vector<schema::Format> *tensors_format, | |||||
| std::map<int, int> *tensors_id_map) override; | |||||
| STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||||
| const std::unique_ptr<tflite::ModelT> &tflite_model, schema::CNodeT *op) override; | |||||
| }; | }; | ||||
| class TfliteAddParser : public TfliteDoubleInputOpParser { | class TfliteAddParser : public TfliteDoubleInputOpParser { | ||||
| @@ -95,11 +92,8 @@ class TfliteSingleInputOpParser : public TfliteNodeParser { | |||||
| public: | public: | ||||
| TfliteSingleInputOpParser() : TfliteNodeParser("node_name") {} | TfliteSingleInputOpParser() : TfliteNodeParser("node_name") {} | ||||
| STATUS Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||||
| const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors, | |||||
| const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer, schema::CNodeT *op, | |||||
| std::vector<int32_t> *tensors_id, std::vector<schema::Format> *tensors_format, | |||||
| std::map<int, int> *tensors_id_map) override; | |||||
| STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||||
| const std::unique_ptr<tflite::ModelT> &tflite_model, schema::CNodeT *op) override; | |||||
| }; | }; | ||||
| class TfliteAbsParser : public TfliteSingleInputOpParser { | class TfliteAbsParser : public TfliteSingleInputOpParser { | ||||
| @@ -166,11 +160,8 @@ class TfliteCompareOpParser : public TfliteNodeParser { | |||||
| public: | public: | ||||
| TfliteCompareOpParser() : TfliteNodeParser("node_name") {} | TfliteCompareOpParser() : TfliteNodeParser("node_name") {} | ||||
| STATUS Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||||
| const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors, | |||||
| const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer, schema::CNodeT *op, | |||||
| std::vector<int32_t> *tensors_id, std::vector<schema::Format> *tensors_format, | |||||
| std::map<int, int> *tensors_id_map) override; | |||||
| STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||||
| const std::unique_ptr<tflite::ModelT> &tflite_model, schema::CNodeT *op) override; | |||||
| }; | }; | ||||
| class TfliteEqualParser : public TfliteCompareOpParser { | class TfliteEqualParser : public TfliteCompareOpParser { | ||||
| @@ -23,12 +23,9 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | namespace lite { | ||||
| STATUS TfliteBatchToSpaceParser::Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||||
| const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors, | |||||
| const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer, | |||||
| schema::CNodeT *op, std::vector<int32_t> *tensors_id, | |||||
| std::vector<schema::Format> *tensors_format, | |||||
| std::map<int, int> *tensors_id_map) { | |||||
| STATUS TfliteBatchToSpaceParser::Parse(TfliteTensorsInfo *tensors_info, | |||||
| const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||||
| const std::unique_ptr<tflite::ModelT> &tflite_model, schema::CNodeT *op) { | |||||
| if (op == nullptr) { | if (op == nullptr) { | ||||
| MS_LOG(ERROR) << "op is null"; | MS_LOG(ERROR) << "op is null"; | ||||
| return RET_NULL_PTR; | return RET_NULL_PTR; | ||||
| @@ -54,11 +51,12 @@ STATUS TfliteBatchToSpaceParser::Parse(const std::unique_ptr<tflite::OperatorT> | |||||
| return RET_NULL_PTR; | return RET_NULL_PTR; | ||||
| } | } | ||||
| if (GetTfliteData(tflite_op->inputs[1], tflite_tensors, tflite_model_buffer, attr->blockShape)) { | |||||
| if (GetTfliteData(tflite_op->inputs[1], tflite_model->subgraphs[0]->tensors, tflite_model->buffers, | |||||
| attr->blockShape)) { | |||||
| MS_LOG(ERROR) << "get batchToSpace -> blockShape failed"; | MS_LOG(ERROR) << "get batchToSpace -> blockShape failed"; | ||||
| return RET_ERROR; | return RET_ERROR; | ||||
| } | } | ||||
| if (GetTfliteData(tflite_op->inputs[2], tflite_tensors, tflite_model_buffer, attr->crops)) { | |||||
| if (GetTfliteData(tflite_op->inputs[2], tflite_model->subgraphs[0]->tensors, tflite_model->buffers, attr->crops)) { | |||||
| MS_LOG(ERROR) << "get batchToSpace -> crops failed"; | MS_LOG(ERROR) << "get batchToSpace -> crops failed"; | ||||
| return RET_ERROR; | return RET_ERROR; | ||||
| } | } | ||||
| @@ -66,10 +64,10 @@ STATUS TfliteBatchToSpaceParser::Parse(const std::unique_ptr<tflite::OperatorT> | |||||
| op->primitive->value.type = schema::PrimitiveType_BatchToSpace; | op->primitive->value.type = schema::PrimitiveType_BatchToSpace; | ||||
| op->primitive->value.value = attr.release(); | op->primitive->value.value = attr.release(); | ||||
| AddOpInput(op, tensors_id, tensors_format, tensors_id_map, tflite_op->inputs[0], tensors_id->size(), | |||||
| tflite_tensors.size(), schema::Format::Format_NHWC); | |||||
| AddOpOutput(op, tensors_id, tensors_format, tensors_id_map, tflite_op->outputs[0], tensors_id->size(), | |||||
| tflite_tensors.size(), schema::Format::Format_NHWC); | |||||
| AddOpInput(op, tensors_info, tflite_op->inputs[0], tflite_model->subgraphs[0]->tensors.size(), | |||||
| schema::Format::Format_NHWC); | |||||
| AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_model->subgraphs[0]->tensors.size(), | |||||
| schema::Format::Format_NHWC); | |||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| @@ -29,11 +29,8 @@ class TfliteBatchToSpaceParser : public TfliteNodeParser { | |||||
| public: | public: | ||||
| TfliteBatchToSpaceParser() : TfliteNodeParser("BatchToSpace") {} | TfliteBatchToSpaceParser() : TfliteNodeParser("BatchToSpace") {} | ||||
| STATUS Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||||
| const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors, | |||||
| const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer, schema::CNodeT *op, | |||||
| std::vector<int32_t> *tensors_id, std::vector<schema::Format> *tensors_format, | |||||
| std::map<int, int> *tensors_id_map) override; | |||||
| STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||||
| const std::unique_ptr<tflite::ModelT> &tflite_model, schema::CNodeT *op) override; | |||||
| }; | }; | ||||
| class TfliteBatchToSpaceNDParser : public TfliteBatchToSpaceParser { | class TfliteBatchToSpaceNDParser : public TfliteBatchToSpaceParser { | ||||
| @@ -22,11 +22,9 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | namespace lite { | ||||
| STATUS TfliteBroadcastToParser::Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||||
| const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors, | |||||
| const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer, | |||||
| schema::CNodeT *op, std::vector<int32_t> *tensors_id, | |||||
| std::vector<schema::Format> *tensors_format, std::map<int, int> *tensors_id_map) { | |||||
| STATUS TfliteBroadcastToParser::Parse(TfliteTensorsInfo *tensors_info, | |||||
| const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||||
| const std::unique_ptr<tflite::ModelT> &tflite_model, schema::CNodeT *op) { | |||||
| MS_LOG(DEBUG) << "parse TfliteBroadcastToParser"; | MS_LOG(DEBUG) << "parse TfliteBroadcastToParser"; | ||||
| if (op == nullptr) { | if (op == nullptr) { | ||||
| MS_LOG(ERROR) << "op is null"; | MS_LOG(ERROR) << "op is null"; | ||||
| @@ -44,7 +42,8 @@ STATUS TfliteBroadcastToParser::Parse(const std::unique_ptr<tflite::OperatorT> & | |||||
| return RET_NULL_PTR; | return RET_NULL_PTR; | ||||
| } | } | ||||
| if (GetTfliteData(tflite_op->inputs[1], tflite_tensors, tflite_model_buffer, attr->dst_shape)) { | |||||
| if (GetTfliteData(tflite_op->inputs[1], tflite_model->subgraphs[0]->tensors, tflite_model->buffers, | |||||
| attr->dst_shape)) { | |||||
| MS_LOG(ERROR) << "get broadCastTo -> dst_shape failed"; | MS_LOG(ERROR) << "get broadCastTo -> dst_shape failed"; | ||||
| return RET_ERROR; | return RET_ERROR; | ||||
| } | } | ||||
| @@ -52,10 +51,10 @@ STATUS TfliteBroadcastToParser::Parse(const std::unique_ptr<tflite::OperatorT> & | |||||
| op->primitive->value.type = schema::PrimitiveType_BroadcastTo; | op->primitive->value.type = schema::PrimitiveType_BroadcastTo; | ||||
| op->primitive->value.value = attr.release(); | op->primitive->value.value = attr.release(); | ||||
| AddOpInput(op, tensors_id, tensors_format, tensors_id_map, tflite_op->inputs[0], tensors_id->size(), | |||||
| tflite_tensors.size(), schema::Format::Format_NHWC); | |||||
| AddOpOutput(op, tensors_id, tensors_format, tensors_id_map, tflite_op->outputs[0], tensors_id->size(), | |||||
| tflite_tensors.size(), schema::Format::Format_NHWC); | |||||
| AddOpInput(op, tensors_info, tflite_op->inputs[0], tflite_model->subgraphs[0]->tensors.size(), | |||||
| schema::Format::Format_NHWC); | |||||
| AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_model->subgraphs[0]->tensors.size(), | |||||
| schema::Format::Format_NHWC); | |||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| @@ -29,11 +29,8 @@ class TfliteBroadcastToParser : public TfliteNodeParser { | |||||
| public: | public: | ||||
| TfliteBroadcastToParser() : TfliteNodeParser("BroadcastTo") {} | TfliteBroadcastToParser() : TfliteNodeParser("BroadcastTo") {} | ||||
| STATUS Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||||
| const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors, | |||||
| const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer, schema::CNodeT *op, | |||||
| std::vector<int32_t> *tensors_id, std::vector<schema::Format> *tensors_format, | |||||
| std::map<int, int> *tensors_id_map) override; | |||||
| STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||||
| const std::unique_ptr<tflite::ModelT> &tflite_model, schema::CNodeT *op) override; | |||||
| }; | }; | ||||
| } // namespace lite | } // namespace lite | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -21,11 +21,8 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | namespace lite { | ||||
| STATUS TfliteCastParser::Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||||
| const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors, | |||||
| const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer, | |||||
| schema::CNodeT *op, std::vector<int32_t> *tensors_id, | |||||
| std::vector<schema::Format> *tensors_format, std::map<int, int> *tensors_id_map) { | |||||
| STATUS TfliteCastParser::Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||||
| const std::unique_ptr<tflite::ModelT> &tflite_model, schema::CNodeT *op) { | |||||
| MS_LOG(DEBUG) << "parse TfliteCastParser"; | MS_LOG(DEBUG) << "parse TfliteCastParser"; | ||||
| if (op == nullptr) { | if (op == nullptr) { | ||||
| MS_LOG(ERROR) << "op is null"; | MS_LOG(ERROR) << "op is null"; | ||||
| @@ -43,7 +40,7 @@ STATUS TfliteCastParser::Parse(const std::unique_ptr<tflite::OperatorT> &tflite_ | |||||
| return RET_NULL_PTR; | return RET_NULL_PTR; | ||||
| } | } | ||||
| const auto &in_tensor = tflite_tensors[tflite_op->inputs[0]]; | |||||
| const auto &in_tensor = tflite_model->subgraphs[0]->tensors[tflite_op->inputs[0]]; | |||||
| if (in_tensor == nullptr) { | if (in_tensor == nullptr) { | ||||
| MS_LOG(ERROR) << "tensor is null"; | MS_LOG(ERROR) << "tensor is null"; | ||||
| return RET_NULL_PTR; | return RET_NULL_PTR; | ||||
| @@ -52,7 +49,7 @@ STATUS TfliteCastParser::Parse(const std::unique_ptr<tflite::OperatorT> &tflite_ | |||||
| if (attr->srcT == TypeId::kNumberTypeBool) { | if (attr->srcT == TypeId::kNumberTypeBool) { | ||||
| attr->srcT = TypeId::kNumberTypeUInt8; | attr->srcT = TypeId::kNumberTypeUInt8; | ||||
| } | } | ||||
| const auto &out_tensor = tflite_tensors[tflite_op->outputs[0]]; | |||||
| const auto &out_tensor = tflite_model->subgraphs[0]->tensors[tflite_op->outputs[0]]; | |||||
| if (out_tensor == nullptr) { | if (out_tensor == nullptr) { | ||||
| MS_LOG(ERROR) << "tensor is null"; | MS_LOG(ERROR) << "tensor is null"; | ||||
| return RET_NULL_PTR; | return RET_NULL_PTR; | ||||
| @@ -62,10 +59,10 @@ STATUS TfliteCastParser::Parse(const std::unique_ptr<tflite::OperatorT> &tflite_ | |||||
| op->primitive->value.type = schema::PrimitiveType_Cast; | op->primitive->value.type = schema::PrimitiveType_Cast; | ||||
| op->primitive->value.value = attr.release(); | op->primitive->value.value = attr.release(); | ||||
| AddOpInput(op, tensors_id, tensors_format, tensors_id_map, tflite_op->inputs[0], tensors_id->size(), | |||||
| tflite_tensors.size(), schema::Format::Format_NHWC); | |||||
| AddOpOutput(op, tensors_id, tensors_format, tensors_id_map, tflite_op->outputs[0], tensors_id->size(), | |||||
| tflite_tensors.size(), schema::Format::Format_NHWC); | |||||
| AddOpInput(op, tensors_info, tflite_op->inputs[0], tflite_model->subgraphs[0]->tensors.size(), | |||||
| schema::Format::Format_NHWC); | |||||
| AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_model->subgraphs[0]->tensors.size(), | |||||
| schema::Format::Format_NHWC); | |||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| @@ -29,11 +29,8 @@ class TfliteCastParser : public TfliteNodeParser { | |||||
| public: | public: | ||||
| TfliteCastParser() : TfliteNodeParser("Cast") {} | TfliteCastParser() : TfliteNodeParser("Cast") {} | ||||
| STATUS Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||||
| const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors, | |||||
| const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer, schema::CNodeT *op, | |||||
| std::vector<int32_t> *tensors_id, std::vector<schema::Format> *tensors_format, | |||||
| std::map<int, int> *tensors_id_map) override; | |||||
| STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||||
| const std::unique_ptr<tflite::ModelT> &tflite_model, schema::CNodeT *op) override; | |||||
| }; | }; | ||||
| } // namespace lite | } // namespace lite | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -21,11 +21,8 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | namespace lite { | ||||
| STATUS TfliteConcatParser::Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||||
| const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors, | |||||
| const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer, | |||||
| schema::CNodeT *op, std::vector<int32_t> *tensors_id, | |||||
| std::vector<schema::Format> *tensors_format, std::map<int, int> *tensors_id_map) { | |||||
| STATUS TfliteConcatParser::Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||||
| const std::unique_ptr<tflite::ModelT> &tflite_model, schema::CNodeT *op) { | |||||
| MS_LOG(DEBUG) << "parse TfliteConcatParser"; | MS_LOG(DEBUG) << "parse TfliteConcatParser"; | ||||
| if (op == nullptr) { | if (op == nullptr) { | ||||
| MS_LOG(ERROR) << "op is null"; | MS_LOG(ERROR) << "op is null"; | ||||
| @@ -55,11 +52,11 @@ STATUS TfliteConcatParser::Parse(const std::unique_ptr<tflite::OperatorT> &tflit | |||||
| op->primitive->value.value = attr.release(); | op->primitive->value.value = attr.release(); | ||||
| for (size_t i = 0; i < tflite_op->inputs.size(); i++) { | for (size_t i = 0; i < tflite_op->inputs.size(); i++) { | ||||
| AddOpInput(op, tensors_id, tensors_format, tensors_id_map, tflite_op->inputs[i], tensors_id->size(), | |||||
| tflite_tensors.size(), schema::Format::Format_NHWC); | |||||
| AddOpInput(op, tensors_info, tflite_op->inputs[i], tflite_model->subgraphs[0]->tensors.size(), | |||||
| schema::Format::Format_NHWC); | |||||
| } | } | ||||
| AddOpOutput(op, tensors_id, tensors_format, tensors_id_map, tflite_op->outputs[0], tensors_id->size(), | |||||
| tflite_tensors.size(), schema::Format::Format_NHWC); | |||||
| AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_model->subgraphs[0]->tensors.size(), | |||||
| schema::Format::Format_NHWC); | |||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| @@ -29,11 +29,8 @@ class TfliteConcatParser : public TfliteNodeParser { | |||||
| public: | public: | ||||
| TfliteConcatParser() : TfliteNodeParser("Concat") {} | TfliteConcatParser() : TfliteNodeParser("Concat") {} | ||||
| STATUS Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||||
| const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors, | |||||
| const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer, schema::CNodeT *, | |||||
| std::vector<int32_t> *tensors_id, std::vector<schema::Format> *tensors_format, | |||||
| std::map<int, int> *tensors_id_map) override; | |||||
| STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||||
| const std::unique_ptr<tflite::ModelT> &tflite_model, schema::CNodeT *op) override; | |||||
| }; | }; | ||||
| } // namespace lite | } // namespace lite | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -21,11 +21,8 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | namespace lite { | ||||
| STATUS TfliteConvParser::Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||||
| const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors, | |||||
| const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer, | |||||
| schema::CNodeT *op, std::vector<int32_t> *tensors_id, | |||||
| std::vector<schema::Format> *tensors_format, std::map<int, int> *tensors_id_map) { | |||||
| STATUS TfliteConvParser::Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||||
| const std::unique_ptr<tflite::ModelT> &tflite_model, schema::CNodeT *op) { | |||||
| MS_LOG(DEBUG) << "parse TfliteConvParser"; | MS_LOG(DEBUG) << "parse TfliteConvParser"; | ||||
| if (op == nullptr) { | if (op == nullptr) { | ||||
| MS_LOG(ERROR) << "op is null"; | MS_LOG(ERROR) << "op is null"; | ||||
| @@ -60,7 +57,7 @@ STATUS TfliteConvParser::Parse(const std::unique_ptr<tflite::OperatorT> &tflite_ | |||||
| // get the conv op weight tensor | // get the conv op weight tensor | ||||
| auto weight_index = tflite_op->inputs[1]; | auto weight_index = tflite_op->inputs[1]; | ||||
| const auto &weight_tensor = tflite_tensors[weight_index]; | |||||
| const auto &weight_tensor = tflite_model->subgraphs[0]->tensors[weight_index]; | |||||
| if (weight_tensor == nullptr) { | if (weight_tensor == nullptr) { | ||||
| MS_LOG(ERROR) << "the weight tensor is null"; | MS_LOG(ERROR) << "the weight tensor is null"; | ||||
| return RET_NULL_PTR; | return RET_NULL_PTR; | ||||
| @@ -73,7 +70,7 @@ STATUS TfliteConvParser::Parse(const std::unique_ptr<tflite::OperatorT> &tflite_ | |||||
| // calculate pad params | // calculate pad params | ||||
| auto data_index = tflite_op->inputs[0]; | auto data_index = tflite_op->inputs[0]; | ||||
| const auto &data_tensor = tflite_tensors[data_index]; | |||||
| const auto &data_tensor = tflite_model->subgraphs[0]->tensors[data_index]; | |||||
| std::vector<int> params; | std::vector<int> params; | ||||
| if (getPaddingParam(data_tensor, attr->padMode, attr->strideH, attr->strideW, attr->kernelH, attr->kernelW, | if (getPaddingParam(data_tensor, attr->padMode, attr->strideH, attr->strideW, attr->kernelH, attr->kernelW, | ||||
| ¶ms) != RET_OK) { | ¶ms) != RET_OK) { | ||||
| @@ -89,14 +86,14 @@ STATUS TfliteConvParser::Parse(const std::unique_ptr<tflite::OperatorT> &tflite_ | |||||
| op->primitive->value.type = schema::PrimitiveType_Conv2D; | op->primitive->value.type = schema::PrimitiveType_Conv2D; | ||||
| op->primitive->value.value = attr.release(); | op->primitive->value.value = attr.release(); | ||||
| AddOpInput(op, tensors_id, tensors_format, tensors_id_map, tflite_op->inputs[0], tensors_id->size(), | |||||
| tflite_tensors.size(), schema::Format::Format_NHWC); | |||||
| AddOpInput(op, tensors_id, tensors_format, tensors_id_map, tflite_op->inputs[1], tensors_id->size(), | |||||
| tflite_tensors.size(), schema::Format::Format_KHWC); | |||||
| AddOpInput(op, tensors_id, tensors_format, tensors_id_map, tflite_op->inputs[2], tensors_id->size(), | |||||
| tflite_tensors.size(), schema::Format::Format_NHWC); | |||||
| AddOpOutput(op, tensors_id, tensors_format, tensors_id_map, tflite_op->outputs[0], tensors_id->size(), | |||||
| tflite_tensors.size(), schema::Format::Format_NHWC); | |||||
| AddOpInput(op, tensors_info, tflite_op->inputs[0], tflite_model->subgraphs[0]->tensors.size(), | |||||
| schema::Format::Format_NHWC); | |||||
| AddOpInput(op, tensors_info, tflite_op->inputs[1], tflite_model->subgraphs[0]->tensors.size(), | |||||
| schema::Format::Format_KHWC); | |||||
| AddOpInput(op, tensors_info, tflite_op->inputs[2], tflite_model->subgraphs[0]->tensors.size(), | |||||
| schema::Format::Format_NHWC); | |||||
| AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_model->subgraphs[0]->tensors.size(), | |||||
| schema::Format::Format_NHWC); | |||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| @@ -29,11 +29,8 @@ class TfliteConvParser : public TfliteNodeParser { | |||||
| public: | public: | ||||
| TfliteConvParser() : TfliteNodeParser("Conv2D") {} | TfliteConvParser() : TfliteNodeParser("Conv2D") {} | ||||
| STATUS Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||||
| const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors, | |||||
| const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer, schema::CNodeT *op, | |||||
| std::vector<int32_t> *tensors_id, std::vector<schema::Format> *tensors_format, | |||||
| std::map<int, int> *tensors_id_map) override; | |||||
| STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||||
| const std::unique_ptr<tflite::ModelT> &tflite_model, schema::CNodeT *op) override; | |||||
| }; | }; | ||||
| } // namespace lite | } // namespace lite | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -23,11 +23,8 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | namespace lite { | ||||
| STATUS TfliteCustomParser::Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||||
| const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors, | |||||
| const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer, | |||||
| schema::CNodeT *op, std::vector<int32_t> *tensors_id, | |||||
| std::vector<schema::Format> *tensors_format, std::map<int, int> *tensors_id_map) { | |||||
| STATUS TfliteCustomParser::Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||||
| const std::unique_ptr<tflite::ModelT> &tflite_model, schema::CNodeT *op) { | |||||
| MS_LOG(DEBUG) << "parse TfliteCustomParser"; | MS_LOG(DEBUG) << "parse TfliteCustomParser"; | ||||
| if (op == nullptr) { | if (op == nullptr) { | ||||
| MS_LOG(ERROR) << "op is null"; | MS_LOG(ERROR) << "op is null"; | ||||
| @@ -78,12 +75,12 @@ STATUS TfliteCustomParser::Parse(const std::unique_ptr<tflite::OperatorT> &tflit | |||||
| op->primitive->value.value = attr.release(); | op->primitive->value.value = attr.release(); | ||||
| for (size_t i = 0; i < tflite_op->inputs.size(); ++i) { | for (size_t i = 0; i < tflite_op->inputs.size(); ++i) { | ||||
| AddOpInput(op, tensors_id, tensors_format, tensors_id_map, tflite_op->inputs[i], tensors_id->size(), | |||||
| tflite_tensors.size(), schema::Format_NHWC); | |||||
| AddOpInput(op, tensors_info, tflite_op->inputs[i], tflite_model->subgraphs[0]->tensors.size(), | |||||
| schema::Format::Format_NHWC); | |||||
| } | } | ||||
| for (size_t i = 0; i < tflite_op->outputs.size(); ++i) { | for (size_t i = 0; i < tflite_op->outputs.size(); ++i) { | ||||
| AddOpOutput(op, tensors_id, tensors_format, tensors_id_map, tflite_op->outputs[i], tensors_id->size(), | |||||
| tflite_tensors.size(), schema::Format_NHWC); | |||||
| AddOpOutput(op, tensors_info, tflite_op->outputs[i], tflite_model->subgraphs[0]->tensors.size(), | |||||
| schema::Format::Format_NHWC); | |||||
| } | } | ||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| @@ -29,11 +29,8 @@ class TfliteCustomParser : public TfliteNodeParser { | |||||
| public: | public: | ||||
| TfliteCustomParser() : TfliteNodeParser("Custom") {} | TfliteCustomParser() : TfliteNodeParser("Custom") {} | ||||
| STATUS Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||||
| const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors, | |||||
| const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer, schema::CNodeT *op, | |||||
| std::vector<int32_t> *tensors_id, std::vector<schema::Format> *tensors_format, | |||||
| std::map<int, int> *tensors_id_map) override; | |||||
| STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||||
| const std::unique_ptr<tflite::ModelT> &tflite_model, schema::CNodeT *op) override; | |||||
| }; | }; | ||||
| } // namespace lite | } // namespace lite | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -21,11 +21,8 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | namespace lite { | ||||
| STATUS TfliteDeConvParser::Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||||
| const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors, | |||||
| const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer, | |||||
| schema::CNodeT *op, std::vector<int32_t> *tensors_id, | |||||
| std::vector<schema::Format> *tensors_format, std::map<int, int> *tensors_id_map) { | |||||
| STATUS TfliteDeConvParser::Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||||
| const std::unique_ptr<tflite::ModelT> &tflite_model, schema::CNodeT *op) { | |||||
| MS_LOG(DEBUG) << "parse tflite Transpose_Conv parser"; | MS_LOG(DEBUG) << "parse tflite Transpose_Conv parser"; | ||||
| if (op == nullptr) { | if (op == nullptr) { | ||||
| MS_LOG(ERROR) << "op is null"; | MS_LOG(ERROR) << "op is null"; | ||||
| @@ -61,7 +58,7 @@ STATUS TfliteDeConvParser::Parse(const std::unique_ptr<tflite::OperatorT> &tflit | |||||
| // get the conv op weight tensor | // get the conv op weight tensor | ||||
| auto weight_index = tflite_op->inputs[1]; | auto weight_index = tflite_op->inputs[1]; | ||||
| const auto &weight_tensor = tflite_tensors[weight_index]; | |||||
| const auto &weight_tensor = tflite_model->subgraphs[0]->tensors[weight_index]; | |||||
| if (weight_tensor == nullptr) { | if (weight_tensor == nullptr) { | ||||
| MS_LOG(ERROR) << "the weight tensor is null"; | MS_LOG(ERROR) << "the weight tensor is null"; | ||||
| return RET_NULL_PTR; | return RET_NULL_PTR; | ||||
| @@ -74,7 +71,7 @@ STATUS TfliteDeConvParser::Parse(const std::unique_ptr<tflite::OperatorT> &tflit | |||||
| // calculate pad params | // calculate pad params | ||||
| auto data_index = tflite_op->inputs[2]; | auto data_index = tflite_op->inputs[2]; | ||||
| const auto &data_tensor = tflite_tensors[data_index]; | |||||
| const auto &data_tensor = tflite_model->subgraphs[0]->tensors[data_index]; | |||||
| std::vector<int> params; | std::vector<int> params; | ||||
| if (getPaddingParam(data_tensor, attr->padMode, attr->strideH, attr->strideW, attr->kernelH, attr->kernelW, | if (getPaddingParam(data_tensor, attr->padMode, attr->strideH, attr->strideW, attr->kernelH, attr->kernelW, | ||||
| ¶ms) != RET_OK) { | ¶ms) != RET_OK) { | ||||
| @@ -90,12 +87,12 @@ STATUS TfliteDeConvParser::Parse(const std::unique_ptr<tflite::OperatorT> &tflit | |||||
| op->primitive->value.type = schema::PrimitiveType_DeConv2D; | op->primitive->value.type = schema::PrimitiveType_DeConv2D; | ||||
| op->primitive->value.value = attr.release(); | op->primitive->value.value = attr.release(); | ||||
| AddOpInput(op, tensors_id, tensors_format, tensors_id_map, tflite_op->inputs[2], tensors_id->size(), | |||||
| tflite_tensors.size(), schema::Format::Format_NHWC); | |||||
| AddOpInput(op, tensors_id, tensors_format, tensors_id_map, tflite_op->inputs[1], tensors_id->size(), | |||||
| tflite_tensors.size(), schema::Format::Format_KHWC); | |||||
| AddOpOutput(op, tensors_id, tensors_format, tensors_id_map, tflite_op->outputs[0], tensors_id->size(), | |||||
| tflite_tensors.size(), schema::Format::Format_NHWC); | |||||
| AddOpInput(op, tensors_info, tflite_op->inputs[2], tflite_model->subgraphs[0]->tensors.size(), | |||||
| schema::Format::Format_NHWC); | |||||
| AddOpInput(op, tensors_info, tflite_op->inputs[1], tflite_model->subgraphs[0]->tensors.size(), | |||||
| schema::Format::Format_KHWC); | |||||
| AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_model->subgraphs[0]->tensors.size(), | |||||
| schema::Format::Format_NHWC); | |||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| @@ -29,11 +29,8 @@ class TfliteDeConvParser : public TfliteNodeParser { | |||||
| public: | public: | ||||
| TfliteDeConvParser() : TfliteNodeParser("DeConv2D") {} | TfliteDeConvParser() : TfliteNodeParser("DeConv2D") {} | ||||
| STATUS Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||||
| const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors, | |||||
| const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer, schema::CNodeT *op, | |||||
| std::vector<int32_t> *tensors_id, std::vector<schema::Format> *tensors_format, | |||||
| std::map<int, int> *tensors_id_map) override; | |||||
| STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||||
| const std::unique_ptr<tflite::ModelT> &tflite_model, schema::CNodeT *op) override; | |||||
| }; | }; | ||||
| } // namespace lite | } // namespace lite | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -22,12 +22,9 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | namespace lite { | ||||
| STATUS TfliteDepthToSpaceParser::Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||||
| const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors, | |||||
| const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer, | |||||
| schema::CNodeT *op, std::vector<int32_t> *tensors_id, | |||||
| std::vector<schema::Format> *tensors_format, | |||||
| std::map<int, int> *tensors_id_map) { | |||||
| STATUS TfliteDepthToSpaceParser::Parse(TfliteTensorsInfo *tensors_info, | |||||
| const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||||
| const std::unique_ptr<tflite::ModelT> &tflite_model, schema::CNodeT *op) { | |||||
| MS_LOG(DEBUG) << "parse TfliteDepthToSpaceParser"; | MS_LOG(DEBUG) << "parse TfliteDepthToSpaceParser"; | ||||
| if (op == nullptr) { | if (op == nullptr) { | ||||
| @@ -57,10 +54,10 @@ STATUS TfliteDepthToSpaceParser::Parse(const std::unique_ptr<tflite::OperatorT> | |||||
| op->primitive->value.type = schema::PrimitiveType_DepthToSpace; | op->primitive->value.type = schema::PrimitiveType_DepthToSpace; | ||||
| op->primitive->value.value = attr.release(); | op->primitive->value.value = attr.release(); | ||||
| AddOpInput(op, tensors_id, tensors_format, tensors_id_map, tflite_op->inputs[0], tensors_id->size(), | |||||
| tflite_tensors.size(), schema::Format::Format_NHWC); | |||||
| AddOpOutput(op, tensors_id, tensors_format, tensors_id_map, tflite_op->outputs[0], tensors_id->size(), | |||||
| tflite_tensors.size(), schema::Format::Format_NHWC); | |||||
| AddOpInput(op, tensors_info, tflite_op->inputs[0], tflite_model->subgraphs[0]->tensors.size(), | |||||
| schema::Format::Format_NHWC); | |||||
| AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_model->subgraphs[0]->tensors.size(), | |||||
| schema::Format::Format_NHWC); | |||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| @@ -29,11 +29,8 @@ class TfliteDepthToSpaceParser : public TfliteNodeParser { | |||||
| public: | public: | ||||
| TfliteDepthToSpaceParser() : TfliteNodeParser("DepthToSpace") {} | TfliteDepthToSpaceParser() : TfliteNodeParser("DepthToSpace") {} | ||||
| STATUS Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||||
| const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors, | |||||
| const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer, schema::CNodeT *op, | |||||
| std::vector<int32_t> *tensors_id, std::vector<schema::Format> *tensors_format, | |||||
| std::map<int, int> *tensors_id_map) override; | |||||
| STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||||
| const std::unique_ptr<tflite::ModelT> &tflite_model, schema::CNodeT *op) override; | |||||
| }; | }; | ||||
| } // namespace lite | } // namespace lite | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -21,12 +21,9 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | namespace lite { | ||||
| STATUS TfliteDepthwiseConv2DParser::Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||||
| const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors, | |||||
| const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer, | |||||
| schema::CNodeT *op, std::vector<int32_t> *tensors_id, | |||||
| std::vector<schema::Format> *tensors_format, | |||||
| std::map<int, int> *tensors_id_map) { | |||||
| STATUS TfliteDepthwiseConv2DParser::Parse(TfliteTensorsInfo *tensors_info, | |||||
| const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||||
| const std::unique_ptr<tflite::ModelT> &tflite_model, schema::CNodeT *op) { | |||||
| MS_LOG(DEBUG) << "parse TfliteDepthwiseConv2DParser"; | MS_LOG(DEBUG) << "parse TfliteDepthwiseConv2DParser"; | ||||
| if (op == nullptr) { | if (op == nullptr) { | ||||
| MS_LOG(ERROR) << "op is null"; | MS_LOG(ERROR) << "op is null"; | ||||
| @@ -61,7 +58,7 @@ STATUS TfliteDepthwiseConv2DParser::Parse(const std::unique_ptr<tflite::Operator | |||||
| // get the data tensor | // get the data tensor | ||||
| auto data_index = tflite_op->inputs[1]; | auto data_index = tflite_op->inputs[1]; | ||||
| const auto &data_tensor = tflite_tensors[data_index]; | |||||
| const auto &data_tensor = tflite_model->subgraphs[0]->tensors[data_index]; | |||||
| if (data_tensor == nullptr) { | if (data_tensor == nullptr) { | ||||
| MS_LOG(ERROR) << "the data tensor is null"; | MS_LOG(ERROR) << "the data tensor is null"; | ||||
| return RET_NULL_PTR; | return RET_NULL_PTR; | ||||
| @@ -71,7 +68,7 @@ STATUS TfliteDepthwiseConv2DParser::Parse(const std::unique_ptr<tflite::Operator | |||||
| // get the weight tensor | // get the weight tensor | ||||
| auto weight_index = tflite_op->inputs[1]; | auto weight_index = tflite_op->inputs[1]; | ||||
| const auto &weight_tensor = tflite_tensors[weight_index]; | |||||
| const auto &weight_tensor = tflite_model->subgraphs[0]->tensors[weight_index]; | |||||
| if (weight_tensor == nullptr) { | if (weight_tensor == nullptr) { | ||||
| MS_LOG(ERROR) << "the weight tensor is null"; | MS_LOG(ERROR) << "the weight tensor is null"; | ||||
| return RET_NULL_PTR; | return RET_NULL_PTR; | ||||
| @@ -96,14 +93,14 @@ STATUS TfliteDepthwiseConv2DParser::Parse(const std::unique_ptr<tflite::Operator | |||||
| op->primitive->value.type = schema::PrimitiveType_DepthwiseConv2D; | op->primitive->value.type = schema::PrimitiveType_DepthwiseConv2D; | ||||
| op->primitive->value.value = attr.release(); | op->primitive->value.value = attr.release(); | ||||
| AddOpInput(op, tensors_id, tensors_format, tensors_id_map, tflite_op->inputs[0], tensors_id->size(), | |||||
| tflite_tensors.size(), schema::Format::Format_NHWC); | |||||
| AddOpInput(op, tensors_id, tensors_format, tensors_id_map, tflite_op->inputs[1], tensors_id->size(), | |||||
| tflite_tensors.size(), schema::Format::Format_KHWC); | |||||
| AddOpInput(op, tensors_id, tensors_format, tensors_id_map, tflite_op->inputs[2], tensors_id->size(), | |||||
| tflite_tensors.size(), schema::Format::Format_NHWC); | |||||
| AddOpOutput(op, tensors_id, tensors_format, tensors_id_map, tflite_op->outputs[0], tensors_id->size(), | |||||
| tflite_tensors.size(), schema::Format::Format_NHWC); | |||||
| AddOpInput(op, tensors_info, tflite_op->inputs[0], tflite_model->subgraphs[0]->tensors.size(), | |||||
| schema::Format::Format_NHWC); | |||||
| AddOpInput(op, tensors_info, tflite_op->inputs[1], tflite_model->subgraphs[0]->tensors.size(), | |||||
| schema::Format::Format_KHWC); | |||||
| AddOpInput(op, tensors_info, tflite_op->inputs[2], tflite_model->subgraphs[0]->tensors.size(), | |||||
| schema::Format::Format_NHWC); | |||||
| AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_model->subgraphs[0]->tensors.size(), | |||||
| schema::Format::Format_NHWC); | |||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| @@ -29,11 +29,8 @@ class TfliteDepthwiseConv2DParser : public TfliteNodeParser { | |||||
| public: | public: | ||||
| TfliteDepthwiseConv2DParser() : TfliteNodeParser("DepthwiseConv2D") {} | TfliteDepthwiseConv2DParser() : TfliteNodeParser("DepthwiseConv2D") {} | ||||
| STATUS Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||||
| const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors, | |||||
| const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer, schema::CNodeT *op, | |||||
| std::vector<int32_t> *tensors_id, std::vector<schema::Format> *tensors_format, | |||||
| std::map<int, int> *tensors_id_map) override; | |||||
| STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||||
| const std::unique_ptr<tflite::ModelT> &tflite_model, schema::CNodeT *op) override; | |||||
| }; | }; | ||||
| } // namespace lite | } // namespace lite | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -20,11 +20,9 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | namespace lite { | ||||
| STATUS TfliteDequantizeParser::Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||||
| const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors, | |||||
| const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer, | |||||
| schema::CNodeT *op, std::vector<int32_t> *tensors_id, | |||||
| std::vector<schema::Format> *tensors_format, std::map<int, int> *tensors_id_map) { | |||||
| STATUS TfliteDequantizeParser::Parse(TfliteTensorsInfo *tensors_info, | |||||
| const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||||
| const std::unique_ptr<tflite::ModelT> &tflite_model, schema::CNodeT *op) { | |||||
| MS_LOG(DEBUG) << "parse TfliteDequantizeNParser"; | MS_LOG(DEBUG) << "parse TfliteDequantizeNParser"; | ||||
| if (op == nullptr) { | if (op == nullptr) { | ||||
| MS_LOG(ERROR) << "op is null"; | MS_LOG(ERROR) << "op is null"; | ||||
| @@ -36,12 +34,12 @@ STATUS TfliteDequantizeParser::Parse(const std::unique_ptr<tflite::OperatorT> &t | |||||
| return RET_NULL_PTR; | return RET_NULL_PTR; | ||||
| } | } | ||||
| const auto &in_tensor = tflite_tensors[tflite_op->inputs[0]]; | |||||
| const auto &in_tensor = tflite_model->subgraphs[0]->tensors[tflite_op->inputs[0]]; | |||||
| if (in_tensor == nullptr) { | if (in_tensor == nullptr) { | ||||
| MS_LOG(ERROR) << "input tensor is null"; | MS_LOG(ERROR) << "input tensor is null"; | ||||
| return RET_NULL_PTR; | return RET_NULL_PTR; | ||||
| } | } | ||||
| const auto &out_tensor = tflite_tensors[tflite_op->outputs[0]]; | |||||
| const auto &out_tensor = tflite_model->subgraphs[0]->tensors[tflite_op->outputs[0]]; | |||||
| if (out_tensor == nullptr) { | if (out_tensor == nullptr) { | ||||
| MS_LOG(ERROR) << "output tensor is null"; | MS_LOG(ERROR) << "output tensor is null"; | ||||
| return RET_NULL_PTR; | return RET_NULL_PTR; | ||||
| @@ -70,10 +68,10 @@ STATUS TfliteDequantizeParser::Parse(const std::unique_ptr<tflite::OperatorT> &t | |||||
| op->primitive->value.type = schema::PrimitiveType_Cast; | op->primitive->value.type = schema::PrimitiveType_Cast; | ||||
| } | } | ||||
| AddOpInput(op, tensors_id, tensors_format, tensors_id_map, tflite_op->inputs[0], tensors_id->size(), | |||||
| tflite_tensors.size(), schema::Format::Format_NHWC); | |||||
| AddOpOutput(op, tensors_id, tensors_format, tensors_id_map, tflite_op->outputs[0], tensors_id->size(), | |||||
| tflite_tensors.size(), schema::Format::Format_NHWC); | |||||
| AddOpInput(op, tensors_info, tflite_op->inputs[0], tflite_model->subgraphs[0]->tensors.size(), | |||||
| schema::Format::Format_NHWC); | |||||
| AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_model->subgraphs[0]->tensors.size(), | |||||
| schema::Format::Format_NHWC); | |||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| @@ -28,11 +28,8 @@ class TfliteDequantizeParser : public TfliteNodeParser { | |||||
| public: | public: | ||||
| TfliteDequantizeParser() : TfliteNodeParser("Dequantize") {} | TfliteDequantizeParser() : TfliteNodeParser("Dequantize") {} | ||||
| STATUS Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||||
| const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors, | |||||
| const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer, schema::CNodeT *op, | |||||
| std::vector<int32_t> *tensors_id, std::vector<schema::Format> *tensors_format, | |||||
| std::map<int, int> *tensors_id_map) override; | |||||
| STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||||
| const std::unique_ptr<tflite::ModelT> &tflite_model, schema::CNodeT *op) override; | |||||
| }; | }; | ||||
| } // namespace lite | } // namespace lite | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -21,11 +21,9 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | namespace lite { | ||||
| STATUS TfliteExpandDimsParser::Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||||
| const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors, | |||||
| const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer, | |||||
| schema::CNodeT *op, std::vector<int32_t> *tensors_id, | |||||
| std::vector<schema::Format> *tensors_format, std::map<int, int> *tensors_id_map) { | |||||
| STATUS TfliteExpandDimsParser::Parse(TfliteTensorsInfo *tensors_info, | |||||
| const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||||
| const std::unique_ptr<tflite::ModelT> &tflite_model, schema::CNodeT *op) { | |||||
| MS_LOG(DEBUG) << "parse TfliteExpandDimsParser"; | MS_LOG(DEBUG) << "parse TfliteExpandDimsParser"; | ||||
| if (op == nullptr) { | if (op == nullptr) { | ||||
| MS_LOG(ERROR) << "op is null"; | MS_LOG(ERROR) << "op is null"; | ||||
| @@ -43,17 +41,17 @@ STATUS TfliteExpandDimsParser::Parse(const std::unique_ptr<tflite::OperatorT> &t | |||||
| return RET_NULL_PTR; | return RET_NULL_PTR; | ||||
| } | } | ||||
| std::vector<int> dims; | std::vector<int> dims; | ||||
| if (GetTfliteData(tflite_op->inputs[1], tflite_tensors, tflite_model_buffer, dims)) { | |||||
| if (GetTfliteData(tflite_op->inputs[1], tflite_model->subgraphs[0]->tensors, tflite_model->buffers, dims)) { | |||||
| MS_LOG(ERROR) << "get expand_dims -> dim failed"; | MS_LOG(ERROR) << "get expand_dims -> dim failed"; | ||||
| return RET_ERROR; | return RET_ERROR; | ||||
| } | } | ||||
| attr->dim = dims[0]; | attr->dim = dims[0]; | ||||
| op->primitive->value.type = schema::PrimitiveType_ExpandDims; | op->primitive->value.type = schema::PrimitiveType_ExpandDims; | ||||
| op->primitive->value.value = attr.release(); | op->primitive->value.value = attr.release(); | ||||
| AddOpInput(op, tensors_id, tensors_format, tensors_id_map, tflite_op->inputs[0], tensors_id->size(), | |||||
| tflite_tensors.size(), schema::Format::Format_NHWC); | |||||
| AddOpOutput(op, tensors_id, tensors_format, tensors_id_map, tflite_op->outputs[0], tensors_id->size(), | |||||
| tflite_tensors.size(), schema::Format::Format_NHWC); | |||||
| AddOpInput(op, tensors_info, tflite_op->inputs[0], tflite_model->subgraphs[0]->tensors.size(), | |||||
| schema::Format::Format_NHWC); | |||||
| AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_model->subgraphs[0]->tensors.size(), | |||||
| schema::Format::Format_NHWC); | |||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| TfliteNodeRegister g_tfliteExpandDimsParser("ExpandDims", new TfliteExpandDimsParser()); | TfliteNodeRegister g_tfliteExpandDimsParser("ExpandDims", new TfliteExpandDimsParser()); | ||||
| @@ -29,11 +29,8 @@ class TfliteExpandDimsParser : public TfliteNodeParser { | |||||
| public: | public: | ||||
| TfliteExpandDimsParser() : TfliteNodeParser("ExpandDims") {} | TfliteExpandDimsParser() : TfliteNodeParser("ExpandDims") {} | ||||
| STATUS Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||||
| const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors, | |||||
| const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer, schema::CNodeT *op, | |||||
| std::vector<int32_t> *tensors_id, std::vector<schema::Format> *tensors_format, | |||||
| std::map<int, int> *tensors_id_map) override; | |||||
| STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||||
| const std::unique_ptr<tflite::ModelT> &tflite_model, schema::CNodeT *op) override; | |||||
| }; | }; | ||||
| } // namespace lite | } // namespace lite | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -21,11 +21,8 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | namespace lite { | ||||
| STATUS TfliteFillParser::Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||||
| const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors, | |||||
| const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer, | |||||
| schema::CNodeT *op, std::vector<int32_t> *tensors_id, | |||||
| std::vector<schema::Format> *tensors_format, std::map<int, int> *tensors_id_map) { | |||||
| STATUS TfliteFillParser::Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||||
| const std::unique_ptr<tflite::ModelT> &tflite_model, schema::CNodeT *op) { | |||||
| MS_LOG(DEBUG) << "parse TfliteFillParser"; | MS_LOG(DEBUG) << "parse TfliteFillParser"; | ||||
| if (op == nullptr) { | if (op == nullptr) { | ||||
| MS_LOG(ERROR) << "op is null"; | MS_LOG(ERROR) << "op is null"; | ||||
| @@ -44,7 +41,7 @@ STATUS TfliteFillParser::Parse(const std::unique_ptr<tflite::OperatorT> &tflite_ | |||||
| } | } | ||||
| if (tflite_op->inputs.size() > 1) { | if (tflite_op->inputs.size() > 1) { | ||||
| if (GetTfliteData(tflite_op->inputs[1], tflite_tensors, tflite_model_buffer, attr->dims)) { | |||||
| if (GetTfliteData(tflite_op->inputs[1], tflite_model->subgraphs[0]->tensors, tflite_model->buffers, attr->dims)) { | |||||
| MS_LOG(ERROR) << "get fill -> dims failed"; | MS_LOG(ERROR) << "get fill -> dims failed"; | ||||
| return RET_ERROR; | return RET_ERROR; | ||||
| } | } | ||||
| @@ -53,10 +50,10 @@ STATUS TfliteFillParser::Parse(const std::unique_ptr<tflite::OperatorT> &tflite_ | |||||
| op->primitive->value.type = schema::PrimitiveType_Fill; | op->primitive->value.type = schema::PrimitiveType_Fill; | ||||
| op->primitive->value.value = attr.release(); | op->primitive->value.value = attr.release(); | ||||
| AddOpInput(op, tensors_id, tensors_format, tensors_id_map, tflite_op->inputs[0], tensors_id->size(), | |||||
| tflite_tensors.size(), schema::Format::Format_NHWC); | |||||
| AddOpOutput(op, tensors_id, tensors_format, tensors_id_map, tflite_op->outputs[0], tensors_id->size(), | |||||
| tflite_tensors.size(), schema::Format::Format_NHWC); | |||||
| AddOpInput(op, tensors_info, tflite_op->inputs[0], tflite_model->subgraphs[0]->tensors.size(), | |||||
| schema::Format::Format_NHWC); | |||||
| AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_model->subgraphs[0]->tensors.size(), | |||||
| schema::Format::Format_NHWC); | |||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| @@ -29,11 +29,8 @@ class TfliteFillParser : public TfliteNodeParser { | |||||
| public: | public: | ||||
| TfliteFillParser() : TfliteNodeParser("Fill") {} | TfliteFillParser() : TfliteNodeParser("Fill") {} | ||||
| STATUS Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||||
| const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors, | |||||
| const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer, schema::CNodeT *op, | |||||
| std::vector<int32_t> *tensors_id, std::vector<schema::Format> *tensors_format, | |||||
| std::map<int, int> *tensors_id_map) override; | |||||
| STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||||
| const std::unique_ptr<tflite::ModelT> &tflite_model, schema::CNodeT *op) override; | |||||
| }; | }; | ||||
| } // namespace lite | } // namespace lite | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -21,12 +21,9 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | namespace lite { | ||||
| STATUS TfliteFullyConnectedParser::Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||||
| const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors, | |||||
| const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer, | |||||
| schema::CNodeT *op, std::vector<int32_t> *tensors_id, | |||||
| std::vector<schema::Format> *tensors_format, | |||||
| std::map<int, int> *tensors_id_map) { | |||||
| STATUS TfliteFullyConnectedParser::Parse(TfliteTensorsInfo *tensors_info, | |||||
| const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||||
| const std::unique_ptr<tflite::ModelT> &tflite_model, schema::CNodeT *op) { | |||||
| MS_LOG(DEBUG) << "parse TfliteFullyConnectedParser"; | MS_LOG(DEBUG) << "parse TfliteFullyConnectedParser"; | ||||
| if (op == nullptr) { | if (op == nullptr) { | ||||
| MS_LOG(ERROR) << "op is null"; | MS_LOG(ERROR) << "op is null"; | ||||
| @@ -60,16 +57,16 @@ STATUS TfliteFullyConnectedParser::Parse(const std::unique_ptr<tflite::OperatorT | |||||
| op->primitive->value.type = schema::PrimitiveType_FullConnection; | op->primitive->value.type = schema::PrimitiveType_FullConnection; | ||||
| op->primitive->value.value = attr.release(); | op->primitive->value.value = attr.release(); | ||||
| AddOpInput(op, tensors_id, tensors_format, tensors_id_map, tflite_op->inputs[0], tensors_id->size(), | |||||
| tflite_tensors.size(), schema::Format::Format_NHWC); | |||||
| AddOpInput(op, tensors_id, tensors_format, tensors_id_map, tflite_op->inputs[1], tensors_id->size(), | |||||
| tflite_tensors.size(), schema::Format::Format_KHWC); | |||||
| AddOpInput(op, tensors_info, tflite_op->inputs[0], tflite_model->subgraphs[0]->tensors.size(), | |||||
| schema::Format::Format_NHWC); | |||||
| AddOpInput(op, tensors_info, tflite_op->inputs[1], tflite_model->subgraphs[0]->tensors.size(), | |||||
| schema::Format::Format_KHWC); | |||||
| if (hasBias) { | if (hasBias) { | ||||
| AddOpInput(op, tensors_id, tensors_format, tensors_id_map, tflite_op->inputs[2], tensors_id->size(), | |||||
| tflite_tensors.size(), schema::Format::Format_NHWC); | |||||
| AddOpInput(op, tensors_info, tflite_op->inputs[2], tflite_model->subgraphs[0]->tensors.size(), | |||||
| schema::Format::Format_NHWC); | |||||
| } | } | ||||
| AddOpOutput(op, tensors_id, tensors_format, tensors_id_map, tflite_op->outputs[0], tensors_id->size(), | |||||
| tflite_tensors.size(), schema::Format::Format_NHWC); | |||||
| AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_model->subgraphs[0]->tensors.size(), | |||||
| schema::Format::Format_NHWC); | |||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| @@ -29,11 +29,8 @@ class TfliteFullyConnectedParser : public TfliteNodeParser { | |||||
| public: | public: | ||||
| TfliteFullyConnectedParser() : TfliteNodeParser("FullyConnected") {} | TfliteFullyConnectedParser() : TfliteNodeParser("FullyConnected") {} | ||||
| STATUS Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||||
| const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors, | |||||
| const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer, schema::CNodeT *op, | |||||
| std::vector<int32_t> *tensors_id, std::vector<schema::Format> *tensors_format, | |||||
| std::map<int, int> *tensors_id_map) override; | |||||
| STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||||
| const std::unique_ptr<tflite::ModelT> &tflite_model, schema::CNodeT *op) override; | |||||
| }; | }; | ||||
| class TfliteFakeQuantParser : public TfliteFullyConnectedParser { | class TfliteFakeQuantParser : public TfliteFullyConnectedParser { | ||||
| @@ -21,11 +21,8 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | namespace lite { | ||||
| STATUS TfliteGatherNdParser::Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||||
| const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors, | |||||
| const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer, | |||||
| schema::CNodeT *op, std::vector<int32_t> *tensors_id, | |||||
| std::vector<schema::Format> *tensors_format, std::map<int, int> *tensors_id_map) { | |||||
| STATUS TfliteGatherNdParser::Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||||
| const std::unique_ptr<tflite::ModelT> &tflite_model, schema::CNodeT *op) { | |||||
| MS_LOG(DEBUG) << "parse TfliteGatherNdParser"; | MS_LOG(DEBUG) << "parse TfliteGatherNdParser"; | ||||
| if (op == nullptr) { | if (op == nullptr) { | ||||
| MS_LOG(ERROR) << "op is null"; | MS_LOG(ERROR) << "op is null"; | ||||
| @@ -49,11 +46,11 @@ STATUS TfliteGatherNdParser::Parse(const std::unique_ptr<tflite::OperatorT> &tfl | |||||
| op->primitive->value.value = attr.release(); | op->primitive->value.value = attr.release(); | ||||
| for (size_t i = 0; i < tflite_op->inputs.size(); i++) { | for (size_t i = 0; i < tflite_op->inputs.size(); i++) { | ||||
| AddOpInput(op, tensors_id, tensors_format, tensors_id_map, tflite_op->inputs[i], tensors_id->size(), | |||||
| tflite_tensors.size(), schema::Format::Format_NHWC); | |||||
| AddOpInput(op, tensors_info, tflite_op->inputs[i], tflite_model->subgraphs[0]->tensors.size(), | |||||
| schema::Format::Format_NHWC); | |||||
| } | } | ||||
| AddOpOutput(op, tensors_id, tensors_format, tensors_id_map, tflite_op->outputs[0], tensors_id->size(), | |||||
| tflite_tensors.size(), schema::Format::Format_NHWC); | |||||
| AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_model->subgraphs[0]->tensors.size(), | |||||
| schema::Format::Format_NHWC); | |||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| @@ -29,11 +29,8 @@ class TfliteGatherNdParser : public TfliteNodeParser { | |||||
| public: | public: | ||||
| TfliteGatherNdParser() : TfliteNodeParser("GatherND") {} | TfliteGatherNdParser() : TfliteNodeParser("GatherND") {} | ||||
| STATUS Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||||
| const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors, | |||||
| const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer, schema::CNodeT *op, | |||||
| std::vector<int32_t> *tensors_id, std::vector<schema::Format> *tensors_format, | |||||
| std::map<int, int> *tensors_id_map) override; | |||||
| STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||||
| const std::unique_ptr<tflite::ModelT> &tflite_model, schema::CNodeT *op) override; | |||||
| }; | }; | ||||
| } // namespace lite | } // namespace lite | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -21,11 +21,8 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | namespace lite { | ||||
| STATUS TfliteGatherParser::Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||||
| const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors, | |||||
| const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer, | |||||
| schema::CNodeT *op, std::vector<int32_t> *tensors_id, | |||||
| std::vector<schema::Format> *tensors_format, std::map<int, int> *tensors_id_map) { | |||||
| STATUS TfliteGatherParser::Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||||
| const std::unique_ptr<tflite::ModelT> &tflite_model, schema::CNodeT *op) { | |||||
| MS_LOG(DEBUG) << "parse TfliteGatherParser"; | MS_LOG(DEBUG) << "parse TfliteGatherParser"; | ||||
| if (op == nullptr) { | if (op == nullptr) { | ||||
| MS_LOG(ERROR) << "op is null"; | MS_LOG(ERROR) << "op is null"; | ||||
| @@ -55,11 +52,11 @@ STATUS TfliteGatherParser::Parse(const std::unique_ptr<tflite::OperatorT> &tflit | |||||
| op->primitive->value.value = attr.release(); | op->primitive->value.value = attr.release(); | ||||
| for (size_t i = 0; i < tflite_op->inputs.size(); i++) { | for (size_t i = 0; i < tflite_op->inputs.size(); i++) { | ||||
| AddOpInput(op, tensors_id, tensors_format, tensors_id_map, tflite_op->inputs[i], tensors_id->size(), | |||||
| tflite_tensors.size(), schema::Format::Format_NHWC); | |||||
| AddOpInput(op, tensors_info, tflite_op->inputs[i], tflite_model->subgraphs[0]->tensors.size(), | |||||
| schema::Format::Format_NHWC); | |||||
| } | } | ||||
| AddOpOutput(op, tensors_id, tensors_format, tensors_id_map, tflite_op->outputs[0], tensors_id->size(), | |||||
| tflite_tensors.size(), schema::Format::Format_NHWC); | |||||
| AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_model->subgraphs[0]->tensors.size(), | |||||
| schema::Format::Format_NHWC); | |||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| @@ -29,11 +29,8 @@ class TfliteGatherParser : public TfliteNodeParser { | |||||
| public: | public: | ||||
| TfliteGatherParser() : TfliteNodeParser("Gather") {} | TfliteGatherParser() : TfliteNodeParser("Gather") {} | ||||
| STATUS Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||||
| const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors, | |||||
| const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer, schema::CNodeT *op, | |||||
| std::vector<int32_t> *tensors_id, std::vector<schema::Format> *tensors_format, | |||||
| std::map<int, int> *tensors_id_map) override; | |||||
| STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||||
| const std::unique_ptr<tflite::ModelT> &tflite_model, schema::CNodeT *op) override; | |||||
| }; | }; | ||||
| } // namespace lite | } // namespace lite | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -21,12 +21,9 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | namespace lite { | ||||
| STATUS TfliteHashtableLookupParser::Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||||
| const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors, | |||||
| const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer, | |||||
| schema::CNodeT *op, std::vector<int32_t> *tensors_id, | |||||
| std::vector<schema::Format> *tensors_format, | |||||
| std::map<int, int> *tensors_id_map) { | |||||
| STATUS TfliteHashtableLookupParser::Parse(TfliteTensorsInfo *tensors_info, | |||||
| const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||||
| const std::unique_ptr<tflite::ModelT> &tflite_model, schema::CNodeT *op) { | |||||
| MS_LOG(DEBUG) << "parse TfliteHashtableLookupParser"; | MS_LOG(DEBUG) << "parse TfliteHashtableLookupParser"; | ||||
| if (op == nullptr) { | if (op == nullptr) { | ||||
| MS_LOG(ERROR) << "op is null"; | MS_LOG(ERROR) << "op is null"; | ||||
| @@ -47,12 +44,12 @@ STATUS TfliteHashtableLookupParser::Parse(const std::unique_ptr<tflite::Operator | |||||
| op->primitive->value.type = schema::PrimitiveType_HashtableLookup; | op->primitive->value.type = schema::PrimitiveType_HashtableLookup; | ||||
| op->primitive->value.value = attr.release(); | op->primitive->value.value = attr.release(); | ||||
| for (size_t i = 0; i < tflite_op->inputs.size(); ++i) { | for (size_t i = 0; i < tflite_op->inputs.size(); ++i) { | ||||
| AddOpInput(op, tensors_id, tensors_format, tensors_id_map, tflite_op->inputs[i], tensors_id->size(), | |||||
| tflite_tensors.size(), schema::Format::Format_NHWC); | |||||
| AddOpInput(op, tensors_info, tflite_op->inputs[i], tflite_model->subgraphs[0]->tensors.size(), | |||||
| schema::Format::Format_NHWC); | |||||
| } | } | ||||
| for (size_t i = 0; i < tflite_op->outputs.size(); ++i) { | for (size_t i = 0; i < tflite_op->outputs.size(); ++i) { | ||||
| AddOpOutput(op, tensors_id, tensors_format, tensors_id_map, tflite_op->outputs[i], tensors_id->size(), | |||||
| tflite_tensors.size(), schema::Format::Format_NHWC); | |||||
| AddOpOutput(op, tensors_info, tflite_op->outputs[i], tflite_model->subgraphs[0]->tensors.size(), | |||||
| schema::Format::Format_NHWC); | |||||
| } | } | ||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| @@ -60,4 +57,3 @@ STATUS TfliteHashtableLookupParser::Parse(const std::unique_ptr<tflite::Operator | |||||
| TfliteNodeRegister g_tfliteHashtableLookupParser("HashtableLookup", new TfliteHashtableLookupParser()); | TfliteNodeRegister g_tfliteHashtableLookupParser("HashtableLookup", new TfliteHashtableLookupParser()); | ||||
| } // namespace lite | } // namespace lite | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -29,11 +29,8 @@ class TfliteHashtableLookupParser : public TfliteNodeParser { | |||||
| public: | public: | ||||
| TfliteHashtableLookupParser() : TfliteNodeParser("HashtableLookup") {} | TfliteHashtableLookupParser() : TfliteNodeParser("HashtableLookup") {} | ||||
| STATUS Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||||
| const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors, | |||||
| const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer, schema::CNodeT *op, | |||||
| std::vector<int32_t> *tensors_id, std::vector<schema::Format> *tensors_format, | |||||
| std::map<int, int> *tensors_id_map) override; | |||||
| STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||||
| const std::unique_ptr<tflite::ModelT> &tflite_model, schema::CNodeT *op) override; | |||||
| }; | }; | ||||
| } // namespace lite | } // namespace lite | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -22,11 +22,8 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | namespace lite { | ||||
| STATUS TfliteL2NormParser::Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||||
| const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors, | |||||
| const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer, | |||||
| schema::CNodeT *op, std::vector<int32_t> *tensors_id, | |||||
| std::vector<schema::Format> *tensors_format, std::map<int, int> *tensors_id_map) { | |||||
| STATUS TfliteL2NormParser::Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||||
| const std::unique_ptr<tflite::ModelT> &tflite_model, schema::CNodeT *op) { | |||||
| MS_LOG(DEBUG) << "parse TfliteL2NormParser"; | MS_LOG(DEBUG) << "parse TfliteL2NormParser"; | ||||
| if (op == nullptr) { | if (op == nullptr) { | ||||
| MS_LOG(ERROR) << "op is null"; | MS_LOG(ERROR) << "op is null"; | ||||
| @@ -51,11 +48,11 @@ STATUS TfliteL2NormParser::Parse(const std::unique_ptr<tflite::OperatorT> &tflit | |||||
| op->primitive->value.type = schema::PrimitiveType_L2Norm; | op->primitive->value.type = schema::PrimitiveType_L2Norm; | ||||
| op->primitive->value.value = attr.release(); | op->primitive->value.value = attr.release(); | ||||
| // set input | |||||
| AddOpInput(op, tensors_id, tensors_format, tensors_id_map, tflite_op->inputs[0], tensors_id->size(), | |||||
| tflite_tensors.size(), schema::Format::Format_NHWC); | |||||
| AddOpOutput(op, tensors_id, tensors_format, tensors_id_map, tflite_op->outputs[0], tensors_id->size(), | |||||
| tflite_tensors.size(), schema::Format::Format_NHWC); | |||||
| // set input and output | |||||
| AddOpInput(op, tensors_info, tflite_op->inputs[0], tflite_model->subgraphs[0]->tensors.size(), | |||||
| schema::Format::Format_NHWC); | |||||
| AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_model->subgraphs[0]->tensors.size(), | |||||
| schema::Format::Format_NHWC); | |||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| @@ -29,11 +29,8 @@ class TfliteL2NormParser : public TfliteNodeParser { | |||||
| public: | public: | ||||
| TfliteL2NormParser() : TfliteNodeParser("L2_NORMALIZATION") {} | TfliteL2NormParser() : TfliteNodeParser("L2_NORMALIZATION") {} | ||||
| STATUS Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||||
| const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors, | |||||
| const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer, schema::CNodeT *op, | |||||
| std::vector<int32_t> *tensors_id, std::vector<schema::Format> *tensors_format, | |||||
| std::map<int, int> *tensors_id_map) override; | |||||
| STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||||
| const std::unique_ptr<tflite::ModelT> &tflite_model, schema::CNodeT *op) override; | |||||
| }; | }; | ||||
| } // namespace lite | } // namespace lite | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -22,11 +22,8 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | namespace lite { | ||||
| STATUS TfliteLogicalParser::Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||||
| const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors, | |||||
| const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer, | |||||
| schema::CNodeT *op, std::vector<int32_t> *tensors_id, | |||||
| std::vector<schema::Format> *tensors_format, std::map<int, int> *tensors_id_map) { | |||||
| STATUS TfliteLogicalParser::Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||||
| const std::unique_ptr<tflite::ModelT> &tflite_model, schema::CNodeT *op) { | |||||
| if (op == nullptr) { | if (op == nullptr) { | ||||
| MS_LOG(ERROR) << "op is null"; | MS_LOG(ERROR) << "op is null"; | ||||
| return RET_NULL_PTR; | return RET_NULL_PTR; | ||||
| @@ -70,11 +67,11 @@ STATUS TfliteLogicalParser::Parse(const std::unique_ptr<tflite::OperatorT> &tfli | |||||
| } | } | ||||
| for (size_t i = 0; i < tflite_op->inputs.size(); i++) { | for (size_t i = 0; i < tflite_op->inputs.size(); i++) { | ||||
| AddOpInput(op, tensors_id, tensors_format, tensors_id_map, tflite_op->inputs[i], tensors_id->size(), | |||||
| tflite_tensors.size(), schema::Format::Format_NHWC); | |||||
| AddOpInput(op, tensors_info, tflite_op->inputs[i], tflite_model->subgraphs[0]->tensors.size(), | |||||
| schema::Format::Format_NHWC); | |||||
| } | } | ||||
| AddOpOutput(op, tensors_id, tensors_format, tensors_id_map, tflite_op->outputs[0], tensors_id->size(), | |||||
| tflite_tensors.size(), schema::Format::Format_NHWC); | |||||
| AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_model->subgraphs[0]->tensors.size(), | |||||
| schema::Format::Format_NHWC); | |||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| @@ -29,11 +29,8 @@ class TfliteLogicalParser : public TfliteNodeParser { | |||||
| public: | public: | ||||
| TfliteLogicalParser() : TfliteNodeParser("node_name") {} | TfliteLogicalParser() : TfliteNodeParser("node_name") {} | ||||
| STATUS Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||||
| const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors, | |||||
| const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer, schema::CNodeT *op, | |||||
| std::vector<int32_t> *tensors_id, std::vector<schema::Format> *tensors_format, | |||||
| std::map<int, int> *tensors_id_map) override; | |||||
| STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||||
| const std::unique_ptr<tflite::ModelT> &tflite_model, schema::CNodeT *op) override; | |||||
| }; | }; | ||||
| class TfliteLogicalAndParser : public TfliteLogicalParser { | class TfliteLogicalAndParser : public TfliteLogicalParser { | ||||
| @@ -21,11 +21,8 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | namespace lite { | ||||
| STATUS TfliteLRNParser::Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||||
| const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors, | |||||
| const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer, | |||||
| schema::CNodeT *op, std::vector<int32_t> *tensors_id, | |||||
| std::vector<schema::Format> *tensors_format, std::map<int, int> *tensors_id_map) { | |||||
| STATUS TfliteLRNParser::Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||||
| const std::unique_ptr<tflite::ModelT> &tflite_model, schema::CNodeT *op) { | |||||
| MS_LOG(DEBUG) << "parse TfliteLRNParser"; | MS_LOG(DEBUG) << "parse TfliteLRNParser"; | ||||
| if (op == nullptr) { | if (op == nullptr) { | ||||
| MS_LOG(ERROR) << "op is null"; | MS_LOG(ERROR) << "op is null"; | ||||
| @@ -56,10 +53,10 @@ STATUS TfliteLRNParser::Parse(const std::unique_ptr<tflite::OperatorT> &tflite_o | |||||
| op->primitive->value.type = schema::PrimitiveType_LocalResponseNormalization; | op->primitive->value.type = schema::PrimitiveType_LocalResponseNormalization; | ||||
| op->primitive->value.value = attr.release(); | op->primitive->value.value = attr.release(); | ||||
| AddOpInput(op, tensors_id, tensors_format, tensors_id_map, tflite_op->inputs[0], tensors_id->size(), | |||||
| tflite_tensors.size(), schema::Format::Format_NHWC); | |||||
| AddOpOutput(op, tensors_id, tensors_format, tensors_id_map, tflite_op->outputs[0], tensors_id->size(), | |||||
| tflite_tensors.size(), schema::Format::Format_NHWC); | |||||
| AddOpInput(op, tensors_info, tflite_op->inputs[0], tflite_model->subgraphs[0]->tensors.size(), | |||||
| schema::Format::Format_NHWC); | |||||
| AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_model->subgraphs[0]->tensors.size(), | |||||
| schema::Format::Format_NHWC); | |||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| @@ -29,11 +29,8 @@ class TfliteLRNParser : public TfliteNodeParser { | |||||
| public: | public: | ||||
| TfliteLRNParser() : TfliteNodeParser("LocalResponseNorm") {} | TfliteLRNParser() : TfliteNodeParser("LocalResponseNorm") {} | ||||
| STATUS Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||||
| const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors, | |||||
| const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer, schema::CNodeT *op, | |||||
| std::vector<int32_t> *tensors_id, std::vector<schema::Format> *tensors_format, | |||||
| std::map<int, int> *tensors_id_map) override; | |||||
| STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||||
| const std::unique_ptr<tflite::ModelT> &tflite_model, schema::CNodeT *op) override; | |||||
| }; | }; | ||||
| } // namespace lite | } // namespace lite | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -21,13 +21,9 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | namespace lite { | ||||
| STATUS TfliteLshProjectionParser::Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||||
| const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors, | |||||
| const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer, | |||||
| schema::CNodeT *op, | |||||
| std::vector<int32_t> *tensors_id, | |||||
| std::vector<schema::Format> *tensors_format, | |||||
| std::map<int, int> *tensors_id_map) { | |||||
| STATUS TfliteLshProjectionParser::Parse(TfliteTensorsInfo *tensors_info, | |||||
| const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||||
| const std::unique_ptr<tflite::ModelT> &tflite_model, schema::CNodeT *op) { | |||||
| MS_LOG(DEBUG) << "parse TfliteLshProjectionParser"; | MS_LOG(DEBUG) << "parse TfliteLshProjectionParser"; | ||||
| if (op == nullptr) { | if (op == nullptr) { | ||||
| MS_LOG(ERROR) << "op is null"; | MS_LOG(ERROR) << "op is null"; | ||||
| @@ -60,15 +56,14 @@ STATUS TfliteLshProjectionParser::Parse(const std::unique_ptr<tflite::OperatorT> | |||||
| op->primitive->value.value = attr.release(); | op->primitive->value.value = attr.release(); | ||||
| for (size_t i = 0; i < tflite_op->inputs.size(); ++i) { | for (size_t i = 0; i < tflite_op->inputs.size(); ++i) { | ||||
| AddOpInput(op, tensors_id, tensors_format, tensors_id_map, | |||||
| tflite_op->inputs[i], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC); | |||||
| AddOpInput(op, tensors_info, tflite_op->inputs[i], tflite_model->subgraphs[0]->tensors.size(), | |||||
| schema::Format::Format_NHWC); | |||||
| } | } | ||||
| AddOpOutput(op, tensors_id, tensors_format, tensors_id_map, | |||||
| tflite_op->outputs[0], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC); | |||||
| AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_model->subgraphs[0]->tensors.size(), | |||||
| schema::Format::Format_NHWC); | |||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| TfliteNodeRegister g_tfliteLshProjectionParser("LshProjection", new TfliteLshProjectionParser()); | TfliteNodeRegister g_tfliteLshProjectionParser("LshProjection", new TfliteLshProjectionParser()); | ||||
| } // namespace lite | } // namespace lite | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -29,16 +29,10 @@ class TfliteLshProjectionParser : public TfliteNodeParser { | |||||
| public: | public: | ||||
| TfliteLshProjectionParser() : TfliteNodeParser("LshProjection") {} | TfliteLshProjectionParser() : TfliteNodeParser("LshProjection") {} | ||||
| STATUS Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||||
| const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors, | |||||
| const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer, | |||||
| schema::CNodeT *op, | |||||
| std::vector<int32_t> *tensors_id, | |||||
| std::vector<schema::Format> *tensors_format, | |||||
| std::map<int, int> *tensors_id_map) override; | |||||
| STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||||
| const std::unique_ptr<tflite::ModelT> &tflite_model, schema::CNodeT *op) override; | |||||
| }; | }; | ||||
| } // namespace lite | } // namespace lite | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| #endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_LSH_PROJECTION_PARSER_H | #endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_LSH_PROJECTION_PARSER_H | ||||
| @@ -120,8 +120,7 @@ STATUS TfliteModelParser::ConvertOp(const std::unique_ptr<tflite::ModelT> &tflit | |||||
| continue; | continue; | ||||
| } | } | ||||
| if (status == RET_OK) { | if (status == RET_OK) { | ||||
| status = node_parser->Parse(tflite_op, tflite_subgraph->tensors, tflite_model->buffers, op.get(), &tensorsId, | |||||
| &tensorsFormat, &tensorsIdMap); | |||||
| status = node_parser->Parse(&tensorsInfo, tflite_op, tflite_model, op.get()); | |||||
| if (status != RET_OK) { | if (status != RET_OK) { | ||||
| MS_LOG(ERROR) << "node " << op_type.c_str() << " parser failed"; | MS_LOG(ERROR) << "node " << op_type.c_str() << " parser failed"; | ||||
| continue; | continue; | ||||
| @@ -138,15 +137,15 @@ STATUS TfliteModelParser::ConvertOp(const std::unique_ptr<tflite::ModelT> &tflit | |||||
| STATUS TfliteModelParser::ConvertTensor(const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, | STATUS TfliteModelParser::ConvertTensor(const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, | ||||
| const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer, | const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer, | ||||
| schema::MetaGraphT *sub_graph) { | schema::MetaGraphT *sub_graph) { | ||||
| for (size_t i = 0; i < tensorsId.size(); i++) { | |||||
| auto idx = tensorsId[i]; | |||||
| for (size_t i = 0; i < tensorsInfo.tensorsId.size(); i++) { | |||||
| auto idx = tensorsInfo.tensorsId[i]; | |||||
| if (idx < 0) { | if (idx < 0) { | ||||
| idx += tflite_subgraph->tensors.size(); | idx += tflite_subgraph->tensors.size(); | ||||
| } | } | ||||
| const auto &tflite_tensor = tflite_subgraph->tensors[idx]; | const auto &tflite_tensor = tflite_subgraph->tensors[idx]; | ||||
| std::unique_ptr<schema::TensorT> tensor = std::make_unique<schema::TensorT>(); | std::unique_ptr<schema::TensorT> tensor = std::make_unique<schema::TensorT>(); | ||||
| tensor->format = tensorsFormat[i]; | |||||
| tensor->format = tensorsInfo.tensorsFormat[i]; | |||||
| tensor->dataType = GetTfliteDataType(tflite_tensor->type); | tensor->dataType = GetTfliteDataType(tflite_tensor->type); | ||||
| tensor->dims = tflite_tensor->shape; | tensor->dims = tflite_tensor->shape; | ||||
| @@ -207,8 +206,8 @@ STATUS TfliteModelParser::GetGraphInfo(const std::unique_ptr<tflite::SubGraphT> | |||||
| } else { | } else { | ||||
| id = idx; | id = idx; | ||||
| } | } | ||||
| auto iter = tensorsIdMap.find(id); | |||||
| if (iter != tensorsIdMap.end()) { | |||||
| auto iter = tensorsInfo.tensorsIdMap.find(id); | |||||
| if (iter != tensorsInfo.tensorsIdMap.end()) { | |||||
| graph_inputs.push_back(iter->second); | graph_inputs.push_back(iter->second); | ||||
| } else { | } else { | ||||
| MS_LOG(ERROR) << "get graph input failed"; | MS_LOG(ERROR) << "get graph input failed"; | ||||
| @@ -226,8 +225,8 @@ STATUS TfliteModelParser::GetGraphInfo(const std::unique_ptr<tflite::SubGraphT> | |||||
| } else { | } else { | ||||
| id = idx; | id = idx; | ||||
| } | } | ||||
| auto iter = tensorsIdMap.find(id); | |||||
| if (iter != tensorsIdMap.end()) { | |||||
| auto iter = tensorsInfo.tensorsIdMap.find(id); | |||||
| if (iter != tensorsInfo.tensorsIdMap.end()) { | |||||
| graph_outputs.push_back(iter->second); | graph_outputs.push_back(iter->second); | ||||
| } else { | } else { | ||||
| MS_LOG(ERROR) << "get graph output failed"; | MS_LOG(ERROR) << "get graph output failed"; | ||||
| @@ -65,9 +65,7 @@ class TfliteModelParser : public ModelParser { | |||||
| STATUS ConvertGroupDepthwiseOp(schema::MetaGraphT *sub_graph); | STATUS ConvertGroupDepthwiseOp(schema::MetaGraphT *sub_graph); | ||||
| private: | private: | ||||
| std::vector<int32_t> tensorsId; | |||||
| std::vector<schema::Format> tensorsFormat; | |||||
| std::map<int, int> tensorsIdMap; | |||||
| TfliteTensorsInfo tensorsInfo; | |||||
| std::vector<schema::TensorT *> tensors; | std::vector<schema::TensorT *> tensors; | ||||
| std::map<std::string, schema::CNodeT *> opMap; | std::map<std::string, schema::CNodeT *> opMap; | ||||
| @@ -38,40 +38,37 @@ class TfliteNodeParser { | |||||
| virtual ~TfliteNodeParser() = default; | virtual ~TfliteNodeParser() = default; | ||||
| virtual STATUS Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||||
| const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors, | |||||
| const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer, schema::CNodeT *op, | |||||
| std::vector<int32_t> *tensors_id, std::vector<schema::Format> *tensors_format, | |||||
| std::map<int, int> *tensors_id_map) = 0; | |||||
| virtual STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||||
| const std::unique_ptr<tflite::ModelT> &tflite_model, schema::CNodeT *op) = 0; | |||||
| void AddOpInput(schema::CNodeT *op, std::vector<int32_t> *tensors_id, std::vector<schema::Format> *tensors_format, | |||||
| std::map<int, int> *tensors_id_map, int idx, int new_idx, int total, schema::Format format) { | |||||
| auto iter = tensors_id_map->find(idx); | |||||
| if (iter != tensors_id_map->end()) { | |||||
| void AddOpInput(schema::CNodeT *op, TfliteTensorsInfo *tensors_info, int idx, int total, schema::Format format) { | |||||
| int new_idx = tensors_info->tensorsId.size(); | |||||
| auto iter = tensors_info->tensorsIdMap.find(idx); | |||||
| if (iter != tensors_info->tensorsIdMap.end()) { | |||||
| op->inputIndex.emplace_back(iter->second); | op->inputIndex.emplace_back(iter->second); | ||||
| } else { | } else { | ||||
| if (idx < 0) { | if (idx < 0) { | ||||
| idx += total; | idx += total; | ||||
| } | } | ||||
| tensors_id->emplace_back(idx); | |||||
| tensors_format->emplace_back(format); | |||||
| tensors_id_map->insert(std::make_pair(idx, new_idx)); | |||||
| tensors_info->tensorsId.emplace_back(idx); | |||||
| tensors_info->tensorsFormat.emplace_back(format); | |||||
| tensors_info->tensorsIdMap.insert(std::make_pair(idx, new_idx)); | |||||
| op->inputIndex.emplace_back(new_idx); | op->inputIndex.emplace_back(new_idx); | ||||
| } | } | ||||
| } | } | ||||
| void AddOpOutput(schema::CNodeT *op, std::vector<int32_t> *tensors_id, std::vector<schema::Format> *tensors_format, | |||||
| std::map<int, int> *tensors_id_map, int idx, int new_idx, int total, schema::Format format) { | |||||
| auto iter = tensors_id_map->find(idx); | |||||
| if (iter != tensors_id_map->end()) { | |||||
| void AddOpOutput(schema::CNodeT *op, TfliteTensorsInfo *tensors_info, int idx, int total, schema::Format format) { | |||||
| int new_idx = tensors_info->tensorsId.size(); | |||||
| auto iter = tensors_info->tensorsIdMap.find(idx); | |||||
| if (iter != tensors_info->tensorsIdMap.end()) { | |||||
| op->outputIndex.emplace_back(iter->second); | op->outputIndex.emplace_back(iter->second); | ||||
| } else { | } else { | ||||
| if (idx < 0) { | if (idx < 0) { | ||||
| idx += total; | idx += total; | ||||
| } | } | ||||
| tensors_id->emplace_back(idx); | |||||
| tensors_format->emplace_back(format); | |||||
| tensors_id_map->insert(std::make_pair(idx, new_idx)); | |||||
| tensors_info->tensorsId.emplace_back(idx); | |||||
| tensors_info->tensorsFormat.emplace_back(format); | |||||
| tensors_info->tensorsIdMap.insert(std::make_pair(idx, new_idx)); | |||||
| op->outputIndex.emplace_back(new_idx); | op->outputIndex.emplace_back(new_idx); | ||||
| } | } | ||||
| } | } | ||||
| @@ -21,11 +21,8 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | namespace lite { | ||||
| STATUS TfliteOneHotParser::Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||||
| const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors, | |||||
| const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer, | |||||
| schema::CNodeT *op, std::vector<int32_t> *tensors_id, | |||||
| std::vector<schema::Format> *tensors_format, std::map<int, int> *tensors_id_map) { | |||||
| STATUS TfliteOneHotParser::Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||||
| const std::unique_ptr<tflite::ModelT> &tflite_model, schema::CNodeT *op) { | |||||
| MS_LOG(DEBUG) << "parse TfliteOneHotParser"; | MS_LOG(DEBUG) << "parse TfliteOneHotParser"; | ||||
| if (op == nullptr) { | if (op == nullptr) { | ||||
| MS_LOG(ERROR) << "op is null"; | MS_LOG(ERROR) << "op is null"; | ||||
| @@ -49,7 +46,7 @@ STATUS TfliteOneHotParser::Parse(const std::unique_ptr<tflite::OperatorT> &tflit | |||||
| return RET_NULL_PTR; | return RET_NULL_PTR; | ||||
| } | } | ||||
| auto axis = tflite_attr->axis; | auto axis = tflite_attr->axis; | ||||
| const auto &tensor = tflite_tensors[tflite_op->inputs[0]]; | |||||
| const auto &tensor = tflite_model->subgraphs[0]->tensors[tflite_op->inputs[0]]; | |||||
| if (tensor == nullptr) { | if (tensor == nullptr) { | ||||
| MS_LOG(ERROR) << "tensor is null"; | MS_LOG(ERROR) << "tensor is null"; | ||||
| return RET_NULL_PTR; | return RET_NULL_PTR; | ||||
| @@ -60,11 +57,11 @@ STATUS TfliteOneHotParser::Parse(const std::unique_ptr<tflite::OperatorT> &tflit | |||||
| op->primitive->value.value = attr.release(); | op->primitive->value.value = attr.release(); | ||||
| for (size_t i = 0; i < tflite_op->inputs.size(); i++) { | for (size_t i = 0; i < tflite_op->inputs.size(); i++) { | ||||
| AddOpInput(op, tensors_id, tensors_format, tensors_id_map, tflite_op->inputs[i], tensors_id->size(), | |||||
| tflite_tensors.size(), schema::Format::Format_NHWC); | |||||
| AddOpInput(op, tensors_info, tflite_op->inputs[i], tflite_model->subgraphs[0]->tensors.size(), | |||||
| schema::Format::Format_NHWC); | |||||
| } | } | ||||
| AddOpOutput(op, tensors_id, tensors_format, tensors_id_map, tflite_op->outputs[0], tensors_id->size(), | |||||
| tflite_tensors.size(), schema::Format::Format_NHWC); | |||||
| AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_model->subgraphs[0]->tensors.size(), | |||||
| schema::Format::Format_NHWC); | |||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| @@ -29,11 +29,8 @@ class TfliteOneHotParser : public TfliteNodeParser { | |||||
| public: | public: | ||||
| TfliteOneHotParser() : TfliteNodeParser("OneHot") {} | TfliteOneHotParser() : TfliteNodeParser("OneHot") {} | ||||
| STATUS Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||||
| const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors, | |||||
| const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer, schema::CNodeT *op, | |||||
| std::vector<int32_t> *tensors_id, std::vector<schema::Format> *tensors_format, | |||||
| std::map<int, int> *tensors_id_map) override; | |||||
| STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||||
| const std::unique_ptr<tflite::ModelT> &tflite_model, schema::CNodeT *op) override; | |||||
| }; | }; | ||||
| } // namespace lite | } // namespace lite | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -22,11 +22,8 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | namespace lite { | ||||
| STATUS TflitePadParser::Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||||
| const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors, | |||||
| const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer, | |||||
| schema::CNodeT *op, std::vector<int32_t> *tensors_id, | |||||
| std::vector<schema::Format> *tensors_format, std::map<int, int> *tensors_id_map) { | |||||
| STATUS TflitePadParser::Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||||
| const std::unique_ptr<tflite::ModelT> &tflite_model, schema::CNodeT *op) { | |||||
| MS_LOG(DEBUG) << "parse TflitePadParser"; | MS_LOG(DEBUG) << "parse TflitePadParser"; | ||||
| if (op == nullptr) { | if (op == nullptr) { | ||||
| MS_LOG(ERROR) << "op is null"; | MS_LOG(ERROR) << "op is null"; | ||||
| @@ -54,7 +51,8 @@ STATUS TflitePadParser::Parse(const std::unique_ptr<tflite::OperatorT> &tflite_o | |||||
| } | } | ||||
| attr->paddingMode = schema::PaddingMode_CONSTANT; | attr->paddingMode = schema::PaddingMode_CONSTANT; | ||||
| attr->constantValue = 0.0f; | attr->constantValue = 0.0f; | ||||
| if (GetTfliteData(tflite_op->inputs[1], tflite_tensors, tflite_model_buffer, attr->paddings)) { | |||||
| if (GetTfliteData(tflite_op->inputs[1], tflite_model->subgraphs[0]->tensors, tflite_model->buffers, | |||||
| attr->paddings)) { | |||||
| MS_LOG(ERROR) << "get pad -> paddings failed"; | MS_LOG(ERROR) << "get pad -> paddings failed"; | ||||
| return RET_ERROR; | return RET_ERROR; | ||||
| } | } | ||||
| @@ -74,7 +72,7 @@ STATUS TflitePadParser::Parse(const std::unique_ptr<tflite::OperatorT> &tflite_o | |||||
| default: | default: | ||||
| MS_LOG(ERROR) << "paddingmode:" << tflite_attr->mode << " don't support"; | MS_LOG(ERROR) << "paddingmode:" << tflite_attr->mode << " don't support"; | ||||
| return RET_INVALID_OP_ATTR; | return RET_INVALID_OP_ATTR; | ||||
| } | |||||
| } | |||||
| } else { | } else { | ||||
| MS_LOG(ERROR) << "this pad:" << node_name << " hasn't been supported"; | MS_LOG(ERROR) << "this pad:" << node_name << " hasn't been supported"; | ||||
| return RET_NOT_SUPPORT; | return RET_NOT_SUPPORT; | ||||
| @@ -83,14 +81,14 @@ STATUS TflitePadParser::Parse(const std::unique_ptr<tflite::OperatorT> &tflite_o | |||||
| op->primitive->value.type = schema::PrimitiveType_Pad; | op->primitive->value.type = schema::PrimitiveType_Pad; | ||||
| op->primitive->value.value = attr.release(); | op->primitive->value.value = attr.release(); | ||||
| AddOpInput(op, tensors_id, tensors_format, tensors_id_map, tflite_op->inputs[0], tensors_id->size(), | |||||
| tflite_tensors.size(), schema::Format::Format_NHWC); | |||||
| AddOpInput(op, tensors_info, tflite_op->inputs[0], tflite_model->subgraphs[0]->tensors.size(), | |||||
| schema::Format::Format_NHWC); | |||||
| if (std::strcmp(node_name, "MirrorPad") == 0) { | if (std::strcmp(node_name, "MirrorPad") == 0) { | ||||
| AddOpInput(op, tensors_id, tensors_format, tensors_id_map, tflite_op->inputs[1], tensors_id->size(), | |||||
| tflite_tensors.size(), schema::Format::Format_NHWC); | |||||
| AddOpInput(op, tensors_info, tflite_op->inputs[1], tflite_model->subgraphs[0]->tensors.size(), | |||||
| schema::Format::Format_NHWC); | |||||
| } | } | ||||
| AddOpOutput(op, tensors_id, tensors_format, tensors_id_map, tflite_op->outputs[0], tensors_id->size(), | |||||
| tflite_tensors.size(), schema::Format::Format_NHWC); | |||||
| AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_model->subgraphs[0]->tensors.size(), | |||||
| schema::Format::Format_NHWC); | |||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| @@ -29,11 +29,8 @@ class TflitePadParser : public TfliteNodeParser { | |||||
| public: | public: | ||||
| TflitePadParser() : TfliteNodeParser("Pad") {} | TflitePadParser() : TfliteNodeParser("Pad") {} | ||||
| STATUS Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||||
| const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors, | |||||
| const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer, schema::CNodeT *op, | |||||
| std::vector<int32_t> *tensors_id, std::vector<schema::Format> *tensors_format, | |||||
| std::map<int, int> *tensors_id_map) override; | |||||
| STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||||
| const std::unique_ptr<tflite::ModelT> &tflite_model, schema::CNodeT *op) override; | |||||
| }; | }; | ||||
| } // namespace lite | } // namespace lite | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -22,11 +22,8 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | namespace lite { | ||||
| STATUS TflitePoolingParser::Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||||
| const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors, | |||||
| const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer, | |||||
| schema::CNodeT *op, std::vector<int32_t> *tensors_id, | |||||
| std::vector<schema::Format> *tensors_format, std::map<int, int> *tensors_id_map) { | |||||
| STATUS TflitePoolingParser::Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||||
| const std::unique_ptr<tflite::ModelT> &tflite_model, schema::CNodeT *op) { | |||||
| if (op == nullptr) { | if (op == nullptr) { | ||||
| MS_LOG(ERROR) << "op is null"; | MS_LOG(ERROR) << "op is null"; | ||||
| return RET_NULL_PTR; | return RET_NULL_PTR; | ||||
| @@ -72,7 +69,7 @@ STATUS TflitePoolingParser::Parse(const std::unique_ptr<tflite::OperatorT> &tfli | |||||
| // calculate pad params | // calculate pad params | ||||
| auto data_index = tflite_op->inputs[0]; | auto data_index = tflite_op->inputs[0]; | ||||
| const auto &data_tensor = tflite_tensors[data_index]; | |||||
| const auto &data_tensor = tflite_model->subgraphs[0]->tensors[data_index]; | |||||
| std::vector<int> params; | std::vector<int> params; | ||||
| if (getPaddingParam(data_tensor, attr->padMode, attr->strideH, attr->strideW, attr->windowH, attr->windowW, | if (getPaddingParam(data_tensor, attr->padMode, attr->strideH, attr->strideW, attr->windowH, attr->windowW, | ||||
| ¶ms) != RET_OK) { | ¶ms) != RET_OK) { | ||||
| @@ -88,10 +85,10 @@ STATUS TflitePoolingParser::Parse(const std::unique_ptr<tflite::OperatorT> &tfli | |||||
| op->primitive->value.type = schema::PrimitiveType_Pooling; | op->primitive->value.type = schema::PrimitiveType_Pooling; | ||||
| op->primitive->value.value = attr.release(); | op->primitive->value.value = attr.release(); | ||||
| AddOpInput(op, tensors_id, tensors_format, tensors_id_map, tflite_op->inputs[0], tensors_id->size(), | |||||
| tflite_tensors.size(), schema::Format::Format_NHWC); | |||||
| AddOpOutput(op, tensors_id, tensors_format, tensors_id_map, tflite_op->outputs[0], tensors_id->size(), | |||||
| tflite_tensors.size(), schema::Format::Format_NHWC); | |||||
| AddOpInput(op, tensors_info, tflite_op->inputs[0], tflite_model->subgraphs[0]->tensors.size(), | |||||
| schema::Format::Format_NHWC); | |||||
| AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_model->subgraphs[0]->tensors.size(), | |||||
| schema::Format::Format_NHWC); | |||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| @@ -29,11 +29,8 @@ class TflitePoolingParser : public TfliteNodeParser { | |||||
| public: | public: | ||||
| TflitePoolingParser() : TfliteNodeParser("node_name") {} | TflitePoolingParser() : TfliteNodeParser("node_name") {} | ||||
| STATUS Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||||
| const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors, | |||||
| const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer, schema::CNodeT *op, | |||||
| std::vector<int32_t> *tensors_id, std::vector<schema::Format> *tensors_format, | |||||
| std::map<int, int> *tensors_id_map) override; | |||||
| STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||||
| const std::unique_ptr<tflite::ModelT> &tflite_model, schema::CNodeT *op) override; | |||||
| }; | }; | ||||
| class TfliteMeanPoolingParser : public TflitePoolingParser { | class TfliteMeanPoolingParser : public TflitePoolingParser { | ||||
| @@ -22,11 +22,8 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | namespace lite { | ||||
| STATUS TflitePReLUParser::Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||||
| const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors, | |||||
| const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer, | |||||
| schema::CNodeT *op, std::vector<int32_t> *tensors_id, | |||||
| std::vector<schema::Format> *tensors_format, std::map<int, int> *tensors_id_map) { | |||||
| STATUS TflitePReLUParser::Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||||
| const std::unique_ptr<tflite::ModelT> &tflite_model, schema::CNodeT *op) { | |||||
| MS_LOG(DEBUG) << "parse TflitePReLUParser"; | MS_LOG(DEBUG) << "parse TflitePReLUParser"; | ||||
| if (op == nullptr) { | if (op == nullptr) { | ||||
| MS_LOG(ERROR) << "op is null"; | MS_LOG(ERROR) << "op is null"; | ||||
| @@ -47,12 +44,12 @@ STATUS TflitePReLUParser::Parse(const std::unique_ptr<tflite::OperatorT> &tflite | |||||
| op->primitive->value.type = schema::PrimitiveType_PReLU; | op->primitive->value.type = schema::PrimitiveType_PReLU; | ||||
| op->primitive->value.value = attr.release(); | op->primitive->value.value = attr.release(); | ||||
| AddOpInput(op, tensors_id, tensors_format, tensors_id_map, tflite_op->inputs[0], tensors_id->size(), | |||||
| tflite_tensors.size(), schema::Format::Format_NHWC); | |||||
| AddOpInput(op, tensors_id, tensors_format, tensors_id_map, tflite_op->inputs[1], tensors_id->size(), | |||||
| tflite_tensors.size(), schema::Format::Format_NHWC); | |||||
| AddOpOutput(op, tensors_id, tensors_format, tensors_id_map, tflite_op->outputs[0], tensors_id->size(), | |||||
| tflite_tensors.size(), schema::Format::Format_NHWC); | |||||
| AddOpInput(op, tensors_info, tflite_op->inputs[0], tflite_model->subgraphs[0]->tensors.size(), | |||||
| schema::Format::Format_NHWC); | |||||
| AddOpInput(op, tensors_info, tflite_op->inputs[1], tflite_model->subgraphs[0]->tensors.size(), | |||||
| schema::Format::Format_NHWC); | |||||
| AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_model->subgraphs[0]->tensors.size(), | |||||
| schema::Format::Format_NHWC); | |||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| @@ -29,11 +29,8 @@ class TflitePReLUParser : public TfliteNodeParser { | |||||
| public: | public: | ||||
| TflitePReLUParser() : TfliteNodeParser("PRELU") {} | TflitePReLUParser() : TfliteNodeParser("PRELU") {} | ||||
| STATUS Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||||
| const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors, | |||||
| const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer, schema::CNodeT *op, | |||||
| std::vector<int32_t> *tensors_id, std::vector<schema::Format> *tensors_format, | |||||
| std::map<int, int> *tensors_id_map) override; | |||||
| STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||||
| const std::unique_ptr<tflite::ModelT> &tflite_model, schema::CNodeT *op) override; | |||||
| }; | }; | ||||
| } // namespace lite | } // namespace lite | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -20,11 +20,8 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | namespace lite { | ||||
| STATUS TfliteQuantizeParser::Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||||
| const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors, | |||||
| const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer, | |||||
| schema::CNodeT *op, std::vector<int32_t> *tensors_id, | |||||
| std::vector<schema::Format> *tensors_format, std::map<int, int> *tensors_id_map) { | |||||
| STATUS TfliteQuantizeParser::Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||||
| const std::unique_ptr<tflite::ModelT> &tflite_model, schema::CNodeT *op) { | |||||
| MS_LOG(DEBUG) << "parse TfliteQuantizeNParser"; | MS_LOG(DEBUG) << "parse TfliteQuantizeNParser"; | ||||
| if (op == nullptr) { | if (op == nullptr) { | ||||
| MS_LOG(ERROR) << "op is null"; | MS_LOG(ERROR) << "op is null"; | ||||
| @@ -36,12 +33,12 @@ STATUS TfliteQuantizeParser::Parse(const std::unique_ptr<tflite::OperatorT> &tfl | |||||
| return RET_NULL_PTR; | return RET_NULL_PTR; | ||||
| } | } | ||||
| const auto &in_tensor = tflite_tensors[tflite_op->inputs[0]]; | |||||
| const auto &in_tensor = tflite_model->subgraphs[0]->tensors[tflite_op->inputs[0]]; | |||||
| if (in_tensor == nullptr) { | if (in_tensor == nullptr) { | ||||
| MS_LOG(ERROR) << "input tensor is null"; | MS_LOG(ERROR) << "input tensor is null"; | ||||
| return RET_NULL_PTR; | return RET_NULL_PTR; | ||||
| } | } | ||||
| const auto &out_tensor = tflite_tensors[tflite_op->outputs[0]]; | |||||
| const auto &out_tensor = tflite_model->subgraphs[0]->tensors[tflite_op->outputs[0]]; | |||||
| if (out_tensor == nullptr) { | if (out_tensor == nullptr) { | ||||
| MS_LOG(ERROR) << "output tensor is null"; | MS_LOG(ERROR) << "output tensor is null"; | ||||
| return RET_NULL_PTR; | return RET_NULL_PTR; | ||||
| @@ -70,10 +67,10 @@ STATUS TfliteQuantizeParser::Parse(const std::unique_ptr<tflite::OperatorT> &tfl | |||||
| op->primitive->value.value = attr.release(); | op->primitive->value.value = attr.release(); | ||||
| } | } | ||||
| AddOpInput(op, tensors_id, tensors_format, tensors_id_map, tflite_op->inputs[0], tensors_id->size(), | |||||
| tflite_tensors.size(), schema::Format_NHWC); | |||||
| AddOpOutput(op, tensors_id, tensors_format, tensors_id_map, tflite_op->outputs[0], tensors_id->size(), | |||||
| tflite_tensors.size(), schema::Format_NHWC); | |||||
| AddOpInput(op, tensors_info, tflite_op->inputs[0], tflite_model->subgraphs[0]->tensors.size(), | |||||
| schema::Format::Format_NHWC); | |||||
| AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_model->subgraphs[0]->tensors.size(), | |||||
| schema::Format::Format_NHWC); | |||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| @@ -28,11 +28,8 @@ class TfliteQuantizeParser : public TfliteNodeParser { | |||||
| public: | public: | ||||
| TfliteQuantizeParser() : TfliteNodeParser("Quantize") {} | TfliteQuantizeParser() : TfliteNodeParser("Quantize") {} | ||||
| STATUS Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||||
| const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors, | |||||
| const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer, schema::CNodeT *op, | |||||
| std::vector<int32_t> *tensors_id, std::vector<schema::Format> *tensors_format, | |||||
| std::map<int, int> *tensors_id_map) override; | |||||
| STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||||
| const std::unique_ptr<tflite::ModelT> &tflite_model, schema::CNodeT *op) override; | |||||
| }; | }; | ||||
| } // namespace lite | } // namespace lite | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -21,11 +21,8 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | namespace lite { | ||||
| STATUS TfliteRangeParser::Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||||
| const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors, | |||||
| const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer, | |||||
| schema::CNodeT *op, std::vector<int32_t> *tensors_id, | |||||
| std::vector<schema::Format> *tensors_format, std::map<int, int> *tensors_id_map) { | |||||
| STATUS TfliteRangeParser::Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||||
| const std::unique_ptr<tflite::ModelT> &tflite_model, schema::CNodeT *op) { | |||||
| MS_LOG(DEBUG) << "parse TfliteRangeParser"; | MS_LOG(DEBUG) << "parse TfliteRangeParser"; | ||||
| if (op == nullptr) { | if (op == nullptr) { | ||||
| MS_LOG(ERROR) << "op is null"; | MS_LOG(ERROR) << "op is null"; | ||||
| @@ -51,10 +48,10 @@ STATUS TfliteRangeParser::Parse(const std::unique_ptr<tflite::OperatorT> &tflite | |||||
| op->primitive->value.type = schema::PrimitiveType_Range; | op->primitive->value.type = schema::PrimitiveType_Range; | ||||
| op->primitive->value.value = attr.release(); | op->primitive->value.value = attr.release(); | ||||
| AddOpInput(op, tensors_id, tensors_format, tensors_id_map, tflite_op->inputs[0], tensors_id->size(), | |||||
| tflite_tensors.size(), schema::Format::Format_NHWC); | |||||
| AddOpOutput(op, tensors_id, tensors_format, tensors_id_map, tflite_op->outputs[0], tensors_id->size(), | |||||
| tflite_tensors.size(), schema::Format::Format_NHWC); | |||||
| AddOpInput(op, tensors_info, tflite_op->inputs[0], tflite_model->subgraphs[0]->tensors.size(), | |||||
| schema::Format::Format_NHWC); | |||||
| AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_model->subgraphs[0]->tensors.size(), | |||||
| schema::Format::Format_NHWC); | |||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| @@ -29,11 +29,8 @@ class TfliteRangeParser : public TfliteNodeParser { | |||||
| public: | public: | ||||
| TfliteRangeParser() : TfliteNodeParser("Range") {} | TfliteRangeParser() : TfliteNodeParser("Range") {} | ||||
| STATUS Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||||
| const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors, | |||||
| const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer, schema::CNodeT *op, | |||||
| std::vector<int32_t> *tensors_id, std::vector<schema::Format> *tensors_format, | |||||
| std::map<int, int> *tensors_id_map) override; | |||||
| STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||||
| const std::unique_ptr<tflite::ModelT> &tflite_model, schema::CNodeT *op) override; | |||||
| }; | }; | ||||
| } // namespace lite | } // namespace lite | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -21,11 +21,8 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | namespace lite { | ||||
| STATUS TfliteRankParser::Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||||
| const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors, | |||||
| const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer, | |||||
| schema::CNodeT *op, std::vector<int32_t> *tensors_id, | |||||
| std::vector<schema::Format> *tensors_format, std::map<int, int> *tensors_id_map) { | |||||
| STATUS TfliteRankParser::Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||||
| const std::unique_ptr<tflite::ModelT> &tflite_model, schema::CNodeT *op) { | |||||
| MS_LOG(DEBUG) << "parse TfliteRankParser"; | MS_LOG(DEBUG) << "parse TfliteRankParser"; | ||||
| if (op == nullptr) { | if (op == nullptr) { | ||||
| MS_LOG(ERROR) << "op is null"; | MS_LOG(ERROR) << "op is null"; | ||||
| @@ -46,10 +43,10 @@ STATUS TfliteRankParser::Parse(const std::unique_ptr<tflite::OperatorT> &tflite_ | |||||
| op->primitive->value.type = schema::PrimitiveType_Rank; | op->primitive->value.type = schema::PrimitiveType_Rank; | ||||
| op->primitive->value.value = attr.release(); | op->primitive->value.value = attr.release(); | ||||
| AddOpInput(op, tensors_id, tensors_format, tensors_id_map, tflite_op->inputs[0], tensors_id->size(), | |||||
| tflite_tensors.size(), schema::Format::Format_NHWC); | |||||
| AddOpOutput(op, tensors_id, tensors_format, tensors_id_map, tflite_op->outputs[0], tensors_id->size(), | |||||
| tflite_tensors.size(), schema::Format::Format_NHWC); | |||||
| AddOpInput(op, tensors_info, tflite_op->inputs[0], tflite_model->subgraphs[0]->tensors.size(), | |||||
| schema::Format::Format_NHWC); | |||||
| AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_model->subgraphs[0]->tensors.size(), | |||||
| schema::Format::Format_NHWC); | |||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| @@ -29,11 +29,8 @@ class TfliteRankParser : public TfliteNodeParser { | |||||
| public: | public: | ||||
| TfliteRankParser() : TfliteNodeParser("Rank") {} | TfliteRankParser() : TfliteNodeParser("Rank") {} | ||||
| STATUS Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||||
| const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors, | |||||
| const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer, schema::CNodeT *op, | |||||
| std::vector<int32_t> *tensors_id, std::vector<schema::Format> *tensors_format, | |||||
| std::map<int, int> *tensors_id_map) override; | |||||
| STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||||
| const std::unique_ptr<tflite::ModelT> &tflite_model, schema::CNodeT *op) override; | |||||
| }; | }; | ||||
| } // namespace lite | } // namespace lite | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -22,11 +22,8 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | namespace lite { | ||||
| STATUS TfliteReduceParser::Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||||
| const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors, | |||||
| const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer, | |||||
| schema::CNodeT *op, std::vector<int32_t> *tensors_id, | |||||
| std::vector<schema::Format> *tensors_format, std::map<int, int> *tensors_id_map) { | |||||
| STATUS TfliteReduceParser::Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||||
| const std::unique_ptr<tflite::ModelT> &tflite_model, schema::CNodeT *op) { | |||||
| if (op == nullptr) { | if (op == nullptr) { | ||||
| MS_LOG(ERROR) << "op is null"; | MS_LOG(ERROR) << "op is null"; | ||||
| return RET_NULL_PTR; | return RET_NULL_PTR; | ||||
| @@ -75,7 +72,7 @@ STATUS TfliteReduceParser::Parse(const std::unique_ptr<tflite::OperatorT> &tflit | |||||
| return RET_NOT_FIND_OP; | return RET_NOT_FIND_OP; | ||||
| } | } | ||||
| if (GetTfliteData(tflite_op->inputs[1], tflite_tensors, tflite_model_buffer, attr->axes)) { | |||||
| if (GetTfliteData(tflite_op->inputs[1], tflite_model->subgraphs[0]->tensors, tflite_model->buffers, attr->axes)) { | |||||
| MS_LOG(ERROR) << "get reduce -> axes failed"; | MS_LOG(ERROR) << "get reduce -> axes failed"; | ||||
| return RET_ERROR; | return RET_ERROR; | ||||
| } | } | ||||
| @@ -83,10 +80,10 @@ STATUS TfliteReduceParser::Parse(const std::unique_ptr<tflite::OperatorT> &tflit | |||||
| op->primitive->value.type = schema::PrimitiveType_Reduce; | op->primitive->value.type = schema::PrimitiveType_Reduce; | ||||
| op->primitive->value.value = attr.release(); | op->primitive->value.value = attr.release(); | ||||
| AddOpInput(op, tensors_id, tensors_format, tensors_id_map, tflite_op->inputs[0], tensors_id->size(), | |||||
| tflite_tensors.size(), schema::Format::Format_NHWC); | |||||
| AddOpOutput(op, tensors_id, tensors_format, tensors_id_map, tflite_op->outputs[0], tensors_id->size(), | |||||
| tflite_tensors.size(), schema::Format::Format_NHWC); | |||||
| AddOpInput(op, tensors_info, tflite_op->inputs[0], tflite_model->subgraphs[0]->tensors.size(), | |||||
| schema::Format::Format_NHWC); | |||||
| AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_model->subgraphs[0]->tensors.size(), | |||||
| schema::Format::Format_NHWC); | |||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| @@ -29,11 +29,8 @@ class TfliteReduceParser : public TfliteNodeParser { | |||||
| public: | public: | ||||
| TfliteReduceParser() : TfliteNodeParser("node_name") {} | TfliteReduceParser() : TfliteNodeParser("node_name") {} | ||||
| STATUS Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||||
| const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors, | |||||
| const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer, schema::CNodeT *op, | |||||
| std::vector<int32_t> *tensors_id, std::vector<schema::Format> *tensors_format, | |||||
| std::map<int, int> *tensors_id_map) override; | |||||
| STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||||
| const std::unique_ptr<tflite::ModelT> &tflite_model, schema::CNodeT *op) override; | |||||
| }; | }; | ||||
| class TfliteReduceMaxParser : public TfliteReduceParser { | class TfliteReduceMaxParser : public TfliteReduceParser { | ||||
| @@ -21,11 +21,8 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | namespace lite { | ||||
| STATUS TfliteReshapeParser::Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||||
| const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors, | |||||
| const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer, | |||||
| schema::CNodeT *op, std::vector<int32_t> *tensors_id, | |||||
| std::vector<schema::Format> *tensors_format, std::map<int, int> *tensors_id_map) { | |||||
| STATUS TfliteReshapeParser::Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||||
| const std::unique_ptr<tflite::ModelT> &tflite_model, schema::CNodeT *op) { | |||||
| MS_LOG(DEBUG) << "parse TfliteReshapeParser"; | MS_LOG(DEBUG) << "parse TfliteReshapeParser"; | ||||
| if (op == nullptr) { | if (op == nullptr) { | ||||
| MS_LOG(ERROR) << "op is null"; | MS_LOG(ERROR) << "op is null"; | ||||
| @@ -50,18 +47,19 @@ STATUS TfliteReshapeParser::Parse(const std::unique_ptr<tflite::OperatorT> &tfli | |||||
| return RET_ERROR; | return RET_ERROR; | ||||
| } | } | ||||
| auto shape_tensor_index = tflite_op->inputs[1]; | auto shape_tensor_index = tflite_op->inputs[1]; | ||||
| const auto &shape_tensor = tflite_tensors[shape_tensor_index]; | |||||
| const auto &shape_tensor = tflite_model->subgraphs[0]->tensors[shape_tensor_index]; | |||||
| if (shape_tensor == nullptr) { | if (shape_tensor == nullptr) { | ||||
| MS_LOG(ERROR) << "shape_tensor is null"; | MS_LOG(ERROR) << "shape_tensor is null"; | ||||
| return RET_NULL_PTR; | return RET_NULL_PTR; | ||||
| } | } | ||||
| auto &buf_data = tflite_model_buffer[shape_tensor->buffer]; | |||||
| auto &buf_data = tflite_model->buffers[shape_tensor->buffer]; | |||||
| if (buf_data == nullptr) { | if (buf_data == nullptr) { | ||||
| MS_LOG(ERROR) << "buf_data is null"; | MS_LOG(ERROR) << "buf_data is null"; | ||||
| return RET_NULL_PTR; | return RET_NULL_PTR; | ||||
| } | } | ||||
| if (!buf_data->data.empty()) { | if (!buf_data->data.empty()) { | ||||
| if (GetTfliteData(tflite_op->inputs[1], tflite_tensors, tflite_model_buffer, attr->shape)) { | |||||
| if (GetTfliteData(tflite_op->inputs[1], tflite_model->subgraphs[0]->tensors, tflite_model->buffers, | |||||
| attr->shape)) { | |||||
| MS_LOG(ERROR) << "get reshape -> shape failed"; | MS_LOG(ERROR) << "get reshape -> shape failed"; | ||||
| return RET_ERROR; | return RET_ERROR; | ||||
| } | } | ||||
| @@ -78,11 +76,11 @@ STATUS TfliteReshapeParser::Parse(const std::unique_ptr<tflite::OperatorT> &tfli | |||||
| op->primitive->value.value = attr.release(); | op->primitive->value.value = attr.release(); | ||||
| for (size_t i = 0; i < tflite_op->inputs.size(); i++) { | for (size_t i = 0; i < tflite_op->inputs.size(); i++) { | ||||
| AddOpInput(op, tensors_id, tensors_format, tensors_id_map, tflite_op->inputs[i], tensors_id->size(), | |||||
| tflite_tensors.size(), schema::Format::Format_NHWC); | |||||
| AddOpInput(op, tensors_info, tflite_op->inputs[i], tflite_model->subgraphs[0]->tensors.size(), | |||||
| schema::Format::Format_NHWC); | |||||
| } | } | ||||
| AddOpOutput(op, tensors_id, tensors_format, tensors_id_map, tflite_op->outputs[0], tensors_id->size(), | |||||
| tflite_tensors.size(), schema::Format::Format_NHWC); | |||||
| AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_model->subgraphs[0]->tensors.size(), | |||||
| schema::Format::Format_NHWC); | |||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| @@ -29,11 +29,8 @@ class TfliteReshapeParser : public TfliteNodeParser { | |||||
| public: | public: | ||||
| TfliteReshapeParser() : TfliteNodeParser("Reshape") {} | TfliteReshapeParser() : TfliteNodeParser("Reshape") {} | ||||
| STATUS Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||||
| const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors, | |||||
| const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer, schema::CNodeT *op, | |||||
| std::vector<int32_t> *tensors_id, std::vector<schema::Format> *tensors_format, | |||||
| std::map<int, int> *tensors_id_map) override; | |||||
| STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||||
| const std::unique_ptr<tflite::ModelT> &tflite_model, schema::CNodeT *op) override; | |||||
| }; | }; | ||||
| } // namespace lite | } // namespace lite | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -22,11 +22,8 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | namespace lite { | ||||
| STATUS TfliteResizeParser::Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||||
| const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors, | |||||
| const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer, | |||||
| schema::CNodeT *op, std::vector<int32_t> *tensors_id, | |||||
| std::vector<schema::Format> *tensors_format, std::map<int, int> *tensors_id_map) { | |||||
| STATUS TfliteResizeParser::Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||||
| const std::unique_ptr<tflite::ModelT> &tflite_model, schema::CNodeT *op) { | |||||
| if (op == nullptr) { | if (op == nullptr) { | ||||
| MS_LOG(ERROR) << "op is null"; | MS_LOG(ERROR) << "op is null"; | ||||
| return RET_NULL_PTR; | return RET_NULL_PTR; | ||||
| @@ -73,13 +70,13 @@ STATUS TfliteResizeParser::Parse(const std::unique_ptr<tflite::OperatorT> &tflit | |||||
| attr->preserveAspectRatio = false; | attr->preserveAspectRatio = false; | ||||
| auto tfliteResizeTensorIndex = tflite_op->inputs[1]; | auto tfliteResizeTensorIndex = tflite_op->inputs[1]; | ||||
| const auto &shape_tensor = tflite_tensors[tfliteResizeTensorIndex]; | |||||
| const auto &shape_tensor = tflite_model->subgraphs[0]->tensors[tfliteResizeTensorIndex]; | |||||
| if (shape_tensor == nullptr) { | if (shape_tensor == nullptr) { | ||||
| MS_LOG(ERROR) << "shape_tensor is null"; | MS_LOG(ERROR) << "shape_tensor is null"; | ||||
| return RET_NULL_PTR; | return RET_NULL_PTR; | ||||
| } | } | ||||
| auto resizeTensorBufferIndex = shape_tensor->buffer; | auto resizeTensorBufferIndex = shape_tensor->buffer; | ||||
| const auto &buff = tflite_model_buffer.at(resizeTensorBufferIndex); | |||||
| const auto &buff = tflite_model->buffers.at(resizeTensorBufferIndex); | |||||
| if (buff == nullptr) { | if (buff == nullptr) { | ||||
| MS_LOG(ERROR) << "buff_data is null"; | MS_LOG(ERROR) << "buff_data is null"; | ||||
| return RET_NULL_PTR; | return RET_NULL_PTR; | ||||
| @@ -95,14 +92,14 @@ STATUS TfliteResizeParser::Parse(const std::unique_ptr<tflite::OperatorT> &tflit | |||||
| op->primitive->value.type = schema::PrimitiveType_Resize; | op->primitive->value.type = schema::PrimitiveType_Resize; | ||||
| op->primitive->value.value = attr.release(); | op->primitive->value.value = attr.release(); | ||||
| AddOpInput(op, tensors_id, tensors_format, tensors_id_map, tflite_op->inputs[0], tensors_id->size(), | |||||
| tflite_tensors.size(), schema::Format::Format_NHWC); | |||||
| AddOpInput(op, tensors_info, tflite_op->inputs[0], tflite_model->subgraphs[0]->tensors.size(), | |||||
| schema::Format::Format_NHWC); | |||||
| if (buffData == nullptr) { | if (buffData == nullptr) { | ||||
| AddOpInput(op, tensors_id, tensors_format, tensors_id_map, tflite_op->inputs[1], tensors_id->size(), | |||||
| tflite_tensors.size(), schema::Format::Format_NHWC); | |||||
| AddOpInput(op, tensors_info, tflite_op->inputs[1], tflite_model->subgraphs[0]->tensors.size(), | |||||
| schema::Format::Format_NHWC); | |||||
| } | } | ||||
| AddOpOutput(op, tensors_id, tensors_format, tensors_id_map, tflite_op->outputs[0], tensors_id->size(), | |||||
| tflite_tensors.size(), schema::Format::Format_NHWC); | |||||
| AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_model->subgraphs[0]->tensors.size(), | |||||
| schema::Format::Format_NHWC); | |||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| @@ -29,11 +29,8 @@ class TfliteResizeParser : public TfliteNodeParser { | |||||
| public: | public: | ||||
| TfliteResizeParser() : TfliteNodeParser("node_name") {} | TfliteResizeParser() : TfliteNodeParser("node_name") {} | ||||
| STATUS Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||||
| const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors, | |||||
| const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer, schema::CNodeT *op, | |||||
| std::vector<int32_t> *tensors_id, std::vector<schema::Format> *tensors_format, | |||||
| std::map<int, int> *tensors_id_map) override; | |||||
| STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||||
| const std::unique_ptr<tflite::ModelT> &tflite_model, schema::CNodeT *op) override; | |||||
| }; | }; | ||||
| class TfliteResizeBilinearParser : public TfliteResizeParser { | class TfliteResizeBilinearParser : public TfliteResizeParser { | ||||
| @@ -21,11 +21,8 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | namespace lite { | ||||
| STATUS TfliteReverseParser::Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||||
| const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors, | |||||
| const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer, | |||||
| schema::CNodeT *op, std::vector<int32_t> *tensors_id, | |||||
| std::vector<schema::Format> *tensors_format, std::map<int, int> *tensors_id_map) { | |||||
| STATUS TfliteReverseParser::Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||||
| const std::unique_ptr<tflite::ModelT> &tflite_model, schema::CNodeT *op) { | |||||
| MS_LOG(DEBUG) << "parse TfliteReverseParser"; | MS_LOG(DEBUG) << "parse TfliteReverseParser"; | ||||
| if (op == nullptr) { | if (op == nullptr) { | ||||
| MS_LOG(ERROR) << "op is null"; | MS_LOG(ERROR) << "op is null"; | ||||
| @@ -43,7 +40,7 @@ STATUS TfliteReverseParser::Parse(const std::unique_ptr<tflite::OperatorT> &tfli | |||||
| return RET_NULL_PTR; | return RET_NULL_PTR; | ||||
| } | } | ||||
| if (GetTfliteData(tflite_op->inputs[1], tflite_tensors, tflite_model_buffer, attr->axis)) { | |||||
| if (GetTfliteData(tflite_op->inputs[1], tflite_model->subgraphs[0]->tensors, tflite_model->buffers, attr->axis)) { | |||||
| MS_LOG(ERROR) << "get reverse -> axis failed"; | MS_LOG(ERROR) << "get reverse -> axis failed"; | ||||
| return RET_ERROR; | return RET_ERROR; | ||||
| } | } | ||||
| @@ -51,10 +48,10 @@ STATUS TfliteReverseParser::Parse(const std::unique_ptr<tflite::OperatorT> &tfli | |||||
| op->primitive->value.type = schema::PrimitiveType_Reverse; | op->primitive->value.type = schema::PrimitiveType_Reverse; | ||||
| op->primitive->value.value = attr.release(); | op->primitive->value.value = attr.release(); | ||||
| AddOpInput(op, tensors_id, tensors_format, tensors_id_map, tflite_op->inputs[0], tensors_id->size(), | |||||
| tflite_tensors.size(), schema::Format::Format_NHWC); | |||||
| AddOpOutput(op, tensors_id, tensors_format, tensors_id_map, tflite_op->outputs[0], tensors_id->size(), | |||||
| tflite_tensors.size(), schema::Format::Format_NHWC); | |||||
| AddOpInput(op, tensors_info, tflite_op->inputs[0], tflite_model->subgraphs[0]->tensors.size(), | |||||
| schema::Format::Format_NHWC); | |||||
| AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_model->subgraphs[0]->tensors.size(), | |||||
| schema::Format::Format_NHWC); | |||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| @@ -29,11 +29,8 @@ class TfliteReverseParser : public TfliteNodeParser { | |||||
| public: | public: | ||||
| TfliteReverseParser() : TfliteNodeParser("reverse") {} | TfliteReverseParser() : TfliteNodeParser("reverse") {} | ||||
| STATUS Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||||
| const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors, | |||||
| const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer, schema::CNodeT *op, | |||||
| std::vector<int32_t> *tensors_id, std::vector<schema::Format> *tensors_format, | |||||
| std::map<int, int> *tensors_id_map) override; | |||||
| STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||||
| const std::unique_ptr<tflite::ModelT> &tflite_model, schema::CNodeT *op) override; | |||||
| }; | }; | ||||
| } // namespace lite | } // namespace lite | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -22,12 +22,9 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | namespace lite { | ||||
| STATUS TfliteReverseSequenceParser::Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||||
| const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors, | |||||
| const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer, | |||||
| schema::CNodeT *op, std::vector<int32_t> *tensors_id, | |||||
| std::vector<schema::Format> *tensors_format, | |||||
| std::map<int, int> *tensors_id_map) { | |||||
| STATUS TfliteReverseSequenceParser::Parse(TfliteTensorsInfo *tensors_info, | |||||
| const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||||
| const std::unique_ptr<tflite::ModelT> &tflite_model, schema::CNodeT *op) { | |||||
| MS_LOG(DEBUG) << "parse TfliteReverseSequenceParser"; | MS_LOG(DEBUG) << "parse TfliteReverseSequenceParser"; | ||||
| if (op == nullptr) { | if (op == nullptr) { | ||||
| MS_LOG(ERROR) << "op is null"; | MS_LOG(ERROR) << "op is null"; | ||||
| @@ -56,12 +53,12 @@ STATUS TfliteReverseSequenceParser::Parse(const std::unique_ptr<tflite::Operator | |||||
| op->primitive->value.type = schema::PrimitiveType_ReverseSequence; | op->primitive->value.type = schema::PrimitiveType_ReverseSequence; | ||||
| op->primitive->value.value = attr.release(); | op->primitive->value.value = attr.release(); | ||||
| AddOpInput(op, tensors_id, tensors_format, tensors_id_map, tflite_op->inputs[0], tensors_id->size(), | |||||
| tflite_tensors.size(), schema::Format_NHWC); | |||||
| AddOpInput(op, tensors_id, tensors_format, tensors_id_map, tflite_op->inputs[1], tensors_id->size(), | |||||
| tflite_tensors.size(), schema::Format_NHWC); | |||||
| AddOpOutput(op, tensors_id, tensors_format, tensors_id_map, tflite_op->outputs[0], tensors_id->size(), | |||||
| tflite_tensors.size(), schema::Format_NHWC); | |||||
| AddOpInput(op, tensors_info, tflite_op->inputs[0], tflite_model->subgraphs[0]->tensors.size(), | |||||
| schema::Format::Format_NHWC); | |||||
| AddOpInput(op, tensors_info, tflite_op->inputs[1], tflite_model->subgraphs[0]->tensors.size(), | |||||
| schema::Format::Format_NHWC); | |||||
| AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_model->subgraphs[0]->tensors.size(), | |||||
| schema::Format::Format_NHWC); | |||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| @@ -29,11 +29,8 @@ class TfliteReverseSequenceParser : public TfliteNodeParser { | |||||
| public: | public: | ||||
| TfliteReverseSequenceParser() : TfliteNodeParser("ReverseSequence") {} | TfliteReverseSequenceParser() : TfliteNodeParser("ReverseSequence") {} | ||||
| STATUS Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||||
| const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors, | |||||
| const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer, schema::CNodeT *op, | |||||
| std::vector<int32_t> *tensors_id, std::vector<schema::Format> *tensors_format, | |||||
| std::map<int, int> *tensors_id_map) override; | |||||
| STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||||
| const std::unique_ptr<tflite::ModelT> &tflite_model, schema::CNodeT *op) override; | |||||
| }; | }; | ||||
| } // namespace lite | } // namespace lite | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -22,11 +22,9 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | namespace lite { | ||||
| STATUS TfliteScatterNdParser::Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||||
| const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors, | |||||
| const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer, | |||||
| schema::CNodeT *op, std::vector<int32_t> *tensors_id, | |||||
| std::vector<schema::Format> *tensors_format, std::map<int, int> *tensors_id_map) { | |||||
| STATUS TfliteScatterNdParser::Parse(TfliteTensorsInfo *tensors_info, | |||||
| const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||||
| const std::unique_ptr<tflite::ModelT> &tflite_model, schema::CNodeT *op) { | |||||
| MS_LOG(DEBUG) << "parse TfliteScatterNdParser"; | MS_LOG(DEBUG) << "parse TfliteScatterNdParser"; | ||||
| if (op == nullptr) { | if (op == nullptr) { | ||||
| MS_LOG(ERROR) << "op is null"; | MS_LOG(ERROR) << "op is null"; | ||||
| @@ -54,14 +52,14 @@ STATUS TfliteScatterNdParser::Parse(const std::unique_ptr<tflite::OperatorT> &tf | |||||
| // in tflite, kIndices = 0, kUpdates = 1, kShape = 2 | // in tflite, kIndices = 0, kUpdates = 1, kShape = 2 | ||||
| // in mslite, kScatterShapeIndex = 0, kScatterIndicesIndex = 1, kScatterUpdateIndex = 2; | // in mslite, kScatterShapeIndex = 0, kScatterIndicesIndex = 1, kScatterUpdateIndex = 2; | ||||
| AddOpInput(op, tensors_id, tensors_format, tensors_id_map, tflite_op->inputs[2], tensors_id->size(), | |||||
| tflite_tensors.size(), schema::Format::Format_NHWC); | |||||
| AddOpInput(op, tensors_id, tensors_format, tensors_id_map, tflite_op->inputs[0], tensors_id->size(), | |||||
| tflite_tensors.size(), schema::Format::Format_NHWC); | |||||
| AddOpInput(op, tensors_id, tensors_format, tensors_id_map, tflite_op->inputs[1], tensors_id->size(), | |||||
| tflite_tensors.size(), schema::Format::Format_NHWC); | |||||
| AddOpOutput(op, tensors_id, tensors_format, tensors_id_map, tflite_op->outputs[0], tensors_id->size(), | |||||
| tflite_tensors.size(), schema::Format::Format_NHWC); | |||||
| AddOpInput(op, tensors_info, tflite_op->inputs[2], tflite_model->subgraphs[0]->tensors.size(), | |||||
| schema::Format::Format_NHWC); | |||||
| AddOpInput(op, tensors_info, tflite_op->inputs[0], tflite_model->subgraphs[0]->tensors.size(), | |||||
| schema::Format::Format_NHWC); | |||||
| AddOpInput(op, tensors_info, tflite_op->inputs[1], tflite_model->subgraphs[0]->tensors.size(), | |||||
| schema::Format::Format_NHWC); | |||||
| AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_model->subgraphs[0]->tensors.size(), | |||||
| schema::Format::Format_NHWC); | |||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| @@ -29,11 +29,8 @@ class TfliteScatterNdParser : public TfliteNodeParser { | |||||
| public: | public: | ||||
| TfliteScatterNdParser() : TfliteNodeParser("ScatterNd") {} | TfliteScatterNdParser() : TfliteNodeParser("ScatterNd") {} | ||||
| STATUS Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||||
| const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors, | |||||
| const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer, schema::CNodeT *op, | |||||
| std::vector<int32_t> *tensors_id, std::vector<schema::Format> *tensors_format, | |||||
| std::map<int, int> *tensors_id_map) override; | |||||
| STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||||
| const std::unique_ptr<tflite::ModelT> &tflite_model, schema::CNodeT *op) override; | |||||
| }; | }; | ||||
| } // namespace lite | } // namespace lite | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -21,11 +21,8 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | namespace lite { | ||||
| STATUS TfliteShapeParser::Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||||
| const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors, | |||||
| const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer, | |||||
| schema::CNodeT *op, std::vector<int32_t> *tensors_id, | |||||
| std::vector<schema::Format> *tensors_format, std::map<int, int> *tensors_id_map) { | |||||
| STATUS TfliteShapeParser::Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||||
| const std::unique_ptr<tflite::ModelT> &tflite_model, schema::CNodeT *op) { | |||||
| MS_LOG(DEBUG) << "parse TfliteShapeParser"; | MS_LOG(DEBUG) << "parse TfliteShapeParser"; | ||||
| if (op == nullptr) { | if (op == nullptr) { | ||||
| MS_LOG(ERROR) << "op is null"; | MS_LOG(ERROR) << "op is null"; | ||||
| @@ -46,10 +43,10 @@ STATUS TfliteShapeParser::Parse(const std::unique_ptr<tflite::OperatorT> &tflite | |||||
| op->primitive->value.type = schema::PrimitiveType_Shape; | op->primitive->value.type = schema::PrimitiveType_Shape; | ||||
| op->primitive->value.value = attr.release(); | op->primitive->value.value = attr.release(); | ||||
| AddOpInput(op, tensors_id, tensors_format, tensors_id_map, tflite_op->inputs[0], tensors_id->size(), | |||||
| tflite_tensors.size(), schema::Format::Format_NHWC); | |||||
| AddOpOutput(op, tensors_id, tensors_format, tensors_id_map, tflite_op->outputs[0], tensors_id->size(), | |||||
| tflite_tensors.size(), schema::Format::Format_NHWC); | |||||
| AddOpInput(op, tensors_info, tflite_op->inputs[0], tflite_model->subgraphs[0]->tensors.size(), | |||||
| schema::Format::Format_NHWC); | |||||
| AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_model->subgraphs[0]->tensors.size(), | |||||
| schema::Format::Format_NHWC); | |||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| @@ -29,11 +29,8 @@ class TfliteShapeParser : public TfliteNodeParser { | |||||
| public: | public: | ||||
| TfliteShapeParser() : TfliteNodeParser("Shape") {} | TfliteShapeParser() : TfliteNodeParser("Shape") {} | ||||
| STATUS Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||||
| const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors, | |||||
| const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer, schema::CNodeT *op, | |||||
| std::vector<int32_t> *tensors_id, std::vector<schema::Format> *tensors_format, | |||||
| std::map<int, int> *tensors_id_map) override; | |||||
| STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||||
| const std::unique_ptr<tflite::ModelT> &tflite_model, schema::CNodeT *op) override; | |||||
| }; | }; | ||||
| } // namespace lite | } // namespace lite | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -21,11 +21,8 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | namespace lite { | ||||
| STATUS TfliteSkipGramParser::Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||||
| const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors, | |||||
| const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer, | |||||
| schema::CNodeT *op, std::vector<int32_t> *tensors_id, | |||||
| std::vector<schema::Format> *tensors_format, std::map<int, int> *tensors_id_map) { | |||||
| STATUS TfliteSkipGramParser::Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||||
| const std::unique_ptr<tflite::ModelT> &tflite_model, schema::CNodeT *op) { | |||||
| MS_LOG(DEBUG) << "parse TfliteSkipGramParser"; | MS_LOG(DEBUG) << "parse TfliteSkipGramParser"; | ||||
| if (op == nullptr) { | if (op == nullptr) { | ||||
| MS_LOG(ERROR) << "op is null"; | MS_LOG(ERROR) << "op is null"; | ||||
| @@ -55,10 +52,10 @@ STATUS TfliteSkipGramParser::Parse(const std::unique_ptr<tflite::OperatorT> &tfl | |||||
| op->primitive->value.type = schema::PrimitiveType_SkipGram; | op->primitive->value.type = schema::PrimitiveType_SkipGram; | ||||
| op->primitive->value.value = attr.release(); | op->primitive->value.value = attr.release(); | ||||
| AddOpInput(op, tensors_id, tensors_format, tensors_id_map, tflite_op->inputs[0], tensors_id->size(), | |||||
| tflite_tensors.size(), schema::Format::Format_NHWC); | |||||
| AddOpOutput(op, tensors_id, tensors_format, tensors_id_map, tflite_op->outputs[0], tensors_id->size(), | |||||
| tflite_tensors.size(), schema::Format::Format_NHWC); | |||||
| AddOpInput(op, tensors_info, tflite_op->inputs[0], tflite_model->subgraphs[0]->tensors.size(), | |||||
| schema::Format::Format_NHWC); | |||||
| AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_model->subgraphs[0]->tensors.size(), | |||||
| schema::Format::Format_NHWC); | |||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| @@ -29,11 +29,8 @@ class TfliteSkipGramParser : public TfliteNodeParser { | |||||
| public: | public: | ||||
| TfliteSkipGramParser() : TfliteNodeParser("SkipGram") {} | TfliteSkipGramParser() : TfliteNodeParser("SkipGram") {} | ||||
| STATUS Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||||
| const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors, | |||||
| const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer, schema::CNodeT *op, | |||||
| std::vector<int32_t> *tensors_id, std::vector<schema::Format> *tensors_format, | |||||
| std::map<int, int> *tensors_id_map) override; | |||||
| STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||||
| const std::unique_ptr<tflite::ModelT> &tflite_model, schema::CNodeT *op) override; | |||||
| }; | }; | ||||
| } // namespace lite | } // namespace lite | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -21,11 +21,8 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | namespace lite { | ||||
| STATUS TfliteSliceParser::Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||||
| const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors, | |||||
| const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer, | |||||
| schema::CNodeT *op, std::vector<int32_t> *tensors_id, | |||||
| std::vector<schema::Format> *tensors_format, std::map<int, int> *tensors_id_map) { | |||||
| STATUS TfliteSliceParser::Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||||
| const std::unique_ptr<tflite::ModelT> &tflite_model, schema::CNodeT *op) { | |||||
| MS_LOG(DEBUG) << "parse TfliteSliceParser"; | MS_LOG(DEBUG) << "parse TfliteSliceParser"; | ||||
| if (op == nullptr) { | if (op == nullptr) { | ||||
| MS_LOG(ERROR) << "op is null"; | MS_LOG(ERROR) << "op is null"; | ||||
| @@ -45,11 +42,11 @@ STATUS TfliteSliceParser::Parse(const std::unique_ptr<tflite::OperatorT> &tflite | |||||
| attr->format = schema::Format::Format_NHWC; | attr->format = schema::Format::Format_NHWC; | ||||
| if (GetTfliteData(tflite_op->inputs[1], tflite_tensors, tflite_model_buffer, attr->begin)) { | |||||
| if (GetTfliteData(tflite_op->inputs[1], tflite_model->subgraphs[0]->tensors, tflite_model->buffers, attr->begin)) { | |||||
| MS_LOG(ERROR) << "get slice -> begin failed"; | MS_LOG(ERROR) << "get slice -> begin failed"; | ||||
| return RET_ERROR; | return RET_ERROR; | ||||
| } | } | ||||
| if (GetTfliteData(tflite_op->inputs[2], tflite_tensors, tflite_model_buffer, attr->size)) { | |||||
| if (GetTfliteData(tflite_op->inputs[2], tflite_model->subgraphs[0]->tensors, tflite_model->buffers, attr->size)) { | |||||
| MS_LOG(ERROR) << "get slice -> size failed"; | MS_LOG(ERROR) << "get slice -> size failed"; | ||||
| return RET_ERROR; | return RET_ERROR; | ||||
| } | } | ||||
| @@ -62,10 +59,10 @@ STATUS TfliteSliceParser::Parse(const std::unique_ptr<tflite::OperatorT> &tflite | |||||
| op->primitive->value.type = schema::PrimitiveType_Slice; | op->primitive->value.type = schema::PrimitiveType_Slice; | ||||
| op->primitive->value.value = attr.release(); | op->primitive->value.value = attr.release(); | ||||
| AddOpInput(op, tensors_id, tensors_format, tensors_id_map, tflite_op->inputs[0], tensors_id->size(), | |||||
| tflite_tensors.size(), schema::Format::Format_NHWC); | |||||
| AddOpOutput(op, tensors_id, tensors_format, tensors_id_map, tflite_op->outputs[0], tensors_id->size(), | |||||
| tflite_tensors.size(), schema::Format::Format_NHWC); | |||||
| AddOpInput(op, tensors_info, tflite_op->inputs[0], tflite_model->subgraphs[0]->tensors.size(), | |||||
| schema::Format::Format_NHWC); | |||||
| AddOpOutput(op, tensors_info, tflite_op->outputs[0], tflite_model->subgraphs[0]->tensors.size(), | |||||
| schema::Format::Format_NHWC); | |||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| @@ -29,11 +29,8 @@ class TfliteSliceParser : public TfliteNodeParser { | |||||
| public: | public: | ||||
| TfliteSliceParser() : TfliteNodeParser("Slice") {} | TfliteSliceParser() : TfliteNodeParser("Slice") {} | ||||
| STATUS Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||||
| const std::vector<std::unique_ptr<tflite::TensorT>> &tflite_tensors, | |||||
| const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer, schema::CNodeT *op, | |||||
| std::vector<int32_t> *tensors_id, std::vector<schema::Format> *tensors_format, | |||||
| std::map<int, int> *tensors_id_map) override; | |||||
| STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op, | |||||
| const std::unique_ptr<tflite::ModelT> &tflite_model, schema::CNodeT *op) override; | |||||
| }; | }; | ||||
| } // namespace lite | } // namespace lite | ||||
| } // namespace mindspore | } // namespace mindspore | ||||