Browse Source

add fp16 testcase

tags/v0.7.0-beta
cjh9368 5 years ago
parent
commit
c49bc53b5b
2 changed files with 10 additions and 12 deletions
  1. +1
    -0
      mindspore/lite/test/models_tflite.cfg
  2. +9
    -12
      mindspore/lite/tools/converter/parser/tflite/tflite_model_parser.cc

+ 1
- 0
mindspore/lite/test/models_tflite.cfg View File

@@ -56,3 +56,4 @@ ml_ocr_latin.tflite
hiai_ssd_mobilenetv2_object.tflite
inception_v4.tflite
ml_object_detect.tflite
mtk_model_normalize_object_scene_ps_20200519_f16.tflite

+ 9
- 12
mindspore/lite/tools/converter/parser/tflite/tflite_model_parser.cc View File

@@ -43,8 +43,7 @@ std::unique_ptr<tflite::ModelT> TfliteModelParser::ReadTfliteModel(const char *m
}

STATUS TfliteModelParser::CopyConstTensorData(const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer,
const tflite::TensorT *tflite_tensor,
schema::TensorT *tensor) {
const tflite::TensorT *tflite_tensor, schema::TensorT *tensor) {
auto count = 1;
std::for_each(tflite_tensor->shape.begin(), tflite_tensor->shape.end(), [&](int32_t sha) { count *= sha; });
auto data_size = count * GetDataTypeSize(TypeId(tensor->dataType));
@@ -92,8 +91,7 @@ void TfliteModelParser::SetTensorQuantParam(const std::unique_ptr<tflite::Tensor

STATUS TfliteModelParser::ConvertOp(const std::unique_ptr<tflite::ModelT> &tflite_model,
const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph,
const QuantType &quant_type,
schema::MetaGraphT* sub_graph) {
const QuantType &quant_type, schema::MetaGraphT *sub_graph) {
int idx = 0;
for (const auto &tflite_op : tflite_subgraph->operators) {
auto tflite_op_type = (tflite_model->operator_codes[tflite_op->opcode_index])->builtin_code;
@@ -126,7 +124,7 @@ STATUS TfliteModelParser::ConvertOp(const std::unique_ptr<tflite::ModelT> &tflit

STATUS TfliteModelParser::ConvertTensor(const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph,
const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer,
schema::MetaGraphT* sub_graph) {
schema::MetaGraphT *sub_graph) {
for (int i = 0; i < tensorsId.size(); i++) {
auto idx = tensorsId[i];
if (idx < 0) {
@@ -164,8 +162,9 @@ STATUS TfliteModelParser::ConvertTensor(const std::unique_ptr<tflite::SubGraphT>
}

// quant param
if (!(tflite_tensor->quantization->scale.empty() && tflite_tensor->quantization->zero_point.empty() &&
tflite_tensor->quantization->min.empty() && tflite_tensor->quantization->max.empty())) {
if (tflite_tensor->quantization != nullptr &&
!(tflite_tensor->quantization->scale.empty() && tflite_tensor->quantization->zero_point.empty() &&
tflite_tensor->quantization->min.empty() && tflite_tensor->quantization->max.empty())) {
SetTensorQuantParam(tflite_tensor, tensor.get());
}

@@ -180,7 +179,7 @@ STATUS TfliteModelParser::ConvertTensor(const std::unique_ptr<tflite::SubGraphT>
}

STATUS TfliteModelParser::GetGraphInfo(const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph,
schema::MetaGraphT* sub_graph) {
schema::MetaGraphT *sub_graph) {
int id;

// graph input
@@ -217,7 +216,7 @@ STATUS TfliteModelParser::GetGraphInfo(const std::unique_ptr<tflite::SubGraphT>
return RET_OK;
}

STATUS TfliteModelParser::UpdateOp(schema::MetaGraphT* sub_graph) {
STATUS TfliteModelParser::UpdateOp(schema::MetaGraphT *sub_graph) {
for (auto &op : sub_graph->nodes) {
if (op->primitive->value.type == schema::PrimitiveType_DepthwiseConv2D) {
auto attr = op->primitive->value.AsDepthwiseConv2D();
@@ -270,9 +269,7 @@ STATUS TfliteModelParser::UpdateOp(schema::MetaGraphT* sub_graph) {
return RET_OK;
}


MetaGraphT *TfliteModelParser::Parse(const std::string &model_file,
const std::string &weight_file,
MetaGraphT *TfliteModelParser::Parse(const std::string &model_file, const std::string &weight_file,
const QuantType &quant_type) {
std::unique_ptr<schema::MetaGraphT> sub_graph(new schema::MetaGraphT);
sub_graph->name = "MS_model converted by TF-Lite";


Loading…
Cancel
Save