Browse Source

Merge branch 'master' into fix_python_lib

pull/5860/head
佰阅 GitHub 1 year ago
parent
commit
fd2d4d5d46
No known key found for this signature in database GPG Key ID: B5690EEEBB952194
16 changed files with 3669 additions and 1774 deletions
  1. +115
    -86
      CMakeLists.txt
  2. +5
    -1
      examples/CMakeLists.txt
  3. +358
    -136
      examples/yolov8.cpp
  4. +325
    -0
      examples/yolov8_cls.cpp
  5. +522
    -0
      examples/yolov8_obb.cpp
  6. +561
    -0
      examples/yolov8_pose.cpp
  7. +624
    -0
      examples/yolov8_seg.cpp
  8. +365
    -684
      src/layer/reduction.cpp
  9. +58
    -46
      tests/test_copyto_1.cpp
  10. +306
    -295
      tests/test_crop_1.cpp
  11. +47
    -44
      tests/test_expanddims.cpp
  12. +116
    -245
      tests/test_reduction.cpp
  13. +81
    -69
      tests/test_slice.cpp
  14. +57
    -45
      tests/test_slice_oom.cpp
  15. +47
    -44
      tests/test_squeeze.cpp
  16. +82
    -79
      tests/test_tile.cpp

+ 115
- 86
CMakeLists.txt View File

@@ -140,10 +140,11 @@ endif()

include(CheckCXXCompilerFlag)
set(CMAKE_TRY_COMPILE_CONFIGURATION release)
set(CMAKE_TRY_COMPILE_TARGET_TYPE STATIC_LIBRARY)

# gnu inline assembly in clang msvc does not work actually
if(NOT (CMAKE_CXX_COMPILER_ID MATCHES "MSVC" OR (CMAKE_CXX_COMPILER_ID MATCHES "Clang" AND CMAKE_CXX_SIMULATE_ID MATCHES "MSVC" AND CMAKE_CXX_COMPILER_FRONTEND_VARIANT MATCHES "MSVC")))
check_cxx_source_compiles("int main() { int a = 0; asm volatile(\"\" : \"=r\"(a) : \"0\"(a) : \"memory\"); return 0; }" NCNN_COMPILER_SUPPORT_GNU_INLINE_ASM)
check_cxx_source_compiles("int test(int a) { asm volatile(\"\" : \"=r\"(a) : \"0\"(a) : \"memory\"); return a; }" NCNN_COMPILER_SUPPORT_GNU_INLINE_ASM)
if(NCNN_COMPILER_SUPPORT_GNU_INLINE_ASM)
option(NCNN_GNU_INLINE_ASM "optimize platform with gnu style inline assembly" ON)
else()
@@ -163,21 +164,21 @@ if((IOS AND CMAKE_OSX_ARCHITECTURES MATCHES "arm")
endif()

if(CMAKE_SIZEOF_VOID_P EQUAL 4 AND NOT NCNN_TARGET_ILP32)
check_cxx_source_compiles("#include <arm_neon.h>\nint main() { float32x4_t _s, _a, _b; _s = vmlaq_f32(_s, _a, _b); return 0; }" NCNN_COMPILER_SUPPORT_ARM_NEON)
check_cxx_source_compiles("#include <arm_neon.h>\nfloat32x4_t test(float32x4_t s, float32x4_t a, float32x4_t b) { return vmlaq_f32(s, a, b); }" NCNN_COMPILER_SUPPORT_ARM_NEON)

if(NCNN_COMPILER_SUPPORT_ARM_NEON)
if(CMAKE_CXX_COMPILER_ID MATCHES "MSVC" OR (CMAKE_CXX_COMPILER_ID MATCHES "Clang" AND CMAKE_CXX_SIMULATE_ID MATCHES "MSVC" AND CMAKE_CXX_COMPILER_FRONTEND_VARIANT MATCHES "MSVC"))
set(CMAKE_REQUIRED_FLAGS "/arch:VFPv4")
check_cxx_source_compiles("#include <arm_neon.h>\nint main() { float32x4_t _a; float16x4_t _s = vcvt_f16_f32(_a); return 0; }" NCNN_COMPILER_SUPPORT_ARM_VFPV4)
check_cxx_source_compiles("#include <arm_neon.h>\nfloat16x4_t test(float32x4_t a) { return vcvt_f16_f32(a); }" NCNN_COMPILER_SUPPORT_ARM_VFPV4)

unset(CMAKE_REQUIRED_FLAGS)
else()
set(CMAKE_REQUIRED_FLAGS "-mfpu=neon-vfpv4")
check_cxx_source_compiles("#include <arm_neon.h>\nint main() { float32x4_t _a; float16x4_t _s = vcvt_f16_f32(_a); return 0; }" NCNN_COMPILER_SUPPORT_ARM_VFPV4)
check_cxx_source_compiles("#include <arm_neon.h>\nfloat16x4_t test(float32x4_t a) { return vcvt_f16_f32(a); }" NCNN_COMPILER_SUPPORT_ARM_VFPV4)

if(NOT NCNN_COMPILER_SUPPORT_ARM_VFPV4)
set(CMAKE_REQUIRED_FLAGS "-mfpu=neon-vfpv4 -mfp16-format=ieee")
check_cxx_source_compiles("#include <arm_neon.h>\nint main() { float32x4_t _a; float16x4_t _s = vcvt_f16_f32(_a); return 0; }" NCNN_COMPILER_SUPPORT_ARM_VFPV4_FP16)
check_cxx_source_compiles("#include <arm_neon.h>\nfloat16x4_t test(float32x4_t a) { return vcvt_f16_f32(a); }" NCNN_COMPILER_SUPPORT_ARM_VFPV4_FP16)
endif()

unset(CMAKE_REQUIRED_FLAGS)
@@ -194,107 +195,107 @@ if((IOS AND CMAKE_OSX_ARCHITECTURES MATCHES "arm")
if(CMAKE_SIZEOF_VOID_P EQUAL 8 OR NCNN_TARGET_ILP32)
if(CMAKE_CXX_COMPILER_ID MATCHES "MSVC")
set(CMAKE_REQUIRED_FLAGS "/arch:armv8.0")
check_cxx_source_compiles("#include <arm_neon.h>\nint main() { float32x4_t _a; float16x4_t _s = vcvt_f16_f32(_a); return 0; }" NCNN_COMPILER_SUPPORT_ARM_VFPV4)
check_cxx_source_compiles("#include <arm_neon.h>\nfloat16x4_t test(float32x4_t a) { return vcvt_f16_f32(a); }" NCNN_COMPILER_SUPPORT_ARM_VFPV4)

set(CMAKE_REQUIRED_FLAGS "/arch:armv8.2")
check_cxx_source_compiles("#include <arm_neon.h>\nint main() { float16x8_t _s, _a, _b; _s = vfmaq_f16(_s, _a, _b); return 0; }" NCNN_COMPILER_SUPPORT_ARM82_FP16)
check_cxx_source_compiles("#include <arm_neon.h>\nfloat16x8_t test(float16x8_t s, float16x8_t a, float16x8_t b) { return vfmaq_f16(s, a, b); }" NCNN_COMPILER_SUPPORT_ARM82_FP16)

set(CMAKE_REQUIRED_FLAGS "/arch:armv8.2")
check_cxx_source_compiles("#include <arm_neon.h>\nint main() { int32x4_t _s; int8x16_t _a, _b; _s = vdotq_s32(_s, _a, _b); return 0; }" NCNN_COMPILER_SUPPORT_ARM82_DOTPROD)
check_cxx_source_compiles("#include <arm_neon.h>\nint32x4_t test(int32x4_t s, int8x16_t a, int8x16_t b) { return vdotq_s32(s, a, b); }" NCNN_COMPILER_SUPPORT_ARM82_DOTPROD)

set(CMAKE_REQUIRED_FLAGS "/arch:armv8.2")
check_cxx_source_compiles("#include <arm_neon.h>\nint main() { float32x4_t _s; float16x8_t _a, _b; _s = vfmlalq_low_f16(_s, _a, _b); return 0; }" NCNN_COMPILER_SUPPORT_ARM82_FP16FML)
check_cxx_source_compiles("#include <arm_neon.h>\nfloat32x4_t test(float32x4_t s, float16x8_t a, float16x8_t b) { return vfmlalq_low_f16(s, a, b); }" NCNN_COMPILER_SUPPORT_ARM82_FP16FML)

set(CMAKE_REQUIRED_FLAGS "/arch:armv8.4")
check_cxx_source_compiles("#include <arm_neon.h>\nint main() { float32x4_t _s; bfloat16x8_t _a, _b; _s = vcvt_f32_bf16(vcvt_bf16_f32(vbfmmlaq_f32(_s, _a, _b))); return 0; }" NCNN_COMPILER_SUPPORT_ARM84_BF16)
check_cxx_source_compiles("#include <arm_neon.h>\nfloat32x4_t test(float32x4_t s, bfloat16x8_t a, bfloat16x8_t b) { return vcvt_f32_bf16(vcvt_bf16_f32(vbfmmlaq_f32(s, a, b))); }" NCNN_COMPILER_SUPPORT_ARM84_BF16)

set(CMAKE_REQUIRED_FLAGS "/arch:armv8.4")
check_cxx_source_compiles("#include <arm_neon.h>\nint main() { int32x4_t _s; int8x16_t _a, _b; _s = vmmlaq_s32(_s, _a, _b); return 0; }" NCNN_COMPILER_SUPPORT_ARM84_I8MM)
check_cxx_source_compiles("#include <arm_neon.h>\nint32x4_t test(int32x4_t s, int8x16_t a, int8x16_t b) { return vmmlaq_s32(s, a, b); }" NCNN_COMPILER_SUPPORT_ARM84_I8MM)

set(CMAKE_REQUIRED_FLAGS "/arch:armv8.6")
check_cxx_source_compiles("#include <arm_sve.h>\nint main() { svfloat16_t _s, _a, _b; svbool_t bp; _s = svmla_f16_z(bp, _s, _a, _b); return 0; }" NCNN_COMPILER_SUPPORT_ARM86_SVE)
check_cxx_source_compiles("#include <arm_sve.h>\nsvfloat16_t test(svfloat16_t s, svfloat16_t a, svfloat16_t b, svbool_t bp) { return svmla_f16_z(bp, s, a, b); }" NCNN_COMPILER_SUPPORT_ARM86_SVE)

set(CMAKE_REQUIRED_FLAGS "/arch:armv8.6")
check_cxx_source_compiles("#include <arm_sve.h>\nint main() { svint16_t _s; svint8_t _a, _b; _s = svmlslb_s16(_s, _a, _b); return 0; }" NCNN_COMPILER_SUPPORT_ARM86_SVE2)
check_cxx_source_compiles("#include <arm_sve.h>\nsvint16_t test(svint16_t s, svint8_t a, svint8_t b) { return svmlslb_s16(s, a, b); }" NCNN_COMPILER_SUPPORT_ARM86_SVE2)

set(CMAKE_REQUIRED_FLAGS "/arch:armv8.6")
check_cxx_source_compiles("#include <arm_sve.h>\nint main() { svfloat32_t _s; svbfloat16_t _a, _b; _s = svbfmmla_f32(_s, _a, _b); return 0; }" NCNN_COMPILER_SUPPORT_ARM86_SVEBF16)
check_cxx_source_compiles("#include <arm_sve.h>\nsvfloat32_t test(svfloat32_t s, svbfloat16_t a, svbfloat16_t b) { return svbfmmla_f32(s, a, b); }" NCNN_COMPILER_SUPPORT_ARM86_SVEBF16)

set(CMAKE_REQUIRED_FLAGS "/arch:armv8.6")
check_cxx_source_compiles("#include <arm_sve.h>\nint main() { svint32_t _s; svint8_t _a, _b; _s = svmmla_s32(_s, _a, _b); return 0; }" NCNN_COMPILER_SUPPORT_ARM86_SVEI8MM)
check_cxx_source_compiles("#include <arm_sve.h>\nsvint32_t test(svint32_t s, svint8_t a, svint8_t b) { return svmmla_s32(s, a, b); }" NCNN_COMPILER_SUPPORT_ARM86_SVEI8MM)

set(CMAKE_REQUIRED_FLAGS "/arch:armv8.6")
check_cxx_source_compiles("#include <arm_sve.h>\nint main() { svfloat32_t _s, _a, _b; _s = svmmla_f32(_s, _a, _b); return 0; }" NCNN_COMPILER_SUPPORT_ARM86_SVEF32MM)
check_cxx_source_compiles("#include <arm_sve.h>\nsvfloat32_t test(svfloat32_t s, svfloat32_t a, svfloat32_t b) { return svmmla_f32(s, a, b); }" NCNN_COMPILER_SUPPORT_ARM86_SVEF32MM)

unset(CMAKE_REQUIRED_FLAGS)
elseif(CMAKE_CXX_COMPILER_ID MATCHES "Clang" AND CMAKE_CXX_SIMULATE_ID MATCHES "MSVC" AND CMAKE_CXX_COMPILER_FRONTEND_VARIANT MATCHES "MSVC")
set(CMAKE_REQUIRED_FLAGS "/arch:armv8.0")
check_cxx_source_compiles("#include <arm_neon.h>\nint main() { float32x4_t _a; float16x4_t _s = vcvt_f16_f32(_a); return 0; }" NCNN_COMPILER_SUPPORT_ARM_VFPV4)
check_cxx_source_compiles("#include <arm_neon.h>\nfloat16x4_t test(float32x4_t a) { return vcvt_f16_f32(a); }" NCNN_COMPILER_SUPPORT_ARM_VFPV4)

set(CMAKE_REQUIRED_FLAGS "/arch:armv8.2 -march=armv8.2-a+fp16")
check_cxx_source_compiles("#include <arm_neon.h>\nint main() { float16x8_t _s, _a, _b; _s = vfmaq_f16(_s, _a, _b); return 0; }" NCNN_COMPILER_SUPPORT_ARM82_FP16)
check_cxx_source_compiles("#include <arm_neon.h>\nfloat16x8_t test(float16x8_t s, float16x8_t a, float16x8_t b) { return vfmaq_f16(s, a, b); }" NCNN_COMPILER_SUPPORT_ARM82_FP16)

set(CMAKE_REQUIRED_FLAGS "/arch:armv8.2 -march=armv8.2-a+dotprod")
check_cxx_source_compiles("#include <arm_neon.h>\nint main() { int32x4_t _s; int8x16_t _a, _b; _s = vdotq_s32(_s, _a, _b); return 0; }" NCNN_COMPILER_SUPPORT_ARM82_DOTPROD)
check_cxx_source_compiles("#include <arm_neon.h>\nint32x4_t test(int32x4_t s, int8x16_t a, int8x16_t b) { return vdotq_s32(s, a, b); }" NCNN_COMPILER_SUPPORT_ARM82_DOTPROD)

set(CMAKE_REQUIRED_FLAGS "/arch:armv8.2 -march=armv8.2-a+fp16fml")
check_cxx_source_compiles("#include <arm_neon.h>\nint main() { float32x4_t _s; float16x8_t _a, _b; _s = vfmlalq_low_f16(_s, _a, _b); return 0; }" NCNN_COMPILER_SUPPORT_ARM82_FP16FML)
check_cxx_source_compiles("#include <arm_neon.h>\nfloat32x4_t test(float32x4_t s, float16x8_t a, float16x8_t b) { return vfmlalq_low_f16(s, a, b); }" NCNN_COMPILER_SUPPORT_ARM82_FP16FML)

set(CMAKE_REQUIRED_FLAGS "/arch:armv8.4 -march=armv8.4-a+bf16")
check_cxx_source_compiles("#include <arm_neon.h>\nint main() { float32x4_t _s; bfloat16x8_t _a, _b; _s = vcvt_f32_bf16(vcvt_bf16_f32(vbfmmlaq_f32(_s, _a, _b))); return 0; }" NCNN_COMPILER_SUPPORT_ARM84_BF16)
check_cxx_source_compiles("#include <arm_neon.h>\nfloat32x4_t test(float32x4_t s, bfloat16x8_t a, bfloat16x8_t b) { return vcvt_f32_bf16(vcvt_bf16_f32(vbfmmlaq_f32(s, a, b))); }" NCNN_COMPILER_SUPPORT_ARM84_BF16)

set(CMAKE_REQUIRED_FLAGS "/arch:armv8.4 -march=armv8.4-a+i8mm")
check_cxx_source_compiles("#include <arm_neon.h>\nint main() { int32x4_t _s; int8x16_t _a, _b; _s = vmmlaq_s32(_s, _a, _b); return 0; }" NCNN_COMPILER_SUPPORT_ARM84_I8MM)
check_cxx_source_compiles("#include <arm_neon.h>\nint32x4_t test(int32x4_t s, int8x16_t a, int8x16_t b) { return vmmlaq_s32(s, a, b); }" NCNN_COMPILER_SUPPORT_ARM84_I8MM)

set(CMAKE_REQUIRED_FLAGS "/arch:armv8.6 -march=armv8.6-a+sve")
check_cxx_source_compiles("#include <arm_sve.h>\nint main() { svfloat16_t _s, _a, _b; svbool_t bp; _s = svmla_f16_z(bp, _s, _a, _b); return 0; }" NCNN_COMPILER_SUPPORT_ARM86_SVE)
check_cxx_source_compiles("#include <arm_sve.h>\nsvfloat16_t test(svfloat16_t s, svfloat16_t a, svfloat16_t b, svbool_t bp) { return svmla_f16_z(bp, s, a, b); }" NCNN_COMPILER_SUPPORT_ARM86_SVE)

set(CMAKE_REQUIRED_FLAGS "/arch:armv8.6 -march=armv8.6-a+sve2")
check_cxx_source_compiles("#include <arm_sve.h>\nint main() { svint16_t _s; svint8_t _a, _b; _s = svmlslb_s16(_s, _a, _b); return 0; }" NCNN_COMPILER_SUPPORT_ARM86_SVE2)
check_cxx_source_compiles("#include <arm_sve.h>\nsvint16_t test(svint16_t s, svint8_t a, svint8_t b) { return svmlslb_s16(s, a, b); }" NCNN_COMPILER_SUPPORT_ARM86_SVE2)

set(CMAKE_REQUIRED_FLAGS "/arch:armv8.6 -march=armv8.6-a+sve+bf16")
check_cxx_source_compiles("#include <arm_sve.h>\nint main() { svfloat32_t _s; svbfloat16_t _a, _b; _s = svbfmmla_f32(_s, _a, _b); return 0; }" NCNN_COMPILER_SUPPORT_ARM86_SVEBF16)
check_cxx_source_compiles("#include <arm_sve.h>\nsvfloat32_t test(svfloat32_t s, svbfloat16_t a, svbfloat16_t b) { return svbfmmla_f32(s, a, b); }" NCNN_COMPILER_SUPPORT_ARM86_SVEBF16)

set(CMAKE_REQUIRED_FLAGS "/arch:armv8.6 -march=armv8.6-a+sve+i8mm")
check_cxx_source_compiles("#include <arm_sve.h>\nint main() { svint32_t _s; svint8_t _a, _b; _s = svmmla_s32(_s, _a, _b); return 0; }" NCNN_COMPILER_SUPPORT_ARM86_SVEI8MM)
check_cxx_source_compiles("#include <arm_sve.h>\nsvint32_t test(svint32_t s, svint8_t a, svint8_t b) { return svmmla_s32(s, a, b); }" NCNN_COMPILER_SUPPORT_ARM86_SVEI8MM)

set(CMAKE_REQUIRED_FLAGS "/arch:armv8.6 -march=armv8.6-a+sve+f32mm")
check_cxx_source_compiles("#include <arm_sve.h>\nint main() { svfloat32_t _s, _a, _b; _s = svmmla_f32(_s, _a, _b); return 0; }" NCNN_COMPILER_SUPPORT_ARM86_SVEF32MM)
check_cxx_source_compiles("#include <arm_sve.h>\nsvfloat32_t test(svfloat32_t s, svfloat32_t a, svfloat32_t b) { return svmmla_f32(s, a, b); }" NCNN_COMPILER_SUPPORT_ARM86_SVEF32MM)

unset(CMAKE_REQUIRED_FLAGS)
else()
set(CMAKE_REQUIRED_FLAGS "-march=armv8-a")
check_cxx_source_compiles("#include <arm_neon.h>\nint main() { float32x4_t _a; float16x4_t _s = vcvt_f16_f32(_a); return 0; }" NCNN_COMPILER_SUPPORT_ARM_VFPV4)
check_cxx_source_compiles("#include <arm_neon.h>\nfloat16x4_t test(float32x4_t a) { return vcvt_f16_f32(a); }" NCNN_COMPILER_SUPPORT_ARM_VFPV4)

set(CMAKE_REQUIRED_FLAGS "-march=armv8.2-a+fp16")
check_cxx_source_compiles("#include <arm_neon.h>\nint main() { float16x8_t _s, _a, _b; _s = vfmaq_f16(_s, _a, _b); return 0; }" NCNN_COMPILER_SUPPORT_ARM82_FP16)
check_cxx_source_compiles("#include <arm_neon.h>\nfloat16x8_t test(float16x8_t s, float16x8_t a, float16x8_t b) { return vfmaq_f16(s, a, b); }" NCNN_COMPILER_SUPPORT_ARM82_FP16)

set(CMAKE_REQUIRED_FLAGS "-march=armv8.2-a+dotprod")
check_cxx_source_compiles("#include <arm_neon.h>\nint main() { int32x4_t _s; int8x16_t _a, _b; _s = vdotq_s32(_s, _a, _b); return 0; }" NCNN_COMPILER_SUPPORT_ARM82_DOTPROD)
check_cxx_source_compiles("#include <arm_neon.h>\nint32x4_t test(int32x4_t s, int8x16_t a, int8x16_t b) { return vdotq_s32(s, a, b); }" NCNN_COMPILER_SUPPORT_ARM82_DOTPROD)

set(CMAKE_REQUIRED_FLAGS "-march=armv8.2-a+fp16fml")
check_cxx_source_compiles("#include <arm_neon.h>\nint main() { float32x4_t _s; float16x8_t _a, _b; _s = vfmlalq_low_f16(_s, _a, _b); return 0; }" NCNN_COMPILER_SUPPORT_ARM82_FP16FML)
check_cxx_source_compiles("#include <arm_neon.h>\nfloat32x4_t test(float32x4_t s, float16x8_t a, float16x8_t b) { return vfmlalq_low_f16(s, a, b); }" NCNN_COMPILER_SUPPORT_ARM82_FP16FML)

set(CMAKE_REQUIRED_FLAGS "-march=armv8.4-a+bf16")
check_cxx_source_compiles("#include <arm_neon.h>\nint main() { float32x4_t _s; bfloat16x8_t _a, _b; _s = vcvt_f32_bf16(vcvt_bf16_f32(vbfmmlaq_f32(_s, _a, _b))); return 0; }" NCNN_COMPILER_SUPPORT_ARM84_BF16)
check_cxx_source_compiles("#include <arm_neon.h>\nfloat32x4_t test(float32x4_t s, bfloat16x8_t a, bfloat16x8_t b) { return vcvt_f32_bf16(vcvt_bf16_f32(vbfmmlaq_f32(s, a, b))); }" NCNN_COMPILER_SUPPORT_ARM84_BF16)

set(CMAKE_REQUIRED_FLAGS "-march=armv8.4-a+i8mm")
check_cxx_source_compiles("#include <arm_neon.h>\nint main() { int32x4_t _s; int8x16_t _a, _b; _s = vmmlaq_s32(_s, _a, _b); return 0; }" NCNN_COMPILER_SUPPORT_ARM84_I8MM)
check_cxx_source_compiles("#include <arm_neon.h>\nint32x4_t test(int32x4_t s, int8x16_t a, int8x16_t b) { return vmmlaq_s32(s, a, b); }" NCNN_COMPILER_SUPPORT_ARM84_I8MM)

set(CMAKE_REQUIRED_FLAGS "-march=armv8.6-a+sve")
check_cxx_source_compiles("#include <arm_sve.h>\nint main() { svfloat16_t _s, _a, _b; svbool_t bp; _s = svmla_f16_z(bp, _s, _a, _b); return 0; }" NCNN_COMPILER_SUPPORT_ARM86_SVE)
check_cxx_source_compiles("#include <arm_sve.h>\nsvfloat16_t test(svfloat16_t s, svfloat16_t a, svfloat16_t b, svbool_t bp) { return svmla_f16_z(bp, s, a, b); }" NCNN_COMPILER_SUPPORT_ARM86_SVE)

set(CMAKE_REQUIRED_FLAGS "-march=armv8.6-a+sve2")
check_cxx_source_compiles("#include <arm_sve.h>\nint main() { svint16_t _s; svint8_t _a, _b; _s = svmlslb_s16(_s, _a, _b); return 0; }" NCNN_COMPILER_SUPPORT_ARM86_SVE2)
check_cxx_source_compiles("#include <arm_sve.h>\nsvint16_t test(svint16_t s, svint8_t a, svint8_t b) { return svmlslb_s16(s, a, b); }" NCNN_COMPILER_SUPPORT_ARM86_SVE2)

set(CMAKE_REQUIRED_FLAGS "-march=armv8.6-a+sve+bf16")
check_cxx_source_compiles("#include <arm_sve.h>\nint main() { svfloat32_t _s; svbfloat16_t _a, _b; _s = svbfmmla_f32(_s, _a, _b); return 0; }" NCNN_COMPILER_SUPPORT_ARM86_SVEBF16)
check_cxx_source_compiles("#include <arm_sve.h>\nsvfloat32_t test(svfloat32_t s, svbfloat16_t a, svbfloat16_t b) { return svbfmmla_f32(s, a, b); }" NCNN_COMPILER_SUPPORT_ARM86_SVEBF16)

set(CMAKE_REQUIRED_FLAGS "-march=armv8.6-a+sve+i8mm")
check_cxx_source_compiles("#include <arm_sve.h>\nint main() { svint32_t _s; svint8_t _a, _b; _s = svmmla_s32(_s, _a, _b); return 0; }" NCNN_COMPILER_SUPPORT_ARM86_SVEI8MM)
check_cxx_source_compiles("#include <arm_sve.h>\nsvint32_t test(svint32_t s, svint8_t a, svint8_t b) { return svmmla_s32(s, a, b); }" NCNN_COMPILER_SUPPORT_ARM86_SVEI8MM)

set(CMAKE_REQUIRED_FLAGS "-march=armv8.6-a+sve+f32mm")
check_cxx_source_compiles("#include <arm_sve.h>\nint main() { svfloat32_t _s, _a, _b; _s = svmmla_f32(_s, _a, _b); return 0; }" NCNN_COMPILER_SUPPORT_ARM86_SVEF32MM)
check_cxx_source_compiles("#include <arm_sve.h>\nsvfloat32_t test(svfloat32_t s, svfloat32_t a, svfloat32_t b) { return svmmla_f32(s, a, b); }" NCNN_COMPILER_SUPPORT_ARM86_SVEF32MM)

unset(CMAKE_REQUIRED_FLAGS)
endif()
@@ -380,7 +381,7 @@ elseif(CMAKE_SYSTEM_PROCESSOR MATCHES "^(mips)")
check_cxx_compiler_flag("-mmsa" NCNN_COMPILER_SUPPORT_MIPS_MSA)

set(CMAKE_REQUIRED_FLAGS "-mloongson-mmi -I${CMAKE_CURRENT_SOURCE_DIR}/src/layer/mips")
check_cxx_source_compiles("#include \"loongson_mmi.h\"\nint main() { int16x4_t _a, _b; int32x2_t _s = __mmi_pmaddhw(_a, _b); return 0; }" NCNN_COMPILER_SUPPORT_LOONGSON_MMI)
check_cxx_source_compiles("#include \"loongson_mmi.h\"\nint32x2_t test(int16x4_t a, int16x4_t b) { return __mmi_pmaddhw(a, b); }" NCNN_COMPILER_SUPPORT_LOONGSON_MMI)

unset(CMAKE_REQUIRED_FLAGS)

@@ -398,10 +399,10 @@ elseif(CMAKE_SYSTEM_PROCESSOR MATCHES "^(loongarch64|loongarch32)")
set(NCNN_TARGET_ARCH loongarch)

set(CMAKE_REQUIRED_FLAGS "-mlsx")
check_cxx_source_compiles("#include <lsxintrin.h>\nint main() { __m128 _s, _a, _b, _c; _s = __lsx_vfmadd_s(_a, _b, _c); return 0; }" NCNN_COMPILER_SUPPORT_LOONGARCH_LSX)
check_cxx_source_compiles("#include <lsxintrin.h>\n__m128 test(__m128 a, __m128 b, __m128 c) { return __lsx_vfmadd_s(a, b, c); }" NCNN_COMPILER_SUPPORT_LOONGARCH_LSX)

set(CMAKE_REQUIRED_FLAGS "-mlasx")
check_cxx_source_compiles("#include <lasxintrin.h>\nint main() { __m256 _s, _a, _b, _c; _s = __lasx_xvfmadd_s(_a, _b, _c); return 0; }" NCNN_COMPILER_SUPPORT_LOONGARCH_LASX)
check_cxx_source_compiles("#include <lasxintrin.h>\n__m256 test(__m256 a, __m256 b, __m256 c) { return __lasx_xvfmadd_s(a, b, c); }" NCNN_COMPILER_SUPPORT_LOONGARCH_LASX)

unset(CMAKE_REQUIRED_FLAGS)

@@ -421,16 +422,16 @@ elseif(CMAKE_SYSTEM_PROCESSOR MATCHES "^(riscv)")

if(CMAKE_SIZEOF_VOID_P EQUAL 8)
set(CMAKE_REQUIRED_FLAGS "-march=rv64gcv")
check_cxx_source_compiles("#include <riscv_vector.h>\nint main() { vfloat32m8_t _s, _w; float _v; size_t vl; _s = __riscv_vfmacc_vf_f32m8(_s, _v, _w, vl); vfloat32m1_t _x; vfloat32m1x2_t _xx = __riscv_vcreate_v_f32m1x2(_x, _x); return 0; }" NCNN_COMPILER_SUPPORT_RISCV_V)
check_cxx_source_compiles("#include <riscv_vector.h>\nvfloat32m8_t test(vfloat32m8_t s, vfloat32m8_t w, float v, size_t vl) { return __riscv_vfmacc_vf_f32m8(s, v, w, vl); }\nvfloat32m1x2_t test2(vfloat32m1_t x) { return __riscv_vcreate_v_f32m1x2(x, x); }" NCNN_COMPILER_SUPPORT_RISCV_V)

set(CMAKE_REQUIRED_FLAGS "-march=rv64gc_zfh -D__fp16=_Float16")
check_cxx_source_compiles("int main() { __fp16 s, v; s = v * v; return 0; }" NCNN_COMPILER_SUPPORT_RISCV_ZFH)
check_cxx_source_compiles("__fp16 test(__fp16 a) { return a * a; }" NCNN_COMPILER_SUPPORT_RISCV_ZFH)

set(CMAKE_REQUIRED_FLAGS "-march=rv64gcv_zfh_zvfh -D__fp16=_Float16")
check_cxx_source_compiles("#include <riscv_vector.h>\nint main() { vfloat16m8_t _s, _w; __fp16 _v; size_t vl; _s = __riscv_vfmacc_vf_f16m8(_s, _v, _w, vl); return 0; }" NCNN_COMPILER_SUPPORT_RISCV_ZVFH)
check_cxx_source_compiles("#include <riscv_vector.h>\nvfloat16m8_t test(vfloat16m8_t s, vfloat16m8_t w, __fp16 v, size_t vl) { return __riscv_vfmacc_vf_f16m8(s, v, w, vl); }\nvfloat16m1x2_t test2(vfloat16m1_t x){ return __riscv_vcreate_v_f16m1x2(x, x); }" NCNN_COMPILER_SUPPORT_RISCV_ZVFH)

set(CMAKE_REQUIRED_FLAGS "-march=rv64gc_zfh_xtheadvector -D__fp16=_Float16")
check_cxx_source_compiles("#include <riscv_vector.h>\nint main() { vfloat16m8_t _s, _w; __fp16 _v; size_t vl; _s = __riscv_vfmacc_vf_f16m8(_s, _v, _w, vl); vfloat32m1_t _x; vfloat32m1x2_t _xx = __riscv_vcreate_v_f32m1x2(_x, _x); return 0; }" NCNN_COMPILER_SUPPORT_RISCV_XTHEADVECTOR)
check_cxx_source_compiles("#include <riscv_vector.h>\nvfloat16m8_t test(vfloat16m8_t s, vfloat16m8_t w, __fp16 v, size_t vl) { return __riscv_vfmacc_vf_f16m8(s, v, w, vl); }\nvfloat16m1x2_t test2(vfloat16m1_t x){ return __riscv_vcreate_v_f16m1x2(x, x); }" NCNN_COMPILER_SUPPORT_RISCV_XTHEADVECTOR)

unset(CMAKE_REQUIRED_FLAGS)

@@ -467,11 +468,11 @@ elseif(CMAKE_SYSTEM_PROCESSOR MATCHES "^(powerpc|ppc)")
set(NCNN_TARGET_ARCH x86)

set(CMAKE_REQUIRED_FLAGS "-DNO_WARN_X86_INTRINSICS -D__SSE2__")
check_cxx_source_compiles("#include <emmintrin.h>\nint main() { return 0; }" NCNN_COMPILER_SUPPORT_PPC64LE_SSE2)
check_cxx_source_compiles("#include <emmintrin.h>\n__m128i test(__m128i a, __m128i b) { return _mm_madd_epi16(a, b); }" NCNN_COMPILER_SUPPORT_PPC64LE_SSE2)
unset(CMAKE_REQUIRED_FLAGS)

set(CMAKE_REQUIRED_FLAGS "-DNO_WARN_X86_INTRINSICS -D__SSE4_1__")
check_cxx_source_compiles("#include <smmintrin.h>\nint main() { __m128i _v, _a, _b; _v = _mm_packus_epi32(_a, _b); return 0; }" NCNN_COMPILER_SUPPORT_PPC64LE_SSE41)
check_cxx_source_compiles("#include <smmintrin.h>\n__m128i test(__m128i a, __m128i b) { return _mm_packus_epi32(a, b); }" NCNN_COMPILER_SUPPORT_PPC64LE_SSE41)
unset(CMAKE_REQUIRED_FLAGS)

if(NCNN_COMPILER_SUPPORT_PPC64LE_SSE2)
@@ -501,105 +502,130 @@ else()
option(NCNN_SSE2 "optimize x86 platform with sse2 extension" ON)

if(CMAKE_CXX_COMPILER_ID MATCHES "MSVC")
check_cxx_compiler_flag("/arch:AVX" NCNN_COMPILER_SUPPORT_X86_AVX)
check_cxx_compiler_flag("/arch:AVX" NCNN_COMPILER_SUPPORT_X86_FMA)
check_cxx_compiler_flag("/arch:AVX" NCNN_COMPILER_SUPPORT_X86_XOP)
check_cxx_compiler_flag("/arch:AVX" NCNN_COMPILER_SUPPORT_X86_F16C)
check_cxx_compiler_flag("/arch:AVX2" NCNN_COMPILER_SUPPORT_X86_AVX2)
check_cxx_compiler_flag("/arch:AVX512" NCNN_COMPILER_SUPPORT_X86_AVX512)
set(CMAKE_REQUIRED_FLAGS "/arch:AVX")
check_cxx_source_compiles("#include <immintrin.h>\n__m256 test(__m256 a, __m256 b) { return _mm256_mul_ps(a, b); }" NCNN_COMPILER_SUPPORT_X86_AVX)

set(CMAKE_REQUIRED_FLAGS "/arch:AVX")
check_cxx_source_compiles("#include <immintrin.h>\n__m256 test(__m256 s, __m256 a, __m256 b) { return _mm256_fmadd_ps(a, b, s); }" NCNN_COMPILER_SUPPORT_X86_FMA)

set(CMAKE_REQUIRED_FLAGS "/arch:AVX")
check_cxx_source_compiles("#include <immintrin.h>\n#include <ammintrin.h>\n__m128i test(__m128i s, __m128i a, __m128i b) { return _mm_maddd_epi16(a, b, s); }" NCNN_COMPILER_SUPPORT_X86_XOP)

set(CMAKE_REQUIRED_FLAGS "/arch:AVX")
check_cxx_source_compiles("#include <immintrin.h>\n__m256 test(__m128i a) { return _mm256_cvtph_ps(a); }" NCNN_COMPILER_SUPPORT_X86_F16C)

set(CMAKE_REQUIRED_FLAGS "/arch:AVX2")
check_cxx_source_compiles("#include <immintrin.h>\n__m256i test(__m256i a, __m256i b) { return _mm256_madd_epi16(a, b); }" NCNN_COMPILER_SUPPORT_X86_AVX2)

set(CMAKE_REQUIRED_FLAGS "/arch:AVX512")
check_cxx_source_compiles("#include <immintrin.h>\n__m512i test(__m512i a, __m512i b) { return _mm512_madd_epi16(a, b); }" NCNN_COMPILER_SUPPORT_X86_AVX512)

set(CMAKE_REQUIRED_FLAGS "/arch:AVX2")
check_cxx_source_compiles("#include <immintrin.h>\nint main() { __m256i _s, _a, _b; _s = _mm256_dpwssd_avx_epi32(_s, _a, _b); return 0; }" NCNN_COMPILER_SUPPORT_X86_AVX_VNNI)
check_cxx_source_compiles("#include <immintrin.h>\n__m256i test(__m256i s, __m256i a, __m256i b) { return _mm256_dpwssd_avx_epi32(s, a, b); }" NCNN_COMPILER_SUPPORT_X86_AVX_VNNI)

set(CMAKE_REQUIRED_FLAGS "/arch:AVX2")
check_cxx_source_compiles("#include <immintrin.h>\nint main() { __m256i _s, _a, _b; _s = _mm256_dpbssd_epi32(_s, _a, _b); return 0; }" NCNN_COMPILER_SUPPORT_X86_AVX_VNNI_INT8)
check_cxx_source_compiles("#include <immintrin.h>\n__m256i test(__m256i s, __m256i a, __m256i b) { return _mm256_dpbssd_epi32(s, a, b); }" NCNN_COMPILER_SUPPORT_X86_AVX_VNNI_INT8)

set(CMAKE_REQUIRED_FLAGS "/arch:AVX2")
check_cxx_source_compiles("#include <immintrin.h>\nint main() { __m256i _s, _a, _b; _s = _mm256_dpwsud_avx_epi32(_s, _a, _b); return 0; }" NCNN_COMPILER_SUPPORT_X86_AVX_VNNI_INT16)
check_cxx_source_compiles("#include <immintrin.h>\n__m256i test(__m256i s, __m256i a, __m256i b) { return _mm256_dpwsud_avx_epi32(s, a, b); }" NCNN_COMPILER_SUPPORT_X86_AVX_VNNI_INT16)

set(CMAKE_REQUIRED_FLAGS "/arch:AVX2")
check_cxx_source_compiles("#include <immintrin.h>\nint main() { __m256 _a; __m128bh _s = _mm256_cvtneps_avx_pbh(_a); return 0; }" NCNN_COMPILER_SUPPORT_X86_AVX_NE_CONVERT)
check_cxx_source_compiles("#include <immintrin.h>\n__m128bh test(__m256 a) { return _mm256_cvtneps_avx_pbh(a); }" NCNN_COMPILER_SUPPORT_X86_AVX_NE_CONVERT)

set(CMAKE_REQUIRED_FLAGS "/arch:AVX512")
check_cxx_source_compiles("#include <immintrin.h>\nint main() { __m512i _s, _a, _b; _s = _mm512_dpwssd_epi32(_s, _a, _b); return 0; }" NCNN_COMPILER_SUPPORT_X86_AVX512_VNNI)
check_cxx_source_compiles("#include <immintrin.h>\n__m512i test(__m512i s, __m512i a, __m512i b) { return _mm512_dpwssd_epi32(s, a, b); }" NCNN_COMPILER_SUPPORT_X86_AVX512_VNNI)

set(CMAKE_REQUIRED_FLAGS "/arch:AVX512")
check_cxx_source_compiles("#include <immintrin.h>\nint main() { __m256bh _s; __m512bh _a, _b; _s = _mm512_cvtneps_pbh(_mm512_dpbf16_ps(_mm512_cvtpbh_ps(_s), _a, _b)); return 0; }\n__m512i t(__m512 a) { __m256i _a = (__m256i)_mm512_cvtneps_pbh(a); return _mm512_inserti32x8(_mm512_castsi256_si512(_a), _a, 1); }" NCNN_COMPILER_SUPPORT_X86_AVX512_BF16)
check_cxx_source_compiles("#include <immintrin.h>\n__m256bh test(__m256bh s, __m512bh a, __m512bh b) { return _mm512_cvtneps_pbh(_mm512_dpbf16_ps(_mm512_cvtpbh_ps(s), a, b)); }\n__m512i test2(__m512 a) { __m256i _a = (__m256i)_mm512_cvtneps_pbh(a); return _mm512_inserti32x8(_mm512_castsi256_si512(_a), _a, 1); }" NCNN_COMPILER_SUPPORT_X86_AVX512_BF16)

set(CMAKE_REQUIRED_FLAGS "/arch:AVX512")
check_cxx_source_compiles("#include <immintrin.h>\nint main() { __m512h _s, _a, _b; _s = _mm512_fmadd_ph(_s, _a, _b); __m512 _s2; _s2 = _mm512_cvtxph_ps(_mm512_cvtxps_ph(_s2)); return 0; }" NCNN_COMPILER_SUPPORT_X86_AVX512_FP16)
check_cxx_source_compiles("#include <immintrin.h>\n__m512h test(__m512h s, __m512h a, __m512h b) { return _mm512_fmadd_ph(s, a, b); }\n__m512 test2(__m512 a) { return _mm512_cvtxph_ps(_mm512_cvtxps_ph(a)); }" NCNN_COMPILER_SUPPORT_X86_AVX512_FP16)

unset(CMAKE_REQUIRED_FLAGS)
elseif(CMAKE_CXX_COMPILER_ID MATCHES "Clang" AND CMAKE_CXX_SIMULATE_ID MATCHES "MSVC" AND CMAKE_CXX_COMPILER_FRONTEND_VARIANT MATCHES "MSVC")
check_cxx_compiler_flag("-mrecip=none" NCNN_COMPILER_SUPPORT_X86_RECIP_NONE)

check_cxx_compiler_flag("/arch:AVX" NCNN_COMPILER_SUPPORT_X86_AVX)
set(CMAKE_REQUIRED_FLAGS "/arch:AVX")
check_cxx_source_compiles("#include <immintrin.h>\n__m256 test(__m256 a, __m256 b) { return _mm256_mul_ps(a, b); }" NCNN_COMPILER_SUPPORT_X86_AVX)

set(CMAKE_REQUIRED_FLAGS "/arch:AVX -mfma -mf16c")
check_cxx_source_compiles("#include <immintrin.h>\nint main() { __m256 _s, _a, _b; _s = _mm256_fmadd_ps(_a, _b, _s); return 0; }" NCNN_COMPILER_SUPPORT_X86_FMA)
check_cxx_source_compiles("#include <immintrin.h>\n__m256 test(__m256 s, __m256 a, __m256 b) { return _mm256_fmadd_ps(a, b, s); }" NCNN_COMPILER_SUPPORT_X86_FMA)

set(CMAKE_REQUIRED_FLAGS "/arch:AVX -mxop")
check_cxx_source_compiles("#include <x86intrin.h>\nint main() { __m128 _s, _a, _b; _s = _mm_maddd_epi16(_a, _b, _s); return 0; }" NCNN_COMPILER_SUPPORT_X86_XOP)
check_cxx_source_compiles("#include <x86intrin.h>\n__m128i test(__m128i s, __m128i a, __m128i b) { return _mm_maddd_epi16(a, b, s); }" NCNN_COMPILER_SUPPORT_X86_XOP)

check_cxx_compiler_flag("/arch:AVX -mf16c" NCNN_COMPILER_SUPPORT_X86_F16C)
check_cxx_compiler_flag("/arch:AVX2 -mfma -mf16c" NCNN_COMPILER_SUPPORT_X86_AVX2)
check_cxx_compiler_flag("/arch:AVX512 -mfma -mf16c -mavx512cd -mavx512bw -mavx512dq -mavx512vl" NCNN_COMPILER_SUPPORT_X86_AVX512)
set(CMAKE_REQUIRED_FLAGS "/arch:AVX -mf16c")
check_cxx_source_compiles("#include <immintrin.h>\n__m256 test(__m128i a) { return _mm256_cvtph_ps(a); }" NCNN_COMPILER_SUPPORT_X86_F16C)

set(CMAKE_REQUIRED_FLAGS "/arch:AVX2 -mfma -mf16c")
check_cxx_source_compiles("#include <immintrin.h>\n__m256i test(__m256i a, __m256i b) { return _mm256_madd_epi16(a, b); }" NCNN_COMPILER_SUPPORT_X86_AVX2)

set(CMAKE_REQUIRED_FLAGS "/arch:AVX512 -mfma -mf16c -mavx512cd -mavx512bw -mavx512dq -mavx512vl")
check_cxx_source_compiles("#include <immintrin.h>\n__m512i test(__m512i a, __m512i b) { return _mm512_madd_epi16(a, b); }" NCNN_COMPILER_SUPPORT_X86_AVX512)

set(CMAKE_REQUIRED_FLAGS "/arch:AVX2 -mfma -mf16c -mavxvnni")
check_cxx_source_compiles("#include <immintrin.h>\nint main() { __m256i _s, _a, _b; _s = _mm256_dpwssd_avx_epi32(_s, _a, _b); return 0; }" NCNN_COMPILER_SUPPORT_X86_AVX_VNNI)
check_cxx_source_compiles("#include <immintrin.h>\n__m256i test(__m256i s, __m256i a, __m256i b) { return _mm256_dpwssd_avx_epi32(s, a, b); }" NCNN_COMPILER_SUPPORT_X86_AVX_VNNI)

set(CMAKE_REQUIRED_FLAGS "/arch:AVX2 -mfma -mf16c -mavxvnni -mavxvnniint8")
check_cxx_source_compiles("#include <immintrin.h>\nint main() { __m256i _s, _a, _b; _s = _mm256_dpbssd_epi32(_s, _a, _b); return 0; }" NCNN_COMPILER_SUPPORT_X86_AVX_VNNI_INT8)
check_cxx_source_compiles("#include <immintrin.h>\n__m256i test(__m256i s, __m256i a, __m256i b) { return _mm256_dpbssd_epi32(s, a, b); }" NCNN_COMPILER_SUPPORT_X86_AVX_VNNI_INT8)

set(CMAKE_REQUIRED_FLAGS "/arch:AVX2 -mfma -mf16c -mavxvnni -mavxvnniint16")
check_cxx_source_compiles("#include <immintrin.h>\nint main() { __m256i _s, _a, _b; _s = _mm256_dpwsud_avx_epi32(_s, _a, _b); return 0; }" NCNN_COMPILER_SUPPORT_X86_AVX_VNNI_INT16)
check_cxx_source_compiles("#include <immintrin.h>\n__m256i test(__m256i s, __m256i a, __m256i b) { return _mm256_dpwsud_avx_epi32(s, a, b); }" NCNN_COMPILER_SUPPORT_X86_AVX_VNNI_INT16)

set(CMAKE_REQUIRED_FLAGS "/arch:AVX2 -mfma -mf16c -mavxneconvert")
check_cxx_source_compiles("#include <immintrin.h>\nint main() { __m256 _a; __m128bh _s = _mm256_cvtneps_avx_pbh(_a); return 0; }" NCNN_COMPILER_SUPPORT_X86_AVX_NE_CONVERT)
check_cxx_source_compiles("#include <immintrin.h>\n__m128bh test(__m256 a) { return _mm256_cvtneps_avx_pbh(a); }" NCNN_COMPILER_SUPPORT_X86_AVX_NE_CONVERT)

set(CMAKE_REQUIRED_FLAGS "/arch:AVX512 -mfma -mf16c -mavx512cd -mavx512bw -mavx512dq -mavx512vl -mavx512vnni")
check_cxx_source_compiles("#include <immintrin.h>\nint main() { __m512i _s, _a, _b; _s = _mm512_dpwssd_epi32(_s, _a, _b); return 0; }" NCNN_COMPILER_SUPPORT_X86_AVX512_VNNI)
check_cxx_source_compiles("#include <immintrin.h>\n__m512i test(__m512i s, __m512i a, __m512i b) { return _mm512_dpwssd_epi32(s, a, b); }" NCNN_COMPILER_SUPPORT_X86_AVX512_VNNI)

set(CMAKE_REQUIRED_FLAGS "/arch:AVX512 -mfma -mf16c -mavx512cd -mavx512bw -mavx512dq -mavx512vl -mavx512bf16")
check_cxx_source_compiles("#include <immintrin.h>\nint main() { __m256bh _s; __m512bh _a, _b; _s = _mm512_cvtneps_pbh(_mm512_dpbf16_ps(_mm512_cvtpbh_ps(_s), _a, _b)); return 0; }\n__m512i t(__m512 a) { __m256i _a = (__m256i)_mm512_cvtneps_pbh(a); return _mm512_inserti32x8(_mm512_castsi256_si512(_a), _a, 1); }" NCNN_COMPILER_SUPPORT_X86_AVX512_BF16)
check_cxx_source_compiles("#include <immintrin.h>\n__m256bh test(__m256bh s, __m512bh a, __m512bh b) { return _mm512_cvtneps_pbh(_mm512_dpbf16_ps(_mm512_cvtpbh_ps(s), a, b)); }\n__m512i test2(__m512 a) { __m256i _a = (__m256i)_mm512_cvtneps_pbh(a); return _mm512_inserti32x8(_mm512_castsi256_si512(_a), _a, 1); }" NCNN_COMPILER_SUPPORT_X86_AVX512_BF16)

set(CMAKE_REQUIRED_FLAGS "/arch:AVX512 -mfma -mf16c -mavx512cd -mavx512bw -mavx512dq -mavx512vl -mavx512fp16")
check_cxx_source_compiles("#include <immintrin.h>\nint main() { __m512h _s, _a, _b; _s = _mm512_fmadd_ph(_s, _a, _b); __m512 _s2; _s2 = _mm512_cvtxph_ps(_mm512_cvtxps_ph(_s2)); return 0; }" NCNN_COMPILER_SUPPORT_X86_AVX512_FP16)
check_cxx_source_compiles("#include <immintrin.h>\n__m512h test(__m512h s, __m512h a, __m512h b) { return _mm512_fmadd_ph(s, a, b); }\n__m512 test2(__m512 a) { return _mm512_cvtxph_ps(_mm512_cvtxps_ph(a)); }" NCNN_COMPILER_SUPPORT_X86_AVX512_FP16)

unset(CMAKE_REQUIRED_FLAGS)
else()
check_cxx_compiler_flag("-mrecip=none" NCNN_COMPILER_SUPPORT_X86_RECIP_NONE)

check_cxx_compiler_flag("-mavx" NCNN_COMPILER_SUPPORT_X86_AVX)
set(CMAKE_REQUIRED_FLAGS "-mavx")
check_cxx_source_compiles("#include <immintrin.h>\n__m256 test(__m256 a, __m256 b) { return _mm256_mul_ps(a, b); }" NCNN_COMPILER_SUPPORT_X86_AVX)

set(CMAKE_REQUIRED_FLAGS "-mfma -mf16c")
check_cxx_source_compiles("#include <immintrin.h>\nint main() { __m256 _s, _a, _b; _s = _mm256_fmadd_ps(_a, _b, _s); return 0; }" NCNN_COMPILER_SUPPORT_X86_FMA)
check_cxx_source_compiles("#include <immintrin.h>\n__m256 test(__m256 s, __m256 a, __m256 b) { return _mm256_fmadd_ps(a, b, s); }" NCNN_COMPILER_SUPPORT_X86_FMA)

set(CMAKE_REQUIRED_FLAGS "-mfma -mxop")
check_cxx_source_compiles("#include <x86intrin.h>\n__m128i test(__m128i s, __m128i a, __m128i b) { return _mm_maddd_epi16(a, b, s); }" NCNN_COMPILER_SUPPORT_X86_XOP)

check_cxx_compiler_flag("-mxop" NCNN_COMPILER_SUPPORT_X86_XOP)
check_cxx_compiler_flag("-mf16c" NCNN_COMPILER_SUPPORT_X86_F16C)
check_cxx_compiler_flag("-mfma -mf16c -mavx2" NCNN_COMPILER_SUPPORT_X86_AVX2)
check_cxx_compiler_flag("-mfma -mf16c -mavx512f -mavx512cd -mavx512bw -mavx512dq -mavx512vl" NCNN_COMPILER_SUPPORT_X86_AVX512)
set(CMAKE_REQUIRED_FLAGS "-mf16c")
check_cxx_source_compiles("#include <immintrin.h>\n__m256 test(__m128i a) { return _mm256_cvtph_ps(a); }" NCNN_COMPILER_SUPPORT_X86_F16C)

set(CMAKE_REQUIRED_FLAGS "-mfma -mf16c -mavx2")
check_cxx_source_compiles("#include <immintrin.h>\n__m256i test(__m256i a, __m256i b) { return _mm256_madd_epi16(a, b); }" NCNN_COMPILER_SUPPORT_X86_AVX2)

set(CMAKE_REQUIRED_FLAGS "-mfma -mf16c -mavx512f -mavx512cd -mavx512bw -mavx512dq -mavx512vl")
check_cxx_source_compiles("#include <immintrin.h>\n__m512i test(__m512i a, __m512i b) { return _mm512_madd_epi16(a, b); }" NCNN_COMPILER_SUPPORT_X86_AVX512)

set(CMAKE_REQUIRED_FLAGS "-mfma -mf16c -mavx2 -mavxvnni")
check_cxx_source_compiles("#include <immintrin.h>\nint main() { __m256i _s, _a, _b; _s = _mm256_dpwssd_epi32(_s, _a, _b); return 0; }" NCNN_COMPILER_SUPPORT_X86_AVX_VNNI)
check_cxx_source_compiles("#include <immintrin.h>\n__m256i test(__m256i s, __m256i a, __m256i b) { return _mm256_dpwssd_avx_epi32(s, a, b); }" NCNN_COMPILER_SUPPORT_X86_AVX_VNNI)

set(CMAKE_REQUIRED_FLAGS "-mfma -mf16c -mavx2 -mavxvnni -mavxvnniint8")
check_cxx_source_compiles("#include <immintrin.h>\nint main() { __m256i _s, _a, _b; _s = _mm256_dpbssd_epi32(_s, _a, _b); return 0; }" NCNN_COMPILER_SUPPORT_X86_AVX_VNNI_INT8)
check_cxx_source_compiles("#include <immintrin.h>\n__m256i test(__m256i s, __m256i a, __m256i b) { return _mm256_dpbssd_epi32(s, a, b); }" NCNN_COMPILER_SUPPORT_X86_AVX_VNNI_INT8)

set(CMAKE_REQUIRED_FLAGS "-mfma -mf16c -mavx2 -mavxvnni -mavxvnniint16")
check_cxx_source_compiles("#include <immintrin.h>\nint main() { __m256i _s, _a, _b; _s = _mm256_dpwsud_avx_epi32(_s, _a, _b); return 0; }" NCNN_COMPILER_SUPPORT_X86_AVX_VNNI_INT16)
check_cxx_source_compiles("#include <immintrin.h>\n__m256i test(__m256i s, __m256i a, __m256i b) { return _mm256_dpwsud_avx_epi32(s, a, b); }" NCNN_COMPILER_SUPPORT_X86_AVX_VNNI_INT16)

set(CMAKE_REQUIRED_FLAGS "-mfma -mf16c -mavx2 -mavxneconvert")
check_cxx_source_compiles("#include <immintrin.h>\nint main() { __m256 _a; __m128bh _s = _mm256_cvtneps_avx_pbh(_a); return 0; }" NCNN_COMPILER_SUPPORT_X86_AVX_NE_CONVERT)
check_cxx_source_compiles("#include <immintrin.h>\n__m128bh test(__m256 a) { return _mm256_cvtneps_avx_pbh(a); }" NCNN_COMPILER_SUPPORT_X86_AVX_NE_CONVERT)

set(CMAKE_REQUIRED_FLAGS "-mfma -mf16c -mavx512f -mavx512cd -mavx512bw -mavx512dq -mavx512vl -mavx512vnni")
check_cxx_source_compiles("#include <immintrin.h>\nint main() { __m512i _s, _a, _b; _s = _mm512_dpwssd_epi32(_s, _a, _b); return 0; }" NCNN_COMPILER_SUPPORT_X86_AVX512_VNNI)
check_cxx_source_compiles("#include <immintrin.h>\n__m512i test(__m512i s, __m512i a, __m512i b) { return _mm512_dpwssd_epi32(s, a, b); }" NCNN_COMPILER_SUPPORT_X86_AVX512_VNNI)

set(CMAKE_REQUIRED_FLAGS "-mfma -mf16c -mavx512f -mavx512cd -mavx512bw -mavx512dq -mavx512vl -mavx512bf16")
check_cxx_source_compiles("#include <immintrin.h>\nint main() { __m256bh _s; __m512bh _a, _b; _s = _mm512_cvtneps_pbh(_mm512_dpbf16_ps(_mm512_cvtpbh_ps(_s), _a, _b)); return 0; }\n__m512i t(__m512 a) { __m256i _a = (__m256i)_mm512_cvtneps_pbh(a); return _mm512_inserti32x8(_mm512_castsi256_si512(_a), _a, 1); }" NCNN_COMPILER_SUPPORT_X86_AVX512_BF16)
check_cxx_source_compiles("#include <immintrin.h>\n__m256bh test(__m256bh s, __m512bh a, __m512bh b) { return _mm512_cvtneps_pbh(_mm512_dpbf16_ps(_mm512_cvtpbh_ps(s), a, b)); }\n__m512i test2(__m512 a) { __m256i _a = (__m256i)_mm512_cvtneps_pbh(a); return _mm512_inserti32x8(_mm512_castsi256_si512(_a), _a, 1); }" NCNN_COMPILER_SUPPORT_X86_AVX512_BF16)

set(CMAKE_REQUIRED_FLAGS "-mfma -mf16c -mavx512f -mavx512cd -mavx512bw -mavx512dq -mavx512vl -mavx512fp16")
check_cxx_source_compiles("#include <immintrin.h>\nint main() { __m512h _s, _a, _b; _s = _mm512_fmadd_ph(_s, _a, _b); __m512 _s2; _s2 = _mm512_cvtxph_ps(_mm512_cvtxps_ph(_s2)); return 0; }" NCNN_COMPILER_SUPPORT_X86_AVX512_FP16)
check_cxx_source_compiles("#include <immintrin.h>\n__m512h test(__m512h s, __m512h a, __m512h b) { return _mm512_fmadd_ph(s, a, b); }\n__m512 test2(__m512 a) { return _mm512_cvtxph_ps(_mm512_cvtxps_ph(a)); }" NCNN_COMPILER_SUPPORT_X86_AVX512_FP16)

unset(CMAKE_REQUIRED_FLAGS)
endif()
@@ -695,6 +721,9 @@ else()
endif()
endif()

unset(CMAKE_TRY_COMPILE_CONFIGURATION)
unset(CMAKE_TRY_COMPILE_TARGET_TYPE)

if(NCNN_TARGET_ILP32)
message(STATUS "Target arch: ${NCNN_TARGET_ARCH} 64bit ilp32")
elseif(CMAKE_SIZEOF_VOID_P EQUAL 8)


+ 5
- 1
examples/CMakeLists.txt View File

@@ -52,6 +52,10 @@ if(NCNN_PIXEL)
ncnn_add_example(yolov5_pnnx)
ncnn_add_example(yolov7_pnnx)
ncnn_add_example(yolov7)
ncnn_add_example(yolov8)
ncnn_add_example(yolov8_seg)
ncnn_add_example(yolov8_pose)
ncnn_add_example(yolov8_cls)
ncnn_add_example(yolox)
ncnn_add_example(mobilenetv2ssdlite)
ncnn_add_example(mobilenetssd)
@@ -67,9 +71,9 @@ if(NCNN_PIXEL)
ncnn_add_example(scrfd_crowdhuman)
if(OpenCV_FOUND)
ncnn_add_example(yolov4)
ncnn_add_example(yolov8_obb)
ncnn_add_example(rvm)
ncnn_add_example(p2pnet)
ncnn_add_example(yolov8)
endif()
else()
message(WARNING "OpenCV not found and NCNN_SIMPLEOCV disabled, examples won't be built")


+ 358
- 136
examples/yolov8.cpp View File

@@ -2,8 +2,6 @@
//
// Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved.
//
// Copyright (C) 2024 whyb(https://github.com/whyb). All rights reserved.
//
// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
// in compliance with the License. You may obtain a copy of the License at
//
@@ -14,49 +12,61 @@
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
// specific language governing permissions and limitations under the License.

// ReadMe
// Convert yolov8 model to ncnn model workflow:
//
// step 1:
// If you don't want to train the model yourself. You should go to the ultralytics website download the pretrained model file.
// original pretrained model from https://docs.ultralytics.com/models/yolov8/#supported-tasks-and-modes
// 1. install
// pip3 install -U ultralytics pnnx ncnn
// 2. export yolov8 torchscript
// yolo export model=yolov8n.pt format=torchscript
// 3. convert torchscript with static shape
// pnnx yolov8n.torchscript
// 4. modify yolov8n_pnnx.py for dynamic shape inference
// A. modify reshape to support dynamic image sizes
// B. permute tensor before concat and adjust concat axis
// C. drop post-process part
// before:
// v_165 = v_142.view(1, 144, 6400)
// v_166 = v_153.view(1, 144, 1600)
// v_167 = v_164.view(1, 144, 400)
// v_168 = torch.cat((v_165, v_166, v_167), dim=2)
// ...
// after:
// v_165 = v_142.view(1, 144, -1).transpose(1, 2)
// v_166 = v_153.view(1, 144, -1).transpose(1, 2)
// v_167 = v_164.view(1, 144, -1).transpose(1, 2)
// v_168 = torch.cat((v_165, v_166, v_167), dim=1)
// return v_168
// 5. re-export yolov8 torchscript
// python3 -c 'import yolov8n_pnnx; yolov8n_pnnx.export_torchscript()'
// 6. convert new torchscript with dynamic shape
// pnnx yolov8n_pnnx.py.pt inputshape=[1,3,640,640] inputshape2=[1,3,320,320]
// 7. now you get ncnn model files
// mv yolov8n_pnnx.py.ncnn.param yolov8n.ncnn.param
// mv yolov8n_pnnx.py.ncnn.bin yolov8n.ncnn.bin

// the out blob would be a 2-dim tensor with w=144 h=8400
//
// step 2:
// run this command.
// conda create --name yolov8 python=3.11
// conda activate yolov8
// pip install ultralytics onnx numpy protobuf
// | bbox-reg 16 x 4 | per-class scores(80) |
// +-----+-----+-----+-----+----------------------+
// | dx0 | dy0 | dx1 | dy1 |0.1 0.0 0.0 0.5 ......|
// all /| | | | | . |
// boxes | .. | .. | .. | .. |0.0 0.9 0.0 0.0 ......|
// (8400)| | | | | . |
// \| | | | | . |
// +-----+-----+-----+-----+----------------------+
//
// step 3:
// save source code file(export_model_to_ncnn.py):
// from ultralytics import YOLO
// detection_models = [
// ["./Detection-pt/yolov8n.pt", "./Detection-pt/"],
// ["./Detection-pt/yolov8s.pt", "./Detection-pt/"],
// ["./Detection-pt/yolov8m.pt", "./Detection-pt/"],
// ["./Detection-pt/yolov8l.pt", "./Detection-pt/"],
// ["./Detection-pt/yolov8x.pt", "./Detection-pt/"]
// ]
// for model_dict in detection_models:
// model = YOLO(model_dict[0]) # load an official pretrained weight model
// model.export(format="ncnn", dynamic=True, save_dir=model_dict[1], simplify=True)
//
// step 4:
// run command: python export_model_to_ncnn.py

#include <memory>
#include <vector>
#include <algorithm>
#include "layer.h"
#include "net.h"

#include <opencv2/opencv.hpp>
#if defined(USE_NCNN_SIMPLEOCV)
#include "simpleocv.h"
#else
#include <opencv2/core/core.hpp>
#include <opencv2/highgui/highgui.hpp>
#include <opencv2/imgproc/imgproc.hpp>
#endif
#include <float.h>
#include <stdio.h>

#define MAX_STRIDE 32
#include <vector>

struct Object
{
@@ -95,13 +105,13 @@ static void qsort_descent_inplace(std::vector<Object>& objects, int left, int ri
}
}

#pragma omp parallel sections
// #pragma omp parallel sections
{
#pragma omp section
// #pragma omp section
{
if (left < j) qsort_descent_inplace(objects, left, j);
}
#pragma omp section
// #pragma omp section
{
if (i < right) qsort_descent_inplace(objects, i, right);
}
@@ -116,26 +126,26 @@ static void qsort_descent_inplace(std::vector<Object>& objects)
qsort_descent_inplace(objects, 0, objects.size() - 1);
}

static void nms_sorted_bboxes(const std::vector<Object>& faceobjects, std::vector<int>& picked, float nms_threshold, bool agnostic = false)
static void nms_sorted_bboxes(const std::vector<Object>& objects, std::vector<int>& picked, float nms_threshold, bool agnostic = false)
{
picked.clear();

const int n = faceobjects.size();
const int n = objects.size();

std::vector<float> areas(n);
for (int i = 0; i < n; i++)
{
areas[i] = faceobjects[i].rect.area();
areas[i] = objects[i].rect.area();
}

for (int i = 0; i < n; i++)
{
const Object& a = faceobjects[i];
const Object& a = objects[i];

int keep = 1;
for (int j = 0; j < (int)picked.size(); j++)
{
const Object& b = faceobjects[picked[j]];
const Object& b = objects[picked[j]];

if (!agnostic && a.label != b.label)
continue;
@@ -155,66 +165,146 @@ static void nms_sorted_bboxes(const std::vector<Object>& faceobjects, std::vecto

static inline float sigmoid(float x)
{
return static_cast<float>(1.f / (1.f + exp(-x)));
return 1.0f / (1.0f + expf(-x));
}

static inline float clampf(float d, float min, float max)
static void generate_proposals(const ncnn::Mat& pred, int stride, const ncnn::Mat& in_pad, float prob_threshold, std::vector<Object>& objects)
{
const float t = d < min ? min : d;
return t > max ? max : t;
}
const int w = in_pad.w;
const int h = in_pad.h;

static void parse_yolov8_detections(
float* inputs, float confidence_threshold,
int num_channels, int num_anchors, int num_labels,
int infer_img_width, int infer_img_height,
std::vector<Object>& objects)
{
std::vector<Object> detections;
cv::Mat output = cv::Mat((int)num_channels, (int)num_anchors, CV_32F, inputs).t();
const int num_grid_x = w / stride;
const int num_grid_y = h / stride;

for (int i = 0; i < num_anchors; i++)
const int reg_max_1 = 16;
const int num_class = pred.w - reg_max_1 * 4; // number of classes. 80 for COCO

for (int y = 0; y < num_grid_y; y++)
{
const float* row_ptr = output.row(i).ptr<float>();
const float* bboxes_ptr = row_ptr;
const float* scores_ptr = row_ptr + 4;
const float* max_s_ptr = std::max_element(scores_ptr, scores_ptr + num_labels);
float score = *max_s_ptr;
if (score > confidence_threshold)
for (int x = 0; x < num_grid_x; x++)
{
float x = *bboxes_ptr++;
float y = *bboxes_ptr++;
float w = *bboxes_ptr++;
float h = *bboxes_ptr;

float x0 = clampf((x - 0.5f * w), 0.f, (float)infer_img_width);
float y0 = clampf((y - 0.5f * h), 0.f, (float)infer_img_height);
float x1 = clampf((x + 0.5f * w), 0.f, (float)infer_img_width);
float y1 = clampf((y + 0.5f * h), 0.f, (float)infer_img_height);

cv::Rect_<float> bbox;
bbox.x = x0;
bbox.y = y0;
bbox.width = x1 - x0;
bbox.height = y1 - y0;
Object object;
object.label = max_s_ptr - scores_ptr;
object.prob = score;
object.rect = bbox;
detections.push_back(object);
const ncnn::Mat pred_grid = pred.row_range(y * num_grid_x + x, 1);

// find label with max score
int label = -1;
float score = -FLT_MAX;
{
const ncnn::Mat pred_score = pred_grid.range(reg_max_1 * 4, num_class);

for (int k = 0; k < num_class; k++)
{
float s = pred_score[k];
if (s > score)
{
label = k;
score = s;
}
}

score = sigmoid(score);
}

if (score >= prob_threshold)
{
ncnn::Mat pred_bbox = pred_grid.range(0, reg_max_1 * 4).reshape(reg_max_1, 4);

{
ncnn::Layer* softmax = ncnn::create_layer("Softmax");

ncnn::ParamDict pd;
pd.set(0, 1); // axis
pd.set(1, 1);
softmax->load_param(pd);

ncnn::Option opt;
opt.num_threads = 1;
opt.use_packing_layout = false;

softmax->create_pipeline(opt);

softmax->forward_inplace(pred_bbox, opt);

softmax->destroy_pipeline(opt);

delete softmax;
}

float pred_ltrb[4];
for (int k = 0; k < 4; k++)
{
float dis = 0.f;
const float* dis_after_sm = pred_bbox.row(k);
for (int l = 0; l < reg_max_1; l++)
{
dis += l * dis_after_sm[l];
}

pred_ltrb[k] = dis * stride;
}

float pb_cx = (x + 0.5f) * stride;
float pb_cy = (y + 0.5f) * stride;

float x0 = pb_cx - pred_ltrb[0];
float y0 = pb_cy - pred_ltrb[1];
float x1 = pb_cx + pred_ltrb[2];
float y1 = pb_cy + pred_ltrb[3];

Object obj;
obj.rect.x = x0;
obj.rect.y = y0;
obj.rect.width = x1 - x0;
obj.rect.height = y1 - y0;
obj.label = label;
obj.prob = score;

objects.push_back(obj);
}
}
}
objects = detections;
}

static void generate_proposals(const ncnn::Mat& pred, const std::vector<int>& strides, const ncnn::Mat& in_pad, float prob_threshold, std::vector<Object>& objects)
{
const int w = in_pad.w;
const int h = in_pad.h;

int pred_row_offset = 0;
for (size_t i = 0; i < strides.size(); i++)
{
const int stride = strides[i];

const int num_grid_x = w / stride;
const int num_grid_y = h / stride;
const int num_grid = num_grid_x * num_grid_y;

generate_proposals(pred.row_range(pred_row_offset, num_grid), stride, in_pad, prob_threshold, objects);
pred_row_offset += num_grid;
}
}

static int detect_yolov8(const cv::Mat& bgr, std::vector<Object>& objects)
{
ncnn::Net yolov8;

yolov8.opt.use_vulkan_compute = true; // if you want detect in hardware, then enable it

yolov8.load_param("yolov8n.param");
yolov8.load_model("yolov8n.bin");
yolov8.opt.use_vulkan_compute = true;
// yolov8.opt.use_bf16_storage = true;

// https://github.com/nihui/ncnn-android-yolov8/tree/master/app/src/main/assets
yolov8.load_param("yolov8n.ncnn.param");
yolov8.load_model("yolov8n.ncnn.bin");
// yolov8.load_param("yolov8s.ncnn.param");
// yolov8.load_model("yolov8s.ncnn.bin");
// yolov8.load_param("yolov8m.ncnn.param");
// yolov8.load_model("yolov8m.ncnn.bin");

// if you use oiv7 models, you shall call draw_objects_oiv() instead
// yolov8.load_param("yolov8n_oiv7.ncnn.param");
// yolov8.load_model("yolov8n_oiv7.ncnn.bin");
// yolov8.load_param("yolov8s_oiv7.ncnn.param");
// yolov8.load_model("yolov8s_oiv7.ncnn.bin");
// yolov8.load_param("yolov8m_oiv7.ncnn.param");
// yolov8.load_model("yolov8m_oiv7.ncnn.bin");

const int target_size = 640;
const float prob_threshold = 0.25f;
@@ -223,7 +313,14 @@ static int detect_yolov8(const cv::Mat& bgr, std::vector<Object>& objects)
int img_w = bgr.cols;
int img_h = bgr.rows;

// letterbox pad to multiple of MAX_STRIDE
// ultralytics/cfg/models/v8/yolov8.yaml
std::vector<int> strides(3);
strides[0] = 8;
strides[1] = 16;
strides[2] = 32;
const int max_stride = 32;

// letterbox pad to multiple of max_stride
int w = img_w;
int h = img_h;
float scale = 1.f;
@@ -242,8 +339,9 @@ static int detect_yolov8(const cv::Mat& bgr, std::vector<Object>& objects)

ncnn::Mat in = ncnn::Mat::from_pixels_resize(bgr.data, ncnn::Mat::PIXEL_BGR2RGB, img_w, img_h, w, h);

int wpad = (target_size + MAX_STRIDE - 1) / MAX_STRIDE * MAX_STRIDE - w;
int hpad = (target_size + MAX_STRIDE - 1) / MAX_STRIDE * MAX_STRIDE - h;
// letterbox pad to target_size rectangle
int wpad = (w + max_stride - 1) / max_stride * max_stride - w;
int hpad = (h + max_stride - 1) / max_stride * max_stride - h;
ncnn::Mat in_pad;
ncnn::copy_make_border(in, in_pad, hpad / 2, hpad - hpad / 2, wpad / 2, wpad - wpad / 2, ncnn::BORDER_CONSTANT, 114.f);

@@ -254,22 +352,11 @@ static int detect_yolov8(const cv::Mat& bgr, std::vector<Object>& objects)

ex.input("in0", in_pad);

std::vector<Object> proposals;
ncnn::Mat out;
ex.extract("out0", out);

// stride 32
{
ncnn::Mat out;
ex.extract("out0", out);

std::vector<Object> objects32;
const int num_labels = 80; // COCO has detect 80 object labels.
parse_yolov8_detections(
(float*)out.data, prob_threshold,
out.h, out.w, num_labels,
in_pad.w, in_pad.h,
objects32);
proposals.insert(proposals.end(), objects32.begin(), objects32.end());
}
std::vector<Object> proposals;
generate_proposals(out, strides, in_pad, prob_threshold, proposals);

// sort all proposals by score from highest to lowest
qsort_descent_inplace(proposals);
@@ -306,7 +393,7 @@ static int detect_yolov8(const cv::Mat& bgr, std::vector<Object>& objects)
return 0;
}

static void draw_objects(const cv::Mat& bgr, const std::vector<Object>& objects)
static void draw_objects_coco(const cv::Mat& bgr, const std::vector<Object>& objects)
{
static const char* class_names[] = {
"person", "bicycle", "car", "motorcycle", "airplane", "bus", "train", "truck", "boat", "traffic light",
@@ -320,45 +407,179 @@ static void draw_objects(const cv::Mat& bgr, const std::vector<Object>& objects)
"hair drier", "toothbrush"
};

static const unsigned char colors[19][3] = {
{54, 67, 244},
{99, 30, 233},
{176, 39, 156},
{183, 58, 103},
{181, 81, 63},
{243, 150, 33},
{244, 169, 3},
{212, 188, 0},
{136, 150, 0},
{80, 175, 76},
{74, 195, 139},
{57, 220, 205},
{59, 235, 255},
{7, 193, 255},
{0, 152, 255},
{34, 87, 255},
{72, 85, 121},
{158, 158, 158},
{139, 125, 96}
static cv::Scalar colors[] = {
cv::Scalar(244, 67, 54),
cv::Scalar(233, 30, 99),
cv::Scalar(156, 39, 176),
cv::Scalar(103, 58, 183),
cv::Scalar(63, 81, 181),
cv::Scalar(33, 150, 243),
cv::Scalar(3, 169, 244),
cv::Scalar(0, 188, 212),
cv::Scalar(0, 150, 136),
cv::Scalar(76, 175, 80),
cv::Scalar(139, 195, 74),
cv::Scalar(205, 220, 57),
cv::Scalar(255, 235, 59),
cv::Scalar(255, 193, 7),
cv::Scalar(255, 152, 0),
cv::Scalar(255, 87, 34),
cv::Scalar(121, 85, 72),
cv::Scalar(158, 158, 158),
cv::Scalar(96, 125, 139)
};

int color_index = 0;

cv::Mat image = bgr.clone();

for (size_t i = 0; i < objects.size(); i++)
{
const Object& obj = objects[i];

const unsigned char* color = colors[color_index % 19];
color_index++;
const cv::Scalar& color = colors[i % 19];

fprintf(stderr, "%d = %.5f at %.2f %.2f %.2f x %.2f\n", obj.label, obj.prob,
obj.rect.x, obj.rect.y, obj.rect.width, obj.rect.height);

cv::rectangle(image, obj.rect, color);

char text[256];
sprintf(text, "%s %.1f%%", class_names[obj.label], obj.prob * 100);

int baseLine = 0;
cv::Size label_size = cv::getTextSize(text, cv::FONT_HERSHEY_SIMPLEX, 0.5, 1, &baseLine);

int x = obj.rect.x;
int y = obj.rect.y - label_size.height - baseLine;
if (y < 0)
y = 0;
if (x + label_size.width > image.cols)
x = image.cols - label_size.width;

cv::rectangle(image, cv::Rect(cv::Point(x, y), cv::Size(label_size.width, label_size.height + baseLine)),
cv::Scalar(255, 255, 255), -1);

cv::putText(image, text, cv::Point(x, y + label_size.height),
cv::FONT_HERSHEY_SIMPLEX, 0.5, cv::Scalar(0, 0, 0));
}

cv::imshow("image", image);
cv::waitKey(0);
}

static void draw_objects_oiv(const cv::Mat& bgr, const std::vector<Object>& objects)
{
static const char* class_names[] = {
"Accordion", "Adhesive tape", "Aircraft", "Airplane", "Alarm clock", "Alpaca", "Ambulance", "Animal",
"Ant", "Antelope", "Apple", "Armadillo", "Artichoke", "Auto part", "Axe", "Backpack", "Bagel",
"Baked goods", "Balance beam", "Ball", "Balloon", "Banana", "Band-aid", "Banjo", "Barge", "Barrel",
"Baseball bat", "Baseball glove", "Bat (Animal)", "Bathroom accessory", "Bathroom cabinet", "Bathtub",
"Beaker", "Bear", "Bed", "Bee", "Beehive", "Beer", "Beetle", "Bell pepper", "Belt", "Bench", "Bicycle",
"Bicycle helmet", "Bicycle wheel", "Bidet", "Billboard", "Billiard table", "Binoculars", "Bird",
"Blender", "Blue jay", "Boat", "Bomb", "Book", "Bookcase", "Boot", "Bottle", "Bottle opener",
"Bow and arrow", "Bowl", "Bowling equipment", "Box", "Boy", "Brassiere", "Bread", "Briefcase",
"Broccoli", "Bronze sculpture", "Brown bear", "Building", "Bull", "Burrito", "Bus", "Bust", "Butterfly",
"Cabbage", "Cabinetry", "Cake", "Cake stand", "Calculator", "Camel", "Camera", "Can opener", "Canary",
"Candle", "Candy", "Cannon", "Canoe", "Cantaloupe", "Car", "Carnivore", "Carrot", "Cart", "Cassette deck",
"Castle", "Cat", "Cat furniture", "Caterpillar", "Cattle", "Ceiling fan", "Cello", "Centipede",
"Chainsaw", "Chair", "Cheese", "Cheetah", "Chest of drawers", "Chicken", "Chime", "Chisel", "Chopsticks",
"Christmas tree", "Clock", "Closet", "Clothing", "Coat", "Cocktail", "Cocktail shaker", "Coconut",
"Coffee", "Coffee cup", "Coffee table", "Coffeemaker", "Coin", "Common fig", "Common sunflower",
"Computer keyboard", "Computer monitor", "Computer mouse", "Container", "Convenience store", "Cookie",
"Cooking spray", "Corded phone", "Cosmetics", "Couch", "Countertop", "Cowboy hat", "Crab", "Cream",
"Cricket ball", "Crocodile", "Croissant", "Crown", "Crutch", "Cucumber", "Cupboard", "Curtain",
"Cutting board", "Dagger", "Dairy Product", "Deer", "Desk", "Dessert", "Diaper", "Dice", "Digital clock",
"Dinosaur", "Dishwasher", "Dog", "Dog bed", "Doll", "Dolphin", "Door", "Door handle", "Doughnut",
"Dragonfly", "Drawer", "Dress", "Drill (Tool)", "Drink", "Drinking straw", "Drum", "Duck", "Dumbbell",
"Eagle", "Earrings", "Egg (Food)", "Elephant", "Envelope", "Eraser", "Face powder", "Facial tissue holder",
"Falcon", "Fashion accessory", "Fast food", "Fax", "Fedora", "Filing cabinet", "Fire hydrant",
"Fireplace", "Fish", "Flag", "Flashlight", "Flower", "Flowerpot", "Flute", "Flying disc", "Food",
"Food processor", "Football", "Football helmet", "Footwear", "Fork", "Fountain", "Fox", "French fries",
"French horn", "Frog", "Fruit", "Frying pan", "Furniture", "Garden Asparagus", "Gas stove", "Giraffe",
"Girl", "Glasses", "Glove", "Goat", "Goggles", "Goldfish", "Golf ball", "Golf cart", "Gondola",
"Goose", "Grape", "Grapefruit", "Grinder", "Guacamole", "Guitar", "Hair dryer", "Hair spray", "Hamburger",
"Hammer", "Hamster", "Hand dryer", "Handbag", "Handgun", "Harbor seal", "Harmonica", "Harp",
"Harpsichord", "Hat", "Headphones", "Heater", "Hedgehog", "Helicopter", "Helmet", "High heels",
"Hiking equipment", "Hippopotamus", "Home appliance", "Honeycomb", "Horizontal bar", "Horse", "Hot dog",
"House", "Houseplant", "Human arm", "Human beard", "Human body", "Human ear", "Human eye", "Human face",
"Human foot", "Human hair", "Human hand", "Human head", "Human leg", "Human mouth", "Human nose",
"Humidifier", "Ice cream", "Indoor rower", "Infant bed", "Insect", "Invertebrate", "Ipod", "Isopod",
"Jacket", "Jacuzzi", "Jaguar (Animal)", "Jeans", "Jellyfish", "Jet ski", "Jug", "Juice", "Kangaroo",
"Kettle", "Kitchen & dining room table", "Kitchen appliance", "Kitchen knife", "Kitchen utensil",
"Kitchenware", "Kite", "Knife", "Koala", "Ladder", "Ladle", "Ladybug", "Lamp", "Land vehicle",
"Lantern", "Laptop", "Lavender (Plant)", "Lemon", "Leopard", "Light bulb", "Light switch", "Lighthouse",
"Lily", "Limousine", "Lion", "Lipstick", "Lizard", "Lobster", "Loveseat", "Luggage and bags", "Lynx",
"Magpie", "Mammal", "Man", "Mango", "Maple", "Maracas", "Marine invertebrates", "Marine mammal",
"Measuring cup", "Mechanical fan", "Medical equipment", "Microphone", "Microwave oven", "Milk",
"Miniskirt", "Mirror", "Missile", "Mixer", "Mixing bowl", "Mobile phone", "Monkey", "Moths and butterflies",
"Motorcycle", "Mouse", "Muffin", "Mug", "Mule", "Mushroom", "Musical instrument", "Musical keyboard",
"Nail (Construction)", "Necklace", "Nightstand", "Oboe", "Office building", "Office supplies", "Orange",
"Organ (Musical Instrument)", "Ostrich", "Otter", "Oven", "Owl", "Oyster", "Paddle", "Palm tree",
"Pancake", "Panda", "Paper cutter", "Paper towel", "Parachute", "Parking meter", "Parrot", "Pasta",
"Pastry", "Peach", "Pear", "Pen", "Pencil case", "Pencil sharpener", "Penguin", "Perfume", "Person",
"Personal care", "Personal flotation device", "Piano", "Picnic basket", "Picture frame", "Pig",
"Pillow", "Pineapple", "Pitcher (Container)", "Pizza", "Pizza cutter", "Plant", "Plastic bag", "Plate",
"Platter", "Plumbing fixture", "Polar bear", "Pomegranate", "Popcorn", "Porch", "Porcupine", "Poster",
"Potato", "Power plugs and sockets", "Pressure cooker", "Pretzel", "Printer", "Pumpkin", "Punching bag",
"Rabbit", "Raccoon", "Racket", "Radish", "Ratchet (Device)", "Raven", "Rays and skates", "Red panda",
"Refrigerator", "Remote control", "Reptile", "Rhinoceros", "Rifle", "Ring binder", "Rocket",
"Roller skates", "Rose", "Rugby ball", "Ruler", "Salad", "Salt and pepper shakers", "Sandal",
"Sandwich", "Saucer", "Saxophone", "Scale", "Scarf", "Scissors", "Scoreboard", "Scorpion",
"Screwdriver", "Sculpture", "Sea lion", "Sea turtle", "Seafood", "Seahorse", "Seat belt", "Segway",
"Serving tray", "Sewing machine", "Shark", "Sheep", "Shelf", "Shellfish", "Shirt", "Shorts",
"Shotgun", "Shower", "Shrimp", "Sink", "Skateboard", "Ski", "Skirt", "Skull", "Skunk", "Skyscraper",
"Slow cooker", "Snack", "Snail", "Snake", "Snowboard", "Snowman", "Snowmobile", "Snowplow",
"Soap dispenser", "Sock", "Sofa bed", "Sombrero", "Sparrow", "Spatula", "Spice rack", "Spider",
"Spoon", "Sports equipment", "Sports uniform", "Squash (Plant)", "Squid", "Squirrel", "Stairs",
"Stapler", "Starfish", "Stationary bicycle", "Stethoscope", "Stool", "Stop sign", "Strawberry",
"Street light", "Stretcher", "Studio couch", "Submarine", "Submarine sandwich", "Suit", "Suitcase",
"Sun hat", "Sunglasses", "Surfboard", "Sushi", "Swan", "Swim cap", "Swimming pool", "Swimwear",
"Sword", "Syringe", "Table", "Table tennis racket", "Tablet computer", "Tableware", "Taco", "Tank",
"Tap", "Tart", "Taxi", "Tea", "Teapot", "Teddy bear", "Telephone", "Television", "Tennis ball",
"Tennis racket", "Tent", "Tiara", "Tick", "Tie", "Tiger", "Tin can", "Tire", "Toaster", "Toilet",
"Toilet paper", "Tomato", "Tool", "Toothbrush", "Torch", "Tortoise", "Towel", "Tower", "Toy",
"Traffic light", "Traffic sign", "Train", "Training bench", "Treadmill", "Tree", "Tree house",
"Tripod", "Trombone", "Trousers", "Truck", "Trumpet", "Turkey", "Turtle", "Umbrella", "Unicycle",
"Van", "Vase", "Vegetable", "Vehicle", "Vehicle registration plate", "Violin", "Volleyball (Ball)",
"Waffle", "Waffle iron", "Wall clock", "Wardrobe", "Washing machine", "Waste container", "Watch",
"Watercraft", "Watermelon", "Weapon", "Whale", "Wheel", "Wheelchair", "Whisk", "Whiteboard", "Willow",
"Window", "Window blind", "Wine", "Wine glass", "Wine rack", "Winter melon", "Wok", "Woman",
"Wood-burning stove", "Woodpecker", "Worm", "Wrench", "Zebra", "Zucchini"
};

static cv::Scalar colors[] = {
cv::Scalar(244, 67, 54),
cv::Scalar(233, 30, 99),
cv::Scalar(156, 39, 176),
cv::Scalar(103, 58, 183),
cv::Scalar(63, 81, 181),
cv::Scalar(33, 150, 243),
cv::Scalar(3, 169, 244),
cv::Scalar(0, 188, 212),
cv::Scalar(0, 150, 136),
cv::Scalar(76, 175, 80),
cv::Scalar(139, 195, 74),
cv::Scalar(205, 220, 57),
cv::Scalar(255, 235, 59),
cv::Scalar(255, 193, 7),
cv::Scalar(255, 152, 0),
cv::Scalar(255, 87, 34),
cv::Scalar(121, 85, 72),
cv::Scalar(158, 158, 158),
cv::Scalar(96, 125, 139)
};

cv::Mat image = bgr.clone();

for (size_t i = 0; i < objects.size(); i++)
{
const Object& obj = objects[i];

cv::Scalar cc(color[0], color[1], color[2]);
const cv::Scalar& color = colors[i % 19];

fprintf(stderr, "%d = %.5f at %.2f %.2f %.2f x %.2f\n", obj.label, obj.prob,
obj.rect.x, obj.rect.y, obj.rect.width, obj.rect.height);

cv::rectangle(image, obj.rect, cc, 2);
cv::rectangle(image, obj.rect, color);

char text[256];
sprintf(text, "%s %.1f%%", class_names[obj.label], obj.prob * 100);
@@ -374,10 +595,10 @@ static void draw_objects(const cv::Mat& bgr, const std::vector<Object>& objects)
x = image.cols - label_size.width;

cv::rectangle(image, cv::Rect(cv::Point(x, y), cv::Size(label_size.width, label_size.height + baseLine)),
cc, -1);
cv::Scalar(255, 255, 255), -1);

cv::putText(image, text, cv::Point(x, y + label_size.height),
cv::FONT_HERSHEY_SIMPLEX, 0.5, cv::Scalar(255, 255, 255));
cv::FONT_HERSHEY_SIMPLEX, 0.5, cv::Scalar(0, 0, 0));
}

cv::imshow("image", image);
@@ -404,7 +625,8 @@ int main(int argc, char** argv)
std::vector<Object> objects;
detect_yolov8(m, objects);

draw_objects(m, objects);
draw_objects_coco(m, objects);
// draw_objects_oiv(m, objects);

return 0;
}

+ 325
- 0
examples/yolov8_cls.cpp View File

@@ -0,0 +1,325 @@
// Tencent is pleased to support the open source community by making ncnn available.
//
// Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved.
//
// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
// in compliance with the License. You may obtain a copy of the License at
//
// https://opensource.org/licenses/BSD-3-Clause
//
// Unless required by applicable law or agreed to in writing, software distributed
// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
// specific language governing permissions and limitations under the License.

// 1. install
// pip3 install -U ultralytics pnnx ncnn
// 2. export yolov8-cls torchscript
// yolo export model=yolov8n-cls.pt format=torchscript
// 3. convert torchscript with static shape
// pnnx yolov8n-cls.torchscript
// 4. now you get ncnn model files
// yolov8n_cls.ncnn.param
// yolov8n_cls.ncnn.bin

#include "net.h"

#if defined(USE_NCNN_SIMPLEOCV)
#include "simpleocv.h"
#else
#include <opencv2/core/core.hpp>
#include <opencv2/highgui/highgui.hpp>
#include <opencv2/imgproc/imgproc.hpp>
#endif
#include <float.h>
#include <stdio.h>
#include <vector>

struct Object
{
int label;
float prob;
};

static void get_topk(const ncnn::Mat& cls_scores, int topk, std::vector<Object>& objects)
{
// partial sort topk with index
int size = cls_scores.w;
std::vector<std::pair<float, int> > vec;
vec.resize(size);
for (int i = 0; i < size; i++)
{
vec[i] = std::make_pair(cls_scores[i], i);
}

std::partial_sort(vec.begin(), vec.begin() + topk, vec.end(),
std::greater<std::pair<float, int> >());

objects.resize(topk);
for (int i = 0; i < topk; i++)
{
objects[i].label = vec[i].second;
objects[i].prob = vec[i].first;
}
}

static int detect_yolov8_cls(const cv::Mat& bgr, std::vector<Object>& objects)
{
ncnn::Net yolov8;

yolov8.opt.use_vulkan_compute = true;
// yolov8.opt.use_bf16_storage = true;

// https://github.com/nihui/ncnn-android-yolov8/tree/master/app/src/main/assets
yolov8.load_param("yolov8n_cls.ncnn.param");
yolov8.load_model("yolov8n_cls.ncnn.bin");
// yolov8.load_param("yolov8s_cls.ncnn.param");
// yolov8.load_model("yolov8s_cls.ncnn.bin");
// yolov8.load_param("yolov8m_cls.ncnn.param");
// yolov8.load_model("yolov8m_cls.ncnn.bin");

const int target_size = 224;
const int topk = 5;

int img_w = bgr.cols;
int img_h = bgr.rows;

// letterbox pad
int w = img_w;
int h = img_h;
float scale = 1.f;
if (w > h)
{
scale = (float)target_size / w;
w = target_size;
h = h * scale;
}
else
{
scale = (float)target_size / h;
h = target_size;
w = w * scale;
}

ncnn::Mat in = ncnn::Mat::from_pixels_resize(bgr.data, ncnn::Mat::PIXEL_BGR2RGB, img_w, img_h, w, h);

// letterbox pad to target_size rectangle
int wpad = target_size - w;
int hpad = target_size - h;
ncnn::Mat in_pad;
ncnn::copy_make_border(in, in_pad, hpad / 2, hpad - hpad / 2, wpad / 2, wpad - wpad / 2, ncnn::BORDER_CONSTANT, 114.f);

const float norm_vals[3] = {1 / 255.f, 1 / 255.f, 1 / 255.f};
in_pad.substract_mean_normalize(0, norm_vals);

ncnn::Extractor ex = yolov8.create_extractor();

ex.input("in0", in_pad);

ncnn::Mat out;
ex.extract("out0", out);

// return top-5
get_topk(out, topk, objects);

return 0;
}

static void draw_objects(const cv::Mat& bgr, const std::vector<Object>& objects)
{
static const char* class_names[] = {
"tench", "goldfish", "great white shark", "tiger shark", "hammerhead", "electric ray", "stingray", "cock",
"hen", "ostrich", "brambling", "goldfinch", "house finch", "junco", "indigo bunting", "robin", "bulbul",
"jay", "magpie", "chickadee", "water ouzel", "kite", "bald eagle", "vulture", "great grey owl",
"European fire salamander", "common newt", "eft", "spotted salamander", "axolotl", "bullfrog", "tree frog",
"tailed frog", "loggerhead", "leatherback turtle", "mud turtle", "terrapin", "box turtle", "banded gecko",
"common iguana", "American chameleon", "whiptail", "agama", "frilled lizard", "alligator lizard",
"Gila monster", "green lizard", "African chameleon", "Komodo dragon", "African crocodile",
"American alligator", "triceratops", "thunder snake", "ringneck snake", "hognose snake", "green snake",
"king snake", "garter snake", "water snake", "vine snake", "night snake", "boa constrictor", "rock python",
"Indian cobra", "green mamba", "sea snake", "horned viper", "diamondback", "sidewinder", "trilobite",
"harvestman", "scorpion", "black and gold garden spider", "barn spider", "garden spider", "black widow",
"tarantula", "wolf spider", "tick", "centipede", "black grouse", "ptarmigan", "ruffed grouse",
"prairie chicken", "peacock", "quail", "partridge", "African grey", "macaw", "sulphur-crested cockatoo",
"lorikeet", "coucal", "bee eater", "hornbill", "hummingbird", "jacamar", "toucan", "drake",
"red-breasted merganser", "goose", "black swan", "tusker", "echidna", "platypus", "wallaby", "koala",
"wombat", "jellyfish", "sea anemone", "brain coral", "flatworm", "nematode", "conch", "snail", "slug",
"sea slug", "chiton", "chambered nautilus", "Dungeness crab", "rock crab", "fiddler crab", "king crab",
"American lobster", "spiny lobster", "crayfish", "hermit crab", "isopod", "white stork", "black stork",
"spoonbill", "flamingo", "little blue heron", "American egret", "bittern", "crane (bird)", "limpkin",
"European gallinule", "American coot", "bustard", "ruddy turnstone", "red-backed sandpiper", "redshank",
"dowitcher", "oystercatcher", "pelican", "king penguin", "albatross", "grey whale", "killer whale",
"dugong", "sea lion", "Chihuahua", "Japanese spaniel", "Maltese dog", "Pekinese", "Shih-Tzu",
"Blenheim spaniel", "papillon", "toy terrier", "Rhodesian ridgeback", "Afghan hound", "basset", "beagle",
"bloodhound", "bluetick", "black-and-tan coonhound", "Walker hound", "English foxhound", "redbone",
"borzoi", "Irish wolfhound", "Italian greyhound", "whippet", "Ibizan hound", "Norwegian elkhound",
"otterhound", "Saluki", "Scottish deerhound", "Weimaraner", "Staffordshire bullterrier",
"American Staffordshire terrier", "Bedlington terrier", "Border terrier", "Kerry blue terrier",
"Irish terrier", "Norfolk terrier", "Norwich terrier", "Yorkshire terrier", "wire-haired fox terrier",
"Lakeland terrier", "Sealyham terrier", "Airedale", "cairn", "Australian terrier", "Dandie Dinmont",
"Boston bull", "miniature schnauzer", "giant schnauzer", "standard schnauzer", "Scotch terrier",
"Tibetan terrier", "silky terrier", "soft-coated wheaten terrier", "West Highland white terrier",
"Lhasa", "flat-coated retriever", "curly-coated retriever", "golden retriever", "Labrador retriever",
"Chesapeake Bay retriever", "German short-haired pointer", "vizsla", "English setter", "Irish setter",
"Gordon setter", "Brittany spaniel", "clumber", "English springer", "Welsh springer spaniel",
"cocker spaniel", "Sussex spaniel", "Irish water spaniel", "kuvasz", "schipperke", "groenendael",
"malinois", "briard", "kelpie", "komondor", "Old English sheepdog", "Shetland sheepdog", "collie",
"Border collie", "Bouvier des Flandres", "Rottweiler", "German shepherd", "Doberman",
"miniature pinscher", "Greater Swiss Mountain dog", "Bernese mountain dog", "Appenzeller", "EntleBucher",
"boxer", "bull mastiff", "Tibetan mastiff", "French bulldog", "Great Dane", "Saint Bernard",
"Eskimo dog", "malamute", "Siberian husky", "dalmatian", "affenpinscher", "basenji", "pug", "Leonberg",
"Newfoundland", "Great Pyrenees", "Samoyed", "Pomeranian", "chow", "keeshond", "Brabancon griffon",
"Pembroke", "Cardigan", "toy poodle", "miniature poodle", "standard poodle", "Mexican hairless",
"timber wolf", "white wolf", "red wolf", "coyote", "dingo", "dhole", "African hunting dog", "hyena",
"red fox", "kit fox", "Arctic fox", "grey fox", "tabby", "tiger cat", "Persian cat", "Siamese cat",
"Egyptian cat", "cougar", "lynx", "leopard", "snow leopard", "jaguar", "lion", "tiger", "cheetah",
"brown bear", "American black bear", "ice bear", "sloth bear", "mongoose", "meerkat", "tiger beetle",
"ladybug", "ground beetle", "long-horned beetle", "leaf beetle", "dung beetle", "rhinoceros beetle",
"weevil", "fly", "bee", "ant", "grasshopper", "cricket", "walking stick", "cockroach", "mantis",
"cicada", "leafhopper", "lacewing", "dragonfly", "damselfly", "admiral", "ringlet", "monarch",
"cabbage butterfly", "sulphur butterfly", "lycaenid", "starfish", "sea urchin", "sea cucumber",
"wood rabbit", "hare", "Angora", "hamster", "porcupine", "fox squirrel", "marmot", "beaver",
"guinea pig", "sorrel", "zebra", "hog", "wild boar", "warthog", "hippopotamus", "ox", "water buffalo",
"bison", "ram", "bighorn", "ibex", "hartebeest", "impala", "gazelle", "Arabian camel", "llama",
"weasel", "mink", "polecat", "black-footed ferret", "otter", "skunk", "badger", "armadillo",
"three-toed sloth", "orangutan", "gorilla", "chimpanzee", "gibbon", "siamang", "guenon", "patas",
"baboon", "macaque", "langur", "colobus", "proboscis monkey", "marmoset", "capuchin", "howler monkey",
"titi", "spider monkey", "squirrel monkey", "Madagascar cat", "indri", "Indian elephant",
"African elephant", "lesser panda", "giant panda", "barracouta", "eel", "coho", "rock beauty",
"anemone fish", "sturgeon", "gar", "lionfish", "puffer", "abacus", "abaya", "academic gown",
"accordion", "acoustic guitar", "aircraft carrier", "airliner", "airship", "altar", "ambulance",
"amphibian", "analog clock", "apiary", "apron", "ashcan", "assault rifle", "backpack", "bakery",
"balance beam", "balloon", "ballpoint", "Band Aid", "banjo", "bannister", "barbell", "barber chair",
"barbershop", "barn", "barometer", "barrel", "barrow", "baseball", "basketball", "bassinet", "bassoon",
"bathing cap", "bath towel", "bathtub", "beach wagon", "beacon", "beaker", "bearskin", "beer bottle",
"beer glass", "bell cote", "bib", "bicycle-built-for-two", "bikini", "binder", "binoculars",
"birdhouse", "boathouse", "bobsled", "bolo tie", "bonnet", "bookcase", "bookshop", "bottlecap", "bow",
"bow tie", "brass", "brassiere", "breakwater", "breastplate", "broom", "bucket", "buckle",
"bulletproof vest", "bullet train", "butcher shop", "cab", "caldron", "candle", "cannon", "canoe",
"can opener", "cardigan", "car mirror", "carousel", "carpenter's kit", "carton", "car wheel",
"cash machine", "cassette", "cassette player", "castle", "catamaran", "CD player", "cello",
"cellular telephone", "chain", "chainlink fence", "chain mail", "chain saw", "chest", "chiffonier",
"chime", "china cabinet", "Christmas stocking", "church", "cinema", "cleaver", "cliff dwelling",
"cloak", "clog", "cocktail shaker", "coffee mug", "coffeepot", "coil", "combination lock",
"computer keyboard", "confectionery", "container ship", "convertible", "corkscrew", "cornet",
"cowboy boot", "cowboy hat", "cradle", "crane (machine)", "crash helmet", "crate", "crib",
"Crock Pot", "croquet ball", "crutch", "cuirass", "dam", "desk", "desktop computer", "dial telephone",
"diaper", "digital clock", "digital watch", "dining table", "dishrag", "dishwasher", "disk brake",
"dock", "dogsled", "dome", "doormat", "drilling platform", "drum", "drumstick", "dumbbell",
"Dutch oven", "electric fan", "electric guitar", "electric locomotive", "entertainment center",
"envelope", "espresso maker", "face powder", "feather boa", "file", "fireboat", "fire engine",
"fire screen", "flagpole", "flute", "folding chair", "football helmet", "forklift", "fountain",
"fountain pen", "four-poster", "freight car", "French horn", "frying pan", "fur coat", "garbage truck",
"gasmask", "gas pump", "goblet", "go-kart", "golf ball", "golfcart", "gondola", "gong", "gown",
"grand piano", "greenhouse", "grille", "grocery store", "guillotine", "hair slide", "hair spray",
"half track", "hammer", "hamper", "hand blower", "hand-held computer", "handkerchief", "hard disc",
"harmonica", "harp", "harvester", "hatchet", "holster", "home theater", "honeycomb", "hook",
"hoopskirt", "horizontal bar", "horse cart", "hourglass", "iPod", "iron", "jack-o'-lantern", "jean",
"jeep", "jersey", "jigsaw puzzle", "jinrikisha", "joystick", "kimono", "knee pad", "knot", "lab coat",
"ladle", "lampshade", "laptop", "lawn mower", "lens cap", "letter opener", "library", "lifeboat",
"lighter", "limousine", "liner", "lipstick", "Loafer", "lotion", "loudspeaker", "loupe", "lumbermill",
"magnetic compass", "mailbag", "mailbox", "maillot (tights)", "maillot (tank suit)", "manhole cover",
"maraca", "marimba", "mask", "matchstick", "maypole", "maze", "measuring cup", "medicine chest",
"megalith", "microphone", "microwave", "military uniform", "milk can", "minibus", "miniskirt",
"minivan", "missile", "mitten", "mixing bowl", "mobile home", "Model T", "modem", "monastery",
"monitor", "moped", "mortar", "mortarboard", "mosque", "mosquito net", "motor scooter", "mountain bike",
"mountain tent", "mouse", "mousetrap", "moving van", "muzzle", "nail", "neck brace", "necklace",
"nipple", "notebook", "obelisk", "oboe", "ocarina", "odometer", "oil filter", "organ", "oscilloscope",
"overskirt", "oxcart", "oxygen mask", "packet", "paddle", "paddlewheel", "padlock", "paintbrush",
"pajama", "palace", "panpipe", "paper towel", "parachute", "parallel bars", "park bench",
"parking meter", "passenger car", "patio", "pay-phone", "pedestal", "pencil box", "pencil sharpener",
"perfume", "Petri dish", "photocopier", "pick", "pickelhaube", "picket fence", "pickup", "pier",
"piggy bank", "pill bottle", "pillow", "ping-pong ball", "pinwheel", "pirate", "pitcher", "plane",
"planetarium", "plastic bag", "plate rack", "plow", "plunger", "Polaroid camera", "pole",
"police van", "poncho", "pool table", "pop bottle", "pot", "potter's wheel", "power drill",
"prayer rug", "printer", "prison", "projectile", "projector", "puck", "punching bag", "purse",
"quill", "quilt", "racer", "racket", "radiator", "radio", "radio telescope", "rain barrel",
"recreational vehicle", "reel", "reflex camera", "refrigerator", "remote control", "restaurant",
"revolver", "rifle", "rocking chair", "rotisserie", "rubber eraser", "rugby ball", "rule",
"running shoe", "safe", "safety pin", "saltshaker", "sandal", "sarong", "sax", "scabbard", "scale",
"school bus", "schooner", "scoreboard", "screen", "screw", "screwdriver", "seat belt", "sewing machine",
"shield", "shoe shop", "shoji", "shopping basket", "shopping cart", "shovel", "shower cap",
"shower curtain", "ski", "ski mask", "sleeping bag", "slide rule", "sliding door", "slot", "snorkel",
"snowmobile", "snowplow", "soap dispenser", "soccer ball", "sock", "solar dish", "sombrero",
"soup bowl", "space bar", "space heater", "space shuttle", "spatula", "speedboat", "spider web",
"spindle", "sports car", "spotlight", "stage", "steam locomotive", "steel arch bridge", "steel drum",
"stethoscope", "stole", "stone wall", "stopwatch", "stove", "strainer", "streetcar", "stretcher",
"studio couch", "stupa", "submarine", "suit", "sundial", "sunglass", "sunglasses", "sunscreen",
"suspension bridge", "swab", "sweatshirt", "swimming trunks", "swing", "switch", "syringe",
"table lamp", "tank", "tape player", "teapot", "teddy", "television", "tennis ball", "thatch",
"theater curtain", "thimble", "thresher", "throne", "tile roof", "toaster", "tobacco shop",
"toilet seat", "torch", "totem pole", "tow truck", "toyshop", "tractor", "trailer truck", "tray",
"trench coat", "tricycle", "trimaran", "tripod", "triumphal arch", "trolleybus", "trombone", "tub",
"turnstile", "typewriter keyboard", "umbrella", "unicycle", "upright", "vacuum", "vase", "vault",
"velvet", "vending machine", "vestment", "viaduct", "violin", "volleyball", "waffle iron", "wall clock",
"wallet", "wardrobe", "warplane", "washbasin", "washer", "water bottle", "water jug", "water tower",
"whiskey jug", "whistle", "wig", "window screen", "window shade", "Windsor tie", "wine bottle", "wing",
"wok", "wooden spoon", "wool", "worm fence", "wreck", "yawl", "yurt", "web site", "comic book",
"crossword puzzle", "street sign", "traffic light", "book jacket", "menu", "plate", "guacamole",
"consomme", "hot pot", "trifle", "ice cream", "ice lolly", "French loaf", "bagel", "pretzel",
"cheeseburger", "hotdog", "mashed potato", "head cabbage", "broccoli", "cauliflower", "zucchini",
"spaghetti squash", "acorn squash", "butternut squash", "cucumber", "artichoke", "bell pepper",
"cardoon", "mushroom", "Granny Smith", "strawberry", "orange", "lemon", "fig", "pineapple", "banana",
"jackfruit", "custard apple", "pomegranate", "hay", "carbonara", "chocolate sauce", "dough",
"meat loaf", "pizza", "potpie", "burrito", "red wine", "espresso", "cup", "eggnog", "alp", "bubble",
"cliff", "coral reef", "geyser", "lakeside", "promontory", "sandbar", "seashore", "valley", "volcano",
"ballplayer", "groom", "scuba diver", "rapeseed", "daisy", "yellow lady's slipper", "corn", "acorn",
"hip", "buckeye", "coral fungus", "agaric", "gyromitra", "stinkhorn", "earthstar", "hen-of-the-woods",
"bolete", "ear", "toilet tissue"
};

cv::Mat image = bgr.clone();

int y_offset = 0;
for (size_t i = 0; i < objects.size(); i++)
{
const Object& obj = objects[i];

fprintf(stderr, "%d = %.5f\n", obj.label, obj.prob);

char text[256];
sprintf(text, "%4.1f%% %s", obj.prob * 100, class_names[obj.label]);

int baseLine = 0;
cv::Size label_size = cv::getTextSize(text, cv::FONT_HERSHEY_SIMPLEX, 0.5, 1, &baseLine);

int x = 0;
int y = y_offset;

cv::rectangle(image, cv::Rect(cv::Point(x, y), cv::Size(label_size.width, label_size.height + baseLine)),
cv::Scalar(255, 255, 255), -1);

cv::putText(image, text, cv::Point(x, y + label_size.height),
cv::FONT_HERSHEY_SIMPLEX, 0.5, cv::Scalar(0, 0, 0));

y_offset += label_size.height;
}

cv::imshow("image", image);
cv::waitKey(0);
}

int main(int argc, char** argv)
{
if (argc != 2)
{
fprintf(stderr, "Usage: %s [imagepath]\n", argv[0]);
return -1;
}

const char* imagepath = argv[1];

cv::Mat m = cv::imread(imagepath, 1);
if (m.empty())
{
fprintf(stderr, "cv::imread %s failed\n", imagepath);
return -1;
}

std::vector<Object> objects;
detect_yolov8_cls(m, objects);

draw_objects(m, objects);

return 0;
}

+ 522
- 0
examples/yolov8_obb.cpp View File

@@ -0,0 +1,522 @@
// Tencent is pleased to support the open source community by making ncnn available.
//
// Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved.
//
// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
// in compliance with the License. You may obtain a copy of the License at
//
// https://opensource.org/licenses/BSD-3-Clause
//
// Unless required by applicable law or agreed to in writing, software distributed
// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
// specific language governing permissions and limitations under the License.

// 1. install
// pip3 install -U ultralytics pnnx ncnn
// 2. export yolov8-obb torchscript
// yolo export model=yolov8n-obb.pt format=torchscript
// 3. convert torchscript with static shape
// pnnx yolov8n-obb.torchscript
// 4. modify yolov8n_obb_pnnx.py for dynamic shape inference
// A. modify reshape to support dynamic image sizes
// B. permute tensor before concat and adjust concat axis
// C. drop post-process part
// before:
// v_137 = v_136.view(1, 1, 16384)
// v_143 = v_142.view(1, 1, 4096)
// v_149 = v_148.view(1, 1, 1024)
// v_150 = torch.cat((v_137, v_143, v_149), dim=2)
// ...
// v_186 = v_163.view(1, 79, 16384)
// v_187 = v_174.view(1, 79, 4096)
// v_188 = v_185.view(1, 79, 1024)
// v_189 = torch.cat((v_186, v_187, v_188), dim=2)
// ...
// after:
// v_137 = v_136.view(1, 1, -1).transpose(1, 2)
// v_143 = v_142.view(1, 1, -1).transpose(1, 2)
// v_149 = v_148.view(1, 1, -1).transpose(1, 2)
// v_150 = torch.cat((v_137, v_143, v_149), dim=1)
// ...
// v_186 = v_163.view(1, 79, -1).transpose(1, 2)
// v_187 = v_174.view(1, 79, -1).transpose(1, 2)
// v_188 = v_185.view(1, 79, -1).transpose(1, 2)
// v_189 = torch.cat((v_186, v_187, v_188), dim=1)
// return v_189, v_150
// 5. re-export yolov8-obb torchscript
// python3 -c 'import yolov8n_obb_pnnx; yolov8n_obb_pnnx.export_torchscript()'
// 6. convert new torchscript with dynamic shape
// pnnx yolov8n_obb_pnnx.py.pt inputshape=[1,3,1024,1024] inputshape2=[1,3,512,512]
// 7. now you get ncnn model files
// mv yolov8n_obb_pnnx.py.ncnn.param yolov8n_obb.ncnn.param
// mv yolov8n_obb_pnnx.py.ncnn.bin yolov8n_obb.ncnn.bin

// the out blob would be a 2-dim tensor with w=79 h=21504
//
// | bbox-reg 16 x 4 |score(15)|
// +-----+-----+-----+-----+---------+
// | dx0 | dy0 | dx1 | dy1 | 0.1 ... |
// all /| | | | | ... |
// boxes | .. | .. | .. | .. | 0.0 ... |
// (21504)| | | | | . ... |
// \| | | | | . ... |
// +-----+-----+-----+-----+---------+
//

// the out blob would be a 2-dim tensor with w=1 h=21504
//
// | degree(1)|
// +----------+
// | 0.1 |
// all /| |
// boxes | 0.0 |
// (21504)| . |
// \| . |
// +----------+
//

#include "layer.h"
#include "net.h"

#include <opencv2/core/core.hpp>
#include <opencv2/highgui/highgui.hpp>
#include <opencv2/imgproc/imgproc.hpp>

#include <float.h>
#include <math.h>
#include <stdio.h>
#include <vector>

struct Object
{
cv::RotatedRect rrect;
int label;
float prob;
};

static inline float intersection_area(const Object& a, const Object& b)
{
std::vector<cv::Point2f> intersection;
cv::rotatedRectangleIntersection(a.rrect, b.rrect, intersection);
if (intersection.empty())
return 0.f;

return cv::contourArea(intersection);
}

static void qsort_descent_inplace(std::vector<Object>& objects, int left, int right)
{
int i = left;
int j = right;
float p = objects[(left + right) / 2].prob;

while (i <= j)
{
while (objects[i].prob > p)
i++;

while (objects[j].prob < p)
j--;

if (i <= j)
{
// swap
std::swap(objects[i], objects[j]);

i++;
j--;
}
}

// #pragma omp parallel sections
{
// #pragma omp section
{
if (left < j) qsort_descent_inplace(objects, left, j);
}
// #pragma omp section
{
if (i < right) qsort_descent_inplace(objects, i, right);
}
}
}

static void qsort_descent_inplace(std::vector<Object>& objects)
{
if (objects.empty())
return;

qsort_descent_inplace(objects, 0, objects.size() - 1);
}

static void nms_sorted_bboxes(const std::vector<Object>& objects, std::vector<int>& picked, float nms_threshold, bool agnostic = false)
{
picked.clear();

const int n = objects.size();

std::vector<float> areas(n);
for (int i = 0; i < n; i++)
{
areas[i] = objects[i].rrect.size.area();
}

for (int i = 0; i < n; i++)
{
const Object& a = objects[i];

int keep = 1;
for (int j = 0; j < (int)picked.size(); j++)
{
const Object& b = objects[picked[j]];

if (!agnostic && a.label != b.label)
continue;

// intersection over union
float inter_area = intersection_area(a, b);
float union_area = areas[i] + areas[picked[j]] - inter_area;
// float IoU = inter_area / union_area;
if (inter_area / union_area > nms_threshold)
keep = 0;
}

if (keep)
picked.push_back(i);
}
}

static inline float sigmoid(float x)
{
return 1.0f / (1.0f + expf(-x));
}

static void generate_proposals(const ncnn::Mat& pred, const ncnn::Mat& pred_angle, int stride, const ncnn::Mat& in_pad, float prob_threshold, std::vector<Object>& objects)
{
const int w = in_pad.w;
const int h = in_pad.h;

const int num_grid_x = w / stride;
const int num_grid_y = h / stride;

const int reg_max_1 = 16;
const int num_class = pred.w - reg_max_1 * 4; // number of classes. 15 for DOTAv1

for (int y = 0; y < num_grid_y; y++)
{
for (int x = 0; x < num_grid_x; x++)
{
const ncnn::Mat pred_grid = pred.row_range(y * num_grid_x + x, 1);

// find label with max score
int label = -1;
float score = -FLT_MAX;
{
const ncnn::Mat pred_score = pred_grid.range(reg_max_1 * 4, num_class);

for (int k = 0; k < num_class; k++)
{
float s = pred_score[k];
if (s > score)
{
label = k;
score = s;
}
}

score = sigmoid(score);
}

if (score >= prob_threshold)
{
ncnn::Mat pred_bbox = pred_grid.range(0, reg_max_1 * 4).reshape(reg_max_1, 4).clone();

{
ncnn::Layer* softmax = ncnn::create_layer("Softmax");

ncnn::ParamDict pd;
pd.set(0, 1); // axis
pd.set(1, 1);
softmax->load_param(pd);

ncnn::Option opt;
opt.num_threads = 1;
opt.use_packing_layout = false;

softmax->create_pipeline(opt);

softmax->forward_inplace(pred_bbox, opt);

softmax->destroy_pipeline(opt);

delete softmax;
}

float pred_ltrb[4];
for (int k = 0; k < 4; k++)
{
float dis = 0.f;
const float* dis_after_sm = pred_bbox.row(k);
for (int l = 0; l < reg_max_1; l++)
{
dis += l * dis_after_sm[l];
}

pred_ltrb[k] = dis * stride;
}

float pb_cx = (x + 0.5f) * stride;
float pb_cy = (y + 0.5f) * stride;

const float angle = sigmoid(pred_angle.row(y * num_grid_x + x)[0]) - 0.25f;

const float angle_rad = angle * 3.14159265358979323846f;
const float angle_degree = angle * 180.f;

float cos = cosf(angle_rad);
float sin = sinf(angle_rad);

float xx = (pred_ltrb[2] - pred_ltrb[0]) * 0.5f;
float yy = (pred_ltrb[3] - pred_ltrb[1]) * 0.5f;
float xr = xx * cos - yy * sin;
float yr = xx * sin + yy * cos;
const float cx = pb_cx + xr;
const float cy = pb_cy + yr;
const float ww = pred_ltrb[2] + pred_ltrb[0];
const float hh = pred_ltrb[3] + pred_ltrb[1];

Object obj;
obj.rrect = cv::RotatedRect(cv::Point2f(cx, cy), cv::Size_<float>(ww, hh), angle_degree);
obj.label = label;
obj.prob = score;

objects.push_back(obj);
}
}
}
}

static void generate_proposals(const ncnn::Mat& pred, const ncnn::Mat& pred_angle, const std::vector<int>& strides, const ncnn::Mat& in_pad, float prob_threshold, std::vector<Object>& objects)
{
const int w = in_pad.w;
const int h = in_pad.h;

int pred_row_offset = 0;
for (size_t i = 0; i < strides.size(); i++)
{
const int stride = strides[i];

const int num_grid_x = w / stride;
const int num_grid_y = h / stride;
const int num_grid = num_grid_x * num_grid_y;

generate_proposals(pred.row_range(pred_row_offset, num_grid), pred_angle.row_range(pred_row_offset, num_grid), stride, in_pad, prob_threshold, objects);

pred_row_offset += num_grid;
}
}

static int detect_yolov8_obb(const cv::Mat& bgr, std::vector<Object>& objects)
{
ncnn::Net yolov8;

yolov8.opt.use_vulkan_compute = true;
// yolov8.opt.use_bf16_storage = true;

// https://github.com/nihui/ncnn-android-yolov8/tree/master/app/src/main/assets
yolov8.load_param("yolov8n_obb.ncnn.param");
yolov8.load_model("yolov8n_obb.ncnn.bin");
// yolov8.load_param("yolov8s_obb.ncnn.param");
// yolov8.load_model("yolov8s_obb.ncnn.bin");
// yolov8.load_param("yolov8m_obb.ncnn.param");
// yolov8.load_model("yolov8m_obb.ncnn.bin");

const int target_size = 1024;
const float prob_threshold = 0.25f;
const float nms_threshold = 0.45f;

int img_w = bgr.cols;
int img_h = bgr.rows;

// ultralytics/cfg/models/v8/yolov8.yaml
std::vector<int> strides(3);
strides[0] = 8;
strides[1] = 16;
strides[2] = 32;
const int max_stride = 32;

// letterbox pad to multiple of max_stride
int w = img_w;
int h = img_h;
float scale = 1.f;
if (w > h)
{
scale = (float)target_size / w;
w = target_size;
h = h * scale;
}
else
{
scale = (float)target_size / h;
h = target_size;
w = w * scale;
}

ncnn::Mat in = ncnn::Mat::from_pixels_resize(bgr.data, ncnn::Mat::PIXEL_BGR2RGB, img_w, img_h, w, h);

// letterbox pad to target_size rectangle
int wpad = (w + max_stride - 1) / max_stride * max_stride - w;
int hpad = (h + max_stride - 1) / max_stride * max_stride - h;
ncnn::Mat in_pad;
ncnn::copy_make_border(in, in_pad, hpad / 2, hpad - hpad / 2, wpad / 2, wpad - wpad / 2, ncnn::BORDER_CONSTANT, 114.f);

const float norm_vals[3] = {1 / 255.f, 1 / 255.f, 1 / 255.f};
in_pad.substract_mean_normalize(0, norm_vals);

ncnn::Extractor ex = yolov8.create_extractor();

ex.input("in0", in_pad);

ncnn::Mat out;
ex.extract("out0", out);

ncnn::Mat out_angle;
ex.extract("out1", out_angle);

std::vector<Object> proposals;
generate_proposals(out, out_angle, strides, in_pad, prob_threshold, proposals);

// sort all proposals by score from highest to lowest
qsort_descent_inplace(proposals);

// apply nms with nms_threshold
std::vector<int> picked;
nms_sorted_bboxes(proposals, picked, nms_threshold);

int count = picked.size();
if (count == 0)
return 0;

objects.resize(count);
for (int i = 0; i < count; i++)
{
Object obj = proposals[picked[i]];

// adjust offset to original unpadded
obj.rrect.center.x = (obj.rrect.center.x - (wpad / 2)) / scale;
obj.rrect.center.y = (obj.rrect.center.y - (hpad / 2)) / scale;
obj.rrect.size.width = (obj.rrect.size.width) / scale;
obj.rrect.size.height = (obj.rrect.size.height) / scale;

objects[i] = obj;
}

return 0;
}

static void draw_objects(const cv::Mat& bgr, const std::vector<Object>& objects)
{
static const char* class_names[] = {
"plane", "ship", "storage tank", "baseball diamond", "tennis court",
"basketball court", "ground track field", "harbor", "bridge", "large vehicle",
"small vehicle", "helicopter", "roundabout", "soccer ball field", "swimming pool"
};

static const cv::Scalar colors[] = {
cv::Scalar(156, 39, 176),
cv::Scalar(103, 58, 183),
cv::Scalar(63, 81, 181),
cv::Scalar(33, 150, 243),
cv::Scalar(3, 169, 244),
cv::Scalar(0, 188, 212),
cv::Scalar(0, 150, 136),
cv::Scalar(76, 175, 80),
cv::Scalar(139, 195, 74),
cv::Scalar(205, 220, 57),
cv::Scalar(255, 235, 59),
cv::Scalar(255, 193, 7),
cv::Scalar(255, 152, 0),
cv::Scalar(255, 87, 34),
cv::Scalar(121, 85, 72),
cv::Scalar(158, 158, 158),
cv::Scalar(96, 125, 139)
};

cv::Mat image = bgr.clone();

for (size_t i = 0; i < objects.size(); i++)
{
const Object& obj = objects[i];

const cv::Scalar& color = colors[obj.label];

fprintf(stderr, "%d = %.5f at %.2f %.2f %.2f x %.2f @ %.2f\n", obj.label, obj.prob,
obj.rrect.center.x, obj.rrect.center.y, obj.rrect.size.width, obj.rrect.size.height, obj.rrect.angle);

cv::Point2f corners[4];
obj.rrect.points(corners);
cv::line(image, corners[0], corners[1], color);
cv::line(image, corners[1], corners[2], color);
cv::line(image, corners[2], corners[3], color);
cv::line(image, corners[3], corners[0], color);
}

for (size_t i = 0; i < objects.size(); i++)
{
const Object& obj = objects[i];

const cv::Scalar& color = colors[obj.label];

char text[256];
sprintf(text, "%s %.1f%%", class_names[obj.label], obj.prob * 100);

int baseLine = 0;
cv::Size label_size = cv::getTextSize(text, cv::FONT_HERSHEY_SIMPLEX, 0.5, 1, &baseLine);

int x = obj.rrect.center.x - label_size.width / 2;
int y = obj.rrect.center.y - label_size.height / 2 - baseLine;
if (y < 0)
y = 0;
if (y + label_size.height > image.rows)
y = image.rows - label_size.height;
if (x < 0)
x = 0;
if (x + label_size.width > image.cols)
x = image.cols - label_size.width;

cv::rectangle(image, cv::Rect(cv::Point(x, y), cv::Size(label_size.width, label_size.height + baseLine)),
cv::Scalar(255, 255, 255), -1);

cv::putText(image, text, cv::Point(x, y + label_size.height),
cv::FONT_HERSHEY_SIMPLEX, 0.5, cv::Scalar(0, 0, 0));
}

cv::imshow("image", image);
cv::waitKey(0);
}

int main(int argc, char** argv)
{
if (argc != 2)
{
fprintf(stderr, "Usage: %s [imagepath]\n", argv[0]);
return -1;
}

const char* imagepath = argv[1];

cv::Mat m = cv::imread(imagepath, 1);
if (m.empty())
{
fprintf(stderr, "cv::imread %s failed\n", imagepath);
return -1;
}

std::vector<Object> objects;
detect_yolov8_obb(m, objects);

draw_objects(m, objects);

return 0;
}

+ 561
- 0
examples/yolov8_pose.cpp View File

@@ -0,0 +1,561 @@
// Tencent is pleased to support the open source community by making ncnn available.
//
// Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved.
//
// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
// in compliance with the License. You may obtain a copy of the License at
//
// https://opensource.org/licenses/BSD-3-Clause
//
// Unless required by applicable law or agreed to in writing, software distributed
// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
// specific language governing permissions and limitations under the License.

// 1. install
// pip3 install -U ultralytics pnnx ncnn
// 2. export yolov8-pose torchscript
// yolo export model=yolov8n-pose.pt format=torchscript
// 3. convert torchscript with static shape
// pnnx yolov8n-pose.torchscript
// 4. modify yolov8n_pose_pnnx.py for dynamic shape inference
// A. modify reshape to support dynamic image sizes
// B. permute tensor before concat and adjust concat axis
// C. drop post-process part
// before:
// v_137 = v_136.view(1, 51, 6400)
// v_143 = v_142.view(1, 51, 1600)
// v_149 = v_148.view(1, 51, 400)
// v_150 = torch.cat((v_137, v_143, v_149), dim=-1)
// ...
// v_184 = v_161.view(1, 65, 6400)
// v_185 = v_172.view(1, 65, 1600)
// v_186 = v_183.view(1, 65, 400)
// v_187 = torch.cat((v_184, v_185, v_186), dim=2)
// ...
// after:
// v_137 = v_136.view(1, 51, -1).transpose(1, 2)
// v_143 = v_142.view(1, 51, -1).transpose(1, 2)
// v_149 = v_148.view(1, 51, -1).transpose(1, 2)
// v_150 = torch.cat((v_137, v_143, v_149), dim=1)
// ...
// v_184 = v_161.view(1, 65, -1).transpose(1, 2)
// v_185 = v_172.view(1, 65, -1).transpose(1, 2)
// v_186 = v_183.view(1, 65, -1).transpose(1, 2)
// v_187 = torch.cat((v_184, v_185, v_186), dim=1)
// return v_187, v_150
// 5. re-export yolov8-pose torchscript
// python3 -c 'import yolov8n_pose_pnnx; yolov8n_pose_pnnx.export_torchscript()'
// 6. convert new torchscript with dynamic shape
// pnnx yolov8n_pose_pnnx.py.pt inputshape=[1,3,640,640] inputshape2=[1,3,320,320]
// 7. now you get ncnn model files
// mv yolov8n_pose_pnnx.py.ncnn.param yolov8n_pose.ncnn.param
// mv yolov8n_pose_pnnx.py.ncnn.bin yolov8n_pose.ncnn.bin

// the out blob would be a 2-dim tensor with w=65 h=8400
//
// | bbox-reg 16 x 4 |score(1)|
// +-----+-----+-----+-----+--------+
// | dx0 | dy0 | dx1 | dy1 | 0.1 |
// all /| | | | | |
// boxes | .. | .. | .. | .. | 0.0 |
// (8400)| | | | | . |
// \| | | | | . |
// +-----+-----+-----+-----+--------+
//

//
// | pose (51) |
// +-----------+
// |0.1........|
// all /| |
// boxes |0.0........|
// (8400)| . |
// \| . |
// +-----------+
//

#include "layer.h"
#include "net.h"

#if defined(USE_NCNN_SIMPLEOCV)
#include "simpleocv.h"
#else
#include <opencv2/core/core.hpp>
#include <opencv2/highgui/highgui.hpp>
#include <opencv2/imgproc/imgproc.hpp>
#endif
#include <float.h>
#include <stdio.h>
#include <vector>

struct KeyPoint
{
cv::Point2f p;
float prob;
};

struct Object
{
cv::Rect_<float> rect;
int label;
float prob;
std::vector<KeyPoint> keypoints;
};

static inline float intersection_area(const Object& a, const Object& b)
{
cv::Rect_<float> inter = a.rect & b.rect;
return inter.area();
}

static void qsort_descent_inplace(std::vector<Object>& objects, int left, int right)
{
int i = left;
int j = right;
float p = objects[(left + right) / 2].prob;

while (i <= j)
{
while (objects[i].prob > p)
i++;

while (objects[j].prob < p)
j--;

if (i <= j)
{
// swap
std::swap(objects[i], objects[j]);

i++;
j--;
}
}

// #pragma omp parallel sections
{
// #pragma omp section
{
if (left < j) qsort_descent_inplace(objects, left, j);
}
// #pragma omp section
{
if (i < right) qsort_descent_inplace(objects, i, right);
}
}
}

static void qsort_descent_inplace(std::vector<Object>& objects)
{
if (objects.empty())
return;

qsort_descent_inplace(objects, 0, objects.size() - 1);
}

static void nms_sorted_bboxes(const std::vector<Object>& objects, std::vector<int>& picked, float nms_threshold, bool agnostic = false)
{
picked.clear();

const int n = objects.size();

std::vector<float> areas(n);
for (int i = 0; i < n; i++)
{
areas[i] = objects[i].rect.area();
}

for (int i = 0; i < n; i++)
{
const Object& a = objects[i];

int keep = 1;
for (int j = 0; j < (int)picked.size(); j++)
{
const Object& b = objects[picked[j]];

if (!agnostic && a.label != b.label)
continue;

// intersection over union
float inter_area = intersection_area(a, b);
float union_area = areas[i] + areas[picked[j]] - inter_area;
// float IoU = inter_area / union_area
if (inter_area / union_area > nms_threshold)
keep = 0;
}

if (keep)
picked.push_back(i);
}
}

static inline float sigmoid(float x)
{
return 1.0f / (1.0f + expf(-x));
}

static void generate_proposals(const ncnn::Mat& pred, const ncnn::Mat& pred_points, int stride, const ncnn::Mat& in_pad, float prob_threshold, std::vector<Object>& objects)
{
const int w = in_pad.w;
const int h = in_pad.h;

const int num_grid_x = w / stride;
const int num_grid_y = h / stride;

const int reg_max_1 = 16;
const int num_points = pred_points.w / 3;

for (int y = 0; y < num_grid_y; y++)
{
for (int x = 0; x < num_grid_x; x++)
{
const ncnn::Mat pred_grid = pred.row_range(y * num_grid_x + x, 1);
const ncnn::Mat pred_points_grid = pred_points.row_range(y * num_grid_x + x, 1).reshape(3, num_points);

// find label with max score
int label = 0;
float score = sigmoid(pred_grid[reg_max_1 * 4]);

if (score >= prob_threshold)
{
ncnn::Mat pred_bbox = pred_grid.range(0, reg_max_1 * 4).reshape(reg_max_1, 4).clone();

{
ncnn::Layer* softmax = ncnn::create_layer("Softmax");

ncnn::ParamDict pd;
pd.set(0, 1); // axis
pd.set(1, 1);
softmax->load_param(pd);

ncnn::Option opt;
opt.num_threads = 1;
opt.use_packing_layout = false;

softmax->create_pipeline(opt);

softmax->forward_inplace(pred_bbox, opt);

softmax->destroy_pipeline(opt);

delete softmax;
}

float pred_ltrb[4];
for (int k = 0; k < 4; k++)
{
float dis = 0.f;
const float* dis_after_sm = pred_bbox.row(k);
for (int l = 0; l < reg_max_1; l++)
{
dis += l * dis_after_sm[l];
}

pred_ltrb[k] = dis * stride;
}

float pb_cx = (x + 0.5f) * stride;
float pb_cy = (y + 0.5f) * stride;

float x0 = pb_cx - pred_ltrb[0];
float y0 = pb_cy - pred_ltrb[1];
float x1 = pb_cx + pred_ltrb[2];
float y1 = pb_cy + pred_ltrb[3];

std::vector<KeyPoint> keypoints;
for (int k = 0; k < num_points; k++)
{
KeyPoint keypoint;
keypoint.p.x = (x + pred_points_grid.row(k)[0] * 2) * stride;
keypoint.p.y = (y + pred_points_grid.row(k)[1] * 2) * stride;
keypoint.prob = sigmoid(pred_points_grid.row(k)[2]);
keypoints.push_back(keypoint);
}

Object obj;
obj.rect.x = x0;
obj.rect.y = y0;
obj.rect.width = x1 - x0;
obj.rect.height = y1 - y0;
obj.label = label;
obj.prob = score;
obj.keypoints = keypoints;

objects.push_back(obj);
}
}
}
}

static void generate_proposals(const ncnn::Mat& pred, const ncnn::Mat& pred_points, const std::vector<int>& strides, const ncnn::Mat& in_pad, float prob_threshold, std::vector<Object>& objects)
{
const int w = in_pad.w;
const int h = in_pad.h;

int pred_row_offset = 0;
for (size_t i = 0; i < strides.size(); i++)
{
const int stride = strides[i];

const int num_grid_x = w / stride;
const int num_grid_y = h / stride;
const int num_grid = num_grid_x * num_grid_y;

generate_proposals(pred.row_range(pred_row_offset, num_grid), pred_points.row_range(pred_row_offset, num_grid), stride, in_pad, prob_threshold, objects);

pred_row_offset += num_grid;
}
}

static int detect_yolov8_pose(const cv::Mat& bgr, std::vector<Object>& objects)
{
ncnn::Net yolov8;

yolov8.opt.use_vulkan_compute = true;
// yolov8.opt.use_bf16_storage = true;

// https://github.com/nihui/ncnn-android-yolov8/tree/master/app/src/main/assets
yolov8.load_param("yolov8n_pose.ncnn.param");
yolov8.load_model("yolov8n_pose.ncnn.bin");
// yolov8.load_param("yolov8s_pose.ncnn.param");
// yolov8.load_model("yolov8s_pose.ncnn.bin");
// yolov8.load_param("yolov8m_pose.ncnn.param");
// yolov8.load_model("yolov8m_pose.ncnn.bin");

const int target_size = 640;
const float prob_threshold = 0.25f;
const float nms_threshold = 0.45f;
const float mask_threshold = 0.5f;

int img_w = bgr.cols;
int img_h = bgr.rows;

// ultralytics/cfg/models/v8/yolov8.yaml
std::vector<int> strides(3);
strides[0] = 8;
strides[1] = 16;
strides[2] = 32;
const int max_stride = 32;

// letterbox pad to multiple of max_stride
int w = img_w;
int h = img_h;
float scale = 1.f;
if (w > h)
{
scale = (float)target_size / w;
w = target_size;
h = h * scale;
}
else
{
scale = (float)target_size / h;
h = target_size;
w = w * scale;
}

ncnn::Mat in = ncnn::Mat::from_pixels_resize(bgr.data, ncnn::Mat::PIXEL_BGR2RGB, img_w, img_h, w, h);

// letterbox pad to target_size rectangle
int wpad = (w + max_stride - 1) / max_stride * max_stride - w;
int hpad = (h + max_stride - 1) / max_stride * max_stride - h;
ncnn::Mat in_pad;
ncnn::copy_make_border(in, in_pad, hpad / 2, hpad - hpad / 2, wpad / 2, wpad - wpad / 2, ncnn::BORDER_CONSTANT, 114.f);

const float norm_vals[3] = {1 / 255.f, 1 / 255.f, 1 / 255.f};
in_pad.substract_mean_normalize(0, norm_vals);

ncnn::Extractor ex = yolov8.create_extractor();

ex.input("in0", in_pad);

ncnn::Mat out;
ex.extract("out0", out);

ncnn::Mat out_points;
ex.extract("out1", out_points);

std::vector<Object> proposals;
generate_proposals(out, out_points, strides, in_pad, prob_threshold, proposals);

// sort all proposals by score from highest to lowest
qsort_descent_inplace(proposals);

// apply nms with nms_threshold
std::vector<int> picked;
nms_sorted_bboxes(proposals, picked, nms_threshold);

int count = picked.size();
if (count == 0)
return 0;

const int num_points = out_points.w / 3;

objects.resize(count);
for (int i = 0; i < count; i++)
{
objects[i] = proposals[picked[i]];

// adjust offset to original unpadded
float x0 = (objects[i].rect.x - (wpad / 2)) / scale;
float y0 = (objects[i].rect.y - (hpad / 2)) / scale;
float x1 = (objects[i].rect.x + objects[i].rect.width - (wpad / 2)) / scale;
float y1 = (objects[i].rect.y + objects[i].rect.height - (hpad / 2)) / scale;

for (int j = 0; j < num_points; j++)
{
objects[i].keypoints[j].p.x = (objects[i].keypoints[j].p.x - (wpad / 2)) / scale;
objects[i].keypoints[j].p.y = (objects[i].keypoints[j].p.y - (hpad / 2)) / scale;
}

// clip
x0 = std::max(std::min(x0, (float)(img_w - 1)), 0.f);
y0 = std::max(std::min(y0, (float)(img_h - 1)), 0.f);
x1 = std::max(std::min(x1, (float)(img_w - 1)), 0.f);
y1 = std::max(std::min(y1, (float)(img_h - 1)), 0.f);

objects[i].rect.x = x0;
objects[i].rect.y = y0;
objects[i].rect.width = x1 - x0;
objects[i].rect.height = y1 - y0;
}

return 0;
}

static void draw_objects(const cv::Mat& bgr, const std::vector<Object>& objects)
{
static const char* class_names[] = {"person"};

static const cv::Scalar colors[] = {
cv::Scalar(244, 67, 54),
cv::Scalar(233, 30, 99),
cv::Scalar(156, 39, 176),
cv::Scalar(103, 58, 183),
cv::Scalar(63, 81, 181),
cv::Scalar(33, 150, 243),
cv::Scalar(3, 169, 244),
cv::Scalar(0, 188, 212),
cv::Scalar(0, 150, 136),
cv::Scalar(76, 175, 80),
cv::Scalar(139, 195, 74),
cv::Scalar(205, 220, 57),
cv::Scalar(255, 235, 59),
cv::Scalar(255, 193, 7),
cv::Scalar(255, 152, 0),
cv::Scalar(255, 87, 34),
cv::Scalar(121, 85, 72),
cv::Scalar(158, 158, 158),
cv::Scalar(96, 125, 139)
};

cv::Mat image = bgr.clone();

for (size_t i = 0; i < objects.size(); i++)
{
const Object& obj = objects[i];

const cv::Scalar& color = colors[i % 19];

fprintf(stderr, "%d = %.5f at %.2f %.2f %.2f x %.2f\n", obj.label, obj.prob,
obj.rect.x, obj.rect.y, obj.rect.width, obj.rect.height);

// draw bone
static const int joint_pairs[16][2] = {
{0, 1}, {1, 3}, {0, 2}, {2, 4}, {5, 6}, {5, 7}, {7, 9}, {6, 8}, {8, 10}, {5, 11}, {6, 12}, {11, 12}, {11, 13}, {12, 14}, {13, 15}, {14, 16}
};
static const cv::Scalar bone_colors[] = {
cv::Scalar(0, 255, 0),
cv::Scalar(0, 255, 0),
cv::Scalar(0, 255, 0),
cv::Scalar(0, 255, 0),
cv::Scalar(255, 128, 0),
cv::Scalar(255, 128, 0),
cv::Scalar(255, 128, 0),
cv::Scalar(255, 128, 0),
cv::Scalar(255, 128, 0),
cv::Scalar(255, 51, 255),
cv::Scalar(255, 51, 255),
cv::Scalar(255, 51, 255),
cv::Scalar(51, 153, 255),
cv::Scalar(51, 153, 255),
cv::Scalar(51, 153, 255),
cv::Scalar(51, 153, 255),
};

for (int j = 0; j < 16; j++)
{
const KeyPoint& p1 = obj.keypoints[joint_pairs[j][0]];
const KeyPoint& p2 = obj.keypoints[joint_pairs[j][1]];

if (p1.prob < 0.2f || p2.prob < 0.2f)
continue;

cv::line(image, p1.p, p2.p, bone_colors[j], 2);
}

// draw joint
for (size_t j = 0; j < obj.keypoints.size(); j++)
{
const KeyPoint& keypoint = obj.keypoints[j];

fprintf(stderr, "%.2f %.2f = %.5f\n", keypoint.p.x, keypoint.p.y, keypoint.prob);

if (keypoint.prob < 0.2f)
continue;

cv::circle(image, keypoint.p, 3, color, -1);
}

cv::rectangle(image, obj.rect, color);

char text[256];
sprintf(text, "%s %.1f%%", class_names[obj.label], obj.prob * 100);

int baseLine = 0;
cv::Size label_size = cv::getTextSize(text, cv::FONT_HERSHEY_SIMPLEX, 0.5, 1, &baseLine);

int x = obj.rect.x;
int y = obj.rect.y - label_size.height - baseLine;
if (y < 0)
y = 0;
if (x + label_size.width > image.cols)
x = image.cols - label_size.width;

cv::rectangle(image, cv::Rect(cv::Point(x, y), cv::Size(label_size.width, label_size.height + baseLine)),
cv::Scalar(255, 255, 255), -1);

cv::putText(image, text, cv::Point(x, y + label_size.height),
cv::FONT_HERSHEY_SIMPLEX, 0.5, cv::Scalar(0, 0, 0));
}

cv::imshow("image", image);
cv::waitKey(0);
}

int main(int argc, char** argv)
{
if (argc != 2)
{
fprintf(stderr, "Usage: %s [imagepath]\n", argv[0]);
return -1;
}

const char* imagepath = argv[1];

cv::Mat m = cv::imread(imagepath, 1);
if (m.empty())
{
fprintf(stderr, "cv::imread %s failed\n", imagepath);
return -1;
}

std::vector<Object> objects;
detect_yolov8_pose(m, objects);

draw_objects(m, objects);

return 0;
}

+ 624
- 0
examples/yolov8_seg.cpp View File

@@ -0,0 +1,624 @@
// Tencent is pleased to support the open source community by making ncnn available.
//
// Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved.
//
// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
// in compliance with the License. You may obtain a copy of the License at
//
// https://opensource.org/licenses/BSD-3-Clause
//
// Unless required by applicable law or agreed to in writing, software distributed
// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
// specific language governing permissions and limitations under the License.

// 1. install
// pip3 install -U ultralytics pnnx ncnn
// 2. export yolov8-seg torchscript
// yolo export model=yolov8n-seg.pt format=torchscript
// 3. convert torchscript with static shape
// pnnx yolov8n-seg.torchscript
// 4. modify yolov8n_seg_pnnx.py for dynamic shape inference
// A. modify reshape to support dynamic image sizes
// B. permute tensor before concat and adjust concat axis
// C. drop post-process part
// before:
// v_144 = v_143.view(1, 32, 6400)
// v_150 = v_149.view(1, 32, 1600)
// v_156 = v_155.view(1, 32, 400)
// v_157 = torch.cat((v_144, v_150, v_156), dim=2)
// ...
// v_191 = v_168.view(1, 144, 6400)
// v_192 = v_179.view(1, 144, 1600)
// v_193 = v_190.view(1, 144, 400)
// v_194 = torch.cat((v_191, v_192, v_193), dim=2)
// ...
// v_215 = (v_214, v_138, )
// return v_215
// after:
// v_144 = v_143.view(1, 32, -1).transpose(1, 2)
// v_150 = v_149.view(1, 32, -1).transpose(1, 2)
// v_156 = v_155.view(1, 32, -1).transpose(1, 2)
// v_157 = torch.cat((v_144, v_150, v_156), dim=1)
// ...
// v_191 = v_168.view(1, 144, -1).transpose(1, 2)
// v_192 = v_179.view(1, 144, -1).transpose(1, 2)
// v_193 = v_190.view(1, 144, -1).transpose(1, 2)
// v_194 = torch.cat((v_191, v_192, v_193), dim=1)
// return v_194, v_157, v_138
// 5. re-export yolov8-seg torchscript
// python3 -c 'import yolov8n_seg_pnnx; yolov8n_seg_pnnx.export_torchscript()'
// 6. convert new torchscript with dynamic shape
// pnnx yolov8n_seg_pnnx.py.pt inputshape=[1,3,640,640] inputshape2=[1,3,320,320]
// 7. now you get ncnn model files
// mv yolov8n_seg_pnnx.py.ncnn.param yolov8n_seg.ncnn.param
// mv yolov8n_seg_pnnx.py.ncnn.bin yolov8n_seg.ncnn.bin

// the out blob would be a 2-dim tensor with w=176 h=8400
//
// | bbox-reg 16 x 4 | per-class scores(80) |
// +-----+-----+-----+-----+----------------------+
// | dx0 | dy0 | dx1 | dy1 |0.1 0.0 0.0 0.5 ......|
// all /| | | | | . |
// boxes | .. | .. | .. | .. |0.0 0.9 0.0 0.0 ......|
// (8400)| | | | | . |
// \| | | | | . |
// +-----+-----+-----+-----+----------------------+
//

//
// | mask (32) |
// +-----------+
// |0.1........|
// all /| |
// boxes |0.0........|
// (8400)| . |
// \| . |
// +-----------+
//

#include "layer.h"
#include "net.h"

#if defined(USE_NCNN_SIMPLEOCV)
#include "simpleocv.h"
#else
#include <opencv2/core/core.hpp>
#include <opencv2/highgui/highgui.hpp>
#include <opencv2/imgproc/imgproc.hpp>
#endif
#include <float.h>
#include <stdio.h>
#include <vector>

struct Object
{
cv::Rect_<float> rect;
int label;
float prob;
int gindex;
cv::Mat mask;
};

static inline float intersection_area(const Object& a, const Object& b)
{
cv::Rect_<float> inter = a.rect & b.rect;
return inter.area();
}

static void qsort_descent_inplace(std::vector<Object>& objects, int left, int right)
{
int i = left;
int j = right;
float p = objects[(left + right) / 2].prob;

while (i <= j)
{
while (objects[i].prob > p)
i++;

while (objects[j].prob < p)
j--;

if (i <= j)
{
// swap
std::swap(objects[i], objects[j]);

i++;
j--;
}
}

// #pragma omp parallel sections
{
// #pragma omp section
{
if (left < j) qsort_descent_inplace(objects, left, j);
}
// #pragma omp section
{
if (i < right) qsort_descent_inplace(objects, i, right);
}
}
}

static void qsort_descent_inplace(std::vector<Object>& objects)
{
if (objects.empty())
return;

qsort_descent_inplace(objects, 0, objects.size() - 1);
}

static void nms_sorted_bboxes(const std::vector<Object>& objects, std::vector<int>& picked, float nms_threshold, bool agnostic = false)
{
picked.clear();

const int n = objects.size();

std::vector<float> areas(n);
for (int i = 0; i < n; i++)
{
areas[i] = objects[i].rect.area();
}

for (int i = 0; i < n; i++)
{
const Object& a = objects[i];

int keep = 1;
for (int j = 0; j < (int)picked.size(); j++)
{
const Object& b = objects[picked[j]];

if (!agnostic && a.label != b.label)
continue;

// intersection over union
float inter_area = intersection_area(a, b);
float union_area = areas[i] + areas[picked[j]] - inter_area;
// float IoU = inter_area / union_area
if (inter_area / union_area > nms_threshold)
keep = 0;
}

if (keep)
picked.push_back(i);
}
}

static inline float sigmoid(float x)
{
return 1.0f / (1.0f + expf(-x));
}

static void generate_proposals(const ncnn::Mat& pred, int stride, const ncnn::Mat& in_pad, float prob_threshold, std::vector<Object>& objects)
{
const int w = in_pad.w;
const int h = in_pad.h;

const int num_grid_x = w / stride;
const int num_grid_y = h / stride;

const int reg_max_1 = 16;
const int num_class = pred.w - reg_max_1 * 4; // number of classes. 80 for COCO

for (int y = 0; y < num_grid_y; y++)
{
for (int x = 0; x < num_grid_x; x++)
{
const ncnn::Mat pred_grid = pred.row_range(y * num_grid_x + x, 1);

// find label with max score
int label = -1;
float score = -FLT_MAX;
{
const ncnn::Mat pred_score = pred_grid.range(reg_max_1 * 4, num_class);

for (int k = 0; k < num_class; k++)
{
float s = pred_score[k];
if (s > score)
{
label = k;
score = s;
}
}

score = sigmoid(score);
}

if (score >= prob_threshold)
{
ncnn::Mat pred_bbox = pred_grid.range(0, reg_max_1 * 4).reshape(reg_max_1, 4).clone();

{
ncnn::Layer* softmax = ncnn::create_layer("Softmax");

ncnn::ParamDict pd;
pd.set(0, 1); // axis
pd.set(1, 1);
softmax->load_param(pd);

ncnn::Option opt;
opt.num_threads = 1;
opt.use_packing_layout = false;

softmax->create_pipeline(opt);

softmax->forward_inplace(pred_bbox, opt);

softmax->destroy_pipeline(opt);

delete softmax;
}

float pred_ltrb[4];
for (int k = 0; k < 4; k++)
{
float dis = 0.f;
const float* dis_after_sm = pred_bbox.row(k);
for (int l = 0; l < reg_max_1; l++)
{
dis += l * dis_after_sm[l];
}

pred_ltrb[k] = dis * stride;
}

float pb_cx = (x + 0.5f) * stride;
float pb_cy = (y + 0.5f) * stride;

float x0 = pb_cx - pred_ltrb[0];
float y0 = pb_cy - pred_ltrb[1];
float x1 = pb_cx + pred_ltrb[2];
float y1 = pb_cy + pred_ltrb[3];

Object obj;
obj.rect.x = x0;
obj.rect.y = y0;
obj.rect.width = x1 - x0;
obj.rect.height = y1 - y0;
obj.label = label;
obj.prob = score;
obj.gindex = y * num_grid_x + x;

objects.push_back(obj);
}
}
}
}

static void generate_proposals(const ncnn::Mat& pred, const std::vector<int>& strides, const ncnn::Mat& in_pad, float prob_threshold, std::vector<Object>& objects)
{
const int w = in_pad.w;
const int h = in_pad.h;

int pred_row_offset = 0;
for (size_t i = 0; i < strides.size(); i++)
{
const int stride = strides[i];

const int num_grid_x = w / stride;
const int num_grid_y = h / stride;
const int num_grid = num_grid_x * num_grid_y;

std::vector<Object> objects_stride;
generate_proposals(pred.row_range(pred_row_offset, num_grid), stride, in_pad, prob_threshold, objects_stride);

for (size_t j = 0; j < objects_stride.size(); j++)
{
Object obj = objects_stride[j];
obj.gindex += pred_row_offset;
objects.push_back(obj);
}

pred_row_offset += num_grid;
}
}

static int detect_yolov8_seg(const cv::Mat& bgr, std::vector<Object>& objects)
{
ncnn::Net yolov8;

yolov8.opt.use_vulkan_compute = true;
// yolov8.opt.use_bf16_storage = true;

// https://github.com/nihui/ncnn-android-yolov8/tree/master/app/src/main/assets
yolov8.load_param("yolov8n_seg.ncnn.param");
yolov8.load_model("yolov8n_seg.ncnn.bin");
// yolov8.load_param("yolov8s_seg.ncnn.param");
// yolov8.load_model("yolov8s_seg.ncnn.bin");
// yolov8.load_param("yolov8m_seg.ncnn.param");
// yolov8.load_model("yolov8m_seg.ncnn.bin");

const int target_size = 640;
const float prob_threshold = 0.25f;
const float nms_threshold = 0.45f;
const float mask_threshold = 0.5f;

int img_w = bgr.cols;
int img_h = bgr.rows;

// ultralytics/cfg/models/v8/yolov8.yaml
std::vector<int> strides(3);
strides[0] = 8;
strides[1] = 16;
strides[2] = 32;
const int max_stride = 32;

// letterbox pad to multiple of max_stride
int w = img_w;
int h = img_h;
float scale = 1.f;
if (w > h)
{
scale = (float)target_size / w;
w = target_size;
h = h * scale;
}
else
{
scale = (float)target_size / h;
h = target_size;
w = w * scale;
}

ncnn::Mat in = ncnn::Mat::from_pixels_resize(bgr.data, ncnn::Mat::PIXEL_BGR2RGB, img_w, img_h, w, h);

// letterbox pad to target_size rectangle
int wpad = (w + max_stride - 1) / max_stride * max_stride - w;
int hpad = (h + max_stride - 1) / max_stride * max_stride - h;
ncnn::Mat in_pad;
ncnn::copy_make_border(in, in_pad, hpad / 2, hpad - hpad / 2, wpad / 2, wpad - wpad / 2, ncnn::BORDER_CONSTANT, 114.f);

const float norm_vals[3] = {1 / 255.f, 1 / 255.f, 1 / 255.f};
in_pad.substract_mean_normalize(0, norm_vals);

ncnn::Extractor ex = yolov8.create_extractor();

ex.input("in0", in_pad);

ncnn::Mat out;
ex.extract("out0", out);

std::vector<Object> proposals;
generate_proposals(out, strides, in_pad, prob_threshold, proposals);

// sort all proposals by score from highest to lowest
qsort_descent_inplace(proposals);

// apply nms with nms_threshold
std::vector<int> picked;
nms_sorted_bboxes(proposals, picked, nms_threshold);

int count = picked.size();
if (count == 0)
return 0;

ncnn::Mat mask_feat;
ex.extract("out1", mask_feat);

ncnn::Mat mask_protos;
ex.extract("out2", mask_protos);

ncnn::Mat objects_mask_feat(mask_feat.w, 1, count);

objects.resize(count);
for (int i = 0; i < count; i++)
{
objects[i] = proposals[picked[i]];

// adjust offset to original unpadded
float x0 = (objects[i].rect.x - (wpad / 2)) / scale;
float y0 = (objects[i].rect.y - (hpad / 2)) / scale;
float x1 = (objects[i].rect.x + objects[i].rect.width - (wpad / 2)) / scale;
float y1 = (objects[i].rect.y + objects[i].rect.height - (hpad / 2)) / scale;

// clip
x0 = std::max(std::min(x0, (float)(img_w - 1)), 0.f);
y0 = std::max(std::min(y0, (float)(img_h - 1)), 0.f);
x1 = std::max(std::min(x1, (float)(img_w - 1)), 0.f);
y1 = std::max(std::min(y1, (float)(img_h - 1)), 0.f);

objects[i].rect.x = x0;
objects[i].rect.y = y0;
objects[i].rect.width = x1 - x0;
objects[i].rect.height = y1 - y0;

// pick mask feat
memcpy(objects_mask_feat.channel(i), mask_feat.row(objects[i].gindex), mask_feat.w * sizeof(float));
}

// process mask
ncnn::Mat objects_mask;
{
ncnn::Layer* gemm = ncnn::create_layer("Gemm");

ncnn::ParamDict pd;
pd.set(6, 1); // constantC
pd.set(7, count); // constantM
pd.set(8, mask_protos.w * mask_protos.h); // constantN
pd.set(9, mask_feat.w); // constantK
pd.set(10, -1); // constant_broadcast_type_C
pd.set(11, 1); // output_N1M
gemm->load_param(pd);

ncnn::Option opt;
opt.num_threads = 1;
opt.use_packing_layout = false;

gemm->create_pipeline(opt);

std::vector<ncnn::Mat> gemm_inputs(2);
gemm_inputs[0] = objects_mask_feat;
gemm_inputs[1] = mask_protos.reshape(mask_protos.w * mask_protos.h, 1, mask_protos.c);
std::vector<ncnn::Mat> gemm_outputs(1);
gemm->forward(gemm_inputs, gemm_outputs, opt);
objects_mask = gemm_outputs[0].reshape(mask_protos.w, mask_protos.h, count);

gemm->destroy_pipeline(opt);

delete gemm;
}
{
ncnn::Layer* sigmoid = ncnn::create_layer("Sigmoid");

ncnn::Option opt;
opt.num_threads = 1;
opt.use_packing_layout = false;

sigmoid->create_pipeline(opt);

sigmoid->forward_inplace(objects_mask, opt);

sigmoid->destroy_pipeline(opt);

delete sigmoid;
}

// resize mask map
{
ncnn::Mat objects_mask_resized;
ncnn::resize_bilinear(objects_mask, objects_mask_resized, in_pad.w / scale, in_pad.h / scale);
objects_mask = objects_mask_resized;
}

// create per-object mask
for (int i = 0; i < count; i++)
{
Object& obj = objects[i];

const ncnn::Mat mm = objects_mask.channel(i);

obj.mask = cv::Mat((int)obj.rect.height, (int)obj.rect.width, CV_8UC1);

// adjust offset to original unpadded and clip inside object box
for (int y = 0; y < (int)obj.rect.height; y++)
{
const float* pmm = mm.row((int)(hpad / 2 / scale + obj.rect.y + y)) + (int)(wpad / 2 / scale + obj.rect.x);
uchar* pmask = obj.mask.ptr<uchar>(y);
for (int x = 0; x < (int)obj.rect.width; x++)
{
pmask[x] = pmm[x] > mask_threshold ? 1 : 0;
}
}
}

return 0;
}

static void draw_objects(const cv::Mat& bgr, const std::vector<Object>& objects)
{
static const char* class_names[] = {
"person", "bicycle", "car", "motorcycle", "airplane", "bus", "train", "truck", "boat", "traffic light",
"fire hydrant", "stop sign", "parking meter", "bench", "bird", "cat", "dog", "horse", "sheep", "cow",
"elephant", "bear", "zebra", "giraffe", "backpack", "umbrella", "handbag", "tie", "suitcase", "frisbee",
"skis", "snowboard", "sports ball", "kite", "baseball bat", "baseball glove", "skateboard", "surfboard",
"tennis racket", "bottle", "wine glass", "cup", "fork", "knife", "spoon", "bowl", "banana", "apple",
"sandwich", "orange", "broccoli", "carrot", "hot dog", "pizza", "donut", "cake", "chair", "couch",
"potted plant", "bed", "dining table", "toilet", "tv", "laptop", "mouse", "remote", "keyboard", "cell phone",
"microwave", "oven", "toaster", "sink", "refrigerator", "book", "clock", "vase", "scissors", "teddy bear",
"hair drier", "toothbrush"
};

static cv::Scalar colors[] = {
cv::Scalar(244, 67, 54),
cv::Scalar(233, 30, 99),
cv::Scalar(156, 39, 176),
cv::Scalar(103, 58, 183),
cv::Scalar(63, 81, 181),
cv::Scalar(33, 150, 243),
cv::Scalar(3, 169, 244),
cv::Scalar(0, 188, 212),
cv::Scalar(0, 150, 136),
cv::Scalar(76, 175, 80),
cv::Scalar(139, 195, 74),
cv::Scalar(205, 220, 57),
cv::Scalar(255, 235, 59),
cv::Scalar(255, 193, 7),
cv::Scalar(255, 152, 0),
cv::Scalar(255, 87, 34),
cv::Scalar(121, 85, 72),
cv::Scalar(158, 158, 158),
cv::Scalar(96, 125, 139)
};

cv::Mat image = bgr.clone();

for (size_t i = 0; i < objects.size(); i++)
{
const Object& obj = objects[i];

const cv::Scalar& color = colors[i % 19];

fprintf(stderr, "%d = %.5f at %.2f %.2f %.2f x %.2f\n", obj.label, obj.prob,
obj.rect.x, obj.rect.y, obj.rect.width, obj.rect.height);

for (int y = 0; y < (int)obj.rect.height; y++)
{
const uchar* maskptr = obj.mask.ptr<const uchar>(y);
uchar* bgrptr = image.ptr<uchar>((int)obj.rect.y + y) + (int)obj.rect.x * 3;
for (int x = 0; x < (int)obj.rect.width; x++)
{
if (maskptr[x])
{
bgrptr[0] = bgrptr[0] * 0.5 + color[0] * 0.5;
bgrptr[1] = bgrptr[1] * 0.5 + color[1] * 0.5;
bgrptr[2] = bgrptr[2] * 0.5 + color[2] * 0.5;
}
bgrptr += 3;
}
}

cv::rectangle(image, obj.rect, color);

char text[256];
sprintf(text, "%s %.1f%%", class_names[obj.label], obj.prob * 100);

int baseLine = 0;
cv::Size label_size = cv::getTextSize(text, cv::FONT_HERSHEY_SIMPLEX, 0.5, 1, &baseLine);

int x = obj.rect.x;
int y = obj.rect.y - label_size.height - baseLine;
if (y < 0)
y = 0;
if (x + label_size.width > image.cols)
x = image.cols - label_size.width;

cv::rectangle(image, cv::Rect(cv::Point(x, y), cv::Size(label_size.width, label_size.height + baseLine)),
cv::Scalar(255, 255, 255), -1);

cv::putText(image, text, cv::Point(x, y + label_size.height),
cv::FONT_HERSHEY_SIMPLEX, 0.5, cv::Scalar(0, 0, 0));
}

cv::imshow("image", image);
cv::waitKey(0);
}

int main(int argc, char** argv)
{
if (argc != 2)
{
fprintf(stderr, "Usage: %s [imagepath]\n", argv[0]);
return -1;
}

const char* imagepath = argv[1];

cv::Mat m = cv::imread(imagepath, 1);
if (m.empty())
{
fprintf(stderr, "cv::imread %s failed\n", imagepath);
return -1;
}

std::vector<Object> objects;
detect_yolov8_seg(m, objects);

draw_objects(m, objects);

return 0;
}

+ 365
- 684
src/layer/reduction.cpp
File diff suppressed because it is too large
View File


+ 58
- 46
tests/test_copyto_1.cpp View File

@@ -14,58 +14,70 @@

#include "testutil.h"

static ncnn::Mat IntArrayMat(int a0)
static std::vector<int> IntArray(int a0)
{
ncnn::Mat m(1);
int* p = m;
p[0] = a0;
std::vector<int> m(1);
m[0] = a0;
return m;
}

static ncnn::Mat IntArrayMat(int a0, int a1)
static std::vector<int> IntArray(int a0, int a1)
{
ncnn::Mat m(2);
int* p = m;
p[0] = a0;
p[1] = a1;
std::vector<int> m(2);
m[0] = a0;
m[1] = a1;
return m;
}

static ncnn::Mat IntArrayMat(int a0, int a1, int a2)
static std::vector<int> IntArray(int a0, int a1, int a2)
{
ncnn::Mat m(3);
int* p = m;
p[0] = a0;
p[1] = a1;
p[2] = a2;
std::vector<int> m(3);
m[0] = a0;
m[1] = a1;
m[2] = a2;
return m;
}

static ncnn::Mat IntArrayMat(int a0, int a1, int a2, int a3)
static std::vector<int> IntArray(int a0, int a1, int a2, int a3)
{
ncnn::Mat m(4);
int* p = m;
p[0] = a0;
p[1] = a1;
p[2] = a2;
p[3] = a3;
std::vector<int> m(4);
m[0] = a0;
m[1] = a1;
m[2] = a2;
m[3] = a3;
return m;
}

static void print_int_array(const ncnn::Mat& a)
static void print_int_array(const std::vector<int>& a)
{
const int* pa = a;

fprintf(stderr, "[");
for (int i = 0; i < a.w; i++)
for (size_t i = 0; i < a.size(); i++)
{
fprintf(stderr, " %d", pa[i]);
fprintf(stderr, " %d", a[i]);
}
fprintf(stderr, " ]");
}

static int test_copyto(const ncnn::Mat& self, const ncnn::Mat& src, const ncnn::Mat& starts, const ncnn::Mat& axes)
static int test_copyto(const ncnn::Mat& self, const ncnn::Mat& src, const std::vector<int>& starts_array, const std::vector<int>& axes_array)
{
ncnn::Mat starts(starts_array.size());
{
int* p = starts;
for (size_t i = 0; i < starts_array.size(); i++)
{
p[i] = starts_array[i];
}
}

ncnn::Mat axes(axes_array.size());
{
int* p = axes;
for (size_t i = 0; i < axes_array.size(); i++)
{
p[i] = axes_array[i];
}
}

ncnn::ParamDict pd;
pd.set(9, starts); // starts
pd.set(11, axes); // axes
@@ -81,9 +93,9 @@ static int test_copyto(const ncnn::Mat& self, const ncnn::Mat& src, const ncnn::
{
fprintf(stderr, "test_copyto failed self.dims=%d self=(%d %d %d %d) src.dims=%d src=(%d %d %d %d)", self.dims, self.w, self.h, self.d, self.c, src.dims, src.w, src.h, src.d, src.c);
fprintf(stderr, " starts=");
print_int_array(starts);
print_int_array(starts_array);
fprintf(stderr, " axes=");
print_int_array(axes);
print_int_array(axes_array);
fprintf(stderr, "\n");
}

@@ -111,10 +123,10 @@ static int test_copyto_0()
const ncnn::Mat& src = b[j];

int ret = 0
|| test_copyto(self, src, IntArrayMat(0), IntArrayMat(0))
|| test_copyto(self, src, IntArrayMat(13), IntArrayMat(-1))
|| test_copyto(self, src, IntArrayMat(28), IntArrayMat(0))
|| test_copyto(self, src, IntArrayMat(32), ncnn::Mat());
|| test_copyto(self, src, IntArray(0), IntArray(0))
|| test_copyto(self, src, IntArray(13), IntArray(-1))
|| test_copyto(self, src, IntArray(28), IntArray(0))
|| test_copyto(self, src, IntArray(32), std::vector<int>());

if (ret != 0)
return ret;
@@ -148,10 +160,10 @@ static int test_copyto_1()
const ncnn::Mat& src = b[j];

int ret = 0
|| test_copyto(self, src, IntArrayMat(0, 0), IntArrayMat(0, 1))
|| test_copyto(self, src, IntArrayMat(13, 1), IntArrayMat(-2, -1))
|| test_copyto(self, src, IntArrayMat(28, 3), IntArrayMat(0, 1))
|| test_copyto(self, src, IntArrayMat(32, 10), IntArrayMat(0, 1));
|| test_copyto(self, src, IntArray(0, 0), IntArray(0, 1))
|| test_copyto(self, src, IntArray(13, 1), IntArray(-2, -1))
|| test_copyto(self, src, IntArray(28, 3), IntArray(0, 1))
|| test_copyto(self, src, IntArray(32, 10), IntArray(0, 1));

if (ret != 0)
return ret;
@@ -188,10 +200,10 @@ static int test_copyto_2()
const ncnn::Mat& src = b[j];

int ret = 0
|| test_copyto(self, src, IntArrayMat(0, 0, 0), IntArrayMat(0, 1, 2))
|| test_copyto(self, src, IntArrayMat(13, 1, 0), IntArrayMat(-3, -2, -1))
|| test_copyto(self, src, IntArrayMat(28, 3, 4), IntArrayMat(0, 1, 2))
|| test_copyto(self, src, IntArrayMat(32, 0, 5), IntArrayMat(0, 1, 2));
|| test_copyto(self, src, IntArray(0, 0, 0), IntArray(0, 1, 2))
|| test_copyto(self, src, IntArray(13, 1, 0), IntArray(-3, -2, -1))
|| test_copyto(self, src, IntArray(28, 3, 4), IntArray(0, 1, 2))
|| test_copyto(self, src, IntArray(32, 0, 5), IntArray(0, 1, 2));

if (ret != 0)
return ret;
@@ -231,10 +243,10 @@ static int test_copyto_3()
const ncnn::Mat& src = b[j];

int ret = 0
|| test_copyto(self, src, IntArrayMat(0, 0, 0, 0), IntArrayMat(0, 1, 2, 3))
|| test_copyto(self, src, IntArrayMat(13, 1, 1, 0), IntArrayMat(-4, -3, 2, 3))
|| test_copyto(self, src, IntArrayMat(28, 0, 3, 4), IntArrayMat(0, 1, 2, 3))
|| test_copyto(self, src, IntArrayMat(32, 2, 0, 5), IntArrayMat(0, 1, 2, 3));
|| test_copyto(self, src, IntArray(0, 0, 0, 0), IntArray(0, 1, 2, 3))
|| test_copyto(self, src, IntArray(13, 1, 1, 0), IntArray(-4, -3, 2, 3))
|| test_copyto(self, src, IntArray(28, 0, 3, 4), IntArray(0, 1, 2, 3))
|| test_copyto(self, src, IntArray(32, 2, 0, 5), IntArray(0, 1, 2, 3));

if (ret != 0)
return ret;


+ 306
- 295
tests/test_crop_1.cpp View File

@@ -14,58 +14,79 @@

#include "testutil.h"

static ncnn::Mat IntArrayMat(int a0)
static std::vector<int> IntArray(int a0)
{
ncnn::Mat m(1);
int* p = m;
p[0] = a0;
std::vector<int> m(1);
m[0] = a0;
return m;
}

static ncnn::Mat IntArrayMat(int a0, int a1)
static std::vector<int> IntArray(int a0, int a1)
{
ncnn::Mat m(2);
int* p = m;
p[0] = a0;
p[1] = a1;
std::vector<int> m(2);
m[0] = a0;
m[1] = a1;
return m;
}

static ncnn::Mat IntArrayMat(int a0, int a1, int a2)
static std::vector<int> IntArray(int a0, int a1, int a2)
{
ncnn::Mat m(3);
int* p = m;
p[0] = a0;
p[1] = a1;
p[2] = a2;
std::vector<int> m(3);
m[0] = a0;
m[1] = a1;
m[2] = a2;
return m;
}

static ncnn::Mat IntArrayMat(int a0, int a1, int a2, int a3)
static std::vector<int> IntArray(int a0, int a1, int a2, int a3)
{
ncnn::Mat m(4);
int* p = m;
p[0] = a0;
p[1] = a1;
p[2] = a2;
p[3] = a3;
std::vector<int> m(4);
m[0] = a0;
m[1] = a1;
m[2] = a2;
m[3] = a3;
return m;
}

static void print_int_array(const ncnn::Mat& a)
static void print_int_array(const std::vector<int>& a)
{
const int* pa = a;

fprintf(stderr, "[");
for (int i = 0; i < a.w; i++)
for (size_t i = 0; i < a.size(); i++)
{
fprintf(stderr, " %d", pa[i]);
fprintf(stderr, " %d", a[i]);
}
fprintf(stderr, " ]");
}

static int test_crop(const ncnn::Mat& a, const ncnn::Mat& starts, const ncnn::Mat& ends, const ncnn::Mat& axes)
static int test_crop(const ncnn::Mat& a, const std::vector<int>& starts_array, const std::vector<int>& ends_array, const std::vector<int>& axes_array)
{
ncnn::Mat starts(starts_array.size());
{
int* p = starts;
for (size_t i = 0; i < starts_array.size(); i++)
{
p[i] = starts_array[i];
}
}

ncnn::Mat ends(ends_array.size());
{
int* p = ends;
for (size_t i = 0; i < ends_array.size(); i++)
{
p[i] = ends_array[i];
}
}

ncnn::Mat axes(axes_array.size());
{
int* p = axes;
for (size_t i = 0; i < axes_array.size(); i++)
{
p[i] = axes_array[i];
}
}

ncnn::ParamDict pd;
pd.set(9, starts); // starts
pd.set(10, ends); // ends
@@ -78,282 +99,272 @@ static int test_crop(const ncnn::Mat& a, const ncnn::Mat& starts, const ncnn::Ma
{
fprintf(stderr, "test_crop failed a.dims=%d a=(%d %d %d %d)", a.dims, a.w, a.h, a.d, a.c);
fprintf(stderr, " starts=");
print_int_array(starts);
print_int_array(starts_array);
fprintf(stderr, " ends=");
print_int_array(ends);
print_int_array(ends_array);
fprintf(stderr, " axes=");
print_int_array(axes);
print_int_array(axes_array);
fprintf(stderr, "\n");
}

return ret;
}

static int test_crop_1(const ncnn::Mat& a)
static int test_crop_1d(const ncnn::Mat& a)
{
return 0
|| test_crop(a, IntArrayMat(12), IntArrayMat(-233), IntArrayMat(0))
|| test_crop(a, IntArrayMat(16), IntArrayMat(-233), IntArrayMat(0))
|| test_crop(a, IntArrayMat(11), IntArrayMat(11 + 16), IntArrayMat(0))
|| test_crop(a, IntArrayMat(12), IntArrayMat(12 + 7), IntArrayMat(-1))
|| test_crop(a, IntArrayMat(16), IntArrayMat(16 + 12), ncnn::Mat())
|| test_crop(a, IntArrayMat(11), IntArrayMat(-7 + 1), IntArrayMat(0))
|| test_crop(a, IntArrayMat(12), IntArrayMat(-12 + 1), IntArrayMat(-1))
|| test_crop(a, IntArrayMat(16), IntArrayMat(-16 + 1), ncnn::Mat());
std::vector<int> params[][3] = {
{IntArray(12), IntArray(-233), IntArray(0)},
{IntArray(16), IntArray(-233), IntArray(0)},
{IntArray(11), IntArray(11 + 16), IntArray(0)},
{IntArray(12), IntArray(12 + 7), IntArray(-1)},
{IntArray(16), IntArray(16 + 12), std::vector<int>()},
{IntArray(11), IntArray(-7 + 1), IntArray(0)},
{IntArray(12), IntArray(-12 + 1), IntArray(-1)},
{IntArray(16), IntArray(-16 + 1), std::vector<int>()}
};

for (int i = 0; i < sizeof(params) / sizeof(params[0]); i++)
{
int ret = test_crop(a, params[i][0], params[i][1], params[i][2]);
if (ret)
return ret;
}

return 0;
}

static int test_crop_4(const ncnn::Mat& a)
static int test_crop_2d(const ncnn::Mat& a)
{
return 0
|| test_crop(a, IntArrayMat(12), IntArrayMat(-233), IntArrayMat(0))
|| test_crop(a, IntArrayMat(8), IntArrayMat(-233), IntArrayMat(0))
|| test_crop(a, IntArrayMat(4), IntArrayMat(-233), IntArrayMat(1))
|| test_crop(a, IntArrayMat(5, 11), IntArrayMat(-233, -233), IntArrayMat(0, 1))

|| test_crop(a, IntArrayMat(11), IntArrayMat(11 + 16), IntArrayMat(0))
|| test_crop(a, IntArrayMat(12), IntArrayMat(12 + 7), IntArrayMat(0))
|| test_crop(a, IntArrayMat(8), IntArrayMat(8 + 12), IntArrayMat(-2))

|| test_crop(a, IntArrayMat(5), IntArrayMat(8), IntArrayMat(1))
|| test_crop(a, IntArrayMat(6), IntArrayMat(9), IntArrayMat(1))
|| test_crop(a, IntArrayMat(4), IntArrayMat(12), IntArrayMat(-1))

|| test_crop(a, IntArrayMat(11, 5), IntArrayMat(11 + 7, 11), IntArrayMat(0, 1))
|| test_crop(a, IntArrayMat(12, 6), IntArrayMat(12 + 12, 12), IntArrayMat(0, 1))
|| test_crop(a, IntArrayMat(8, 4), IntArrayMat(8 + 16, 10), IntArrayMat(0, -1))

|| test_crop(a, IntArrayMat(11), IntArrayMat(-16 + 1), IntArrayMat(0))
|| test_crop(a, IntArrayMat(12), IntArrayMat(-7 + 1), IntArrayMat(0))
|| test_crop(a, IntArrayMat(8), IntArrayMat(-12 + 1), IntArrayMat(-2))

|| test_crop(a, IntArrayMat(5), IntArrayMat(-5 + 1), IntArrayMat(1))
|| test_crop(a, IntArrayMat(6), IntArrayMat(-6 + 1), IntArrayMat(1))
|| test_crop(a, IntArrayMat(4), IntArrayMat(-4 + 1), IntArrayMat(-1))
std::vector<int> params[][3] = {
{IntArray(12), IntArray(-233), IntArray(0)},
{IntArray(8), IntArray(-233), IntArray(0)},
{IntArray(4), IntArray(-233), IntArray(1)},
{IntArray(5, 11), IntArray(-233, -233), IntArray(0, 1)},
{IntArray(11), IntArray(11 + 16), IntArray(0)},
{IntArray(12), IntArray(12 + 7), IntArray(0)},
{IntArray(8), IntArray(8 + 12), IntArray(-2)},
{IntArray(5), IntArray(8), IntArray(1)},
{IntArray(6), IntArray(9), IntArray(1)},
{IntArray(4), IntArray(12), IntArray(-1)},
{IntArray(11, 5), IntArray(11 + 7, 11), IntArray(0, 1)},
{IntArray(12, 6), IntArray(12 + 12, 12), IntArray(0, 1)},
{IntArray(8, 4), IntArray(8 + 16, 10), IntArray(0, -1)},
{IntArray(11), IntArray(-16 + 1), IntArray(0)},
{IntArray(12), IntArray(-7 + 1), IntArray(0)},
{IntArray(8), IntArray(-12 + 1), IntArray(-2)},
{IntArray(5), IntArray(-5 + 1), IntArray(1)},
{IntArray(6), IntArray(-6 + 1), IntArray(1)},
{IntArray(4), IntArray(-4 + 1), IntArray(-1)},
{IntArray(11, 5), IntArray(-12 + 1, -6 + 1), IntArray(0, 1)},
{IntArray(12, 6), IntArray(-16 + 1, -5 + 1), IntArray(0, 1)},
{IntArray(8, 4), IntArray(-7 + 1, -4 + 1), IntArray(-2, -1)}
};

for (int i = 0; i < sizeof(params) / sizeof(params[0]); i++)
{
int ret = test_crop(a, params[i][0], params[i][1], params[i][2]);
if (ret)
return ret;
}

|| test_crop(a, IntArrayMat(11, 5), IntArrayMat(-12 + 1, -6 + 1), IntArrayMat(0, 1))
|| test_crop(a, IntArrayMat(12, 6), IntArrayMat(-16 + 1, -5 + 1), IntArrayMat(0, 1))
|| test_crop(a, IntArrayMat(8, 4), IntArrayMat(-7 + 1, -4 + 1), IntArrayMat(-2, -1));
return 0;
}

static int test_crop_7(const ncnn::Mat& a)
static int test_crop_3d(const ncnn::Mat& a)
{
return 0
|| test_crop(a, IntArrayMat(11), IntArrayMat(-233), IntArrayMat(0))
|| test_crop(a, IntArrayMat(8), IntArrayMat(-233), IntArrayMat(0))
|| test_crop(a, IntArrayMat(5), IntArrayMat(-233), IntArrayMat(1))
|| test_crop(a, IntArrayMat(6), IntArrayMat(-233), IntArrayMat(2))
|| test_crop(a, IntArrayMat(4), IntArrayMat(-233), IntArrayMat(-1))
|| test_crop(a, IntArrayMat(12, 6), IntArrayMat(-233, -233), IntArrayMat(0, 1))
|| test_crop(a, IntArrayMat(11, 5), IntArrayMat(-233, -233), IntArrayMat(0, -1))
|| test_crop(a, IntArrayMat(8, 4), IntArrayMat(-233, -233), IntArrayMat(0, 2))
|| test_crop(a, IntArrayMat(6, 6), IntArrayMat(-233, -233), IntArrayMat(1, -1))
|| test_crop(a, IntArrayMat(11, 5, 5), IntArrayMat(-233, -233, -233), IntArrayMat(0, 1, 2))
|| test_crop(a, IntArrayMat(8, 4, 4), IntArrayMat(-233, -233, -233), IntArrayMat(0, 1, -1))

|| test_crop(a, IntArrayMat(11), IntArrayMat(11 + 7), IntArrayMat(0))
|| test_crop(a, IntArrayMat(12), IntArrayMat(12 + 12), IntArrayMat(0))
|| test_crop(a, IntArrayMat(8), IntArrayMat(8 + 16), IntArrayMat(0))

|| test_crop(a, IntArrayMat(5), IntArrayMat(13), IntArrayMat(1))
|| test_crop(a, IntArrayMat(6), IntArrayMat(12), IntArrayMat(1))
|| test_crop(a, IntArrayMat(4), IntArrayMat(11), IntArrayMat(-2))

|| test_crop(a, IntArrayMat(5), IntArrayMat(12), IntArrayMat(2))
|| test_crop(a, IntArrayMat(6), IntArrayMat(11), IntArrayMat(2))
|| test_crop(a, IntArrayMat(4), IntArrayMat(13), IntArrayMat(-1))

|| test_crop(a, IntArrayMat(11, 5), IntArrayMat(11 + 7, 11), IntArrayMat(0, 1))
|| test_crop(a, IntArrayMat(12, 6), IntArrayMat(12 + 16, 12), IntArrayMat(0, 1))
|| test_crop(a, IntArrayMat(8, 4), IntArrayMat(8 + 12, 13), IntArrayMat(0, -2))

|| test_crop(a, IntArrayMat(11, 5), IntArrayMat(11 + 16, 13), IntArrayMat(0, 2))
|| test_crop(a, IntArrayMat(12, 6), IntArrayMat(12 + 12, 11), IntArrayMat(0, 2))
|| test_crop(a, IntArrayMat(8, 4), IntArrayMat(8 + 7, 12), IntArrayMat(0, -1))

|| test_crop(a, IntArrayMat(5, 4), IntArrayMat(12, 12), IntArrayMat(1, 2))
|| test_crop(a, IntArrayMat(6, 3), IntArrayMat(13, 13), IntArrayMat(1, 2))
|| test_crop(a, IntArrayMat(4, 2), IntArrayMat(11, 11), IntArrayMat(-2, -1))

|| test_crop(a, IntArrayMat(11, 5, 2), IntArrayMat(11 + 7, 11, 11), IntArrayMat(0, 1, 2))
|| test_crop(a, IntArrayMat(12, 6, 4), IntArrayMat(12 + 16, 12, 12), IntArrayMat(0, 1, 2))
|| test_crop(a, IntArrayMat(8, 4, 3), IntArrayMat(8 + 12, 13, 13), IntArrayMat(-3, -2, -1))

|| test_crop(a, IntArrayMat(11), IntArrayMat(-7 + 1), IntArrayMat(0))
|| test_crop(a, IntArrayMat(12), IntArrayMat(-12 + 1), IntArrayMat(0))
|| test_crop(a, IntArrayMat(8), IntArrayMat(-16 + 1), IntArrayMat(-3))

|| test_crop(a, IntArrayMat(5), IntArrayMat(-6 + 1), IntArrayMat(1))
|| test_crop(a, IntArrayMat(6), IntArrayMat(-5 + 1), IntArrayMat(1))
|| test_crop(a, IntArrayMat(4), IntArrayMat(-4 + 1), IntArrayMat(-2))

|| test_crop(a, IntArrayMat(5), IntArrayMat(-5 + 1), IntArrayMat(2))
|| test_crop(a, IntArrayMat(6), IntArrayMat(-4 + 1), IntArrayMat(2))
|| test_crop(a, IntArrayMat(4), IntArrayMat(-6 + 1), IntArrayMat(-1))

|| test_crop(a, IntArrayMat(11, 5), IntArrayMat(-7 + 1, -4 + 1), IntArrayMat(0, 1))
|| test_crop(a, IntArrayMat(12, 6), IntArrayMat(-12 + 1, -6 + 1), IntArrayMat(0, 1))
|| test_crop(a, IntArrayMat(8, 4), IntArrayMat(-16 + 1, -5 + 1), IntArrayMat(-3, -2))

|| test_crop(a, IntArrayMat(11, 5), IntArrayMat(-12 + 1, -6 + 1), IntArrayMat(0, 2))
|| test_crop(a, IntArrayMat(12, 6), IntArrayMat(-16 + 1, -5 + 1), IntArrayMat(0, 2))
|| test_crop(a, IntArrayMat(8, 4), IntArrayMat(-7 + 1, -4 + 1), IntArrayMat(-3, -1))

|| test_crop(a, IntArrayMat(5, 2), IntArrayMat(-5 + 1, -5 + 1), IntArrayMat(1, 2))
|| test_crop(a, IntArrayMat(6, 4), IntArrayMat(-4 + 1, -4 + 1), IntArrayMat(1, 2))
|| test_crop(a, IntArrayMat(4, 3), IntArrayMat(-6 + 1, -6 + 1), IntArrayMat(-2, -1))
std::vector<int> params[][3] = {
{IntArray(11), IntArray(-233), IntArray(0)},
{IntArray(8), IntArray(-233), IntArray(0)},
{IntArray(5), IntArray(-233), IntArray(1)},
{IntArray(6), IntArray(-233), IntArray(2)},
{IntArray(4), IntArray(-233), IntArray(-1)},
{IntArray(12, 6), IntArray(-233, -233), IntArray(0, 1)},
{IntArray(11, 5), IntArray(-233, -233), IntArray(0, -1)},
{IntArray(8, 4), IntArray(-233, -233), IntArray(0, 2)},
{IntArray(6, 6), IntArray(-233, -233), IntArray(1, -1)},
{IntArray(11, 5, 5), IntArray(-233, -233, -233), IntArray(0, 1, 2)},
{IntArray(8, 4, 4), IntArray(-233, -233, -233), IntArray(0, 1, -1)},
{IntArray(11), IntArray(11 + 7), IntArray(0)},
{IntArray(12), IntArray(12 + 12), IntArray(0)},
{IntArray(8), IntArray(8 + 16), IntArray(0)},
{IntArray(5), IntArray(13), IntArray(1)},
{IntArray(6), IntArray(12), IntArray(1)},
{IntArray(4), IntArray(11), IntArray(-2)},
{IntArray(5), IntArray(12), IntArray(2)},
{IntArray(6), IntArray(11), IntArray(2)},
{IntArray(4), IntArray(13), IntArray(-1)},
{IntArray(11, 5), IntArray(11 + 7, 11), IntArray(0, 1)},
{IntArray(12, 6), IntArray(12 + 16, 12), IntArray(0, 1)},
{IntArray(8, 4), IntArray(8 + 12, 13), IntArray(0, -2)},
{IntArray(11, 5), IntArray(11 + 16, 13), IntArray(0, 2)},
{IntArray(12, 6), IntArray(12 + 12, 11), IntArray(0, 2)},
{IntArray(8, 4), IntArray(8 + 7, 12), IntArray(0, -1)},
{IntArray(5, 4), IntArray(12, 12), IntArray(1, 2)},
{IntArray(6, 3), IntArray(13, 13), IntArray(1, 2)},
{IntArray(4, 2), IntArray(11, 11), IntArray(-2, -1)},
{IntArray(11, 5, 2), IntArray(11 + 7, 11, 11), IntArray(0, 1, 2)},
{IntArray(12, 6, 4), IntArray(12 + 16, 12, 12), IntArray(0, 1, 2)},
{IntArray(8, 4, 3), IntArray(8 + 12, 13, 13), IntArray(-3, -2, -1)},
{IntArray(11), IntArray(-7 + 1), IntArray(0)},
{IntArray(12), IntArray(-12 + 1), IntArray(0)},
{IntArray(8), IntArray(-16 + 1), IntArray(-3)},
{IntArray(5), IntArray(-6 + 1), IntArray(1)},
{IntArray(6), IntArray(-5 + 1), IntArray(1)},
{IntArray(4), IntArray(-4 + 1), IntArray(-2)},
{IntArray(5), IntArray(-5 + 1), IntArray(2)},
{IntArray(6), IntArray(-4 + 1), IntArray(2)},
{IntArray(4), IntArray(-6 + 1), IntArray(-1)},
{IntArray(11, 5), IntArray(-7 + 1, -4 + 1), IntArray(0, 1)},
{IntArray(12, 6), IntArray(-12 + 1, -6 + 1), IntArray(0, 1)},
{IntArray(8, 4), IntArray(-16 + 1, -5 + 1), IntArray(-3, -2)},
{IntArray(11, 5), IntArray(-12 + 1, -6 + 1), IntArray(0, 2)},
{IntArray(12, 6), IntArray(-16 + 1, -5 + 1), IntArray(0, 2)},
{IntArray(8, 4), IntArray(-7 + 1, -4 + 1), IntArray(-3, -1)},
{IntArray(5, 2), IntArray(-5 + 1, -5 + 1), IntArray(1, 2)},
{IntArray(6, 4), IntArray(-4 + 1, -4 + 1), IntArray(1, 2)},
{IntArray(4, 3), IntArray(-6 + 1, -6 + 1), IntArray(-2, -1)},
{IntArray(11, 5, 4), IntArray(-7 + 1, -5 + 1, -5 + 1), IntArray(0, 1, 2)},
{IntArray(12, 6, 3), IntArray(-12 + 1, -6 + 1, -6 + 1), IntArray(0, 1, 2)},
{IntArray(8, 4, 2), IntArray(-16 + 1, -4 + 1, -4 + 1), IntArray(-3, -2, -1)}
};

for (int i = 0; i < sizeof(params) / sizeof(params[0]); i++)
{
int ret = test_crop(a, params[i][0], params[i][1], params[i][2]);
if (ret)
return ret;
}

|| test_crop(a, IntArrayMat(11, 5, 4), IntArrayMat(-7 + 1, -5 + 1, -5 + 1), IntArrayMat(0, 1, 2))
|| test_crop(a, IntArrayMat(12, 6, 3), IntArrayMat(-12 + 1, -6 + 1, -6 + 1), IntArrayMat(0, 1, 2))
|| test_crop(a, IntArrayMat(8, 4, 2), IntArrayMat(-16 + 1, -4 + 1, -4 + 1), IntArrayMat(-3, -2, -1));
return 0;
}

static int test_crop_10(const ncnn::Mat& a)
static int test_crop_4d(const ncnn::Mat& a)
{
return 0
|| test_crop(a, IntArrayMat(11), IntArrayMat(-233), IntArrayMat(0))
|| test_crop(a, IntArrayMat(8), IntArrayMat(-233), IntArrayMat(0))
|| test_crop(a, IntArrayMat(6), IntArrayMat(-233), IntArrayMat(1))
|| test_crop(a, IntArrayMat(5), IntArrayMat(-233), IntArrayMat(2))
|| test_crop(a, IntArrayMat(4), IntArrayMat(-233), IntArrayMat(-2))
|| test_crop(a, IntArrayMat(6), IntArrayMat(-233), IntArrayMat(3))
|| test_crop(a, IntArrayMat(5), IntArrayMat(-233), IntArrayMat(-1))
|| test_crop(a, IntArrayMat(8, 4), IntArrayMat(-233, -233), IntArrayMat(0, 1))
|| test_crop(a, IntArrayMat(12, 6), IntArrayMat(-233, -233), IntArrayMat(0, 2))
|| test_crop(a, IntArrayMat(11, 5), IntArrayMat(-233, -233), IntArrayMat(-4, -2))
|| test_crop(a, IntArrayMat(4, 4), IntArrayMat(-233, -233), IntArrayMat(1, 2))
|| test_crop(a, IntArrayMat(12, 6), IntArrayMat(-233, -233), IntArrayMat(0, 3))
|| test_crop(a, IntArrayMat(5, 5), IntArrayMat(-233, -233), IntArrayMat(1, 3))
|| test_crop(a, IntArrayMat(4, 4), IntArrayMat(-233, -233), IntArrayMat(2, 3))
|| test_crop(a, IntArrayMat(12, 6, 6), IntArrayMat(-233, -233, -233), IntArrayMat(0, 1, 2))
|| test_crop(a, IntArrayMat(11, 5, 5), IntArrayMat(-233, -233, -233), IntArrayMat(0, 1, 2))
|| test_crop(a, IntArrayMat(8, 4, 4), IntArrayMat(-233, -233, -233), IntArrayMat(0, 1, 3))
|| test_crop(a, IntArrayMat(12, 6, 6), IntArrayMat(-233, -233, -233), IntArrayMat(0, 2, 3))
|| test_crop(a, IntArrayMat(11, 5, 5), IntArrayMat(-233, -233, -233), IntArrayMat(0, 2, 3))
|| test_crop(a, IntArrayMat(4, 4, 4), IntArrayMat(-233, -233, -233), IntArrayMat(1, 2, 3))
|| test_crop(a, IntArrayMat(6, 6, 6), IntArrayMat(-233, -233, -233), IntArrayMat(1, 2, 3))
|| test_crop(a, IntArrayMat(11, 5, 5, 5), IntArrayMat(-233, -233, -233, -233), IntArrayMat(0, 1, 2, 3))
|| test_crop(a, IntArrayMat(8, 4, 4, 4), IntArrayMat(-233, -233, -233, -233), IntArrayMat(0, 1, 2, 3))
|| test_crop(a, IntArrayMat(12, 6, 6, 6), IntArrayMat(-233, -233, -233, -233), IntArrayMat(-4, -3, -2, -1))

|| test_crop(a, IntArrayMat(11), IntArrayMat(11 + 16), IntArrayMat(0))
|| test_crop(a, IntArrayMat(12), IntArrayMat(12 + 7), IntArrayMat(0))
|| test_crop(a, IntArrayMat(8), IntArrayMat(8 + 12), IntArrayMat(-4))

|| test_crop(a, IntArrayMat(5), IntArrayMat(11), IntArrayMat(1))
|| test_crop(a, IntArrayMat(6), IntArrayMat(13), IntArrayMat(1))
|| test_crop(a, IntArrayMat(4), IntArrayMat(12), IntArrayMat(-3))

|| test_crop(a, IntArrayMat(3), IntArrayMat(12), IntArrayMat(2))
|| test_crop(a, IntArrayMat(4), IntArrayMat(13), IntArrayMat(2))
|| test_crop(a, IntArrayMat(5), IntArrayMat(11), IntArrayMat(-2))

|| test_crop(a, IntArrayMat(1), IntArrayMat(8), IntArrayMat(3))
|| test_crop(a, IntArrayMat(2), IntArrayMat(7), IntArrayMat(3))
|| test_crop(a, IntArrayMat(3), IntArrayMat(6), IntArrayMat(-1))

|| test_crop(a, IntArrayMat(11, 5), IntArrayMat(11 + 7, 11), IntArrayMat(0, 1))
|| test_crop(a, IntArrayMat(12, 6), IntArrayMat(12 + 12, 12), IntArrayMat(0, 1))
|| test_crop(a, IntArrayMat(8, 4), IntArrayMat(8 + 16, 13), IntArrayMat(-4, -3))

|| test_crop(a, IntArrayMat(11, 4), IntArrayMat(11 + 12, 13), IntArrayMat(0, 2))
|| test_crop(a, IntArrayMat(12, 3), IntArrayMat(12 + 16, 11), IntArrayMat(0, 2))
|| test_crop(a, IntArrayMat(8, 2), IntArrayMat(8 + 7, 12), IntArrayMat(-4, -2))

|| test_crop(a, IntArrayMat(11, 1), IntArrayMat(11 + 16, 5), IntArrayMat(0, 3))
|| test_crop(a, IntArrayMat(12, 2), IntArrayMat(12 + 7, 6), IntArrayMat(0, 3))
|| test_crop(a, IntArrayMat(8, 3), IntArrayMat(8 + 12, 7), IntArrayMat(-4, -1))

|| test_crop(a, IntArrayMat(3, 3), IntArrayMat(13, 4), IntArrayMat(1, 2))
|| test_crop(a, IntArrayMat(4, 2), IntArrayMat(12, 3), IntArrayMat(1, 2))
|| test_crop(a, IntArrayMat(5, 1), IntArrayMat(11, 2), IntArrayMat(-3, -2))

|| test_crop(a, IntArrayMat(5, 5), IntArrayMat(11, 8), IntArrayMat(1, 3))
|| test_crop(a, IntArrayMat(4, 6), IntArrayMat(12, 9), IntArrayMat(1, 3))
|| test_crop(a, IntArrayMat(3, 4), IntArrayMat(13, 7), IntArrayMat(-3, -1))

|| test_crop(a, IntArrayMat(2, 3), IntArrayMat(12, 9), IntArrayMat(2, 3))
|| test_crop(a, IntArrayMat(3, 2), IntArrayMat(11, 7), IntArrayMat(2, 3))
|| test_crop(a, IntArrayMat(4, 1), IntArrayMat(10, 8), IntArrayMat(-2, -1))

|| test_crop(a, IntArrayMat(11, 2, 2), IntArrayMat(11 + 6, 9, 9), IntArrayMat(0, 1, 2))
|| test_crop(a, IntArrayMat(12, 3, 3), IntArrayMat(12 + 1, 10, 10), IntArrayMat(0, 1, 2))
|| test_crop(a, IntArrayMat(8, 4, 4), IntArrayMat(8 + 3, 11, 11), IntArrayMat(-4, -3, -2))

|| test_crop(a, IntArrayMat(11, 4, 4), IntArrayMat(11 + 12, 12, 12), IntArrayMat(0, 1, 3))
|| test_crop(a, IntArrayMat(12, 5, 5), IntArrayMat(12 + 8, 11, 11), IntArrayMat(0, 1, 3))
|| test_crop(a, IntArrayMat(8, 6, 6), IntArrayMat(8 + 4, 13, 13), IntArrayMat(-4, -3, -1))

|| test_crop(a, IntArrayMat(11, 1, 4), IntArrayMat(11 + 5, 12, 12), IntArrayMat(0, 2, 3))
|| test_crop(a, IntArrayMat(12, 3, 3), IntArrayMat(12 + 3, 11, 11), IntArrayMat(0, 2, 3))
|| test_crop(a, IntArrayMat(8, 2, 5), IntArrayMat(8 + 2, 10, 10), IntArrayMat(-4, -2, -1))

|| test_crop(a, IntArrayMat(1, 1, 1), IntArrayMat(7, 7, 7), IntArrayMat(1, 2, 3))
|| test_crop(a, IntArrayMat(2, 2, 2), IntArrayMat(8, 9, 10), IntArrayMat(1, 2, 3))
|| test_crop(a, IntArrayMat(3, 3, 3), IntArrayMat(11, 12, 13), IntArrayMat(-3, -2, -1))

|| test_crop(a, IntArrayMat(11, 2, 3, 6), IntArrayMat(11 + 11, 10, 12, 11), IntArrayMat(0, 1, 2, 3))
|| test_crop(a, IntArrayMat(12, 3, 4, 5), IntArrayMat(12 + 12, 9, 11, 13), IntArrayMat(0, 1, 2, 3))
|| test_crop(a, IntArrayMat(8, 4, 5, 4), IntArrayMat(8 + 8, 8, 10, 12), IntArrayMat(-4, -3, -2, -1))

|| test_crop(a, IntArrayMat(11), IntArrayMat(-7 + 1), IntArrayMat(0))
|| test_crop(a, IntArrayMat(12), IntArrayMat(-12 + 1), IntArrayMat(0))
|| test_crop(a, IntArrayMat(8), IntArrayMat(-16 + 1), IntArrayMat(-4))

|| test_crop(a, IntArrayMat(5), IntArrayMat(-6 + 1), IntArrayMat(1))
|| test_crop(a, IntArrayMat(6), IntArrayMat(-5 + 1), IntArrayMat(1))
|| test_crop(a, IntArrayMat(4), IntArrayMat(-4 + 1), IntArrayMat(-3))

|| test_crop(a, IntArrayMat(4), IntArrayMat(-4 + 1), IntArrayMat(2))
|| test_crop(a, IntArrayMat(5), IntArrayMat(-5 + 1), IntArrayMat(2))
|| test_crop(a, IntArrayMat(6), IntArrayMat(-6 + 1), IntArrayMat(-2))

|| test_crop(a, IntArrayMat(1), IntArrayMat(-5 + 1), IntArrayMat(3))
|| test_crop(a, IntArrayMat(2), IntArrayMat(-4 + 1), IntArrayMat(3))
|| test_crop(a, IntArrayMat(3), IntArrayMat(-3 + 1), IntArrayMat(-1))

|| test_crop(a, IntArrayMat(11, 3), IntArrayMat(-7 + 1, -3 + 1), IntArrayMat(0, 1))
|| test_crop(a, IntArrayMat(12, 4), IntArrayMat(-12 + 1, -4 + 1), IntArrayMat(0, 1))
|| test_crop(a, IntArrayMat(8, 5), IntArrayMat(-16 + 1, -5 + 1), IntArrayMat(-4, -3))

|| test_crop(a, IntArrayMat(11, 1), IntArrayMat(-12 + 1, -5 + 1), IntArrayMat(0, 2))
|| test_crop(a, IntArrayMat(12, 2), IntArrayMat(-16 + 1, -4 + 1), IntArrayMat(0, 2))
|| test_crop(a, IntArrayMat(8, 3), IntArrayMat(-7 + 1, -6 + 1), IntArrayMat(-4, -2))

|| test_crop(a, IntArrayMat(11, 3), IntArrayMat(-12 + 1, -2 + 1), IntArrayMat(0, 3))
|| test_crop(a, IntArrayMat(12, 4), IntArrayMat(-16 + 1, -3 + 1), IntArrayMat(0, 3))
|| test_crop(a, IntArrayMat(8, 5), IntArrayMat(-7 + 1, -4 + 1), IntArrayMat(-4, -1))

|| test_crop(a, IntArrayMat(2, 3), IntArrayMat(-4 + 1, -2 + 1), IntArrayMat(1, 2))
|| test_crop(a, IntArrayMat(3, 4), IntArrayMat(-2 + 1, -3 + 1), IntArrayMat(1, 2))
|| test_crop(a, IntArrayMat(4, 5), IntArrayMat(-3 + 1, -4 + 1), IntArrayMat(-3, -2))

|| test_crop(a, IntArrayMat(3, 2), IntArrayMat(-2 + 1, -4 + 1), IntArrayMat(1, 3))
|| test_crop(a, IntArrayMat(4, 3), IntArrayMat(-3 + 1, -2 + 1), IntArrayMat(1, 3))
|| test_crop(a, IntArrayMat(5, 4), IntArrayMat(-4 + 1, -3 + 1), IntArrayMat(-3, -1))

|| test_crop(a, IntArrayMat(2, 3), IntArrayMat(-4 + 1, -6 + 1), IntArrayMat(2, 3))
|| test_crop(a, IntArrayMat(1, 2), IntArrayMat(-5 + 1, -5 + 1), IntArrayMat(2, 3))
|| test_crop(a, IntArrayMat(3, 1), IntArrayMat(-6 + 1, -4 + 1), IntArrayMat(-2, -1))

|| test_crop(a, IntArrayMat(11, 3, 3), IntArrayMat(-7 + 1, -3 + 1, -4 + 1), IntArrayMat(0, 1, 2))
|| test_crop(a, IntArrayMat(12, 4, 4), IntArrayMat(-12 + 1, -4 + 1, -3 + 1), IntArrayMat(0, 1, 2))
|| test_crop(a, IntArrayMat(8, 5, 5), IntArrayMat(-16 + 1, -5 + 1, -5 + 1), IntArrayMat(-4, -3, -2))

|| test_crop(a, IntArrayMat(11, 2, 2), IntArrayMat(-7 + 1, -5 + 1, -4 + 1), IntArrayMat(0, 1, 3))
|| test_crop(a, IntArrayMat(12, 1, 1), IntArrayMat(-12 + 1, -6 + 1, -5 + 1), IntArrayMat(0, 1, 3))
|| test_crop(a, IntArrayMat(8, 3, 3), IntArrayMat(-16 + 1, -4 + 1, -6 + 1), IntArrayMat(-4, -3, -1))

|| test_crop(a, IntArrayMat(11, 2, 5), IntArrayMat(-7 + 1, -2 + 1, -5 + 1), IntArrayMat(0, 2, 3))
|| test_crop(a, IntArrayMat(12, 3, 3), IntArrayMat(-12 + 1, -3 + 1, -4 + 1), IntArrayMat(0, 2, 3))
|| test_crop(a, IntArrayMat(8, 4, 4), IntArrayMat(-16 + 1, -4 + 1, -3 + 1), IntArrayMat(-4, -2, -1))

|| test_crop(a, IntArrayMat(1, 3, 3), IntArrayMat(-3 + 1, -6 + 1, -4 + 1), IntArrayMat(1, 2, 3))
|| test_crop(a, IntArrayMat(2, 2, 2), IntArrayMat(-4 + 1, -4 + 1, -5 + 1), IntArrayMat(1, 2, 3))
|| test_crop(a, IntArrayMat(3, 1, 1), IntArrayMat(-5 + 1, -5 + 1, -6 + 1), IntArrayMat(-3, -2, -1))
std::vector<int> params[][3] = {
{IntArray(11), IntArray(-233), IntArray(0)},
{IntArray(8), IntArray(-233), IntArray(0)},
{IntArray(6), IntArray(-233), IntArray(1)},
{IntArray(5), IntArray(-233), IntArray(2)},
{IntArray(4), IntArray(-233), IntArray(-2)},
{IntArray(6), IntArray(-233), IntArray(3)},
{IntArray(5), IntArray(-233), IntArray(-1)},
{IntArray(8, 4), IntArray(-233, -233), IntArray(0, 1)},
{IntArray(12, 6), IntArray(-233, -233), IntArray(0, 2)},
{IntArray(11, 5), IntArray(-233, -233), IntArray(-4, -2)},
{IntArray(4, 4), IntArray(-233, -233), IntArray(1, 2)},
{IntArray(12, 6), IntArray(-233, -233), IntArray(0, 3)},
{IntArray(5, 5), IntArray(-233, -233), IntArray(1, 3)},
{IntArray(4, 4), IntArray(-233, -233), IntArray(2, 3)},
{IntArray(12, 6, 6), IntArray(-233, -233, -233), IntArray(0, 1, 2)},
{IntArray(11, 5, 5), IntArray(-233, -233, -233), IntArray(0, 1, 2)},
{IntArray(8, 4, 4), IntArray(-233, -233, -233), IntArray(0, 1, 3)},
{IntArray(12, 6, 6), IntArray(-233, -233, -233), IntArray(0, 2, 3)},
{IntArray(11, 5, 5), IntArray(-233, -233, -233), IntArray(0, 2, 3)},
{IntArray(4, 4, 4), IntArray(-233, -233, -233), IntArray(1, 2, 3)},
{IntArray(6, 6, 6), IntArray(-233, -233, -233), IntArray(1, 2, 3)},
{IntArray(11, 5, 5, 5), IntArray(-233, -233, -233, -233), IntArray(0, 1, 2, 3)},
{IntArray(8, 4, 4, 4), IntArray(-233, -233, -233, -233), IntArray(0, 1, 2, 3)},
{IntArray(12, 6, 6, 6), IntArray(-233, -233, -233, -233), IntArray(-4, -3, -2, -1)},
{IntArray(11), IntArray(11 + 16), IntArray(0)},
{IntArray(12), IntArray(12 + 7), IntArray(0)},
{IntArray(8), IntArray(8 + 12), IntArray(-4)},
{IntArray(5), IntArray(11), IntArray(1)},
{IntArray(6), IntArray(13), IntArray(1)},
{IntArray(4), IntArray(12), IntArray(-3)},
{IntArray(3), IntArray(12), IntArray(2)},
{IntArray(4), IntArray(13), IntArray(2)},
{IntArray(5), IntArray(11), IntArray(-2)},
{IntArray(1), IntArray(8), IntArray(3)},
{IntArray(2), IntArray(7), IntArray(3)},
{IntArray(3), IntArray(6), IntArray(-1)},
{IntArray(11, 5), IntArray(11 + 7, 11), IntArray(0, 1)},
{IntArray(12, 6), IntArray(12 + 12, 12), IntArray(0, 1)},
{IntArray(8, 4), IntArray(8 + 16, 13), IntArray(-4, -3)},
{IntArray(11, 4), IntArray(11 + 12, 13), IntArray(0, 2)},
{IntArray(12, 3), IntArray(12 + 16, 11), IntArray(0, 2)},
{IntArray(8, 2), IntArray(8 + 7, 12), IntArray(-4, -2)},
{IntArray(11, 1), IntArray(11 + 16, 5), IntArray(0, 3)},
{IntArray(12, 2), IntArray(12 + 7, 6), IntArray(0, 3)},
{IntArray(8, 3), IntArray(8 + 12, 7), IntArray(-4, -1)},
{IntArray(3, 3), IntArray(13, 4), IntArray(1, 2)},
{IntArray(4, 2), IntArray(12, 3), IntArray(1, 2)},
{IntArray(5, 1), IntArray(11, 2), IntArray(-3, -2)},
{IntArray(5, 5), IntArray(11, 8), IntArray(1, 3)},
{IntArray(4, 6), IntArray(12, 9), IntArray(1, 3)},
{IntArray(3, 4), IntArray(13, 7), IntArray(-3, -1)},
{IntArray(2, 3), IntArray(12, 9), IntArray(2, 3)},
{IntArray(3, 2), IntArray(11, 7), IntArray(2, 3)},
{IntArray(4, 1), IntArray(10, 8), IntArray(-2, -1)},
{IntArray(11, 2, 2), IntArray(11 + 6, 9, 9), IntArray(0, 1, 2)},
{IntArray(12, 3, 3), IntArray(12 + 1, 10, 10), IntArray(0, 1, 2)},
{IntArray(8, 4, 4), IntArray(8 + 3, 11, 11), IntArray(-4, -3, -2)},
{IntArray(11, 4, 4), IntArray(11 + 12, 12, 12), IntArray(0, 1, 3)},
{IntArray(12, 5, 5), IntArray(12 + 8, 11, 11), IntArray(0, 1, 3)},
{IntArray(8, 6, 6), IntArray(8 + 4, 13, 13), IntArray(-4, -3, -1)},
{IntArray(11, 1, 4), IntArray(11 + 5, 12, 12), IntArray(0, 2, 3)},
{IntArray(12, 3, 3), IntArray(12 + 3, 11, 11), IntArray(0, 2, 3)},
{IntArray(8, 2, 5), IntArray(8 + 2, 10, 10), IntArray(-4, -2, -1)},
{IntArray(1, 1, 1), IntArray(7, 7, 7), IntArray(1, 2, 3)},
{IntArray(2, 2, 2), IntArray(8, 9, 10), IntArray(1, 2, 3)},
{IntArray(3, 3, 3), IntArray(11, 12, 13), IntArray(-3, -2, -1)},
{IntArray(11, 2, 3, 6), IntArray(11 + 11, 10, 12, 11), IntArray(0, 1, 2, 3)},
{IntArray(12, 3, 4, 5), IntArray(12 + 12, 9, 11, 13), IntArray(0, 1, 2, 3)},
{IntArray(8, 4, 5, 4), IntArray(8 + 8, 8, 10, 12), IntArray(-4, -3, -2, -1)},
{IntArray(11), IntArray(-7 + 1), IntArray(0)},
{IntArray(12), IntArray(-12 + 1), IntArray(0)},
{IntArray(8), IntArray(-16 + 1), IntArray(-4)},
{IntArray(5), IntArray(-6 + 1), IntArray(1)},
{IntArray(6), IntArray(-5 + 1), IntArray(1)},
{IntArray(4), IntArray(-4 + 1), IntArray(-3)},
{IntArray(4), IntArray(-4 + 1), IntArray(2)},
{IntArray(5), IntArray(-5 + 1), IntArray(2)},
{IntArray(6), IntArray(-6 + 1), IntArray(-2)},
{IntArray(1), IntArray(-5 + 1), IntArray(3)},
{IntArray(2), IntArray(-4 + 1), IntArray(3)},
{IntArray(3), IntArray(-3 + 1), IntArray(-1)},
{IntArray(11, 3), IntArray(-7 + 1, -3 + 1), IntArray(0, 1)},
{IntArray(12, 4), IntArray(-12 + 1, -4 + 1), IntArray(0, 1)},
{IntArray(8, 5), IntArray(-16 + 1, -5 + 1), IntArray(-4, -3)},
{IntArray(11, 1), IntArray(-12 + 1, -5 + 1), IntArray(0, 2)},
{IntArray(12, 2), IntArray(-16 + 1, -4 + 1), IntArray(0, 2)},
{IntArray(8, 3), IntArray(-7 + 1, -6 + 1), IntArray(-4, -2)},
{IntArray(11, 3), IntArray(-12 + 1, -2 + 1), IntArray(0, 3)},
{IntArray(12, 4), IntArray(-16 + 1, -3 + 1), IntArray(0, 3)},
{IntArray(8, 5), IntArray(-7 + 1, -4 + 1), IntArray(-4, -1)},
{IntArray(2, 3), IntArray(-4 + 1, -2 + 1), IntArray(1, 2)},
{IntArray(3, 4), IntArray(-2 + 1, -3 + 1), IntArray(1, 2)},
{IntArray(4, 5), IntArray(-3 + 1, -4 + 1), IntArray(-3, -2)},
{IntArray(3, 2), IntArray(-2 + 1, -4 + 1), IntArray(1, 3)},
{IntArray(4, 3), IntArray(-3 + 1, -2 + 1), IntArray(1, 3)},
{IntArray(5, 4), IntArray(-4 + 1, -3 + 1), IntArray(-3, -1)},
{IntArray(2, 3), IntArray(-4 + 1, -6 + 1), IntArray(2, 3)},
{IntArray(1, 2), IntArray(-5 + 1, -5 + 1), IntArray(2, 3)},
{IntArray(3, 1), IntArray(-6 + 1, -4 + 1), IntArray(-2, -1)},
{IntArray(11, 3, 3), IntArray(-7 + 1, -3 + 1, -4 + 1), IntArray(0, 1, 2)},
{IntArray(12, 4, 4), IntArray(-12 + 1, -4 + 1, -3 + 1), IntArray(0, 1, 2)},
{IntArray(8, 5, 5), IntArray(-16 + 1, -5 + 1, -5 + 1), IntArray(-4, -3, -2)},
{IntArray(11, 2, 2), IntArray(-7 + 1, -5 + 1, -4 + 1), IntArray(0, 1, 3)},
{IntArray(12, 1, 1), IntArray(-12 + 1, -6 + 1, -5 + 1), IntArray(0, 1, 3)},
{IntArray(8, 3, 3), IntArray(-16 + 1, -4 + 1, -6 + 1), IntArray(-4, -3, -1)},
{IntArray(11, 2, 5), IntArray(-7 + 1, -2 + 1, -5 + 1), IntArray(0, 2, 3)},
{IntArray(12, 3, 3), IntArray(-12 + 1, -3 + 1, -4 + 1), IntArray(0, 2, 3)},
{IntArray(8, 4, 4), IntArray(-16 + 1, -4 + 1, -3 + 1), IntArray(-4, -2, -1)},
{IntArray(1, 3, 3), IntArray(-3 + 1, -6 + 1, -4 + 1), IntArray(1, 2, 3)},
{IntArray(2, 2, 2), IntArray(-4 + 1, -4 + 1, -5 + 1), IntArray(1, 2, 3)},
{IntArray(3, 1, 1), IntArray(-5 + 1, -5 + 1, -6 + 1), IntArray(-3, -2, -1)},
{IntArray(11, 3, 4, 4), IntArray(-7 + 1, -3 + 1, -2 + 1, -4 + 1), IntArray(0, 1, 2, 3)},
{IntArray(12, 4, 5, 3), IntArray(-12 + 1, -4 + 1, -3 + 1, -5 + 1), IntArray(0, 1, 2, 3)},
{IntArray(8, 5, 6, 2), IntArray(-16 + 1, -5 + 1, -4 + 1, -3 + 1), IntArray(-4, -3, -2, -1)}
};

for (int i = 0; i < sizeof(params) / sizeof(params[0]); i++)
{
int ret = test_crop(a, params[i][0], params[i][1], params[i][2]);
if (ret)
return ret;
}

|| test_crop(a, IntArrayMat(11, 3, 4, 4), IntArrayMat(-7 + 1, -3 + 1, -2 + 1, -4 + 1), IntArrayMat(0, 1, 2, 3))
|| test_crop(a, IntArrayMat(12, 4, 5, 3), IntArrayMat(-12 + 1, -4 + 1, -3 + 1, -5 + 1), IntArrayMat(0, 1, 2, 3))
|| test_crop(a, IntArrayMat(8, 5, 6, 2), IntArrayMat(-16 + 1, -5 + 1, -4 + 1, -3 + 1), IntArrayMat(-4, -3, -2, -1));
return 0;
}

int main()
@@ -361,16 +372,16 @@ int main()
SRAND(776757);

return 0
|| test_crop_1(RandomMat(112))
|| test_crop_1(RandomMat(126))
|| test_crop_1(RandomMat(127))
|| test_crop_4(RandomMat(20, 48))
|| test_crop_4(RandomMat(15, 36))
|| test_crop_4(RandomMat(16, 33))
|| test_crop_7(RandomMat(20, 20, 48))
|| test_crop_7(RandomMat(15, 15, 36))
|| test_crop_7(RandomMat(16, 16, 33))
|| test_crop_10(RandomMat(20, 20, 20, 48))
|| test_crop_10(RandomMat(15, 15, 15, 36))
|| test_crop_10(RandomMat(16, 16, 16, 33));
|| test_crop_1d(RandomMat(112))
|| test_crop_1d(RandomMat(126))
|| test_crop_1d(RandomMat(127))
|| test_crop_2d(RandomMat(20, 48))
|| test_crop_2d(RandomMat(15, 36))
|| test_crop_2d(RandomMat(16, 33))
|| test_crop_3d(RandomMat(20, 20, 48))
|| test_crop_3d(RandomMat(15, 15, 36))
|| test_crop_3d(RandomMat(16, 16, 33))
|| test_crop_4d(RandomMat(20, 20, 20, 48))
|| test_crop_4d(RandomMat(15, 15, 15, 36))
|| test_crop_4d(RandomMat(16, 16, 16, 33));
}

+ 47
- 44
tests/test_expanddims.cpp View File

@@ -33,58 +33,61 @@ static int test_expanddims(const ncnn::Mat& a, int expand_w, int expand_h, int e
return ret;
}

static ncnn::Mat IntArrayMat(int a0)
static std::vector<int> IntArray(int a0)
{
ncnn::Mat m(1);
int* p = m;
p[0] = a0;
std::vector<int> m(1);
m[0] = a0;
return m;
}

static ncnn::Mat IntArrayMat(int a0, int a1)
static std::vector<int> IntArray(int a0, int a1)
{
ncnn::Mat m(2);
int* p = m;
p[0] = a0;
p[1] = a1;
std::vector<int> m(2);
m[0] = a0;
m[1] = a1;
return m;
}

static ncnn::Mat IntArrayMat(int a0, int a1, int a2)
static std::vector<int> IntArray(int a0, int a1, int a2)
{
ncnn::Mat m(3);
int* p = m;
p[0] = a0;
p[1] = a1;
p[2] = a2;
std::vector<int> m(3);
m[0] = a0;
m[1] = a1;
m[2] = a2;
return m;
}

static ncnn::Mat IntArrayMat(int a0, int a1, int a2, int a3)
static std::vector<int> IntArray(int a0, int a1, int a2, int a3)
{
ncnn::Mat m(4);
int* p = m;
p[0] = a0;
p[1] = a1;
p[2] = a2;
p[3] = a3;
std::vector<int> m(4);
m[0] = a0;
m[1] = a1;
m[2] = a2;
m[3] = a3;
return m;
}

static void print_int_array(const ncnn::Mat& a)
static void print_int_array(const std::vector<int>& a)
{
const int* pa = a;

fprintf(stderr, "[");
for (int i = 0; i < a.w; i++)
for (size_t i = 0; i < a.size(); i++)
{
fprintf(stderr, " %d", pa[i]);
fprintf(stderr, " %d", a[i]);
}
fprintf(stderr, " ]");
}

static int test_expanddims_axes(const ncnn::Mat& a, const ncnn::Mat& axes)
static int test_expanddims_axes(const ncnn::Mat& a, const std::vector<int>& axes_array)
{
ncnn::Mat axes(axes_array.size());
{
int* p = axes;
for (size_t i = 0; i < axes_array.size(); i++)
{
p[i] = axes_array[i];
}
}

ncnn::ParamDict pd;
pd.set(3, axes);

@@ -95,7 +98,7 @@ static int test_expanddims_axes(const ncnn::Mat& a, const ncnn::Mat& axes)
{
fprintf(stderr, "test_expanddims_axes failed a.dims=%d a=(%d %d %d %d)\n", a.dims, a.w, a.h, a.d, a.c);
fprintf(stderr, " axes=");
print_int_array(axes);
print_int_array(axes_array);
fprintf(stderr, "\n");
}

@@ -122,21 +125,21 @@ static int test_expanddims_all_params(const ncnn::Mat& a)
|| test_expanddims(a, 1, 1, 1, 0)
|| test_expanddims(a, 1, 1, 1, 1)

|| test_expanddims_axes(a, IntArrayMat(0))
|| test_expanddims_axes(a, IntArrayMat(1))
|| test_expanddims_axes(a, IntArrayMat(2))
|| test_expanddims_axes(a, IntArrayMat(3))
|| test_expanddims_axes(a, IntArrayMat(0, 1))
|| test_expanddims_axes(a, IntArrayMat(0, 2))
|| test_expanddims_axes(a, IntArrayMat(0, 3))
|| test_expanddims_axes(a, IntArrayMat(1, 2))
|| test_expanddims_axes(a, IntArrayMat(1, 3))
|| test_expanddims_axes(a, IntArrayMat(2, 3))
|| test_expanddims_axes(a, IntArrayMat(0, 1, 2))
|| test_expanddims_axes(a, IntArrayMat(0, 1, 3))
|| test_expanddims_axes(a, IntArrayMat(0, 2, 3))
|| test_expanddims_axes(a, IntArrayMat(1, 2, 3))
|| test_expanddims_axes(a, IntArrayMat(0, 1, 2, 3));
|| test_expanddims_axes(a, IntArray(0))
|| test_expanddims_axes(a, IntArray(1))
|| test_expanddims_axes(a, IntArray(2))
|| test_expanddims_axes(a, IntArray(3))
|| test_expanddims_axes(a, IntArray(0, 1))
|| test_expanddims_axes(a, IntArray(0, 2))
|| test_expanddims_axes(a, IntArray(0, 3))
|| test_expanddims_axes(a, IntArray(1, 2))
|| test_expanddims_axes(a, IntArray(1, 3))
|| test_expanddims_axes(a, IntArray(2, 3))
|| test_expanddims_axes(a, IntArray(0, 1, 2))
|| test_expanddims_axes(a, IntArray(0, 1, 3))
|| test_expanddims_axes(a, IntArray(0, 2, 3))
|| test_expanddims_axes(a, IntArray(1, 2, 3))
|| test_expanddims_axes(a, IntArray(0, 1, 2, 3));
}

static int test_expanddims_0()


+ 116
- 245
tests/test_reduction.cpp View File

@@ -18,52 +18,46 @@

static int op_type = 0;

static ncnn::Mat IntArrayMat(int a0)
static std::vector<int> IntArray(int a0)
{
ncnn::Mat m(1);
int* p = m;
p[0] = a0;
std::vector<int> m(1);
m[0] = a0;
return m;
}

static ncnn::Mat IntArrayMat(int a0, int a1)
static std::vector<int> IntArray(int a0, int a1)
{
ncnn::Mat m(2);
int* p = m;
p[0] = a0;
p[1] = a1;
std::vector<int> m(2);
m[0] = a0;
m[1] = a1;
return m;
}

static ncnn::Mat IntArrayMat(int a0, int a1, int a2)
static std::vector<int> IntArray(int a0, int a1, int a2)
{
ncnn::Mat m(3);
int* p = m;
p[0] = a0;
p[1] = a1;
p[2] = a2;
std::vector<int> m(3);
m[0] = a0;
m[1] = a1;
m[2] = a2;
return m;
}

static ncnn::Mat IntArrayMat(int a0, int a1, int a2, int a3)
static std::vector<int> IntArray(int a0, int a1, int a2, int a3)
{
ncnn::Mat m(4);
int* p = m;
p[0] = a0;
p[1] = a1;
p[2] = a2;
p[3] = a3;
std::vector<int> m(4);
m[0] = a0;
m[1] = a1;
m[2] = a2;
m[3] = a3;
return m;
}

static void print_int_array(const ncnn::Mat& a)
static void print_int_array(const std::vector<int>& a)
{
const int* pa = a;

fprintf(stderr, "[");
for (int i = 0; i < a.w; i++)
for (size_t i = 0; i < a.size(); i++)
{
fprintf(stderr, " %d", pa[i]);
fprintf(stderr, " %d", a[i]);
}
fprintf(stderr, " ]");
}
@@ -94,7 +88,7 @@ static int test_reduction(const ncnn::Mat& _a, float coeff, int keepdims)
return ret;
}

static int test_reduction(const ncnn::Mat& _a, float coeff, int keepdims, const ncnn::Mat& axes)
static int test_reduction(const ncnn::Mat& _a, float coeff, int keepdims, const std::vector<int>& axes_array)
{
ncnn::Mat a = _a;
if (op_type == 9 || op_type == 10)
@@ -103,6 +97,15 @@ static int test_reduction(const ncnn::Mat& _a, float coeff, int keepdims, const
Randomize(a, 0.001f, 2.f);
}

ncnn::Mat axes(axes_array.size());
{
int* p = axes;
for (size_t i = 0; i < axes_array.size(); i++)
{
p[i] = axes_array[i];
}
}

ncnn::ParamDict pd;
pd.set(0, op_type);
pd.set(1, 0); // reduce_all
@@ -118,247 +121,115 @@ static int test_reduction(const ncnn::Mat& _a, float coeff, int keepdims, const
{
fprintf(stderr, "test_reduction failed a.dims=%d a=(%d %d %d %d) op_type=%d coeff=%f keepdims=%d", a.dims, a.w, a.h, a.d, a.c, op_type, coeff, keepdims);
fprintf(stderr, " axes=");
print_int_array(axes);
print_int_array(axes_array);
fprintf(stderr, "\n");
}

return ret;
}

static int test_reduction_nd(const ncnn::Mat& a)
{
int ret1 = 0
|| test_reduction(a, 1.f, 0)
|| test_reduction(a, 2.f, 0)
|| test_reduction(a, 1.f, 1)
|| test_reduction(a, 2.f, 1)
|| test_reduction(a, 1.f, 0, IntArray(0))
|| test_reduction(a, 1.f, 1, IntArray(0));

if (a.dims == 1 || ret1 != 0)
return ret1;

int ret2 = 0
|| test_reduction(a, 2.f, 0, IntArray(1))
|| test_reduction(a, 2.f, 1, IntArray(1))
|| test_reduction(a, 1.f, 0, IntArray(0, 1))
|| test_reduction(a, 1.f, 1, IntArray(0, 1));

if (a.dims == 2 || ret2 != 0)
return ret2;

int ret3 = 0
|| test_reduction(a, 1.f, 0, IntArray(2))
|| test_reduction(a, 1.f, 1, IntArray(2))
|| test_reduction(a, 2.f, 0, IntArray(0, 2))
|| test_reduction(a, 2.f, 0, IntArray(1, 2))
|| test_reduction(a, 2.f, 1, IntArray(0, 2))
|| test_reduction(a, 2.f, 1, IntArray(1, 2))
|| test_reduction(a, 1.f, 0, IntArray(0, 1, 2))
|| test_reduction(a, 1.f, 1, IntArray(0, 1, 2));

if (a.dims == 3 || ret3 != 0)
return ret3;

int ret4 = 0
|| test_reduction(a, 2.f, 0, IntArray(3))
|| test_reduction(a, 2.f, 1, IntArray(3))
|| test_reduction(a, 1.f, 0, IntArray(0, 3))
|| test_reduction(a, 1.f, 0, IntArray(1, 3))
|| test_reduction(a, 2.f, 0, IntArray(2, 3))
|| test_reduction(a, 1.f, 1, IntArray(0, 3))
|| test_reduction(a, 1.f, 1, IntArray(1, 3))
|| test_reduction(a, 2.f, 1, IntArray(2, 3))
|| test_reduction(a, 2.f, 0, IntArray(0, 1, 3))
|| test_reduction(a, 1.f, 0, IntArray(0, 2, 3))
|| test_reduction(a, 2.f, 0, IntArray(1, 2, 3))
|| test_reduction(a, 2.f, 1, IntArray(0, 1, 3))
|| test_reduction(a, 1.f, 1, IntArray(0, 2, 3))
|| test_reduction(a, 2.f, 1, IntArray(1, 2, 3))
|| test_reduction(a, 1.f, 0, IntArray(0, 1, 2, 3))
|| test_reduction(a, 1.f, 1, IntArray(0, 1, 2, 3));

return ret4;
}

static int test_reduction_0()
{
ncnn::Mat a = RandomMat(5, 6, 7, 24);
ncnn::Mat b = RandomMat(7, 8, 9, 12);
ncnn::Mat c = RandomMat(3, 4, 5, 13);

return 0
|| test_reduction(RandomMat(5, 6, 7, 24), 1.f, 0)
|| test_reduction(RandomMat(5, 6, 7, 24), 2.f, 0)
|| test_reduction(RandomMat(7, 8, 9, 12), 1.f, 0)
|| test_reduction(RandomMat(7, 8, 9, 12), 2.f, 0)
|| test_reduction(RandomMat(3, 4, 5, 13), 1.f, 0)
|| test_reduction(RandomMat(3, 4, 5, 13), 2.f, 0)

|| test_reduction(RandomMat(5, 6, 7, 24), 1.f, 1)
|| test_reduction(RandomMat(5, 6, 7, 24), 2.f, 1)
|| test_reduction(RandomMat(7, 8, 9, 12), 1.f, 1)
|| test_reduction(RandomMat(7, 8, 9, 12), 2.f, 1)
|| test_reduction(RandomMat(3, 4, 5, 13), 1.f, 1)
|| test_reduction(RandomMat(3, 4, 5, 13), 2.f, 1)

|| test_reduction(RandomMat(5, 6, 7, 24), 1.f, 0, IntArrayMat(0))
|| test_reduction(RandomMat(5, 6, 7, 24), 2.f, 0, IntArrayMat(1))
|| test_reduction(RandomMat(5, 6, 7, 24), 1.f, 0, IntArrayMat(2))
|| test_reduction(RandomMat(5, 6, 7, 24), 2.f, 0, IntArrayMat(3))
|| test_reduction(RandomMat(5, 6, 7, 24), 1.f, 0, IntArrayMat(0, 1))
|| test_reduction(RandomMat(5, 6, 7, 24), 2.f, 0, IntArrayMat(0, 2))
|| test_reduction(RandomMat(5, 6, 7, 24), 1.f, 0, IntArrayMat(0, 3))
|| test_reduction(RandomMat(5, 6, 7, 24), 2.f, 0, IntArrayMat(1, 2))
|| test_reduction(RandomMat(5, 6, 7, 24), 1.f, 0, IntArrayMat(1, 3))
|| test_reduction(RandomMat(5, 6, 7, 24), 2.f, 0, IntArrayMat(2, 3))
|| test_reduction(RandomMat(5, 6, 7, 24), 1.f, 0, IntArrayMat(0, 1, 2))
|| test_reduction(RandomMat(5, 6, 7, 24), 2.f, 0, IntArrayMat(0, 1, 3))
|| test_reduction(RandomMat(5, 6, 7, 24), 1.f, 0, IntArrayMat(0, 2, 3))
|| test_reduction(RandomMat(5, 6, 7, 24), 2.f, 0, IntArrayMat(1, 2, 3))
|| test_reduction(RandomMat(5, 6, 7, 24), 1.f, 0, IntArrayMat(0, 1, 2, 3))
|| test_reduction(RandomMat(7, 8, 9, 12), 1.f, 0, IntArrayMat(0))
|| test_reduction(RandomMat(7, 8, 9, 12), 2.f, 0, IntArrayMat(1))
|| test_reduction(RandomMat(7, 8, 9, 12), 1.f, 0, IntArrayMat(2))
|| test_reduction(RandomMat(7, 8, 9, 12), 2.f, 0, IntArrayMat(3))
|| test_reduction(RandomMat(7, 8, 9, 12), 1.f, 0, IntArrayMat(0, 1))
|| test_reduction(RandomMat(7, 8, 9, 12), 2.f, 0, IntArrayMat(0, 2))
|| test_reduction(RandomMat(7, 8, 9, 12), 1.f, 0, IntArrayMat(0, 3))
|| test_reduction(RandomMat(7, 8, 9, 12), 2.f, 0, IntArrayMat(1, 2))
|| test_reduction(RandomMat(7, 8, 9, 12), 1.f, 0, IntArrayMat(1, 3))
|| test_reduction(RandomMat(7, 8, 9, 12), 2.f, 0, IntArrayMat(2, 3))
|| test_reduction(RandomMat(7, 8, 9, 12), 1.f, 0, IntArrayMat(0, 1, 2))
|| test_reduction(RandomMat(7, 8, 9, 12), 2.f, 0, IntArrayMat(0, 1, 3))
|| test_reduction(RandomMat(7, 8, 9, 12), 1.f, 0, IntArrayMat(0, 2, 3))
|| test_reduction(RandomMat(7, 8, 9, 12), 2.f, 0, IntArrayMat(1, 2, 3))
|| test_reduction(RandomMat(7, 8, 9, 12), 1.f, 0, IntArrayMat(0, 1, 2, 3))
|| test_reduction(RandomMat(3, 4, 5, 13), 1.f, 0, IntArrayMat(0))
|| test_reduction(RandomMat(3, 4, 5, 13), 2.f, 0, IntArrayMat(1))
|| test_reduction(RandomMat(3, 4, 5, 13), 1.f, 0, IntArrayMat(2))
|| test_reduction(RandomMat(3, 4, 5, 13), 2.f, 0, IntArrayMat(3))
|| test_reduction(RandomMat(3, 4, 5, 13), 1.f, 0, IntArrayMat(0, 1))
|| test_reduction(RandomMat(3, 4, 5, 13), 2.f, 0, IntArrayMat(0, 2))
|| test_reduction(RandomMat(3, 4, 5, 13), 1.f, 0, IntArrayMat(0, 3))
|| test_reduction(RandomMat(3, 4, 5, 13), 2.f, 0, IntArrayMat(1, 2))
|| test_reduction(RandomMat(3, 4, 5, 13), 1.f, 0, IntArrayMat(1, 3))
|| test_reduction(RandomMat(3, 4, 5, 13), 2.f, 0, IntArrayMat(2, 3))
|| test_reduction(RandomMat(3, 4, 5, 13), 1.f, 0, IntArrayMat(0, 1, 2))
|| test_reduction(RandomMat(3, 4, 5, 13), 2.f, 0, IntArrayMat(0, 1, 3))
|| test_reduction(RandomMat(3, 4, 5, 13), 1.f, 0, IntArrayMat(0, 2, 3))
|| test_reduction(RandomMat(3, 4, 5, 13), 2.f, 0, IntArrayMat(1, 2, 3))
|| test_reduction(RandomMat(3, 4, 5, 13), 1.f, 0, IntArrayMat(0, 1, 2, 3))

|| test_reduction(RandomMat(5, 6, 7, 24), 1.f, 1, IntArrayMat(0))
|| test_reduction(RandomMat(5, 6, 7, 24), 2.f, 1, IntArrayMat(1))
|| test_reduction(RandomMat(5, 6, 7, 24), 1.f, 1, IntArrayMat(2))
|| test_reduction(RandomMat(5, 6, 7, 24), 2.f, 1, IntArrayMat(3))
|| test_reduction(RandomMat(5, 6, 7, 24), 1.f, 1, IntArrayMat(0, 1))
|| test_reduction(RandomMat(5, 6, 7, 24), 2.f, 1, IntArrayMat(0, 2))
|| test_reduction(RandomMat(5, 6, 7, 24), 1.f, 1, IntArrayMat(0, 3))
|| test_reduction(RandomMat(5, 6, 7, 24), 2.f, 1, IntArrayMat(1, 2))
|| test_reduction(RandomMat(5, 6, 7, 24), 1.f, 1, IntArrayMat(1, 3))
|| test_reduction(RandomMat(5, 6, 7, 24), 2.f, 1, IntArrayMat(2, 3))
|| test_reduction(RandomMat(5, 6, 7, 24), 1.f, 1, IntArrayMat(0, 1, 2))
|| test_reduction(RandomMat(5, 6, 7, 24), 2.f, 1, IntArrayMat(0, 1, 3))
|| test_reduction(RandomMat(5, 6, 7, 24), 1.f, 1, IntArrayMat(0, 2, 3))
|| test_reduction(RandomMat(5, 6, 7, 24), 2.f, 1, IntArrayMat(1, 2, 3))
|| test_reduction(RandomMat(5, 6, 7, 24), 1.f, 1, IntArrayMat(0, 1, 2, 3))
|| test_reduction(RandomMat(7, 8, 9, 12), 1.f, 1, IntArrayMat(0))
|| test_reduction(RandomMat(7, 8, 9, 12), 2.f, 1, IntArrayMat(1))
|| test_reduction(RandomMat(7, 8, 9, 12), 1.f, 1, IntArrayMat(2))
|| test_reduction(RandomMat(7, 8, 9, 12), 2.f, 1, IntArrayMat(3))
|| test_reduction(RandomMat(7, 8, 9, 12), 1.f, 1, IntArrayMat(0, 1))
|| test_reduction(RandomMat(7, 8, 9, 12), 2.f, 1, IntArrayMat(0, 2))
|| test_reduction(RandomMat(7, 8, 9, 12), 1.f, 1, IntArrayMat(0, 3))
|| test_reduction(RandomMat(7, 8, 9, 12), 2.f, 1, IntArrayMat(1, 2))
|| test_reduction(RandomMat(7, 8, 9, 12), 1.f, 1, IntArrayMat(1, 3))
|| test_reduction(RandomMat(7, 8, 9, 12), 2.f, 1, IntArrayMat(2, 3))
|| test_reduction(RandomMat(7, 8, 9, 12), 1.f, 1, IntArrayMat(0, 1, 2))
|| test_reduction(RandomMat(7, 8, 9, 12), 2.f, 1, IntArrayMat(0, 1, 3))
|| test_reduction(RandomMat(7, 8, 9, 12), 1.f, 1, IntArrayMat(0, 2, 3))
|| test_reduction(RandomMat(7, 8, 9, 12), 2.f, 1, IntArrayMat(1, 2, 3))
|| test_reduction(RandomMat(7, 8, 9, 12), 1.f, 1, IntArrayMat(0, 1, 2, 3))
|| test_reduction(RandomMat(3, 4, 5, 13), 1.f, 1, IntArrayMat(0))
|| test_reduction(RandomMat(3, 4, 5, 13), 2.f, 1, IntArrayMat(1))
|| test_reduction(RandomMat(3, 4, 5, 13), 1.f, 1, IntArrayMat(2))
|| test_reduction(RandomMat(3, 4, 5, 13), 2.f, 1, IntArrayMat(3))
|| test_reduction(RandomMat(3, 4, 5, 13), 1.f, 1, IntArrayMat(0, 1))
|| test_reduction(RandomMat(3, 4, 5, 13), 2.f, 1, IntArrayMat(0, 2))
|| test_reduction(RandomMat(3, 4, 5, 13), 1.f, 1, IntArrayMat(0, 3))
|| test_reduction(RandomMat(3, 4, 5, 13), 2.f, 1, IntArrayMat(1, 2))
|| test_reduction(RandomMat(3, 4, 5, 13), 1.f, 1, IntArrayMat(1, 3))
|| test_reduction(RandomMat(3, 4, 5, 13), 2.f, 1, IntArrayMat(2, 3))
|| test_reduction(RandomMat(3, 4, 5, 13), 1.f, 1, IntArrayMat(0, 1, 2))
|| test_reduction(RandomMat(3, 4, 5, 13), 2.f, 1, IntArrayMat(0, 1, 3))
|| test_reduction(RandomMat(3, 4, 5, 13), 1.f, 1, IntArrayMat(0, 2, 3))
|| test_reduction(RandomMat(3, 4, 5, 13), 2.f, 1, IntArrayMat(1, 2, 3))
|| test_reduction(RandomMat(3, 4, 5, 13), 1.f, 1, IntArrayMat(0, 1, 2, 3));
|| test_reduction_nd(a)
|| test_reduction_nd(b)
|| test_reduction_nd(c);
}

static int test_reduction_1()
{
ncnn::Mat a = RandomMat(5, 7, 24);
ncnn::Mat b = RandomMat(7, 9, 12);
ncnn::Mat c = RandomMat(3, 5, 13);

return 0
|| test_reduction(RandomMat(5, 7, 24), 1.f, 0)
|| test_reduction(RandomMat(5, 7, 24), 2.f, 0)
|| test_reduction(RandomMat(7, 9, 12), 1.f, 0)
|| test_reduction(RandomMat(7, 9, 12), 2.f, 0)
|| test_reduction(RandomMat(3, 5, 13), 1.f, 0)
|| test_reduction(RandomMat(3, 5, 13), 2.f, 0)

|| test_reduction(RandomMat(5, 7, 24), 1.f, 1)
|| test_reduction(RandomMat(5, 7, 24), 2.f, 1)
|| test_reduction(RandomMat(7, 9, 12), 1.f, 1)
|| test_reduction(RandomMat(7, 9, 12), 2.f, 1)
|| test_reduction(RandomMat(3, 5, 13), 1.f, 1)
|| test_reduction(RandomMat(3, 5, 13), 2.f, 1)

|| test_reduction(RandomMat(5, 7, 24), 1.f, 0, IntArrayMat(0))
|| test_reduction(RandomMat(5, 7, 24), 2.f, 0, IntArrayMat(1))
|| test_reduction(RandomMat(5, 7, 24), 1.f, 0, IntArrayMat(0, 1))
|| test_reduction(RandomMat(5, 7, 24), 2.f, 0, IntArrayMat(0, 2))
|| test_reduction(RandomMat(5, 7, 24), 1.f, 0, IntArrayMat(1, 2))
|| test_reduction(RandomMat(5, 7, 24), 2.f, 0, IntArrayMat(0, 1, 2))
|| test_reduction(RandomMat(7, 9, 12), 1.f, 0, IntArrayMat(0))
|| test_reduction(RandomMat(7, 9, 12), 2.f, 0, IntArrayMat(1))
|| test_reduction(RandomMat(7, 9, 12), 1.f, 0, IntArrayMat(0, 1))
|| test_reduction(RandomMat(7, 9, 12), 2.f, 0, IntArrayMat(0, 2))
|| test_reduction(RandomMat(7, 9, 12), 1.f, 0, IntArrayMat(1, 2))
|| test_reduction(RandomMat(7, 9, 12), 2.f, 0, IntArrayMat(0, 1, 2))
|| test_reduction(RandomMat(3, 5, 13), 1.f, 0, IntArrayMat(0))
|| test_reduction(RandomMat(3, 5, 13), 2.f, 0, IntArrayMat(1))
|| test_reduction(RandomMat(3, 5, 13), 1.f, 0, IntArrayMat(0, 1))
|| test_reduction(RandomMat(3, 5, 13), 2.f, 0, IntArrayMat(0, 2))
|| test_reduction(RandomMat(3, 5, 13), 1.f, 0, IntArrayMat(1, 2))
|| test_reduction(RandomMat(3, 5, 13), 2.f, 0, IntArrayMat(0, 1, 2))

|| test_reduction(RandomMat(5, 7, 24), 1.f, 1, IntArrayMat(0))
|| test_reduction(RandomMat(5, 7, 24), 2.f, 1, IntArrayMat(1))
|| test_reduction(RandomMat(5, 7, 24), 1.f, 1, IntArrayMat(0, 1))
|| test_reduction(RandomMat(5, 7, 24), 2.f, 1, IntArrayMat(0, 2))
|| test_reduction(RandomMat(5, 7, 24), 1.f, 1, IntArrayMat(1, 2))
|| test_reduction(RandomMat(5, 7, 24), 2.f, 1, IntArrayMat(0, 1, 2))
|| test_reduction(RandomMat(7, 9, 12), 1.f, 1, IntArrayMat(0))
|| test_reduction(RandomMat(7, 9, 12), 2.f, 1, IntArrayMat(1))
|| test_reduction(RandomMat(7, 9, 12), 1.f, 1, IntArrayMat(0, 1))
|| test_reduction(RandomMat(7, 9, 12), 2.f, 1, IntArrayMat(0, 2))
|| test_reduction(RandomMat(7, 9, 12), 1.f, 1, IntArrayMat(1, 2))
|| test_reduction(RandomMat(7, 9, 12), 2.f, 1, IntArrayMat(0, 1, 2))
|| test_reduction(RandomMat(3, 5, 13), 1.f, 1, IntArrayMat(0))
|| test_reduction(RandomMat(3, 5, 13), 2.f, 1, IntArrayMat(1))
|| test_reduction(RandomMat(3, 5, 13), 1.f, 1, IntArrayMat(0, 1))
|| test_reduction(RandomMat(3, 5, 13), 2.f, 1, IntArrayMat(0, 2))
|| test_reduction(RandomMat(3, 5, 13), 1.f, 1, IntArrayMat(1, 2))
|| test_reduction(RandomMat(3, 5, 13), 2.f, 1, IntArrayMat(0, 1, 2));
|| test_reduction_nd(a)
|| test_reduction_nd(b)
|| test_reduction_nd(c);
}

static int test_reduction_2()
{
ncnn::Mat a = RandomMat(15, 24);
ncnn::Mat b = RandomMat(17, 12);
ncnn::Mat c = RandomMat(19, 15);

return 0
|| test_reduction(RandomMat(15, 24), 1.f, 0)
|| test_reduction(RandomMat(15, 24), 2.f, 0)
|| test_reduction(RandomMat(17, 12), 1.f, 0)
|| test_reduction(RandomMat(17, 12), 2.f, 0)
|| test_reduction(RandomMat(19, 15), 1.f, 0)
|| test_reduction(RandomMat(19, 15), 2.f, 0)

|| test_reduction(RandomMat(15, 24), 1.f, 1)
|| test_reduction(RandomMat(15, 24), 2.f, 1)
|| test_reduction(RandomMat(17, 12), 1.f, 1)
|| test_reduction(RandomMat(17, 12), 2.f, 1)
|| test_reduction(RandomMat(19, 15), 1.f, 1)
|| test_reduction(RandomMat(19, 15), 2.f, 1)

|| test_reduction(RandomMat(15, 24), 1.f, 0, IntArrayMat(0))
|| test_reduction(RandomMat(15, 24), 2.f, 0, IntArrayMat(1))
|| test_reduction(RandomMat(15, 24), 1.f, 0, IntArrayMat(0, 1))
|| test_reduction(RandomMat(17, 12), 2.f, 0, IntArrayMat(0))
|| test_reduction(RandomMat(17, 12), 1.f, 0, IntArrayMat(1))
|| test_reduction(RandomMat(17, 12), 2.f, 0, IntArrayMat(0, 1))
|| test_reduction(RandomMat(19, 15), 1.f, 0, IntArrayMat(0))
|| test_reduction(RandomMat(19, 15), 2.f, 0, IntArrayMat(1))
|| test_reduction(RandomMat(19, 15), 1.f, 0, IntArrayMat(0, 1))

|| test_reduction(RandomMat(15, 24), 2.f, 1, IntArrayMat(0))
|| test_reduction(RandomMat(15, 24), 1.f, 1, IntArrayMat(1))
|| test_reduction(RandomMat(15, 24), 2.f, 1, IntArrayMat(0, 1))
|| test_reduction(RandomMat(17, 12), 1.f, 1, IntArrayMat(0))
|| test_reduction(RandomMat(17, 12), 2.f, 1, IntArrayMat(1))
|| test_reduction(RandomMat(17, 12), 1.f, 1, IntArrayMat(0, 1))
|| test_reduction(RandomMat(19, 15), 2.f, 1, IntArrayMat(0))
|| test_reduction(RandomMat(19, 15), 1.f, 1, IntArrayMat(1))
|| test_reduction(RandomMat(19, 15), 2.f, 1, IntArrayMat(0, 1));
|| test_reduction_nd(a)
|| test_reduction_nd(b)
|| test_reduction_nd(c);
}

static int test_reduction_3()
{
ncnn::Mat a = RandomMat(128);
ncnn::Mat b = RandomMat(124);
ncnn::Mat c = RandomMat(127);

return 0
|| test_reduction(RandomMat(128), 1.f, 0)
|| test_reduction(RandomMat(128), 2.f, 0)
|| test_reduction(RandomMat(124), 1.f, 0)
|| test_reduction(RandomMat(124), 2.f, 0)
|| test_reduction(RandomMat(127), 1.f, 0)
|| test_reduction(RandomMat(127), 2.f, 0)

|| test_reduction(RandomMat(128), 1.f, 1)
|| test_reduction(RandomMat(128), 2.f, 1)
|| test_reduction(RandomMat(124), 1.f, 1)
|| test_reduction(RandomMat(124), 2.f, 1)
|| test_reduction(RandomMat(127), 1.f, 1)
|| test_reduction(RandomMat(127), 2.f, 1)

|| test_reduction(RandomMat(128), 1.f, 0, IntArrayMat(0))
|| test_reduction(RandomMat(128), 2.f, 0, IntArrayMat(0))
|| test_reduction(RandomMat(124), 1.f, 0, IntArrayMat(0))
|| test_reduction(RandomMat(124), 2.f, 0, IntArrayMat(0))
|| test_reduction(RandomMat(127), 1.f, 0, IntArrayMat(0))
|| test_reduction(RandomMat(127), 2.f, 0, IntArrayMat(0))

|| test_reduction(RandomMat(128), 1.f, 1, IntArrayMat(0))
|| test_reduction(RandomMat(128), 2.f, 1, IntArrayMat(0))
|| test_reduction(RandomMat(124), 1.f, 1, IntArrayMat(0))
|| test_reduction(RandomMat(124), 2.f, 1, IntArrayMat(0))
|| test_reduction(RandomMat(127), 1.f, 1, IntArrayMat(0))
|| test_reduction(RandomMat(127), 1.f, 1, IntArrayMat(0));
|| test_reduction_nd(a)
|| test_reduction_nd(b)
|| test_reduction_nd(c);
}

int main()


+ 81
- 69
tests/test_slice.cpp View File

@@ -14,58 +14,61 @@

#include "testutil.h"

static ncnn::Mat IntArrayMat(int a0)
static std::vector<int> IntArray(int a0)
{
ncnn::Mat m(1);
int* p = m;
p[0] = a0;
std::vector<int> m(1);
m[0] = a0;
return m;
}

static ncnn::Mat IntArrayMat(int a0, int a1)
static std::vector<int> IntArray(int a0, int a1)
{
ncnn::Mat m(2);
int* p = m;
p[0] = a0;
p[1] = a1;
std::vector<int> m(2);
m[0] = a0;
m[1] = a1;
return m;
}

static ncnn::Mat IntArrayMat(int a0, int a1, int a2)
static std::vector<int> IntArray(int a0, int a1, int a2)
{
ncnn::Mat m(3);
int* p = m;
p[0] = a0;
p[1] = a1;
p[2] = a2;
std::vector<int> m(3);
m[0] = a0;
m[1] = a1;
m[2] = a2;
return m;
}

static ncnn::Mat IntArrayMat(int a0, int a1, int a2, int a3)
static std::vector<int> IntArray(int a0, int a1, int a2, int a3)
{
ncnn::Mat m(4);
int* p = m;
p[0] = a0;
p[1] = a1;
p[2] = a2;
p[3] = a3;
std::vector<int> m(4);
m[0] = a0;
m[1] = a1;
m[2] = a2;
m[3] = a3;
return m;
}

static void print_int_array(const ncnn::Mat& a)
static void print_int_array(const std::vector<int>& a)
{
const int* pa = a;

fprintf(stderr, "[");
for (int i = 0; i < a.w; i++)
for (size_t i = 0; i < a.size(); i++)
{
fprintf(stderr, " %d", pa[i]);
fprintf(stderr, " %d", a[i]);
}
fprintf(stderr, " ]");
}

static int test_slice(const ncnn::Mat& a, const ncnn::Mat& slices, int axis)
static int test_slice(const ncnn::Mat& a, const std::vector<int>& slices_array, int axis)
{
ncnn::Mat slices(slices_array.size());
{
int* p = slices;
for (size_t i = 0; i < slices_array.size(); i++)
{
p[i] = slices_array[i];
}
}

ncnn::ParamDict pd;
pd.set(0, slices);
pd.set(1, axis);
@@ -80,15 +83,24 @@ static int test_slice(const ncnn::Mat& a, const ncnn::Mat& slices, int axis)
{
fprintf(stderr, "test_slice failed a.dims=%d a=(%d %d %d %d)", a.dims, a.w, a.h, a.d, a.c);
fprintf(stderr, " slices=");
print_int_array(slices);
print_int_array(slices_array);
fprintf(stderr, " axis=%d\n", axis);
}

return ret;
}

static int test_slice_indices(const ncnn::Mat& a, const ncnn::Mat& indices, int axis)
static int test_slice_indices(const ncnn::Mat& a, const std::vector<int>& indices_array, int axis)
{
ncnn::Mat indices(indices_array.size());
{
int* p = indices;
for (size_t i = 0; i < indices_array.size(); i++)
{
p[i] = indices_array[i];
}
}

ncnn::ParamDict pd;
pd.set(1, axis);
pd.set(2, indices);
@@ -103,7 +115,7 @@ static int test_slice_indices(const ncnn::Mat& a, const ncnn::Mat& indices, int
{
fprintf(stderr, "test_slice_indices failed a.dims=%d a=(%d %d %d %d)", a.dims, a.w, a.h, a.d, a.c);
fprintf(stderr, " indices=");
print_int_array(indices);
print_int_array(indices_array);
fprintf(stderr, " axis=%d\n", axis);
}

@@ -121,20 +133,20 @@ static int test_slice_0()
for (int i = 0; i < sizeof(a) / sizeof(a[0]); i++)
{
int ret = 0
|| test_slice(a[i], IntArrayMat(-233, -233, -233), 0)
|| test_slice(a[i], IntArrayMat(-233, -233, -233), 1)
|| test_slice(a[i], IntArrayMat(-233, -233, -233), -2)
|| test_slice(a[i], IntArrayMat(-233, -233, -233), 3)
|| test_slice(a[i], IntArrayMat(3, 12, 16, -233), 0)
|| test_slice(a[i], IntArrayMat(12, 16, -233), 0)
|| test_slice(a[i], IntArrayMat(32, 8, -233), 0)
|| test_slice(a[i], IntArrayMat(2, 12, 16, -233), 1)
|| test_slice(a[i], IntArrayMat(16, 4, 5, -233), -2)
|| test_slice(a[i], IntArrayMat(8, 2, 16, -233), 3)
|| test_slice_indices(a[i], IntArrayMat(2, -24, -8), 0)
|| test_slice_indices(a[i], IntArrayMat(4, 20, 4), 1)
|| test_slice_indices(a[i], IntArrayMat(16, -16), -2)
|| test_slice_indices(a[i], IntArrayMat(1, -12), 3);
|| test_slice(a[i], IntArray(-233, -233, -233), 0)
|| test_slice(a[i], IntArray(-233, -233, -233), 1)
|| test_slice(a[i], IntArray(-233, -233, -233), -2)
|| test_slice(a[i], IntArray(-233, -233, -233), 3)
|| test_slice(a[i], IntArray(3, 12, 16, -233), 0)
|| test_slice(a[i], IntArray(12, 16, -233), 0)
|| test_slice(a[i], IntArray(32, 8, -233), 0)
|| test_slice(a[i], IntArray(2, 12, 16, -233), 1)
|| test_slice(a[i], IntArray(16, 4, 5, -233), -2)
|| test_slice(a[i], IntArray(8, 2, 16, -233), 3)
|| test_slice_indices(a[i], IntArray(2, -24, -8), 0)
|| test_slice_indices(a[i], IntArray(4, 20, 4), 1)
|| test_slice_indices(a[i], IntArray(16, -16), -2)
|| test_slice_indices(a[i], IntArray(1, -12), 3);

if (ret != 0)
return ret;
@@ -154,17 +166,17 @@ static int test_slice_1()
for (int i = 0; i < sizeof(a) / sizeof(a[0]); i++)
{
int ret = 0
|| test_slice(a[i], IntArrayMat(-233, -233, -233), 0)
|| test_slice(a[i], IntArrayMat(-233, -233, -233), 1)
|| test_slice(a[i], IntArrayMat(-233, -233, -233), -1)
|| test_slice(a[i], IntArrayMat(3, 12, 16, -233), 0)
|| test_slice(a[i], IntArrayMat(12, 16, -233), 0)
|| test_slice(a[i], IntArrayMat(32, 8, -233), 0)
|| test_slice(a[i], IntArrayMat(2, 12, 16, -233), 1)
|| test_slice(a[i], IntArrayMat(16, 4, 5, -233), -1)
|| test_slice_indices(a[i], IntArrayMat(2, -24, -8), 0)
|| test_slice_indices(a[i], IntArrayMat(4, 20, 4), 1)
|| test_slice_indices(a[i], IntArrayMat(1, -12), 2);
|| test_slice(a[i], IntArray(-233, -233, -233), 0)
|| test_slice(a[i], IntArray(-233, -233, -233), 1)
|| test_slice(a[i], IntArray(-233, -233, -233), -1)
|| test_slice(a[i], IntArray(3, 12, 16, -233), 0)
|| test_slice(a[i], IntArray(12, 16, -233), 0)
|| test_slice(a[i], IntArray(32, 8, -233), 0)
|| test_slice(a[i], IntArray(2, 12, 16, -233), 1)
|| test_slice(a[i], IntArray(16, 4, 5, -233), -1)
|| test_slice_indices(a[i], IntArray(2, -24, -8), 0)
|| test_slice_indices(a[i], IntArray(4, 20, 4), 1)
|| test_slice_indices(a[i], IntArray(1, -12), 2);

if (ret != 0)
return ret;
@@ -184,14 +196,14 @@ static int test_slice_2()
for (int i = 0; i < sizeof(a) / sizeof(a[0]); i++)
{
int ret = 0
|| test_slice(a[i], IntArrayMat(-233, -233, -233), 0)
|| test_slice(a[i], IntArrayMat(-233, -233, -233), -1)
|| test_slice(a[i], IntArrayMat(3, 12, 16, -233), 0)
|| test_slice(a[i], IntArrayMat(12, 16, -233), 0)
|| test_slice(a[i], IntArrayMat(32, 8, -233), -2)
|| test_slice(a[i], IntArrayMat(2, 12, 16, -233), -1)
|| test_slice_indices(a[i], IntArrayMat(2, -24, -8), 0)
|| test_slice_indices(a[i], IntArrayMat(1, -12), 1);
|| test_slice(a[i], IntArray(-233, -233, -233), 0)
|| test_slice(a[i], IntArray(-233, -233, -233), -1)
|| test_slice(a[i], IntArray(3, 12, 16, -233), 0)
|| test_slice(a[i], IntArray(12, 16, -233), 0)
|| test_slice(a[i], IntArray(32, 8, -233), -2)
|| test_slice(a[i], IntArray(2, 12, 16, -233), -1)
|| test_slice_indices(a[i], IntArray(2, -24, -8), 0)
|| test_slice_indices(a[i], IntArray(1, -12), 1);

if (ret != 0)
return ret;
@@ -211,11 +223,11 @@ static int test_slice_3()
for (int i = 0; i < sizeof(a) / sizeof(a[0]); i++)
{
int ret = 0
|| test_slice(a[i], IntArrayMat(-233, -233, -233), 0)
|| test_slice(a[i], IntArrayMat(3, 12, 16, -233), 0)
|| test_slice(a[i], IntArrayMat(12, 16, -233), 0)
|| test_slice(a[i], IntArrayMat(32, 8, -233), -1)
|| test_slice_indices(a[i], IntArrayMat(2, -24, -8), 0);
|| test_slice(a[i], IntArray(-233, -233, -233), 0)
|| test_slice(a[i], IntArray(3, 12, 16, -233), 0)
|| test_slice(a[i], IntArray(12, 16, -233), 0)
|| test_slice(a[i], IntArray(32, 8, -233), -1)
|| test_slice_indices(a[i], IntArray(2, -24, -8), 0);

if (ret != 0)
return ret;


+ 57
- 45
tests/test_slice_oom.cpp View File

@@ -14,58 +14,61 @@

#include "testutil.h"

static ncnn::Mat IntArrayMat(int a0)
static std::vector<int> IntArray(int a0)
{
ncnn::Mat m(1);
int* p = m;
p[0] = a0;
std::vector<int> m(1);
m[0] = a0;
return m;
}

static ncnn::Mat IntArrayMat(int a0, int a1)
static std::vector<int> IntArray(int a0, int a1)
{
ncnn::Mat m(2);
int* p = m;
p[0] = a0;
p[1] = a1;
std::vector<int> m(2);
m[0] = a0;
m[1] = a1;
return m;
}

static ncnn::Mat IntArrayMat(int a0, int a1, int a2)
static std::vector<int> IntArray(int a0, int a1, int a2)
{
ncnn::Mat m(3);
int* p = m;
p[0] = a0;
p[1] = a1;
p[2] = a2;
std::vector<int> m(3);
m[0] = a0;
m[1] = a1;
m[2] = a2;
return m;
}

static ncnn::Mat IntArrayMat(int a0, int a1, int a2, int a3)
static std::vector<int> IntArray(int a0, int a1, int a2, int a3)
{
ncnn::Mat m(4);
int* p = m;
p[0] = a0;
p[1] = a1;
p[2] = a2;
p[3] = a3;
std::vector<int> m(4);
m[0] = a0;
m[1] = a1;
m[2] = a2;
m[3] = a3;
return m;
}

static void print_int_array(const ncnn::Mat& a)
static void print_int_array(const std::vector<int>& a)
{
const int* pa = a;

fprintf(stderr, "[");
for (int i = 0; i < a.w; i++)
for (size_t i = 0; i < a.size(); i++)
{
fprintf(stderr, " %d", pa[i]);
fprintf(stderr, " %d", a[i]);
}
fprintf(stderr, " ]");
}

static int test_slice_oom(const ncnn::Mat& a, const ncnn::Mat& slices, int axis)
static int test_slice_oom(const ncnn::Mat& a, const std::vector<int>& slices_array, int axis)
{
ncnn::Mat slices(slices_array.size());
{
int* p = slices;
for (size_t i = 0; i < slices_array.size(); i++)
{
p[i] = slices_array[i];
}
}

ncnn::ParamDict pd;
pd.set(0, slices);
pd.set(1, axis);
@@ -80,15 +83,24 @@ static int test_slice_oom(const ncnn::Mat& a, const ncnn::Mat& slices, int axis)
{
fprintf(stderr, "test_slice_oom failed a.dims=%d a=(%d %d %d %d)", a.dims, a.w, a.h, a.d, a.c);
fprintf(stderr, " slices=");
print_int_array(slices);
print_int_array(slices_array);
fprintf(stderr, " axis=%d\n", axis);
}

return ret;
}

static int test_slice_oom_indices(const ncnn::Mat& a, const ncnn::Mat& indices, int axis)
static int test_slice_oom_indices(const ncnn::Mat& a, const std::vector<int>& indices_array, int axis)
{
ncnn::Mat indices(indices_array.size());
{
int* p = indices;
for (size_t i = 0; i < indices_array.size(); i++)
{
p[i] = indices_array[i];
}
}

ncnn::ParamDict pd;
pd.set(1, axis);
pd.set(2, indices);
@@ -103,7 +115,7 @@ static int test_slice_oom_indices(const ncnn::Mat& a, const ncnn::Mat& indices,
{
fprintf(stderr, "test_slice_oom_indices failed a.dims=%d a=(%d %d %d %d)", a.dims, a.w, a.h, a.d, a.c);
fprintf(stderr, " indices=");
print_int_array(indices);
print_int_array(indices_array);
fprintf(stderr, " axis=%d\n", axis);
}

@@ -115,11 +127,11 @@ static int test_slice_0()
ncnn::Mat a = RandomMat(48, 48, 48, 48);

return 0
|| test_slice_oom(a, IntArrayMat(3, 12, 16, -233), 0)
|| test_slice_oom(a, IntArrayMat(3, 12, 16, -233), 1)
|| test_slice_oom(a, IntArrayMat(3, 12, 16, -233), 2)
|| test_slice_oom(a, IntArrayMat(3, 12, 16, -233), 3)
|| test_slice_oom_indices(a, IntArrayMat(2, -24, -8), 0);
|| test_slice_oom(a, IntArray(3, 12, 16, -233), 0)
|| test_slice_oom(a, IntArray(3, 12, 16, -233), 1)
|| test_slice_oom(a, IntArray(3, 12, 16, -233), 2)
|| test_slice_oom(a, IntArray(3, 12, 16, -233), 3)
|| test_slice_oom_indices(a, IntArray(2, -24, -8), 0);
}

static int test_slice_1()
@@ -127,10 +139,10 @@ static int test_slice_1()
ncnn::Mat a = RandomMat(48, 48, 48);

return 0
|| test_slice_oom(a, IntArrayMat(3, 12, 16, -233), 0)
|| test_slice_oom(a, IntArrayMat(3, 12, 16, -233), 1)
|| test_slice_oom(a, IntArrayMat(3, 12, 16, -233), 2)
|| test_slice_oom_indices(a, IntArrayMat(2, -24, -8), 0);
|| test_slice_oom(a, IntArray(3, 12, 16, -233), 0)
|| test_slice_oom(a, IntArray(3, 12, 16, -233), 1)
|| test_slice_oom(a, IntArray(3, 12, 16, -233), 2)
|| test_slice_oom_indices(a, IntArray(2, -24, -8), 0);
}

static int test_slice_2()
@@ -138,9 +150,9 @@ static int test_slice_2()
ncnn::Mat a = RandomMat(48, 48);

return 0
|| test_slice_oom(a, IntArrayMat(3, 12, 16, -233), 0)
|| test_slice_oom(a, IntArrayMat(3, 12, 16, -233), 1)
|| test_slice_oom_indices(a, IntArrayMat(2, -24, -8), 0);
|| test_slice_oom(a, IntArray(3, 12, 16, -233), 0)
|| test_slice_oom(a, IntArray(3, 12, 16, -233), 1)
|| test_slice_oom_indices(a, IntArray(2, -24, -8), 0);
}

static int test_slice_3()
@@ -148,8 +160,8 @@ static int test_slice_3()
ncnn::Mat a = RandomMat(48);

return 0
|| test_slice_oom(a, IntArrayMat(3, 12, 16, -233), 0)
|| test_slice_oom_indices(a, IntArrayMat(2, -24, -8), 0);
|| test_slice_oom(a, IntArray(3, 12, 16, -233), 0)
|| test_slice_oom_indices(a, IntArray(2, -24, -8), 0);
}

int main()


+ 47
- 44
tests/test_squeeze.cpp View File

@@ -33,58 +33,61 @@ static int test_squeeze(const ncnn::Mat& a, int squeeze_w, int squeeze_h, int sq
return ret;
}

static ncnn::Mat IntArrayMat(int a0)
static std::vector<int> IntArray(int a0)
{
ncnn::Mat m(1);
int* p = m;
p[0] = a0;
std::vector<int> m(1);
m[0] = a0;
return m;
}

static ncnn::Mat IntArrayMat(int a0, int a1)
static std::vector<int> IntArray(int a0, int a1)
{
ncnn::Mat m(2);
int* p = m;
p[0] = a0;
p[1] = a1;
std::vector<int> m(2);
m[0] = a0;
m[1] = a1;
return m;
}

static ncnn::Mat IntArrayMat(int a0, int a1, int a2)
static std::vector<int> IntArray(int a0, int a1, int a2)
{
ncnn::Mat m(3);
int* p = m;
p[0] = a0;
p[1] = a1;
p[2] = a2;
std::vector<int> m(3);
m[0] = a0;
m[1] = a1;
m[2] = a2;
return m;
}

static ncnn::Mat IntArrayMat(int a0, int a1, int a2, int a3)
static std::vector<int> IntArray(int a0, int a1, int a2, int a3)
{
ncnn::Mat m(4);
int* p = m;
p[0] = a0;
p[1] = a1;
p[2] = a2;
p[3] = a3;
std::vector<int> m(4);
m[0] = a0;
m[1] = a1;
m[2] = a2;
m[3] = a3;
return m;
}

static void print_int_array(const ncnn::Mat& a)
static void print_int_array(const std::vector<int>& a)
{
const int* pa = a;

fprintf(stderr, "[");
for (int i = 0; i < a.w; i++)
for (size_t i = 0; i < a.size(); i++)
{
fprintf(stderr, " %d", pa[i]);
fprintf(stderr, " %d", a[i]);
}
fprintf(stderr, " ]");
}

static int test_squeeze_axes(const ncnn::Mat& a, const ncnn::Mat& axes)
static int test_squeeze_axes(const ncnn::Mat& a, const std::vector<int>& axes_array)
{
ncnn::Mat axes(axes_array.size());
{
int* p = axes;
for (size_t i = 0; i < axes_array.size(); i++)
{
p[i] = axes_array[i];
}
}

ncnn::ParamDict pd;
pd.set(3, axes);

@@ -95,7 +98,7 @@ static int test_squeeze_axes(const ncnn::Mat& a, const ncnn::Mat& axes)
{
fprintf(stderr, "test_squeeze_axes failed a.dims=%d a=(%d %d %d %d)\n", a.dims, a.w, a.h, a.d, a.c);
fprintf(stderr, " axes=");
print_int_array(axes);
print_int_array(axes_array);
fprintf(stderr, "\n");
}

@@ -122,21 +125,21 @@ static int test_squeeze_all_params(const ncnn::Mat& a)
|| test_squeeze(a, 1, 1, 1, 0)
|| test_squeeze(a, 1, 1, 1, 1)

|| test_squeeze_axes(a, IntArrayMat(0))
|| test_squeeze_axes(a, IntArrayMat(1))
|| test_squeeze_axes(a, IntArrayMat(2))
|| test_squeeze_axes(a, IntArrayMat(3))
|| test_squeeze_axes(a, IntArrayMat(0, 1))
|| test_squeeze_axes(a, IntArrayMat(0, 2))
|| test_squeeze_axes(a, IntArrayMat(0, 3))
|| test_squeeze_axes(a, IntArrayMat(1, 2))
|| test_squeeze_axes(a, IntArrayMat(1, 3))
|| test_squeeze_axes(a, IntArrayMat(2, 3))
|| test_squeeze_axes(a, IntArrayMat(0, 1, 2))
|| test_squeeze_axes(a, IntArrayMat(0, 1, 3))
|| test_squeeze_axes(a, IntArrayMat(0, 2, 3))
|| test_squeeze_axes(a, IntArrayMat(1, 2, 3))
|| test_squeeze_axes(a, IntArrayMat(0, 1, 2, 3));
|| test_squeeze_axes(a, IntArray(0))
|| test_squeeze_axes(a, IntArray(1))
|| test_squeeze_axes(a, IntArray(2))
|| test_squeeze_axes(a, IntArray(3))
|| test_squeeze_axes(a, IntArray(0, 1))
|| test_squeeze_axes(a, IntArray(0, 2))
|| test_squeeze_axes(a, IntArray(0, 3))
|| test_squeeze_axes(a, IntArray(1, 2))
|| test_squeeze_axes(a, IntArray(1, 3))
|| test_squeeze_axes(a, IntArray(2, 3))
|| test_squeeze_axes(a, IntArray(0, 1, 2))
|| test_squeeze_axes(a, IntArray(0, 1, 3))
|| test_squeeze_axes(a, IntArray(0, 2, 3))
|| test_squeeze_axes(a, IntArray(1, 2, 3))
|| test_squeeze_axes(a, IntArray(0, 1, 2, 3));
}

static int test_squeeze_0()


+ 82
- 79
tests/test_tile.cpp View File

@@ -31,58 +31,61 @@ static int test_tile(const ncnn::Mat& a, int axis, int tiles)
return ret;
}

static ncnn::Mat IntArrayMat(int a0)
static std::vector<int> IntArray(int a0)
{
ncnn::Mat m(1);
int* p = m;
p[0] = a0;
std::vector<int> m(1);
m[0] = a0;
return m;
}

static ncnn::Mat IntArrayMat(int a0, int a1)
static std::vector<int> IntArray(int a0, int a1)
{
ncnn::Mat m(2);
int* p = m;
p[0] = a0;
p[1] = a1;
std::vector<int> m(2);
m[0] = a0;
m[1] = a1;
return m;
}

static ncnn::Mat IntArrayMat(int a0, int a1, int a2)
static std::vector<int> IntArray(int a0, int a1, int a2)
{
ncnn::Mat m(3);
int* p = m;
p[0] = a0;
p[1] = a1;
p[2] = a2;
std::vector<int> m(3);
m[0] = a0;
m[1] = a1;
m[2] = a2;
return m;
}

static ncnn::Mat IntArrayMat(int a0, int a1, int a2, int a3)
static std::vector<int> IntArray(int a0, int a1, int a2, int a3)
{
ncnn::Mat m(4);
int* p = m;
p[0] = a0;
p[1] = a1;
p[2] = a2;
p[3] = a3;
std::vector<int> m(4);
m[0] = a0;
m[1] = a1;
m[2] = a2;
m[3] = a3;
return m;
}

static void print_int_array(const ncnn::Mat& a)
static void print_int_array(const std::vector<int>& a)
{
const int* pa = a;

fprintf(stderr, "[");
for (int i = 0; i < a.w; i++)
for (size_t i = 0; i < a.size(); i++)
{
fprintf(stderr, " %d", pa[i]);
fprintf(stderr, " %d", a[i]);
}
fprintf(stderr, " ]");
}

static int test_tile(const ncnn::Mat& a, const ncnn::Mat& repeats)
static int test_tile(const ncnn::Mat& a, const std::vector<int>& repeats_array)
{
ncnn::Mat repeats(repeats_array.size());
{
int* p = repeats;
for (size_t i = 0; i < repeats_array.size(); i++)
{
p[i] = repeats_array[i];
}
}

ncnn::ParamDict pd;
pd.set(2, repeats);

@@ -92,7 +95,7 @@ static int test_tile(const ncnn::Mat& a, const ncnn::Mat& repeats)
if (ret != 0)
{
fprintf(stderr, "test_tile failed a.dims=%d a=(%d %d %d %d) repeats=", a.dims, a.w, a.h, a.d, a.c);
print_int_array(repeats);
print_int_array(repeats_array);
fprintf(stderr, "\n");
}

@@ -119,18 +122,18 @@ static int test_tile_0()
|| test_tile(c, 2, 5)
|| test_tile(c, 3, 2)

|| test_tile(a, IntArrayMat(3))
|| test_tile(a, IntArrayMat(2, 4))
|| test_tile(a, IntArrayMat(2, 2, 5))
|| test_tile(a, IntArrayMat(3, 1, 3, 2))
|| test_tile(b, IntArrayMat(3, 1))
|| test_tile(b, IntArrayMat(4, 1, 4))
|| test_tile(b, IntArrayMat(2, 2, 2, 1))
|| test_tile(b, IntArrayMat(3, 2, 1))
|| test_tile(c, IntArrayMat(3))
|| test_tile(c, IntArrayMat(1, 1, 4))
|| test_tile(c, IntArrayMat(2, 2, 5))
|| test_tile(c, IntArrayMat(3, 2, 1, 9));
|| test_tile(a, IntArray(3))
|| test_tile(a, IntArray(2, 4))
|| test_tile(a, IntArray(2, 2, 5))
|| test_tile(a, IntArray(3, 1, 3, 2))
|| test_tile(b, IntArray(3, 1))
|| test_tile(b, IntArray(4, 1, 4))
|| test_tile(b, IntArray(2, 2, 2, 1))
|| test_tile(b, IntArray(3, 2, 1))
|| test_tile(c, IntArray(3))
|| test_tile(c, IntArray(1, 1, 4))
|| test_tile(c, IntArray(2, 2, 5))
|| test_tile(c, IntArray(3, 2, 1, 9));
}

static int test_tile_1()
@@ -150,18 +153,18 @@ static int test_tile_1()
|| test_tile(c, 1, 2)
|| test_tile(c, 2, 2)

|| test_tile(a, IntArrayMat(5))
|| test_tile(a, IntArrayMat(1, 4))
|| test_tile(a, IntArrayMat(2, 1, 4))
|| test_tile(a, IntArrayMat(1, 2, 1, 4))
|| test_tile(b, IntArrayMat(3))
|| test_tile(b, IntArrayMat(1, 3, 3))
|| test_tile(b, IntArrayMat(2, 3))
|| test_tile(b, IntArrayMat(2, 3, 3, 3))
|| test_tile(c, IntArrayMat(1))
|| test_tile(c, IntArrayMat(2, 1))
|| test_tile(c, IntArrayMat(2, 2, 2))
|| test_tile(c, IntArrayMat(2, 1, 2, 1));
|| test_tile(a, IntArray(5))
|| test_tile(a, IntArray(1, 4))
|| test_tile(a, IntArray(2, 1, 4))
|| test_tile(a, IntArray(1, 2, 1, 4))
|| test_tile(b, IntArray(3))
|| test_tile(b, IntArray(1, 3, 3))
|| test_tile(b, IntArray(2, 3))
|| test_tile(b, IntArray(2, 3, 3, 3))
|| test_tile(c, IntArray(1))
|| test_tile(c, IntArray(2, 1))
|| test_tile(c, IntArray(2, 2, 2))
|| test_tile(c, IntArray(2, 1, 2, 1));
}

static int test_tile_2()
@@ -178,18 +181,18 @@ static int test_tile_2()
|| test_tile(c, 0, 5)
|| test_tile(c, 1, 6)

|| test_tile(a, IntArrayMat(2))
|| test_tile(a, IntArrayMat(1, 1))
|| test_tile(a, IntArrayMat(4, 1, 1))
|| test_tile(a, IntArrayMat(2, 4, 4, 1))
|| test_tile(b, IntArrayMat(3))
|| test_tile(b, IntArrayMat(2, 4))
|| test_tile(b, IntArrayMat(2, 4, 3, 1))
|| test_tile(b, IntArrayMat(1, 2, 1, 4))
|| test_tile(c, IntArrayMat(5))
|| test_tile(c, IntArrayMat(6, 1))
|| test_tile(c, IntArrayMat(6, 1, 6))
|| test_tile(c, IntArrayMat(3, 2, 1, 1));
|| test_tile(a, IntArray(2))
|| test_tile(a, IntArray(1, 1))
|| test_tile(a, IntArray(4, 1, 1))
|| test_tile(a, IntArray(2, 4, 4, 1))
|| test_tile(b, IntArray(3))
|| test_tile(b, IntArray(2, 4))
|| test_tile(b, IntArray(2, 4, 3, 1))
|| test_tile(b, IntArray(1, 2, 1, 4))
|| test_tile(c, IntArray(5))
|| test_tile(c, IntArray(6, 1))
|| test_tile(c, IntArray(6, 1, 6))
|| test_tile(c, IntArray(3, 2, 1, 1));
}

static int test_tile_3()
@@ -204,20 +207,20 @@ static int test_tile_3()
|| test_tile(b, 0, 3)
|| test_tile(c, 0, 4)

|| test_tile(a, IntArrayMat(10))
|| test_tile(a, IntArrayMat(10, 1))
|| test_tile(a, IntArrayMat(5, 2, 1))
|| test_tile(a, IntArrayMat(2, 2, 2, 3))
|| test_tile(b, IntArrayMat(2))
|| test_tile(b, IntArrayMat(2, 2))
|| test_tile(b, IntArrayMat(2, 2, 1))
|| test_tile(b, IntArrayMat(4, 1, 2, 2))
|| test_tile(c, IntArrayMat(3))
|| test_tile(c, IntArrayMat(4, 3))
|| test_tile(c, IntArrayMat(1))
|| test_tile(c, IntArrayMat(1, 1))
|| test_tile(c, IntArrayMat(1, 1, 1))
|| test_tile(c, IntArrayMat(1, 3, 2, 2));
|| test_tile(a, IntArray(10))
|| test_tile(a, IntArray(10, 1))
|| test_tile(a, IntArray(5, 2, 1))
|| test_tile(a, IntArray(2, 2, 2, 3))
|| test_tile(b, IntArray(2))
|| test_tile(b, IntArray(2, 2))
|| test_tile(b, IntArray(2, 2, 1))
|| test_tile(b, IntArray(4, 1, 2, 2))
|| test_tile(c, IntArray(3))
|| test_tile(c, IntArray(4, 3))
|| test_tile(c, IntArray(1))
|| test_tile(c, IntArray(1, 1))
|| test_tile(c, IntArray(1, 1, 1))
|| test_tile(c, IntArray(1, 3, 2, 2));
}

int main()


Loading…
Cancel
Save