Browse Source

fix reshape and l2norm tflite parser bug

tags/v1.0.0
lyvette 5 years ago
parent
commit
34361e2fe4
3 changed files with 10 additions and 8 deletions
  1. +0
    -1
      mindspore/lite/tools/converter/parser/tflite/tflite_activation_parser.cc
  2. +0
    -4
      mindspore/lite/tools/converter/parser/tflite/tflite_l2norm_parser.cc
  3. +10
    -3
      mindspore/lite/tools/converter/parser/tflite/tflite_reshape_parser.cc

+ 0
- 1
mindspore/lite/tools/converter/parser/tflite/tflite_activation_parser.cc View File

@@ -51,7 +51,6 @@ STATUS TfliteActivationParser::Parse(const std::unique_ptr<tflite::OperatorT> &t
if (std::strcmp(node_name, "Relu") == 0) {
MS_LOG(DEBUG) << "parse TfliteReluParser";
attr->type = schema::ActivationType_RELU;

} else if (std::strcmp(node_name, "Relu6") == 0) {
MS_LOG(DEBUG) << "parse TfliteRelu6Parser";
attr->type = schema::ActivationType_RELU6;


+ 0
- 4
mindspore/lite/tools/converter/parser/tflite/tflite_l2norm_parser.cc View File

@@ -51,10 +51,6 @@ STATUS TfliteL2NormParser::Parse(const std::unique_ptr<tflite::OperatorT> &tflit
return RET_NULL_PTR;
}
auto data_index = tflite_op->inputs[0];
if (static_cast<int>(tflite_op->inputs.size()) <= data_index) {
MS_LOG(ERROR) << "the size of input should be greater than " << data_index;
return RET_ERROR;
}
const auto &data_tensor = tflite_tensors[data_index];
if (data_tensor == nullptr) {
MS_LOG(ERROR) << "the input tensor is null";


+ 10
- 3
mindspore/lite/tools/converter/parser/tflite/tflite_reshape_parser.cc View File

@@ -57,9 +57,16 @@ STATUS TfliteReshapeParser::Parse(const std::unique_ptr<tflite::OperatorT> &tfli
MS_LOG(ERROR) << "shape_tensor is null";
return RET_NULL_PTR;
}
if (GetTfliteData(tflite_op->inputs[1], tflite_tensors, tflite_model_buffer, attr->shape)) {
MS_LOG(ERROR) << "get reshape -> shape failed";
return RET_ERROR;
auto &buf_data = tflite_model_buffer[shape_tensor->buffer];
if (buf_data == nullptr) {
MS_LOG(ERROR) << "buf_data is null";
return RET_NULL_PTR;
}
if (!buf_data->data.empty()) {
if (GetTfliteData(tflite_op->inputs[1], tflite_tensors, tflite_model_buffer, attr->shape)) {
MS_LOG(ERROR) << "get reshape -> shape failed";
return RET_ERROR;
}
}
} else {
attr->format = schema::Format_NHWC;


Loading…
Cancel
Save