| @@ -370,65 +370,67 @@ void pooling_backward_max_impl(const ctype* __restrict src, | |||||
| } | } | ||||
| } | } | ||||
| } // anonymous namespace | |||||
| } // namespace | |||||
| namespace megdnn { | namespace megdnn { | ||||
| namespace naive { | namespace naive { | ||||
| void PoolingForwardImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_out dst, | void PoolingForwardImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_out dst, | ||||
| _megdnn_workspace workspace) { | _megdnn_workspace workspace) { | ||||
| MIDOUT_BEGIN(megdnn_naive_pooling) { | |||||
| check_exec(src.layout, dst.layout, workspace.size); | |||||
| size_t c_pos, spatial_pos, batch_pos = 0; | |||||
| if (param().format == Param::Format::NCHW || | |||||
| param().format == Param::Format::NCHW4 || | |||||
| param().format == Param::Format::NCHW88 || | |||||
| param().format == Param::Format::NCHW44 || | |||||
| param().format == Param::Format::NCHW32) { | |||||
| c_pos = 1; | |||||
| spatial_pos = 2; | |||||
| } else if (param().format == Param::Format::NHWC) { | |||||
| c_pos = 3; | |||||
| spatial_pos = 1; | |||||
| } else if (param().format == Param::Format::CHWN4) { | |||||
| c_pos = 0; | |||||
| spatial_pos = 1; | |||||
| batch_pos = 3; | |||||
| } else { | |||||
| megdnn_assert(param().format == Param::Format::NHWCD4); | |||||
| c_pos = 2; | |||||
| spatial_pos = 1; | |||||
| } | |||||
| size_t N = src.layout.shape[batch_pos], C = src.layout.shape[c_pos], | |||||
| IH = src.layout.shape[spatial_pos + 0], | |||||
| IW = src.layout.shape[spatial_pos + 1]; | |||||
| size_t OH = dst.layout.shape[spatial_pos + 0], | |||||
| OW = dst.layout.shape[spatial_pos + 1]; | |||||
| if (param().format == Param::Format::NHWCD4) { | |||||
| C *= 4; | |||||
| IW = src.layout.shape[spatial_pos + 2]; | |||||
| OW = dst.layout.shape[spatial_pos + 2]; | |||||
| } | |||||
| if (param().format == Param::Format::NCHW4 || | |||||
| param().format == Param::Format::NCHW44 || | |||||
| param().format == Param::Format::CHWN4) { | |||||
| C *= 4; | |||||
| } | |||||
| if (param().format == Param::Format::NCHW88) { | |||||
| C *= 8; | |||||
| } | |||||
| if (param().format == Param::Format::NCHW32) { | |||||
| C *= 32; | |||||
| } | |||||
| size_t PH = param().pad_h, PW = param().pad_w; | |||||
| size_t FH = param().window_h, FW = param().window_w; | |||||
| size_t SH = param().stride_h, SW = param().stride_w; | |||||
| #define DISPATCH_WITH_POOLER_AND_IDX_GETTER(Pooler, IdxGetter) \ | |||||
| MEGDNN_DISPATCH_CPU_KERN( \ | |||||
| static_cast<naive::HandleImpl*>(handle()), \ | |||||
| pooling_forward_impl<Pooler MEGDNN_COMMA IdxGetter>( \ | |||||
| sptr, dptr, src.layout.dtype, N, C, IH, IW, OH, OW, PH, \ | |||||
| PW, SH, SW, FH, FW)); | |||||
| check_exec(src.layout, dst.layout, workspace.size); | |||||
| size_t c_pos, spatial_pos, batch_pos = 0; | |||||
| if (param().format == Param::Format::NCHW || | |||||
| param().format == Param::Format::NCHW4 || | |||||
| param().format == Param::Format::NCHW88 || | |||||
| param().format == Param::Format::NCHW44 || | |||||
| param().format == Param::Format::NCHW32) { | |||||
| c_pos = 1; | |||||
| spatial_pos = 2; | |||||
| } else if (param().format == Param::Format::NHWC) { | |||||
| c_pos = 3; | |||||
| spatial_pos = 1; | |||||
| } else if (param().format == Param::Format::CHWN4) { | |||||
| c_pos = 0; | |||||
| spatial_pos = 1; | |||||
| batch_pos = 3; | |||||
| } else { | |||||
| megdnn_assert(param().format == Param::Format::NHWCD4); | |||||
| c_pos = 2; | |||||
| spatial_pos = 1; | |||||
| } | |||||
| size_t N = src.layout.shape[batch_pos], C = src.layout.shape[c_pos], | |||||
| IH = src.layout.shape[spatial_pos + 0], | |||||
| IW = src.layout.shape[spatial_pos + 1]; | |||||
| size_t OH = dst.layout.shape[spatial_pos + 0], | |||||
| OW = dst.layout.shape[spatial_pos + 1]; | |||||
| if (param().format == Param::Format::NHWCD4) { | |||||
| C *= 4; | |||||
| IW = src.layout.shape[spatial_pos + 2]; | |||||
| OW = dst.layout.shape[spatial_pos + 2]; | |||||
| } | |||||
| if (param().format == Param::Format::NCHW4 || | |||||
| param().format == Param::Format::NCHW44 || | |||||
| param().format == Param::Format::CHWN4) { | |||||
| C *= 4; | |||||
| } | |||||
| if (param().format == Param::Format::NCHW88) { | |||||
| C *= 8; | |||||
| } | |||||
| if (param().format == Param::Format::NCHW32) { | |||||
| C *= 32; | |||||
| } | |||||
| size_t PH = param().pad_h, PW = param().pad_w; | |||||
| size_t FH = param().window_h, FW = param().window_w; | |||||
| size_t SH = param().stride_h, SW = param().stride_w; | |||||
| #define DISPATCH_WITH_POOLER_AND_IDX_GETTER(Pooler, IdxGetter) \ | |||||
| MIDOUT_BEGIN(megdnn_naive_pooling, midout_iv(#Pooler #IdxGetter##_hash)) { \ | |||||
| MEGDNN_DISPATCH_CPU_KERN( \ | |||||
| static_cast<naive::HandleImpl*>(handle()), \ | |||||
| pooling_forward_impl<Pooler MEGDNN_COMMA IdxGetter>( \ | |||||
| sptr, dptr, src.layout.dtype, N, C, IH, IW, OH, OW, \ | |||||
| PH, PW, SH, SW, FH, FW)); \ | |||||
| } \ | |||||
| MIDOUT_END(); | |||||
| #define DISPATCH_WITH_POOLER(Pooler) \ | #define DISPATCH_WITH_POOLER(Pooler) \ | ||||
| switch (param().format) { \ | switch (param().format) { \ | ||||
| @@ -484,14 +486,12 @@ void PoolingForwardImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_out dst, | |||||
| } \ | } \ | ||||
| } \ | } \ | ||||
| } | } | ||||
| MEGDNN_FOREACH_COMPUTING_DTYPE(cb) | |||||
| MEGDNN_FOREACH_QUANTIZED_DTYPE(cb) | |||||
| MEGDNN_FOREACH_COMPUTING_DTYPE(cb) | |||||
| MEGDNN_FOREACH_QUANTIZED_DTYPE(cb) | |||||
| #undef cb | #undef cb | ||||
| #undef DISPATCH_WITH_POOLER_AND_IDX_GETTER | #undef DISPATCH_WITH_POOLER_AND_IDX_GETTER | ||||
| #undef DISPATCH_WITH_POOLER | #undef DISPATCH_WITH_POOLER | ||||
| megdnn_assert_internal(0); | |||||
| } | |||||
| MIDOUT_END(); | |||||
| megdnn_assert_internal(0); | |||||
| } | } | ||||
| WorkspaceBundle PoolingBackwardImpl::get_workspace_bundle( | WorkspaceBundle PoolingBackwardImpl::get_workspace_bundle( | ||||