Browse Source

!4584 fix topkv2 parser

Merge pull request !4584 from sunsuodong/fix_topkv2_parser
tags/v0.7.0-beta
mindspore-ci-bot Gitee 5 years ago
parent
commit
485f3dde01
4 changed files with 9 additions and 15 deletions
  1. +0
    -1
      mindspore/lite/schema/model.fbs
  2. +0
    -6
      mindspore/lite/schema/ops.fbs
  3. +4
    -5
      mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_topk_v2_parser_test.cc
  4. +5
    -3
      mindspore/lite/tools/converter/parser/tflite/tflite_topk_v2_parser.cc

+ 0
- 1
mindspore/lite/schema/model.fbs View File

@@ -190,7 +190,6 @@ union PrimitiveType {
ActivationGrad,
PriorBox,
SpaceToBatchND,
TopKV2,
Return,
MakeTuple,
ToFormat,


+ 0
- 6
mindspore/lite/schema/ops.fbs View File

@@ -872,12 +872,6 @@ table SpaceToBatchND {
paddings : [int];
}

table TopKV2 {
k : [int];
sorted : bool = true;
}


table MakeTuple {
}



+ 4
- 5
mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_topk_v2_parser_test.cc View File

@@ -31,14 +31,13 @@ TEST_F(TestTfliteParserTopKV2, OpType) {
ASSERT_NE(meta_graph, nullptr);
ASSERT_GT(meta_graph->nodes.size(), 0);
ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr);
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_TopKV2) << "wrong Op Type";
ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_TopK) << "wrong Op Type";
}

TEST_F(TestTfliteParserTopKV2, AttrValue) {
ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsTopKV2(), nullptr);
auto val = meta_graph->nodes.front()->primitive->value.AsTopKV2();
std::vector<int> k = {3};
ASSERT_EQ(val->k, k);
ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsTopK(), nullptr);
auto val = meta_graph->nodes.front()->primitive->value.AsTopK();
ASSERT_EQ(val->k, 3);
ASSERT_EQ(val->sorted, true);
}
} // namespace mindspore

+ 5
- 3
mindspore/lite/tools/converter/parser/tflite/tflite_topk_v2_parser.cc View File

@@ -41,15 +41,17 @@ STATUS TfliteTopKV2Parser::Parse(const std::unique_ptr<tflite::OperatorT> &tflit
return RET_NULL_PTR;
}

std::unique_ptr<schema::TopKV2T> attr(new schema::TopKV2T());
std::unique_ptr<schema::TopKT> attr(new schema::TopKT());

attr->sorted = true;
if (GetTfliteData(tflite_op->inputs[1], tflite_tensors, tflite_model_buffer, attr->k)) {
std::vector<int32_t> k;
if (GetTfliteData(tflite_op->inputs[1], tflite_tensors, tflite_model_buffer, k)) {
MS_LOG(ERROR) << "get topKV2 -> k failed";
return RET_ERROR;
}
attr->k = k.front();

op->primitive->value.type = schema::PrimitiveType_TopKV2;
op->primitive->value.type = schema::PrimitiveType_TopK;
op->primitive->value.value = attr.release();

AddOpInput(op, tensors_id, tensors_format, tensors_id_map,


Loading…
Cancel
Save