From 5969769b4cc7c126edf2d9ea32c85befab7314fb Mon Sep 17 00:00:00 2001 From: sunsuodong Date: Tue, 4 Aug 2020 19:13:24 +0800 Subject: [PATCH] topk_int8 --- mindspore/lite/src/ops/topk.cc | 12 +-- mindspore/lite/src/populate_parameter.cc | 2 +- .../lite/src/runtime/kernel/arm/fp32/topk.cc | 24 +++--- .../lite/src/runtime/kernel/arm/fp32/topk.h | 11 ++- .../src/runtime/kernel/arm/int8/topk_int8.cc | 76 +++++++++++++++++++ .../src/runtime/kernel/arm/int8/topk_int8.h | 42 ++++++++++ .../kernel/arm/opclib/{ => fp32}/topk.cc | 12 +-- .../kernel/arm/opclib/{ => fp32}/topk.h | 8 +- .../kernel/arm/opclib/int8/topk_int8.cc | 54 +++++++++++++ .../kernel/arm/opclib/int8/topk_int8.h | 30 ++++++++ .../kernel/arm/fp32/topk_fp32_tests.cc | 65 ++++++++++++++++ .../kernel/arm/int8/topk_int8_tests.cc | 65 ++++++++++++++++ 12 files changed, 367 insertions(+), 34 deletions(-) create mode 100644 mindspore/lite/src/runtime/kernel/arm/int8/topk_int8.cc create mode 100644 mindspore/lite/src/runtime/kernel/arm/int8/topk_int8.h rename mindspore/lite/src/runtime/kernel/arm/opclib/{ => fp32}/topk.cc (78%) rename mindspore/lite/src/runtime/kernel/arm/opclib/{ => fp32}/topk.h (86%) create mode 100644 mindspore/lite/src/runtime/kernel/arm/opclib/int8/topk_int8.cc create mode 100644 mindspore/lite/src/runtime/kernel/arm/opclib/int8/topk_int8.h create mode 100644 mindspore/lite/test/ut/src/runtime/kernel/arm/fp32/topk_fp32_tests.cc create mode 100644 mindspore/lite/test/ut/src/runtime/kernel/arm/int8/topk_int8_tests.cc diff --git a/mindspore/lite/src/ops/topk.cc b/mindspore/lite/src/ops/topk.cc index 39a7456754..e3dbee034e 100644 --- a/mindspore/lite/src/ops/topk.cc +++ b/mindspore/lite/src/ops/topk.cc @@ -35,13 +35,15 @@ int TopK::InferShape(std::vector inputs_, std::vectorprimitive->value_as_TopK(); MS_ASSERT(topk_prim != nullptr); - output0->set_shape(input->shape()); + auto out_shape = input->shape(); + out_shape[out_shape.size() - 1] = topk_prim->k(); + + output0->set_shape(out_shape); output0->set_data_type(input->data_type()); - // output0->shape().back() = topk_prim->k(); + output0->SetFormat(input->GetFormat()); - output1->set_shape(input->shape()); - output1->set_data_type(input->data_type()); - // output1->shape().back() = topk_prim->k(); + output1->set_shape(out_shape); + output1->set_data_type(kNumberTypeInt32); output1->SetFormat(input->GetFormat()); return RET_OK; diff --git a/mindspore/lite/src/populate_parameter.cc b/mindspore/lite/src/populate_parameter.cc index 373e253a9a..f80538f7a5 100644 --- a/mindspore/lite/src/populate_parameter.cc +++ b/mindspore/lite/src/populate_parameter.cc @@ -34,7 +34,7 @@ #include "src/runtime/kernel/arm/opclib/matmul.h" #include "src/runtime/kernel/arm/opclib/fp32/softmax.h" #include "src/runtime/kernel/arm/opclib/tile.h" -#include "src/runtime/kernel/arm/opclib/topk.h" +#include "src/runtime/kernel/arm/opclib/fp32/topk.h" #include "src/runtime/kernel/arm/opclib/fp32/reduce.h" #include "src/runtime/kernel/arm/opclib/fp32/activation.h" #include "src/runtime/kernel/arm/opclib/fp32/arithmetic.h" diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/topk.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/topk.cc index 3f639ca23d..954ec041bc 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/topk.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/topk.cc @@ -25,11 +25,18 @@ using mindspore::schema::PrimitiveType_TopK; namespace mindspore::kernel { int TopKCPUKernel::Init() { + TopkParameter *parameter = reinterpret_cast(opParameter); lite::tensor::Tensor *input = inputs_.at(0); - topk_parameter_->last_dim_size_ = input->shape()[input->shape().size() - 1]; - topk_parameter_->loop_num_ = 1; + parameter->last_dim_size_ = input->shape()[input->shape().size() - 1]; + parameter->loop_num_ = 1; for (int i = 0; i < input->shape().size() - 1; ++i) { - topk_parameter_->loop_num_ *= input->shape()[i]; + parameter->loop_num_ *= input->shape()[i]; + } + + parameter->topk_node_list_ = malloc(sizeof(TopkNode) * parameter->last_dim_size_); + if (parameter->topk_node_list_ == nullptr) { + MS_LOG(ERROR) << "malloc fail."; + return RET_ERROR; } return RET_OK; } @@ -39,14 +46,9 @@ int TopKCPUKernel::ReSize() { return RET_OK; } int TopKCPUKernel::Run() { auto input_data = reinterpret_cast(inputs_.at(0)->Data()); auto output_data = reinterpret_cast(outputs_.at(0)->Data()); - auto output_index = reinterpret_cast(outputs_.at(1)->Data()); + auto output_index = reinterpret_cast(outputs_.at(1)->Data()); - Node *top_map = reinterpret_cast(malloc(sizeof(Node) * topk_parameter_->last_dim_size_)); - MS_EXCEPTION_IF_NULL(top_map); - topk_parameter_->topk_node_list_ = top_map; - Topk(input_data, output_data, output_index, topk_parameter_); - free(top_map); - topk_parameter_->topk_node_list_ = nullptr; + Topk(input_data, output_data, output_index, reinterpret_cast(opParameter)); return RET_OK; } @@ -54,7 +56,6 @@ kernel::LiteKernel *CpuTopKFp32KernelCreator(const std::vector &outputs, OpParameter *parameter, const lite::Context *ctx, const KernelKey &desc) { MS_ASSERT(parameter != nullptr); - MS_ASSERT(desc.type == PrimitiveType_Tile); auto *kernel = new (std::nothrow) TopKCPUKernel(parameter, inputs, outputs); if (kernel == nullptr) { MS_LOG(ERROR) << "new TopKCPUKernel fail!"; @@ -73,4 +74,3 @@ kernel::LiteKernel *CpuTopKFp32KernelCreator(const std::vector #include "src/lite_kernel.h" -#include "src/runtime/kernel/arm/opclib/topk.h" +#include "src/runtime/kernel/arm/opclib/fp32/topk.h" namespace mindspore::kernel { class TopKCPUKernel : public LiteKernel { public: explicit TopKCPUKernel(OpParameter *parameter, const std::vector &inputs, const std::vector &outputs) - : LiteKernel(parameter, inputs, outputs) { - topk_parameter_ = reinterpret_cast(parameter); + : LiteKernel(parameter, inputs, outputs) {} + ~TopKCPUKernel() override { + TopkParameter *parameter = reinterpret_cast(opParameter); + free(parameter->topk_node_list_); } - ~TopKCPUKernel() override {} int Init() override; int ReSize() override; int Run() override; private: - TopkParameter *topk_parameter_; }; } // namespace mindspore::kernel #endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_TOPK_H_ - diff --git a/mindspore/lite/src/runtime/kernel/arm/int8/topk_int8.cc b/mindspore/lite/src/runtime/kernel/arm/int8/topk_int8.cc new file mode 100644 index 0000000000..2280d9e078 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/int8/topk_int8.cc @@ -0,0 +1,76 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/runtime/kernel/arm/int8/topk_int8.h" +#include "src/kernel_registry.h" +#include "include/errorcode.h" + +using mindspore::lite::KernelRegistrar; +using mindspore::lite::RET_ERROR; +using mindspore::lite::RET_OK; +using mindspore::schema::PrimitiveType_TopK; + +namespace mindspore::kernel { +int TopKInt8CPUKernel::Init() { + TopkParameter *parameter = reinterpret_cast(opParameter); + lite::tensor::Tensor *input = inputs_.at(0); + parameter->last_dim_size_ = input->shape()[input->shape().size() - 1]; + parameter->loop_num_ = 1; + for (int i = 0; i < input->shape().size() - 1; ++i) { + parameter->loop_num_ *= input->shape()[i]; + } + + parameter->topk_node_list_ = malloc(sizeof(TopkNodeInt8) * parameter->last_dim_size_); + if (parameter->topk_node_list_ == nullptr) { + MS_LOG(ERROR) << "malloc fail."; + return RET_ERROR; + } + return RET_OK; +} + +int TopKInt8CPUKernel::ReSize() { return RET_OK; } + +int TopKInt8CPUKernel::Run() { + int8_t *input_data = reinterpret_cast(inputs_.at(0)->Data()); + int8_t *output_data = reinterpret_cast(outputs_.at(0)->Data()); + int32_t *output_index = reinterpret_cast(outputs_.at(1)->Data()); + + TopkInt8(input_data, output_data, output_index, reinterpret_cast(opParameter)); + return RET_OK; +} + +kernel::LiteKernel *CpuTopKInt8KernelCreator(const std::vector &inputs, + const std::vector &outputs, OpParameter *parameter, + const lite::Context *ctx, const KernelKey &desc) { + MS_ASSERT(parameter != nullptr); + auto *kernel = new (std::nothrow) TopKInt8CPUKernel(parameter, inputs, outputs); + if (kernel == nullptr) { + MS_LOG(ERROR) << "new TopKInt8CPUKernel fail!"; + return nullptr; + } + + auto ret = kernel->Init(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "Init kernel failed, name: " << parameter->name_ + << ", type: " << schema::EnumNamePrimitiveType(static_cast(parameter->type_)); + delete kernel; + return nullptr; + } + return kernel; +} + +REG_KERNEL(kCPU, kNumberTypeInt8, PrimitiveType_TopK, CpuTopKInt8KernelCreator) +} // namespace mindspore::kernel diff --git a/mindspore/lite/src/runtime/kernel/arm/int8/topk_int8.h b/mindspore/lite/src/runtime/kernel/arm/int8/topk_int8.h new file mode 100644 index 0000000000..ab762c3c9d --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/int8/topk_int8.h @@ -0,0 +1,42 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_INT8_TOPK_INT8_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_INT8_TOPK_INT8_H_ + +#include +#include "src/lite_kernel.h" +#include "src/runtime/kernel/arm/opclib/int8/topk_int8.h" + +namespace mindspore::kernel { +class TopKInt8CPUKernel : public LiteKernel { + public: + explicit TopKInt8CPUKernel(OpParameter *parameter, const std::vector &inputs, + const std::vector &outputs) + : LiteKernel(parameter, inputs, outputs) {} + ~TopKInt8CPUKernel() override { + TopkParameter *parameter = reinterpret_cast(opParameter); + free(parameter->topk_node_list_); + } + + int Init() override; + int ReSize() override; + int Run() override; + + private: +}; +} // namespace mindspore::kernel + +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_INT8_TOPK_INT8_H_ diff --git a/mindspore/lite/src/runtime/kernel/arm/opclib/topk.cc b/mindspore/lite/src/runtime/kernel/arm/opclib/fp32/topk.cc similarity index 78% rename from mindspore/lite/src/runtime/kernel/arm/opclib/topk.cc rename to mindspore/lite/src/runtime/kernel/arm/opclib/fp32/topk.cc index 30da41e3ff..72e77f4fad 100644 --- a/mindspore/lite/src/runtime/kernel/arm/opclib/topk.cc +++ b/mindspore/lite/src/runtime/kernel/arm/opclib/fp32/topk.cc @@ -14,25 +14,25 @@ * limitations under the License. */ -#include "src/runtime/kernel/arm/opclib/topk.h" +#include "src/runtime/kernel/arm/opclib/fp32/topk.h" int DescendCmp(const void *a, const void *b) { - return ((const Node *)b)->element - ((const Node *)a)->element; + return ((const TopkNode *)b)->element - ((const TopkNode *)a)->element; } int AscendCmp(const void *a, const void *b) { - return ((const Node *)a)->element - ((const Node *)b)->element; + return ((const TopkNode *)a)->element - ((const TopkNode *)b)->element; } -void Topk(float *input_data, float *output_data, float *output_index, TopkParameter *parameter) { +void Topk(float *input_data, float *output_data, int32_t *output_index, TopkParameter *parameter) { int last_dim_size = parameter->last_dim_size_; int loop_num = parameter->loop_num_; int k = parameter->k_; - Node *top_map = parameter->topk_node_list_; + TopkNode *top_map = (TopkNode *)parameter->topk_node_list_; float *cur_input_data = input_data; float *cur_output_data = output_data; - float *cur_output_index = output_index; + int32_t *cur_output_index = output_index; for (int i = 0; i < loop_num; i++) { for (int j = 0; j < last_dim_size; j++) { top_map[j].element = *(cur_input_data + j); diff --git a/mindspore/lite/src/runtime/kernel/arm/opclib/topk.h b/mindspore/lite/src/runtime/kernel/arm/opclib/fp32/topk.h similarity index 86% rename from mindspore/lite/src/runtime/kernel/arm/opclib/topk.h rename to mindspore/lite/src/runtime/kernel/arm/opclib/fp32/topk.h index 3a038aa592..7d3ca97dda 100644 --- a/mindspore/lite/src/runtime/kernel/arm/opclib/topk.h +++ b/mindspore/lite/src/runtime/kernel/arm/opclib/fp32/topk.h @@ -19,9 +19,9 @@ #include "src/runtime/kernel/arm/opclib/op_base.h" -struct Node { +struct TopkNode { float element; - float index; + int32_t index; }; struct TopkParameter { @@ -30,10 +30,10 @@ struct TopkParameter { int loop_num_; int k_; bool sorted_; - Node *topk_node_list_; + void *topk_node_list_; }; -void Topk(float *input_data, float *output_data, float *output_index, TopkParameter *parameter); +void Topk(float *input_data, float *output_data, int32_t *output_index, TopkParameter *parameter); #endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_OPCLIB_TOPK_H_ diff --git a/mindspore/lite/src/runtime/kernel/arm/opclib/int8/topk_int8.cc b/mindspore/lite/src/runtime/kernel/arm/opclib/int8/topk_int8.cc new file mode 100644 index 0000000000..8394d2bf98 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/opclib/int8/topk_int8.cc @@ -0,0 +1,54 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/runtime/kernel/arm/opclib/int8/topk_int8.h" + +int DescendCmpInt8(const void *a, const void *b) { + return ((const TopkNodeInt8 *)b)->element - ((const TopkNodeInt8 *)a)->element; +} + +int AscendCmpInt8(const void *a, const void *b) { + return ((const TopkNodeInt8 *)a)->element - ((const TopkNodeInt8 *)b)->element; +} + +void TopkInt8(int8_t *input_data, int8_t *output_data, int32_t *output_index, TopkParameter *parameter) { + int last_dim_size = parameter->last_dim_size_; + int loop_num = parameter->loop_num_; + int k = parameter->k_; + TopkNodeInt8 *top_map = (TopkNodeInt8 *)parameter->topk_node_list_; + + int8_t *cur_input_data = input_data; + int8_t *cur_output_data = output_data; + int32_t *cur_output_index = output_index; + for (int i = 0; i < loop_num; i++) { + for (int j = 0; j < last_dim_size; j++) { + top_map[j].element = *(cur_input_data + j); + top_map[j].index = j; + } + if (parameter->sorted_) { + qsort(top_map, last_dim_size, sizeof(top_map[0]), DescendCmpInt8); + } else { + qsort(top_map, last_dim_size, sizeof(top_map[0]), AscendCmpInt8); + } + for (int m = 0; m < k; m++) { + cur_output_data[m] = top_map[m].element; + cur_output_index[m] = top_map[m].index; + } + cur_input_data += last_dim_size; + cur_output_data += k; + cur_output_index += k; + } +} diff --git a/mindspore/lite/src/runtime/kernel/arm/opclib/int8/topk_int8.h b/mindspore/lite/src/runtime/kernel/arm/opclib/int8/topk_int8.h new file mode 100644 index 0000000000..3a33697461 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/opclib/int8/topk_int8.h @@ -0,0 +1,30 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_OPCLIB_INT8_TOPK_INT8_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_OPCLIB_INT8_TOPK_INT8_H_ + +#include "src/runtime/kernel/arm/opclib/op_base.h" +#include "src/runtime/kernel/arm/opclib/fp32/topk.h" + +struct TopkNodeInt8 { + int8_t element; + int32_t index; +}; + +void TopkInt8(int8_t *input_data, int8_t *output_data, int32_t *output_index, TopkParameter *parameter); + +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_OPCLIB_INT8_TOPK_INT8_H_ diff --git a/mindspore/lite/test/ut/src/runtime/kernel/arm/fp32/topk_fp32_tests.cc b/mindspore/lite/test/ut/src/runtime/kernel/arm/fp32/topk_fp32_tests.cc new file mode 100644 index 0000000000..aaeda8026d --- /dev/null +++ b/mindspore/lite/test/ut/src/runtime/kernel/arm/fp32/topk_fp32_tests.cc @@ -0,0 +1,65 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include "common/common_test.h" +#include "mindspore/lite/src/runtime/kernel/arm/opclib/fp32/topk.h" +#include "mindspore/lite/src/kernel_registry.h" + +namespace mindspore { +class TestTopKFp32 : public mindspore::Common { + public: + TestTopKFp32() {} +}; + +TEST_F(TestTopKFp32, TopK) { + lite::tensor::Tensor in_tensor(kNumberTypeFloat32, {2, 2, 3}); + lite::tensor::Tensor out_tensor0(kNumberTypeFloat32, {2, 2, 2}); + lite::tensor::Tensor out_tensor1(kNumberTypeInt32, {2, 2, 2}); + float input_data[] = {1, 2, 3, 6, 5, 4, 9, 8, 7, 10, 12, 11}; + float output_data0[8] = {0}; + int32_t output_data1[8] = {0}; + in_tensor.SetData(input_data); + out_tensor0.SetData(output_data0); + out_tensor1.SetData(output_data1); + std::vector inputs = {&in_tensor}; + std::vector outputs = {&out_tensor0, &out_tensor1}; + + TopkParameter parameter = {{}, 3, 4, 2, true}; + kernel::KernelKey desc = {kernel::KERNEL_ARCH::kCPU, kNumberTypeFloat32, schema::PrimitiveType_TopK}; + + auto creator = lite::KernelRegistry::GetInstance()->GetCreator(desc); + ASSERT_NE(creator, nullptr); + + auto kernel = creator(inputs, outputs, reinterpret_cast(¶meter), nullptr, desc); + ASSERT_NE(kernel, nullptr); + + auto ret = kernel->Run(); + EXPECT_EQ(0, ret); + + float expect0[] = {3, 2, 6, 5, 9, 8, 12, 11}; + int32_t expect1[] = {2, 1, 0, 1, 0, 1, 1, 2}; + for (int i = 0; i < 8; ++i) { + EXPECT_EQ(output_data0[i], expect0[i]); + EXPECT_EQ(output_data1[i], expect1[i]); + } + + in_tensor.SetData(nullptr); + out_tensor0.SetData(nullptr); + out_tensor1.SetData(nullptr); +} +} // namespace mindspore diff --git a/mindspore/lite/test/ut/src/runtime/kernel/arm/int8/topk_int8_tests.cc b/mindspore/lite/test/ut/src/runtime/kernel/arm/int8/topk_int8_tests.cc new file mode 100644 index 0000000000..54ec2f738c --- /dev/null +++ b/mindspore/lite/test/ut/src/runtime/kernel/arm/int8/topk_int8_tests.cc @@ -0,0 +1,65 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include "common/common_test.h" +#include "mindspore/lite/src/runtime/kernel/arm/opclib/fp32/topk.h" +#include "mindspore/lite/src/kernel_registry.h" + +namespace mindspore { +class TestTopKInt8 : public mindspore::Common { + public: + TestTopKInt8() {} +}; + +TEST_F(TestTopKInt8, TopK) { + lite::tensor::Tensor in_tensor(kNumberTypeInt8, {2, 2, 3}); + lite::tensor::Tensor out_tensor0(kNumberTypeInt8, {2, 2, 2}); + lite::tensor::Tensor out_tensor1(kNumberTypeInt32, {2, 2, 2}); + int8_t input_data[] = {1, 2, 3, 6, 5, 4, 9, 8, 7, 10, 12, 11}; + int8_t output_data0[8] = {0}; + int32_t output_data1[8] = {0}; + in_tensor.SetData(input_data); + out_tensor0.SetData(output_data0); + out_tensor1.SetData(output_data1); + std::vector inputs = {&in_tensor}; + std::vector outputs = {&out_tensor0, &out_tensor1}; + + TopkParameter parameter = {{}, 3, 4, 2, true}; + kernel::KernelKey desc = {kernel::KERNEL_ARCH::kCPU, kNumberTypeInt8, schema::PrimitiveType_TopK}; + + auto creator = lite::KernelRegistry::GetInstance()->GetCreator(desc); + ASSERT_NE(creator, nullptr); + + auto kernel = creator(inputs, outputs, reinterpret_cast(¶meter), nullptr, desc); + ASSERT_NE(kernel, nullptr); + + auto ret = kernel->Run(); + EXPECT_EQ(0, ret); + + int8_t expect0[] = {3, 2, 6, 5, 9, 8, 12, 11}; + int32_t expect1[] = {2, 1, 0, 1, 0, 1, 1, 2}; + for (int i = 0; i < 8; ++i) { + EXPECT_EQ(output_data0[i], expect0[i]); + EXPECT_EQ(output_data1[i], expect1[i]); + } + + in_tensor.SetData(nullptr); + out_tensor0.SetData(nullptr); + out_tensor1.SetData(nullptr); +} +} // namespace mindspore