GitOrigin-RevId: 33d5a62478
tags/v1.5.0
| @@ -124,4 +124,5 @@ __ai void load_helper_x(T& weight, T2 ptr, int oc_offset, XT... args) { | |||||
| } // namespace | } // namespace | ||||
| } // namespace megdnn | } // namespace megdnn | ||||
| #undef __ai | #undef __ai | ||||
| // vim: syntax=cpp.doxygen | |||||
| // vim: syntax=cpp.doxygen | |||||
| @@ -30,10 +30,12 @@ bool PoolingImpl::AlgoFp32ModexStridexNCHW44::usable( | |||||
| bool avaible = param.src_type.enumv() == DTypeEnum::Float32 && | bool avaible = param.src_type.enumv() == DTypeEnum::Float32 && | ||||
| param.format == Param::Format::NCHW44 && | param.format == Param::Format::NCHW44 && | ||||
| (param.mode == Mode::MAX || param.mode == Mode::AVERAGE) && | (param.mode == Mode::MAX || param.mode == Mode::AVERAGE) && | ||||
| fh == fw && sh == sw && | |||||
| (fh == 2 || fh == 3 || fh == 4 || fh == 5) && | |||||
| (sh == 1 || sh == 2); | |||||
| return avaible; | |||||
| fh == fw && sh == sw; | |||||
| bool size_ok = ((fh == 2 || fh == 3 || fh == 4 || fh == 5) && | |||||
| (sh == 1 || sh == 2)); | |||||
| size_ok |= ((fh == 9 || fh == 13) && (sh == 1)); | |||||
| return avaible && size_ok; | |||||
| } | } | ||||
| void PoolingImpl::AlgoFp32ModexStridexNCHW44::exec( | void PoolingImpl::AlgoFp32ModexStridexNCHW44::exec( | ||||
| @@ -94,6 +96,15 @@ void PoolingImpl::AlgoFp32ModexStridexNCHW44::exec( | |||||
| megdnn_assert(0, "invalid stride %d", sh); \ | megdnn_assert(0, "invalid stride %d", sh); \ | ||||
| } | } | ||||
| #define DISPATCH_STRIDE_1(filter) \ | |||||
| switch (sh) { \ | |||||
| case 1: \ | |||||
| DISPATCH_MODE(filter, 1); \ | |||||
| break; \ | |||||
| default: \ | |||||
| megdnn_assert(0, "invalid stride %d", sh); \ | |||||
| } | |||||
| #define DISPATCH_FILTER() \ | #define DISPATCH_FILTER() \ | ||||
| switch (fh) { \ | switch (fh) { \ | ||||
| case 2: \ | case 2: \ | ||||
| @@ -108,6 +119,12 @@ void PoolingImpl::AlgoFp32ModexStridexNCHW44::exec( | |||||
| case 5: \ | case 5: \ | ||||
| DISPATCH_STRIDE(5); \ | DISPATCH_STRIDE(5); \ | ||||
| break; \ | break; \ | ||||
| case 9: \ | |||||
| DISPATCH_STRIDE_1(9); \ | |||||
| break; \ | |||||
| case 13: \ | |||||
| DISPATCH_STRIDE_1(13); \ | |||||
| break; \ | |||||
| default: \ | default: \ | ||||
| megdnn_assert(0, "invalid filter %d", fh); \ | megdnn_assert(0, "invalid filter %d", fh); \ | ||||
| } | } | ||||
| @@ -123,4 +140,4 @@ void PoolingImpl::AlgoFp32ModexStridexNCHW44::exec( | |||||
| } // namespace arm_common | } // namespace arm_common | ||||
| } // namespace megdnn | } // namespace megdnn | ||||
| // vim: syntax=cpp.doxygen | |||||
| // vim: syntax=cpp.doxygen | |||||
| @@ -64,6 +64,8 @@ INSTANCE_CAL(2) | |||||
| INSTANCE_CAL(3) | INSTANCE_CAL(3) | ||||
| INSTANCE_CAL(4) | INSTANCE_CAL(4) | ||||
| INSTANCE_CAL(5) | INSTANCE_CAL(5) | ||||
| INSTANCE_CAL(9) | |||||
| INSTANCE_CAL(13) | |||||
| #undef INSTANCE_CAL | #undef INSTANCE_CAL | ||||
| #undef CALCULATE_AVG_CB | #undef CALCULATE_AVG_CB | ||||
| @@ -305,4 +307,4 @@ static inline void pooling_fp32_nchw44(const float32_t* src, float32_t* dst, | |||||
| } // namespace arm_common | } // namespace arm_common | ||||
| } // namespace megdnn | } // namespace megdnn | ||||
| // vim: syntax=cpp.doxygen | |||||
| // vim: syntax=cpp.doxygen | |||||
| @@ -116,6 +116,31 @@ TEST_F(ARM_COMMON_MULTI_THREADS, POOLING_NCHW44_FP32) { | |||||
| } | } | ||||
| } | } | ||||
| TEST_F(ARM_COMMON_MULTI_THREADS, POOLING_W9_w13_NCHW44) | |||||
| { | |||||
| UniformIntRNG rng{-10, 10}; | |||||
| Checker<Pooling> checker(handle()); | |||||
| checker.set_rng(0, &rng); | |||||
| // clang-format off | |||||
| for (size_t ih: {20, 15}) | |||||
| for (size_t iw: {15, 20}) | |||||
| for (size_t kernel: {9, 13}) | |||||
| for (size_t pad: {4, 6}) | |||||
| for(auto mode: {param::Pooling::Mode::MAX, param::Pooling::Mode::AVERAGE}) | |||||
| if (kernel > pad) | |||||
| { | |||||
| param::Pooling param; | |||||
| param.mode = mode; | |||||
| param.format = param::Pooling::Format::NCHW44; | |||||
| param.pad_h = pad; | |||||
| param.pad_w = pad; | |||||
| param.stride_h = param.stride_w = 1; | |||||
| param.window_h = param.window_w = kernel ; | |||||
| checker.set_param(param).exec(TensorShapeArray{{2, 8, ih, iw, 4}, {}}); | |||||
| } | |||||
| // clang-format on | |||||
| } | |||||
| TEST_F(ARM_COMMON_MULTI_THREADS, POOLING_W3x3_NCHW44) | TEST_F(ARM_COMMON_MULTI_THREADS, POOLING_W3x3_NCHW44) | ||||
| { | { | ||||
| UniformIntRNG rng{INT8_MIN >> 1, INT8_MAX >> 1}; | UniformIntRNG rng{INT8_MIN >> 1, INT8_MAX >> 1}; | ||||
| @@ -2,7 +2,7 @@ | |||||
| set -e | set -e | ||||
| ARCHS=("arm64-v8a" "armeabi-v7a") | ARCHS=("arm64-v8a" "armeabi-v7a") | ||||
| BUILD_TYPE=Release | |||||
| BUILD_TYPE=RelWithDebInfo | |||||
| MGE_ARMV8_2_FEATURE_FP16=OFF | MGE_ARMV8_2_FEATURE_FP16=OFF | ||||
| MGE_DISABLE_FLOAT16=OFF | MGE_DISABLE_FLOAT16=OFF | ||||
| ARCH=arm64-v8a | ARCH=arm64-v8a | ||||