|
|
|
@@ -40,6 +40,34 @@ TEST_F(TestArgMinMaxTestFp32, ArgMaxTest1) { |
|
|
|
param.data_type_ = 43; |
|
|
|
param.dims_size_ = 2; |
|
|
|
param.get_max_ = true; |
|
|
|
param.keep_dims_ = false; |
|
|
|
ArgMinMax(in.data(), out, shape.data(), ¶m); |
|
|
|
for (size_t i = 0; i < except_out.size(); ++i) { |
|
|
|
std::cout << out[i] << " "; |
|
|
|
} |
|
|
|
std::cout << "\n"; |
|
|
|
CompareOutputData(out, except_out.data(), except_out.size(), 0.000001); |
|
|
|
} |
|
|
|
|
|
|
|
TEST_F(TestArgMinMaxTestFp32, ArgMaxTest1_keep_dim) { |
|
|
|
std::vector<float> in = {10, 20, 30, 40, 90, |
|
|
|
20, 11, 15, 1, 50, |
|
|
|
30, 45, 25, 50, 30}; |
|
|
|
std::vector<float> except_out = {2, 2, 0, 2, 0}; |
|
|
|
std::vector<int> shape = {3, 5}; |
|
|
|
float out[5]; |
|
|
|
ArgMinMaxParameter param; |
|
|
|
param.topk_ = 1; |
|
|
|
param.out_value_ = false; |
|
|
|
param.axis_ = 0; |
|
|
|
param.data_type_ = 43; |
|
|
|
param.dims_size_ = 2; |
|
|
|
param.get_max_ = true; |
|
|
|
param.keep_dims_ = true; |
|
|
|
param.arg_elements_ = reinterpret_cast<ArgElement *>(malloc(shape[param.axis_] * sizeof(ArgElement))); |
|
|
|
std::vector<int> out_shape = {1, 5}; |
|
|
|
ComputeStrides(shape.data(), param.in_strides_, shape.size()); |
|
|
|
ComputeStrides(out_shape.data(), param.out_strides_, out_shape.size()); |
|
|
|
ArgMinMax(in.data(), out, shape.data(), ¶m); |
|
|
|
for (size_t i = 0; i < except_out.size(); ++i) { |
|
|
|
std::cout << out[i] << " "; |
|
|
|
@@ -62,6 +90,7 @@ TEST_F(TestArgMinMaxTestFp32, ArgMaxTest2) { |
|
|
|
param.data_type_ = 43; |
|
|
|
param.dims_size_ = 2; |
|
|
|
param.get_max_ = true; |
|
|
|
param.keep_dims_ = false; |
|
|
|
ArgMinMax(in.data(), out, shape.data(), ¶m); |
|
|
|
CompareOutputData(out, except_out.data(), except_out.size(), 0.000001); |
|
|
|
} |
|
|
|
@@ -80,6 +109,7 @@ TEST_F(TestArgMinMaxTestFp32, ArgMinTest2) { |
|
|
|
param.data_type_ = 43; |
|
|
|
param.dims_size_ = 2; |
|
|
|
param.get_max_ = false; |
|
|
|
param.keep_dims_ = false; |
|
|
|
ArgMinMax(in.data(), out, shape.data(), ¶m); |
|
|
|
CompareOutputData(out, except_out.data(), except_out.size(), 0.000001); |
|
|
|
} |
|
|
|
|