|
|
|
@@ -16,50 +16,55 @@ |
|
|
|
namespace megdnn { |
|
|
|
|
|
|
|
void PoolingBase::deduce_layout_fwd(const TensorLayout& src, TensorLayout& dst) { |
|
|
|
auto errmsg = |
|
|
|
megdnn_layout_msg(src) + ", " + megdnn_layout_msg(dst) + ", " + |
|
|
|
"pad_h=" + std::to_string(param().pad_h) + ", " + |
|
|
|
"pad_w=" + std::to_string(param().pad_w) + ", " + |
|
|
|
"stride_h=" + std::to_string(param().stride_h) + ", " + |
|
|
|
"stride_w=" + std::to_string(param().stride_w) + ", " + |
|
|
|
"window_h=" + std::to_string(param().window_h) + ", " + |
|
|
|
"window_w=" + std::to_string(param().window_w) + ", " + |
|
|
|
"is_max=" + std::to_string(param().mode == Mode::MAX) + ", " + |
|
|
|
"is_nhwc=" + std::to_string(param().format == Param::Format::NHWC) + ", " + |
|
|
|
"is_nhwcd4=" + std::to_string(param().format == Param::Format::NHWCD4); |
|
|
|
auto errmsg_c = errmsg.c_str(); |
|
|
|
|
|
|
|
MEGDNN_MARK_USED_VAR(errmsg_c); |
|
|
|
auto& p = param(); |
|
|
|
auto pformat = p.format; |
|
|
|
|
|
|
|
// the overhead of generating error message is about 18x of the other part of this |
|
|
|
// function so we use a function to wrap the error message and get it only when need. |
|
|
|
auto get_errmsg = [&](void) -> std::string { |
|
|
|
std::string errmsg = |
|
|
|
megdnn_layout_msg(src) + ", " + megdnn_layout_msg(dst) + ", " + |
|
|
|
"pad_h=" + std::to_string(param().pad_h) + ", " + |
|
|
|
"pad_w=" + std::to_string(param().pad_w) + ", " + |
|
|
|
"stride_h=" + std::to_string(param().stride_h) + ", " + |
|
|
|
"stride_w=" + std::to_string(param().stride_w) + ", " + |
|
|
|
"window_h=" + std::to_string(param().window_h) + ", " + |
|
|
|
"window_w=" + std::to_string(param().window_w) + ", " + |
|
|
|
"is_max=" + std::to_string(param().mode == Mode::MAX) + ", " + |
|
|
|
"is_nhwc=" + std::to_string(pformat == Param::Format::NHWC) + ", " + |
|
|
|
"is_nhwcd4=" + std::to_string(pformat == Param::Format::NHWCD4); |
|
|
|
return errmsg; |
|
|
|
}; |
|
|
|
|
|
|
|
MEGDNN_MARK_USED_VAR(get_errmsg); |
|
|
|
megdnn_assert_contiguous(src); |
|
|
|
size_t spatial_pos, c_pos, batch_pos = 0; |
|
|
|
if (param().format == Param::Format::NCHW) { |
|
|
|
megdnn_assert(src.ndim == 4_z, "%s", errmsg_c); |
|
|
|
if (pformat == Param::Format::NCHW) { |
|
|
|
megdnn_assert(src.ndim == 4_z, "%s", get_errmsg().c_str()); |
|
|
|
|
|
|
|
spatial_pos = 2; |
|
|
|
c_pos = 1; |
|
|
|
} else if (param().format == Param::Format::NHWC) { |
|
|
|
megdnn_assert(src.ndim == 4_z, "%s", errmsg_c); |
|
|
|
} else if (pformat == Param::Format::NHWC) { |
|
|
|
megdnn_assert(src.ndim == 4_z, "%s", get_errmsg().c_str()); |
|
|
|
|
|
|
|
spatial_pos = 1; |
|
|
|
c_pos = 3; |
|
|
|
} else if ( |
|
|
|
param().format == Param::Format::NCHW4 || |
|
|
|
param().format == Param::Format::NCHW44 || |
|
|
|
param().format == Param::Format::NCHW88 || |
|
|
|
param().format == Param::Format::NCHW32 || |
|
|
|
param().format == Param::Format::NCHW64) { |
|
|
|
megdnn_assert(src.ndim == 5_z, "%s", errmsg_c); |
|
|
|
pformat == Param::Format::NCHW4 || pformat == Param::Format::NCHW44 || |
|
|
|
pformat == Param::Format::NCHW88 || pformat == Param::Format::NCHW32 || |
|
|
|
pformat == Param::Format::NCHW64) { |
|
|
|
megdnn_assert(src.ndim == 5_z, "%s", get_errmsg().c_str()); |
|
|
|
|
|
|
|
spatial_pos = 2; |
|
|
|
c_pos = 1; |
|
|
|
} else if (param().format == Param::Format::CHWN4) { |
|
|
|
} else if (pformat == Param::Format::CHWN4) { |
|
|
|
spatial_pos = 1; |
|
|
|
c_pos = 0; |
|
|
|
batch_pos = 3; |
|
|
|
} else { |
|
|
|
megdnn_assert( |
|
|
|
param().format == Param::Format::NHWCD4 && src.ndim == 5_z, "%s", |
|
|
|
errmsg_c); |
|
|
|
pformat == Param::Format::NHWCD4 && src.ndim == 5_z, "%s", |
|
|
|
get_errmsg().c_str()); |
|
|
|
spatial_pos = 1; |
|
|
|
c_pos = 2; |
|
|
|
} |
|
|
|
@@ -67,31 +72,34 @@ void PoolingBase::deduce_layout_fwd(const TensorLayout& src, TensorLayout& dst) |
|
|
|
size_t c = src[c_pos]; |
|
|
|
size_t ih = src[spatial_pos]; |
|
|
|
size_t iw = src[spatial_pos + 1]; |
|
|
|
if (param().format == Param::Format::NHWCD4) { |
|
|
|
if (pformat == Param::Format::NHWCD4) { |
|
|
|
c *= 4; |
|
|
|
iw = src[spatial_pos + 2]; |
|
|
|
} |
|
|
|
if (param().format == Param::Format::NCHW4 || |
|
|
|
param().format == Param::Format::NCHW44 || |
|
|
|
param().format == Param::Format::CHWN4) { |
|
|
|
if (pformat == Param::Format::NCHW4 || pformat == Param::Format::NCHW44 || |
|
|
|
pformat == Param::Format::CHWN4) { |
|
|
|
c *= 4; |
|
|
|
} |
|
|
|
if (param().format == Param::Format::NCHW88) { |
|
|
|
if (pformat == Param::Format::NCHW88) { |
|
|
|
c *= 8; |
|
|
|
} |
|
|
|
if (param().format == Param::Format::NCHW32) { |
|
|
|
if (pformat == Param::Format::NCHW32) { |
|
|
|
c *= 32; |
|
|
|
} |
|
|
|
if (param().format == Param::Format::NCHW64) { |
|
|
|
if (pformat == Param::Format::NCHW64) { |
|
|
|
c *= 64; |
|
|
|
} |
|
|
|
size_t oh, ow; |
|
|
|
size_t fh = this->param().window_h; |
|
|
|
size_t fw = this->param().window_w; |
|
|
|
size_t sh = this->param().stride_h; |
|
|
|
size_t sw = this->param().stride_w; |
|
|
|
size_t ph = this->param().pad_h; |
|
|
|
size_t pw = this->param().pad_w; |
|
|
|
size_t fh = p.window_h; |
|
|
|
size_t fw = p.window_w; |
|
|
|
size_t sh = p.stride_h; |
|
|
|
size_t sw = p.stride_w; |
|
|
|
size_t ph = p.pad_h; |
|
|
|
size_t pw = p.pad_w; |
|
|
|
|
|
|
|
// moving some python assert to here |
|
|
|
// megdnn_assert() |
|
|
|
|
|
|
|
if (ph >= fh || pw >= fw) { |
|
|
|
megdnn_log_warn( |
|
|
|
"pooling padding size (%zu %zu) should not be bigger than " |
|
|
|
@@ -99,26 +107,23 @@ void PoolingBase::deduce_layout_fwd(const TensorLayout& src, TensorLayout& dst) |
|
|
|
pw, ph, fw, fh); |
|
|
|
} |
|
|
|
infer_conv_shape2d(ih, iw, fh, fw, sh, sw, ph, pw, oh, ow); |
|
|
|
if (param().format == Param::Format::NCHW) { |
|
|
|
if (pformat == Param::Format::NCHW) { |
|
|
|
dst = TensorLayout(TensorShape({n, c, oh, ow}), src.dtype); |
|
|
|
} else if (param().format == Param::Format::NHWC) { |
|
|
|
megdnn_assert(param().format == Param::Format::NHWC, "invalid pooling format"); |
|
|
|
} else if (pformat == Param::Format::NHWC) { |
|
|
|
megdnn_assert(pformat == Param::Format::NHWC, "invalid pooling format"); |
|
|
|
dst = TensorLayout({n, oh, ow, c}, src.dtype, src.format); |
|
|
|
} else if ( |
|
|
|
param().format == Param::Format::NCHW4 || |
|
|
|
param().format == Param::Format::NCHW44) { |
|
|
|
} else if (pformat == Param::Format::NCHW4 || pformat == Param::Format::NCHW44) { |
|
|
|
dst = TensorLayout{{n, c / 4, oh, ow, 4}, src.dtype, src.format}; |
|
|
|
} else if (param().format == Param::Format::NCHW88) { |
|
|
|
} else if (pformat == Param::Format::NCHW88) { |
|
|
|
dst = TensorLayout{{n, c / 8, oh, ow, 8}, src.dtype, src.format}; |
|
|
|
} else if (param().format == Param::Format::NCHW32) { |
|
|
|
} else if (pformat == Param::Format::NCHW32) { |
|
|
|
dst = TensorLayout{{n, c / 32, oh, ow, 32}, src.dtype, src.format}; |
|
|
|
} else if (param().format == Param::Format::NCHW64) { |
|
|
|
} else if (pformat == Param::Format::NCHW64) { |
|
|
|
dst = TensorLayout{{n, c / 64, oh, ow, 64}, src.dtype, src.format}; |
|
|
|
} else if (param().format == Param::Format::CHWN4) { |
|
|
|
} else if (pformat == Param::Format::CHWN4) { |
|
|
|
dst = TensorLayout{{c / 4, oh, ow, n, 4}, src.dtype, src.format}; |
|
|
|
} else { |
|
|
|
megdnn_assert( |
|
|
|
param().format == Param::Format::NHWCD4, "invalid pooling format"); |
|
|
|
megdnn_assert(pformat == Param::Format::NHWCD4, "invalid pooling format"); |
|
|
|
dst = TensorLayout{{n, oh, c / 4, ow, 4}, src.dtype, src.format}; |
|
|
|
} |
|
|
|
} |
|
|
|
|