Browse Source

fix(megdnn): fix megdnn benchmark testcase

GitOrigin-RevId: 726971474a
tags/v1.3.0
Megvii Engine Team 4 years ago
parent
commit
97beae2fd8
2 changed files with 12 additions and 4 deletions
  1. +8
    -4
      dnn/test/cuda/batch_conv_bias.cpp
  2. +4
    -0
      dnn/test/cuda/benchmark.cpp

+ 8
- 4
dnn/test/cuda/batch_conv_bias.cpp View File

@@ -241,10 +241,14 @@ void benchmark_target_algo(Handle* handle, const std::vector<BenchArgs>& args,
"v" V(CUDNN_MAJOR) "." V(CUDNN_MINOR) "." V(CUDNN_PATCHLEVEL)
benchmarker_cudnn.set_before_exec_callback(
conv_bias::ConvBiasAlgoChecker<ConvBiasForward>(
"CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_"
"GEMM" CUDNN_VERSION_STRING));
benchmarker_matmul.set_before_exec_callback(
AlgoChecker<BatchedMatrixMul>("BRUTE_FORCE-CUBLAS"));
ConvBiasForward::algo_name<ConvBias::DefaultParam>(
"CUDNN:ConvBiasActivation:"
"CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_"
"GEMM" CUDNN_VERSION_STRING,
{})
.c_str()));
benchmarker_matmul.set_before_exec_callback(AlgoChecker<BatchedMatrixMul>(
ExecutionPolicyAlgoName{"BRUTE_FORCE", {{"CUBLAS", {}}}}));

benchmarker.set_dtype(0, src_dtype)
.set_dtype(1, filter_dtype)


+ 4
- 0
dnn/test/cuda/benchmark.cpp View File

@@ -41,10 +41,12 @@ TEST_F(CUDA, BENCHMARK_CONVOLUTION_8X8X32)
auto time_in_ms_float = benchmarker.set_param(param_float)
.set_dtype(0, dtype::Float32())
.set_dtype(1, dtype::Float32())
.set_dtype(2, dtype::Float32())
.execs({src_float, filter_float, {}});
auto time_in_ms_int = benchmarker.set_param(param_int)
.set_dtype(0, dtype::Int8())
.set_dtype(1, dtype::Int8())
.set_dtype(2, dtype::Int32())
.execs({src_int, filter_int, {}});
std::cout << "1x1: N=" << N << " OC=" << OC << " IC=" << IC
<< " H=" << H << " W=" << W
@@ -67,10 +69,12 @@ TEST_F(CUDA, BENCHMARK_CONVOLUTION_8X8X32)
auto time_in_ms_float = benchmarker.set_param(param_float)
.set_dtype(0, dtype::Float32())
.set_dtype(1, dtype::Float32())
.set_dtype(2, dtype::Float32())
.execs({src_float, filter_float, {}});
auto time_in_ms_int = benchmarker.set_param(param_int)
.set_dtype(0, dtype::Int8())
.set_dtype(1, dtype::Int8())
.set_dtype(2, dtype::Int32())
.execs({src_int, filter_int, {}});
std::cout << "chanwise: N=" << N << " C=" << C
<< " H=" << H << " W=" << W << " F=" << F


Loading…
Cancel
Save