| @@ -370,65 +370,67 @@ void pooling_backward_max_impl(const ctype* __restrict src, | |||
| } | |||
| } | |||
| } // anonymous namespace | |||
| } // namespace | |||
| namespace megdnn { | |||
| namespace naive { | |||
| void PoolingForwardImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_out dst, | |||
| _megdnn_workspace workspace) { | |||
| MIDOUT_BEGIN(megdnn_naive_pooling) { | |||
| check_exec(src.layout, dst.layout, workspace.size); | |||
| size_t c_pos, spatial_pos, batch_pos = 0; | |||
| if (param().format == Param::Format::NCHW || | |||
| param().format == Param::Format::NCHW4 || | |||
| param().format == Param::Format::NCHW88 || | |||
| param().format == Param::Format::NCHW44 || | |||
| param().format == Param::Format::NCHW32) { | |||
| c_pos = 1; | |||
| spatial_pos = 2; | |||
| } else if (param().format == Param::Format::NHWC) { | |||
| c_pos = 3; | |||
| spatial_pos = 1; | |||
| } else if (param().format == Param::Format::CHWN4) { | |||
| c_pos = 0; | |||
| spatial_pos = 1; | |||
| batch_pos = 3; | |||
| } else { | |||
| megdnn_assert(param().format == Param::Format::NHWCD4); | |||
| c_pos = 2; | |||
| spatial_pos = 1; | |||
| } | |||
| size_t N = src.layout.shape[batch_pos], C = src.layout.shape[c_pos], | |||
| IH = src.layout.shape[spatial_pos + 0], | |||
| IW = src.layout.shape[spatial_pos + 1]; | |||
| size_t OH = dst.layout.shape[spatial_pos + 0], | |||
| OW = dst.layout.shape[spatial_pos + 1]; | |||
| if (param().format == Param::Format::NHWCD4) { | |||
| C *= 4; | |||
| IW = src.layout.shape[spatial_pos + 2]; | |||
| OW = dst.layout.shape[spatial_pos + 2]; | |||
| } | |||
| if (param().format == Param::Format::NCHW4 || | |||
| param().format == Param::Format::NCHW44 || | |||
| param().format == Param::Format::CHWN4) { | |||
| C *= 4; | |||
| } | |||
| if (param().format == Param::Format::NCHW88) { | |||
| C *= 8; | |||
| } | |||
| if (param().format == Param::Format::NCHW32) { | |||
| C *= 32; | |||
| } | |||
| size_t PH = param().pad_h, PW = param().pad_w; | |||
| size_t FH = param().window_h, FW = param().window_w; | |||
| size_t SH = param().stride_h, SW = param().stride_w; | |||
| #define DISPATCH_WITH_POOLER_AND_IDX_GETTER(Pooler, IdxGetter) \ | |||
| MEGDNN_DISPATCH_CPU_KERN( \ | |||
| static_cast<naive::HandleImpl*>(handle()), \ | |||
| pooling_forward_impl<Pooler MEGDNN_COMMA IdxGetter>( \ | |||
| sptr, dptr, src.layout.dtype, N, C, IH, IW, OH, OW, PH, \ | |||
| PW, SH, SW, FH, FW)); | |||
| check_exec(src.layout, dst.layout, workspace.size); | |||
| size_t c_pos, spatial_pos, batch_pos = 0; | |||
| if (param().format == Param::Format::NCHW || | |||
| param().format == Param::Format::NCHW4 || | |||
| param().format == Param::Format::NCHW88 || | |||
| param().format == Param::Format::NCHW44 || | |||
| param().format == Param::Format::NCHW32) { | |||
| c_pos = 1; | |||
| spatial_pos = 2; | |||
| } else if (param().format == Param::Format::NHWC) { | |||
| c_pos = 3; | |||
| spatial_pos = 1; | |||
| } else if (param().format == Param::Format::CHWN4) { | |||
| c_pos = 0; | |||
| spatial_pos = 1; | |||
| batch_pos = 3; | |||
| } else { | |||
| megdnn_assert(param().format == Param::Format::NHWCD4); | |||
| c_pos = 2; | |||
| spatial_pos = 1; | |||
| } | |||
| size_t N = src.layout.shape[batch_pos], C = src.layout.shape[c_pos], | |||
| IH = src.layout.shape[spatial_pos + 0], | |||
| IW = src.layout.shape[spatial_pos + 1]; | |||
| size_t OH = dst.layout.shape[spatial_pos + 0], | |||
| OW = dst.layout.shape[spatial_pos + 1]; | |||
| if (param().format == Param::Format::NHWCD4) { | |||
| C *= 4; | |||
| IW = src.layout.shape[spatial_pos + 2]; | |||
| OW = dst.layout.shape[spatial_pos + 2]; | |||
| } | |||
| if (param().format == Param::Format::NCHW4 || | |||
| param().format == Param::Format::NCHW44 || | |||
| param().format == Param::Format::CHWN4) { | |||
| C *= 4; | |||
| } | |||
| if (param().format == Param::Format::NCHW88) { | |||
| C *= 8; | |||
| } | |||
| if (param().format == Param::Format::NCHW32) { | |||
| C *= 32; | |||
| } | |||
| size_t PH = param().pad_h, PW = param().pad_w; | |||
| size_t FH = param().window_h, FW = param().window_w; | |||
| size_t SH = param().stride_h, SW = param().stride_w; | |||
| #define DISPATCH_WITH_POOLER_AND_IDX_GETTER(Pooler, IdxGetter) \ | |||
| MIDOUT_BEGIN(megdnn_naive_pooling, midout_iv(#Pooler #IdxGetter##_hash)) { \ | |||
| MEGDNN_DISPATCH_CPU_KERN( \ | |||
| static_cast<naive::HandleImpl*>(handle()), \ | |||
| pooling_forward_impl<Pooler MEGDNN_COMMA IdxGetter>( \ | |||
| sptr, dptr, src.layout.dtype, N, C, IH, IW, OH, OW, \ | |||
| PH, PW, SH, SW, FH, FW)); \ | |||
| } \ | |||
| MIDOUT_END(); | |||
| #define DISPATCH_WITH_POOLER(Pooler) \ | |||
| switch (param().format) { \ | |||
| @@ -484,14 +486,12 @@ void PoolingForwardImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_out dst, | |||
| } \ | |||
| } \ | |||
| } | |||
| MEGDNN_FOREACH_COMPUTING_DTYPE(cb) | |||
| MEGDNN_FOREACH_QUANTIZED_DTYPE(cb) | |||
| MEGDNN_FOREACH_COMPUTING_DTYPE(cb) | |||
| MEGDNN_FOREACH_QUANTIZED_DTYPE(cb) | |||
| #undef cb | |||
| #undef DISPATCH_WITH_POOLER_AND_IDX_GETTER | |||
| #undef DISPATCH_WITH_POOLER | |||
| megdnn_assert_internal(0); | |||
| } | |||
| MIDOUT_END(); | |||
| megdnn_assert_internal(0); | |||
| } | |||
| WorkspaceBundle PoolingBackwardImpl::get_workspace_bundle( | |||