From 3abe0b2462c25da260d3d0398fa2c8eee4449096 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Thu, 26 Aug 2021 18:26:50 +0800 Subject: [PATCH] fix(mgb): fix rocm pooling GitOrigin-RevId: 44876d398ed56214d71e117ada0b21085a2c05de --- dnn/src/rocm/pooling/algo.cpp | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/dnn/src/rocm/pooling/algo.cpp b/dnn/src/rocm/pooling/algo.cpp index 471bd310..2a354f10 100644 --- a/dnn/src/rocm/pooling/algo.cpp +++ b/dnn/src/rocm/pooling/algo.cpp @@ -60,10 +60,10 @@ void PoolingForwardImpl::AlgoMIOpen::init_mode( case param::Pooling::Mode::MAX: mode = miopenPoolingMax; break; - case param::Pooling::Mode::AVERAGE: + case param::Pooling::Mode::AVERAGE_COUNT_EXCLUDE_PADDING: mode = miopenPoolingAverage; break; - case param::Pooling::Mode::AVERAGE_COUNT_EXCLUDE_PADDING: + case param::Pooling::Mode::AVERAGE: mode = miopenPoolingAverageInclusive; break; default: @@ -96,7 +96,7 @@ void PoolingForwardImpl::AlgoMIOpen::exec(const ExecArgs& args) const { miopen_check(miopenPoolingForward( handle, miopen_desc, &alpha, src_desc.desc, args.src_tensor->raw_ptr, &beta, dst_desc.desc, - args.src_tensor->raw_ptr, false, nullptr, 0_z)); + args.dst_tensor->raw_ptr, false, nullptr, 0_z)); miopen_check(miopenDestroyPoolingDescriptor(miopen_desc)); } @@ -163,10 +163,10 @@ void PoolingBackwardImpl::AlgoMIOpen::init_mode(const ExecArgs& args, case param::Pooling::Mode::MAX: mode = miopenPoolingMax; break; - case param::Pooling::Mode::AVERAGE: + case param::Pooling::Mode::AVERAGE_COUNT_EXCLUDE_PADDING: mode = miopenPoolingAverage; break; - case param::Pooling::Mode::AVERAGE_COUNT_EXCLUDE_PADDING: + case param::Pooling::Mode::AVERAGE: mode = miopenPoolingAverageInclusive; break; default: