From 676e44a130f80b35fd52c26f82fccd3e7e075598 Mon Sep 17 00:00:00 2001 From: sunsuodong Date: Mon, 17 Aug 2020 16:17:43 +0800 Subject: [PATCH] fix_topkv2_parser --- mindspore/lite/schema/model.fbs | 1 - mindspore/lite/schema/ops.fbs | 6 ------ .../parser/tflite/tflite_topk_v2_parser_test.cc | 9 ++++----- .../converter/parser/tflite/tflite_topk_v2_parser.cc | 8 +++++--- 4 files changed, 9 insertions(+), 15 deletions(-) diff --git a/mindspore/lite/schema/model.fbs b/mindspore/lite/schema/model.fbs index 0c0ad360b6..d4b0bd1552 100644 --- a/mindspore/lite/schema/model.fbs +++ b/mindspore/lite/schema/model.fbs @@ -190,7 +190,6 @@ union PrimitiveType { ActivationGrad, PriorBox, SpaceToBatchND, - TopKV2, Return, MakeTuple, ToFormat, diff --git a/mindspore/lite/schema/ops.fbs b/mindspore/lite/schema/ops.fbs index 2e93983c16..b6458e7873 100644 --- a/mindspore/lite/schema/ops.fbs +++ b/mindspore/lite/schema/ops.fbs @@ -872,12 +872,6 @@ table SpaceToBatchND { paddings : [int]; } -table TopKV2 { - k : [int]; - sorted : bool = true; -} - - table MakeTuple { } diff --git a/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_topk_v2_parser_test.cc b/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_topk_v2_parser_test.cc index a2623e8390..569f4112f8 100644 --- a/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_topk_v2_parser_test.cc +++ b/mindspore/lite/test/ut/tools/converter/parser/tflite/tflite_topk_v2_parser_test.cc @@ -31,14 +31,13 @@ TEST_F(TestTfliteParserTopKV2, OpType) { ASSERT_NE(meta_graph, nullptr); ASSERT_GT(meta_graph->nodes.size(), 0); ASSERT_NE(meta_graph->nodes.front()->primitive.get(), nullptr); - ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_TopKV2) << "wrong Op Type"; + ASSERT_EQ(meta_graph->nodes.front()->primitive->value.type, schema::PrimitiveType_TopK) << "wrong Op Type"; } TEST_F(TestTfliteParserTopKV2, AttrValue) { - ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsTopKV2(), nullptr); - auto val = meta_graph->nodes.front()->primitive->value.AsTopKV2(); - std::vector k = {3}; - ASSERT_EQ(val->k, k); + ASSERT_NE(meta_graph->nodes.front()->primitive->value.AsTopK(), nullptr); + auto val = meta_graph->nodes.front()->primitive->value.AsTopK(); + ASSERT_EQ(val->k, 3); ASSERT_EQ(val->sorted, true); } } // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_topk_v2_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_topk_v2_parser.cc index a89a641d13..7c8fbb0b1e 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_topk_v2_parser.cc +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_topk_v2_parser.cc @@ -41,15 +41,17 @@ STATUS TfliteTopKV2Parser::Parse(const std::unique_ptr &tflit return RET_NULL_PTR; } - std::unique_ptr attr(new schema::TopKV2T()); + std::unique_ptr attr(new schema::TopKT()); attr->sorted = true; - if (GetTfliteData(tflite_op->inputs[1], tflite_tensors, tflite_model_buffer, attr->k)) { + std::vector k; + if (GetTfliteData(tflite_op->inputs[1], tflite_tensors, tflite_model_buffer, k)) { MS_LOG(ERROR) << "get topKV2 -> k failed"; return RET_ERROR; } + attr->k = k.front(); - op->primitive->value.type = schema::PrimitiveType_TopKV2; + op->primitive->value.type = schema::PrimitiveType_TopK; op->primitive->value.value = attr.release(); AddOpInput(op, tensors_id, tensors_format, tensors_id_map,