| @@ -56,7 +56,7 @@ struct ArgmxxOp { | |||||
| ArgmxxOp(stype_ *src, dt_int32 *dst, uint32_t A, uint32_t B, uint32_t C): | ArgmxxOp(stype_ *src, dt_int32 *dst, uint32_t A, uint32_t B, uint32_t C): | ||||
| src(src), dst(dst), A(A), B(B), C(C), | src(src), dst(dst), A(A), B(B), C(C), | ||||
| INIT(wtype(is_max ? DTypeTrait<stype_>::min() : | INIT(wtype(is_max ? DTypeTrait<stype_>::min() : | ||||
| DTypeTrait<stype_>::max(), -1)) | |||||
| DTypeTrait<stype_>::max(), 0)) | |||||
| { | { | ||||
| } | } | ||||
| MEGDNN_HOST MEGDNN_DEVICE wtype read(uint32_t idx) | MEGDNN_HOST MEGDNN_DEVICE wtype read(uint32_t idx) | ||||
| @@ -45,7 +45,7 @@ void exec_forward(_megdnn_tensor_in src, | |||||
| reduce::get_ABC(src.layout, A, B, C, param.axis); | reduce::get_ABC(src.layout, A, B, C, param.axis); | ||||
| for (size_t a = 0; a < A; ++a) for (size_t c = 0; c < C; ++c) { | for (size_t a = 0; a < A; ++a) for (size_t c = 0; c < C; ++c) { | ||||
| float best_val = traits<is_max>::init; | float best_val = traits<is_max>::init; | ||||
| size_t best_arg = -1; | |||||
| size_t best_arg = 0; | |||||
| for (size_t b = 0; b < B; ++b) { | for (size_t b = 0; b < B; ++b) { | ||||
| float curr_val = float(src.ptr<T>()[(a*B+b)*C+c]); | float curr_val = float(src.ptr<T>()[(a*B+b)*C+c]); | ||||
| if (traits<is_max>::better_than(curr_val, best_val)) { | if (traits<is_max>::better_than(curr_val, best_val)) { | ||||
| @@ -527,3 +527,20 @@ def test_nms_is_same(): | |||||
| assert op3 != op4 | assert op3 != op4 | ||||
| def test_argmxx_on_inf(): | |||||
| def run_argmax(): | |||||
| x = F.zeros((100, 100)) | |||||
| x[:] = -float("inf") | |||||
| idxs = F.argmax(x, axis=0) | |||||
| return idxs | |||||
| def run_argmin(): | |||||
| x = F.zeros((100, 100)) | |||||
| x[:] = float("inf") | |||||
| idxs = F.argmin(x, axis=0) | |||||
| return idxs | |||||
| assert all(run_argmax() >= 0) | |||||
| assert all(run_argmin() >= 0) | |||||