GitOrigin-RevId: 31e7b72a78
tags/v1.9.0
| @@ -18,21 +18,22 @@ namespace megdnn { | |||||
| void RemapBase::deduce_layout_fwd( | void RemapBase::deduce_layout_fwd( | ||||
| const TensorLayout& src, const TensorLayout& map_xy, TensorLayout& dst) { | const TensorLayout& src, const TensorLayout& map_xy, TensorLayout& dst) { | ||||
| dst.dtype = src.dtype; | |||||
| dst.ndim = src.ndim; | |||||
| dst.shape[0] = src.shape[0]; | |||||
| size_t height_index, channel_index; | |||||
| size_t n = src.shape[0]; | |||||
| size_t c, oh, ow; | |||||
| oh = map_xy.shape[1]; | |||||
| ow = map_xy.shape[2]; | |||||
| if (param().format == param::Remap::Format::NHWC) { | if (param().format == param::Remap::Format::NHWC) { | ||||
| height_index = 1; | |||||
| channel_index = 3; | |||||
| c = src.shape[3]; | |||||
| dst = TensorLayout(TensorShape({n, oh, ow, c}), src.dtype); | |||||
| } else if (param().format == param::Remap::Format::NCHW) { | |||||
| c = src.shape[1]; | |||||
| dst = TensorLayout(TensorShape{n, c, oh, ow}, src.dtype, src.format); | |||||
| } else if (param().format == param::Remap::Format::NHWCD4) { | |||||
| c = src.shape[2]; | |||||
| dst = TensorLayout{{n, oh, c, ow, 4}, src.dtype, src.format}; | |||||
| } else { | } else { | ||||
| megdnn_assert(param().format == param::Remap::Format::NCHW); | |||||
| height_index = 2; | |||||
| channel_index = 1; | |||||
| megdnn_throw("unsupport format"); | |||||
| } | } | ||||
| dst.shape[height_index] = map_xy.shape[1]; | |||||
| dst.shape[height_index + 1] = map_xy.shape[2]; | |||||
| dst.shape[channel_index] = src.shape[channel_index]; | |||||
| } | } | ||||
| void RemapBase::check_layout_fwd( | void RemapBase::check_layout_fwd( | ||||
| @@ -42,7 +43,7 @@ void RemapBase::check_layout_fwd( | |||||
| megdnn_layout_msg(dst); | megdnn_layout_msg(dst); | ||||
| }; | }; | ||||
| MEGDNN_MARK_USED_VAR(errmsg); | MEGDNN_MARK_USED_VAR(errmsg); | ||||
| megdnn_assert(src.ndim == map_xy.ndim && src.ndim == dst.ndim && src.ndim == 4); | |||||
| megdnn_assert(src.ndim == dst.ndim); | |||||
| megdnn_assert(dst.dtype == src.dtype); | megdnn_assert(dst.dtype == src.dtype); | ||||
| megdnn_assert(dst.shape[0] == src.shape[0], "%s", errmsg().c_str()); | megdnn_assert(dst.shape[0] == src.shape[0], "%s", errmsg().c_str()); | ||||
| megdnn_assert(map_xy.shape[3] == 2); | megdnn_assert(map_xy.shape[3] == 2); | ||||
| @@ -64,10 +65,13 @@ void RemapBase::check_layout_fwd( | |||||
| megdnn_assert( | megdnn_assert( | ||||
| dst.shape[2] == map_xy.shape[1] && dst.shape[3] == map_xy.shape[2], | dst.shape[2] == map_xy.shape[1] && dst.shape[3] == map_xy.shape[2], | ||||
| "%s", errmsg().c_str()); | "%s", errmsg().c_str()); | ||||
| } else if (param().format == param::Remap::Format::NHWCD4) { | |||||
| megdnn_assert(src.shape[2] == dst.shape[2], "%s", errmsg().c_str()); | |||||
| megdnn_assert(src.ndim == 5_z, "%s", errmsg().c_str()); | |||||
| megdnn_assert(dst.ndim == 5_z, "%s", errmsg().c_str()); | |||||
| megdnn_assert(param().format == Param::Format::NHWCD4); | |||||
| } else { | } else { | ||||
| megdnn_throw( | |||||
| "currently do not support other param.format except NHWC and " | |||||
| "NCHW"); | |||||
| megdnn_throw("unsupport format"); | |||||
| } | } | ||||
| } | } | ||||
| @@ -22,8 +22,9 @@ void RemapBackwardDataImpl::exec( | |||||
| _megdnn_workspace workspace) { | _megdnn_workspace workspace) { | ||||
| check_exec(map_xy.layout, diff.layout, grad.layout, workspace.size); | check_exec(map_xy.layout, diff.layout, grad.layout, workspace.size); | ||||
| megdnn_assert( | megdnn_assert( | ||||
| param().imode == param::Remap::InterpolationMode::LINEAR, | |||||
| "only support LINEAR interpolationMode"); | |||||
| (param().imode == param::Remap::InterpolationMode::NEAREST) || | |||||
| (param().imode == param::Remap::InterpolationMode::LINEAR), | |||||
| "only support NEAREST and LINEAR interpolationMode"); | |||||
| megdnn_assert( | megdnn_assert( | ||||
| param().format == param::Remap::Format::NCHW, | param().format == param::Remap::Format::NCHW, | ||||
| "only support NCHW format for remap backward"); | "only support NCHW format for remap backward"); | ||||
| @@ -36,13 +37,15 @@ void RemapBackwardDataImpl::exec( | |||||
| OH = map_xy.layout.shape[1]; | OH = map_xy.layout.shape[1]; | ||||
| OW = map_xy.layout.shape[2]; | OW = map_xy.layout.shape[2]; | ||||
| #define cb(dt, _format, bmode) \ | |||||
| #define cb(dt, _format, bmode, inter_mode) \ | |||||
| if (param().format == param::Remap::Format::_format && \ | if (param().format == param::Remap::Format::_format && \ | ||||
| param().border_type == param::Remap::BorderMode::bmode) { \ | |||||
| param().border_type == param::Remap::BorderMode::bmode && \ | |||||
| param().imode == param::Remap::InterpolationMode::inter_mode) { \ | |||||
| using ctype = DTypeTrait<dt>::ctype; \ | using ctype = DTypeTrait<dt>::ctype; \ | ||||
| remap::backwarddata_proxy< \ | remap::backwarddata_proxy< \ | ||||
| ctype, param_enumv::Remap::Format::_format, \ | ctype, param_enumv::Remap::Format::_format, \ | ||||
| ::BorderMode::BORDER_##bmode>( \ | |||||
| ::BorderMode::BORDER_##bmode, \ | |||||
| ::InterpolationMode::INTER_##inter_mode>( \ | |||||
| grad.compatible_ptr<ctype>(), map_xy.compatible_ptr<dt_float32>(), \ | grad.compatible_ptr<ctype>(), map_xy.compatible_ptr<dt_float32>(), \ | ||||
| diff.compatible_ptr<ctype>(), N, C, IH, IW, OH, OW, stream); \ | diff.compatible_ptr<ctype>(), N, C, IH, IW, OH, OW, stream); \ | ||||
| break; \ | break; \ | ||||
| @@ -50,11 +53,16 @@ void RemapBackwardDataImpl::exec( | |||||
| #define support_dtype(dt) \ | #define support_dtype(dt) \ | ||||
| case DTypeTrait<dt>::enumv: { \ | case DTypeTrait<dt>::enumv: { \ | ||||
| cb(dt, NCHW, CONSTANT); \ | |||||
| cb(dt, NCHW, REPLICATE); \ | |||||
| cb(dt, NCHW, REFLECT); \ | |||||
| cb(dt, NCHW, REFLECT_101); \ | |||||
| cb(dt, NCHW, WRAP); \ | |||||
| cb(dt, NCHW, CONSTANT, NEAREST); \ | |||||
| cb(dt, NCHW, REPLICATE, NEAREST); \ | |||||
| cb(dt, NCHW, REFLECT, NEAREST); \ | |||||
| cb(dt, NCHW, REFLECT_101, NEAREST); \ | |||||
| cb(dt, NCHW, WRAP, NEAREST); \ | |||||
| cb(dt, NCHW, CONSTANT, LINEAR); \ | |||||
| cb(dt, NCHW, REPLICATE, LINEAR); \ | |||||
| cb(dt, NCHW, REFLECT, LINEAR); \ | |||||
| cb(dt, NCHW, REFLECT_101, LINEAR); \ | |||||
| cb(dt, NCHW, WRAP, LINEAR); \ | |||||
| megdnn_throw("unsupported border type in remap cuda"); \ | megdnn_throw("unsupported border type in remap cuda"); \ | ||||
| } | } | ||||
| @@ -52,8 +52,49 @@ struct GetSrcData<ctype, format, ::BorderMode::BORDER_CONSTANT> { | |||||
| } | } | ||||
| }; | }; | ||||
| __device__ inline float round_half_to_even(float f) { | |||||
| const float round_away_from_zero = round(f); | |||||
| const float diff = round_away_from_zero - f; | |||||
| if ((diff != 0.5f) && (diff != -0.5f)) { | |||||
| return round_away_from_zero; | |||||
| } | |||||
| if (fmod(round_away_from_zero, 2.0f) == 0.0f) { | |||||
| return round_away_from_zero; | |||||
| } | |||||
| return f - diff; | |||||
| } | |||||
| template <typename ctype, const uint32_t format, ::BorderMode bmode> | |||||
| __global__ void kern_general_nearest( | |||||
| ctype* __restrict grad, const float* map_xy, const ctype* diff, int C, int IH, | |||||
| int IW, int OH, int OW) { | |||||
| int ow = blockIdx.x * blockDim.x + threadIdx.x; | |||||
| int oh = blockIdx.y * blockDim.y + threadIdx.y; | |||||
| grad += blockIdx.z * C * IH * IW; | |||||
| diff += blockIdx.z * C * OH * OW; | |||||
| map_xy += blockIdx.z * 2 * OH * OW; | |||||
| if (ow < OW && oh < OH) { | |||||
| float index_col = map_xy[oh * OW * 2 + ow * 2 + 0]; | |||||
| float index_row = map_xy[oh * OW * 2 + ow * 2 + 1]; | |||||
| int col = static_cast<int>(round_half_to_even(index_col)); | |||||
| int row = static_cast<int>(round_half_to_even(index_row)); | |||||
| for (int c = 0; c < C; ++c) { | |||||
| ctype hidden = diff[get_offset<format>(oh, ow, c, OH, OW, C)]; | |||||
| int idx = | |||||
| GetSrcData<ctype, format, bmode>::get_index(row, col, c, IH, IW, C); | |||||
| if (idx != -1) { | |||||
| atomic_add(grad + idx, hidden); | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| template <typename ctype, const uint32_t format, ::BorderMode bmode> | template <typename ctype, const uint32_t format, ::BorderMode bmode> | ||||
| __global__ void kern_general( | |||||
| __global__ void kern_general_linear( | |||||
| ctype* __restrict grad, const float* map_xy, const ctype* diff, int C, int IH, | ctype* __restrict grad, const float* map_xy, const ctype* diff, int C, int IH, | ||||
| int IW, int OH, int OW) { | int IW, int OH, int OW) { | ||||
| int ow = blockIdx.x * blockDim.x + threadIdx.x; | int ow = blockIdx.x * blockDim.x + threadIdx.x; | ||||
| @@ -93,8 +134,8 @@ __global__ void kern_general( | |||||
| atomic_add(grad + a10, round_converter(u * (one - v) * hidden)); | atomic_add(grad + a10, round_converter(u * (one - v) * hidden)); | ||||
| } | } | ||||
| int a11 = GetSrcData<ctype, param_enumv::Remap::Format::NCHW, bmode>:: | |||||
| get_index(row + 1, col + 1, c, IH, IW, C); | |||||
| int a11 = GetSrcData<ctype, format, bmode>::get_index( | |||||
| row + 1, col + 1, c, IH, IW, C); | |||||
| if (a11 != -1) { | if (a11 != -1) { | ||||
| atomic_add(grad + a11, round_converter(u * v * hidden)); | atomic_add(grad + a11, round_converter(u * v * hidden)); | ||||
| } | } | ||||
| @@ -102,7 +143,9 @@ __global__ void kern_general( | |||||
| } | } | ||||
| } | } | ||||
| template <typename ctype, const uint32_t format, ::BorderMode bmode> | |||||
| template < | |||||
| typename ctype, const uint32_t format, ::BorderMode bmode, | |||||
| ::InterpolationMode imode> | |||||
| void dispatch_backwarddata( | void dispatch_backwarddata( | ||||
| ctype* grad, const float* map_xy, const ctype* diff, int N, int C, int IH, | ctype* grad, const float* map_xy, const ctype* diff, int N, int C, int IH, | ||||
| int IW, int OH, int OW, cudaStream_t stream) { | int IW, int OH, int OW, cudaStream_t stream) { | ||||
| @@ -115,8 +158,13 @@ void dispatch_backwarddata( | |||||
| cuda_check(cudaMemsetAsync( | cuda_check(cudaMemsetAsync( | ||||
| grad, 0, sizeof(ctype) * curr_batch_size * C * IH * IW, stream)); | grad, 0, sizeof(ctype) * curr_batch_size * C * IH * IW, stream)); | ||||
| kern_general<ctype, format, bmode> | |||||
| <<<blocks, threads, 0, stream>>>(grad, map_xy, diff, C, IH, IW, OH, OW); | |||||
| if (imode == ::InterpolationMode::INTER_NEAREST) { | |||||
| kern_general_nearest<ctype, format, bmode><<<blocks, threads, 0, stream>>>( | |||||
| grad, map_xy, diff, C, IH, IW, OH, OW); | |||||
| } else if (imode == ::InterpolationMode::INTER_LINEAR) { | |||||
| kern_general_linear<ctype, format, bmode><<<blocks, threads, 0, stream>>>( | |||||
| grad, map_xy, diff, C, IH, IW, OH, OW); | |||||
| } | |||||
| N -= curr_batch_size; | N -= curr_batch_size; | ||||
| grad += curr_batch_size * C * IH * IW; | grad += curr_batch_size * C * IH * IW; | ||||
| @@ -131,27 +179,35 @@ namespace megdnn { | |||||
| namespace cuda { | namespace cuda { | ||||
| namespace remap { | namespace remap { | ||||
| template <typename ctype, const uint32_t format, ::BorderMode bmode> | |||||
| template < | |||||
| typename ctype, const uint32_t format, ::BorderMode bmode, | |||||
| ::InterpolationMode imode> | |||||
| void backwarddata_proxy( | void backwarddata_proxy( | ||||
| ctype* grad, const float* map_xy, const ctype* diff, int N, int C, int IH, | ctype* grad, const float* map_xy, const ctype* diff, int N, int C, int IH, | ||||
| int IW, int OH, int OW, cudaStream_t stream) { | int IW, int OH, int OW, cudaStream_t stream) { | ||||
| dispatch_backwarddata<ctype, format, bmode>( | |||||
| dispatch_backwarddata<ctype, format, bmode, imode>( | |||||
| grad, map_xy, diff, N, C, IH, IW, OH, OW, stream); | grad, map_xy, diff, N, C, IH, IW, OH, OW, stream); | ||||
| after_kernel_launch(); | after_kernel_launch(); | ||||
| } | } | ||||
| #define INST(ctype, format, bmode) \ | |||||
| #define INST(ctype, format, bmode, imode) \ | |||||
| template void backwarddata_proxy< \ | template void backwarddata_proxy< \ | ||||
| ctype, param_enumv::Remap::Format::format, ::BorderMode::bmode>( \ | |||||
| ctype, param_enumv::Remap::Format::format, ::BorderMode::bmode, \ | |||||
| ::InterpolationMode::imode>( \ | |||||
| ctype*, const float*, const ctype*, int, int, int, int, int, int, \ | ctype*, const float*, const ctype*, int, int, int, int, int, int, \ | ||||
| cudaStream_t); | cudaStream_t); | ||||
| #define FOR_FORMAT_BMODE(ctype) \ | |||||
| INST(ctype, NCHW, BORDER_CONSTANT) \ | |||||
| INST(ctype, NCHW, BORDER_REPLICATE) \ | |||||
| INST(ctype, NCHW, BORDER_REFLECT) \ | |||||
| INST(ctype, NCHW, BORDER_REFLECT_101) \ | |||||
| INST(ctype, NCHW, BORDER_WRAP) | |||||
| #define FOR_FORMAT_BMODE(ctype) \ | |||||
| INST(ctype, NCHW, BORDER_CONSTANT, INTER_NEAREST) \ | |||||
| INST(ctype, NCHW, BORDER_REPLICATE, INTER_NEAREST) \ | |||||
| INST(ctype, NCHW, BORDER_REFLECT, INTER_NEAREST) \ | |||||
| INST(ctype, NCHW, BORDER_REFLECT_101, INTER_NEAREST) \ | |||||
| INST(ctype, NCHW, BORDER_WRAP, INTER_NEAREST) \ | |||||
| INST(ctype, NCHW, BORDER_CONSTANT, INTER_LINEAR) \ | |||||
| INST(ctype, NCHW, BORDER_REPLICATE, INTER_LINEAR) \ | |||||
| INST(ctype, NCHW, BORDER_REFLECT, INTER_LINEAR) \ | |||||
| INST(ctype, NCHW, BORDER_REFLECT_101, INTER_LINEAR) \ | |||||
| INST(ctype, NCHW, BORDER_WRAP, INTER_LINEAR) | |||||
| FOR_FORMAT_BMODE(float) | FOR_FORMAT_BMODE(float) | ||||
| DNN_INC_FLOAT16(FOR_FORMAT_BMODE(dt_bfloat16)) | DNN_INC_FLOAT16(FOR_FORMAT_BMODE(dt_bfloat16)) | ||||
| @@ -22,8 +22,9 @@ void RemapBackwardMatImpl::exec( | |||||
| _megdnn_tensor_out grad, _megdnn_workspace workspace) { | _megdnn_tensor_out grad, _megdnn_workspace workspace) { | ||||
| check_exec(src.layout, map_xy.layout, diff.layout, grad.layout, workspace.size); | check_exec(src.layout, map_xy.layout, diff.layout, grad.layout, workspace.size); | ||||
| megdnn_assert( | megdnn_assert( | ||||
| param().imode == param::Remap::InterpolationMode::LINEAR, | |||||
| "only support LINEAR interpolationMode"); | |||||
| (param().imode == param::Remap::InterpolationMode::NEAREST) || | |||||
| (param().imode == param::Remap::InterpolationMode::LINEAR), | |||||
| "only support NEAREST and LINEAR interpolationMode"); | |||||
| megdnn_assert( | megdnn_assert( | ||||
| param().format == param::Remap::Format::NCHW, | param().format == param::Remap::Format::NCHW, | ||||
| "only support NCHW format for remap backward"); | "only support NCHW format for remap backward"); | ||||
| @@ -36,13 +37,15 @@ void RemapBackwardMatImpl::exec( | |||||
| OH = map_xy.layout.shape[1]; | OH = map_xy.layout.shape[1]; | ||||
| OW = map_xy.layout.shape[2]; | OW = map_xy.layout.shape[2]; | ||||
| #define cb(dt, _format, bmode) \ | |||||
| #define cb(dt, _format, bmode, inter_mode) \ | |||||
| if (param().format == param::Remap::Format::_format && \ | if (param().format == param::Remap::Format::_format && \ | ||||
| param().border_type == param::Remap::BorderMode::bmode) { \ | |||||
| param().border_type == param::Remap::BorderMode::bmode && \ | |||||
| param().imode == param::Remap::InterpolationMode::inter_mode) { \ | |||||
| using ctype = DTypeTrait<dt>::ctype; \ | using ctype = DTypeTrait<dt>::ctype; \ | ||||
| remap::backwardmat_proxy< \ | remap::backwardmat_proxy< \ | ||||
| ctype, param_enumv::Remap::Format::_format, \ | ctype, param_enumv::Remap::Format::_format, \ | ||||
| ::BorderMode::BORDER_##bmode>( \ | |||||
| ::BorderMode::BORDER_##bmode, \ | |||||
| ::InterpolationMode::INTER_##inter_mode>( \ | |||||
| src.compatible_ptr<ctype>(), map_xy.compatible_ptr<dt_float32>(), \ | src.compatible_ptr<ctype>(), map_xy.compatible_ptr<dt_float32>(), \ | ||||
| diff.compatible_ptr<ctype>(), grad.compatible_ptr<dt_float32>(), N, C, \ | diff.compatible_ptr<ctype>(), grad.compatible_ptr<dt_float32>(), N, C, \ | ||||
| IH, IW, OH, OW, param().scalar, stream); \ | IH, IW, OH, OW, param().scalar, stream); \ | ||||
| @@ -51,11 +54,16 @@ void RemapBackwardMatImpl::exec( | |||||
| #define support_dtype(dt) \ | #define support_dtype(dt) \ | ||||
| case DTypeTrait<dt>::enumv: { \ | case DTypeTrait<dt>::enumv: { \ | ||||
| cb(dt, NCHW, CONSTANT); \ | |||||
| cb(dt, NCHW, REPLICATE); \ | |||||
| cb(dt, NCHW, REFLECT); \ | |||||
| cb(dt, NCHW, REFLECT_101); \ | |||||
| cb(dt, NCHW, WRAP); \ | |||||
| cb(dt, NCHW, CONSTANT, NEAREST); \ | |||||
| cb(dt, NCHW, REPLICATE, NEAREST); \ | |||||
| cb(dt, NCHW, REFLECT, NEAREST); \ | |||||
| cb(dt, NCHW, REFLECT_101, NEAREST); \ | |||||
| cb(dt, NCHW, WRAP, NEAREST); \ | |||||
| cb(dt, NCHW, CONSTANT, LINEAR); \ | |||||
| cb(dt, NCHW, REPLICATE, LINEAR); \ | |||||
| cb(dt, NCHW, REFLECT, LINEAR); \ | |||||
| cb(dt, NCHW, REFLECT_101, LINEAR); \ | |||||
| cb(dt, NCHW, WRAP, LINEAR); \ | |||||
| megdnn_throw("unsupported border type in remap cuda"); \ | megdnn_throw("unsupported border type in remap cuda"); \ | ||||
| } | } | ||||
| @@ -53,7 +53,7 @@ struct GetSrcData<ctype, format, ::BorderMode::BORDER_CONSTANT> { | |||||
| }; | }; | ||||
| template <typename ctype, const uint32_t format, ::BorderMode bmode> | template <typename ctype, const uint32_t format, ::BorderMode bmode> | ||||
| __global__ void kern_general( | |||||
| __global__ void kern_general_linear( | |||||
| const ctype* src, const float* map_xy, const ctype* diff, | const ctype* src, const float* map_xy, const ctype* diff, | ||||
| float* __restrict grad, int C, int IH, int IW, int OH, int OW, float scalar) { | float* __restrict grad, int C, int IH, int IW, int OH, int OW, float scalar) { | ||||
| int ow = blockIdx.x * blockDim.x + threadIdx.x; | int ow = blockIdx.x * blockDim.x + threadIdx.x; | ||||
| @@ -62,7 +62,6 @@ __global__ void kern_general( | |||||
| diff += blockIdx.z * C * OH * OW; | diff += blockIdx.z * C * OH * OW; | ||||
| map_xy += blockIdx.z * 2 * OH * OW; | map_xy += blockIdx.z * 2 * OH * OW; | ||||
| grad += blockIdx.z * 2 * OH * OW; | grad += blockIdx.z * 2 * OH * OW; | ||||
| RoundingConverter<ctype> round_converter; | |||||
| if (ow < OW && oh < OH) { | if (ow < OW && oh < OH) { | ||||
| float index_col = map_xy[oh * OW * 2 + ow * 2 + 0]; | float index_col = map_xy[oh * OW * 2 + ow * 2 + 0]; | ||||
| @@ -86,23 +85,25 @@ __global__ void kern_general( | |||||
| int a11 = GetSrcData<ctype, format, bmode>::get_index( | int a11 = GetSrcData<ctype, format, bmode>::get_index( | ||||
| row + 1, col + 1, c, IH, IW, C); | row + 1, col + 1, c, IH, IW, C); | ||||
| dv -= ((a00 != -1) ? src[a00] : scalar) * (one - u); | |||||
| dv += ((a01 != -1) ? src[a01] : scalar) * (one - u); | |||||
| dv -= ((a10 != -1) ? src[a10] : scalar) * u; | |||||
| dv += ((a11 != -1) ? src[a11] : scalar) * u; | |||||
| dv -= ((a00 != -1) ? static_cast<float>(src[a00]) : scalar) * (one - u); | |||||
| dv += ((a01 != -1) ? static_cast<float>(src[a01]) : scalar) * (one - u); | |||||
| dv -= ((a10 != -1) ? static_cast<float>(src[a10]) : scalar) * u; | |||||
| dv += ((a11 != -1) ? static_cast<float>(src[a11]) : scalar) * u; | |||||
| du -= ((a00 != -1) ? src[a00] : scalar) * (one - v); | |||||
| du -= ((a01 != -1) ? src[a01] : scalar) * v; | |||||
| du += ((a10 != -1) ? src[a10] : scalar) * (one - v); | |||||
| du += ((a11 != -1) ? src[a11] : scalar) * v; | |||||
| du -= ((a00 != -1) ? static_cast<float>(src[a00]) : scalar) * (one - v); | |||||
| du -= ((a01 != -1) ? static_cast<float>(src[a01]) : scalar) * v; | |||||
| du += ((a10 != -1) ? static_cast<float>(src[a10]) : scalar) * (one - v); | |||||
| du += ((a11 != -1) ? static_cast<float>(src[a11]) : scalar) * v; | |||||
| grad[oh * OW * 2 + ow * 2 + 0] += round_converter(hidden * dv); | |||||
| grad[oh * OW * 2 + ow * 2 + 1] += round_converter(hidden * du); | |||||
| grad[oh * OW * 2 + ow * 2 + 0] += hidden * dv; | |||||
| grad[oh * OW * 2 + ow * 2 + 1] += hidden * du; | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| template <typename ctype, const uint32_t format, ::BorderMode bmode> | |||||
| template < | |||||
| typename ctype, const uint32_t format, ::BorderMode bmode, | |||||
| ::InterpolationMode imode> | |||||
| void dispatch_backwardmat( | void dispatch_backwardmat( | ||||
| const ctype* src, const float* map_xy, const ctype* diff, float* grad, int N, | const ctype* src, const float* map_xy, const ctype* diff, float* grad, int N, | ||||
| int C, int IH, int IW, int OH, int OW, float scalar, cudaStream_t stream) { | int C, int IH, int IW, int OH, int OW, float scalar, cudaStream_t stream) { | ||||
| @@ -115,8 +116,11 @@ void dispatch_backwardmat( | |||||
| cuda_check(cudaMemsetAsync( | cuda_check(cudaMemsetAsync( | ||||
| grad, 0, sizeof(float) * curr_batch_size * OH * OW * 2, stream)); | grad, 0, sizeof(float) * curr_batch_size * OH * OW * 2, stream)); | ||||
| kern_general<ctype, format, bmode><<<blocks, threads, 0, stream>>>( | |||||
| src, map_xy, diff, grad, C, IH, IW, OH, OW, scalar); | |||||
| if (imode == ::InterpolationMode::INTER_LINEAR) { | |||||
| kern_general_linear<ctype, format, bmode><<<blocks, threads, 0, stream>>>( | |||||
| src, map_xy, diff, grad, C, IH, IW, OH, OW, scalar); | |||||
| } | |||||
| N -= curr_batch_size; | N -= curr_batch_size; | ||||
| src += curr_batch_size * C * IH * IW; | src += curr_batch_size * C * IH * IW; | ||||
| @@ -132,27 +136,35 @@ namespace megdnn { | |||||
| namespace cuda { | namespace cuda { | ||||
| namespace remap { | namespace remap { | ||||
| template <typename ctype, const uint32_t format, ::BorderMode bmode> | |||||
| template < | |||||
| typename ctype, const uint32_t format, ::BorderMode bmode, | |||||
| ::InterpolationMode imode> | |||||
| void backwardmat_proxy( | void backwardmat_proxy( | ||||
| const ctype* src, const float* map_xy, const ctype* diff, float* grad, int N, | const ctype* src, const float* map_xy, const ctype* diff, float* grad, int N, | ||||
| int C, int IH, int IW, int OH, int OW, float scalar, cudaStream_t stream) { | int C, int IH, int IW, int OH, int OW, float scalar, cudaStream_t stream) { | ||||
| dispatch_backwardmat<ctype, format, bmode>( | |||||
| dispatch_backwardmat<ctype, format, bmode, imode>( | |||||
| src, map_xy, diff, grad, N, C, IH, IW, OH, OW, scalar, stream); | src, map_xy, diff, grad, N, C, IH, IW, OH, OW, scalar, stream); | ||||
| after_kernel_launch(); | after_kernel_launch(); | ||||
| } | } | ||||
| #define INST(ctype, format, bmode) \ | |||||
| template void \ | |||||
| backwardmat_proxy<ctype, param_enumv::Remap::Format::format, ::BorderMode::bmode>( \ | |||||
| #define INST(ctype, format, bmode, imode) \ | |||||
| template void backwardmat_proxy< \ | |||||
| ctype, param_enumv::Remap::Format::format, ::BorderMode::bmode, \ | |||||
| ::InterpolationMode::imode>( \ | |||||
| const ctype*, const float*, const ctype*, float*, int, int, int, int, int, \ | const ctype*, const float*, const ctype*, float*, int, int, int, int, int, \ | ||||
| int, float, cudaStream_t); | int, float, cudaStream_t); | ||||
| #define FOR_FORMAT_BMODE(ctype) \ | |||||
| INST(ctype, NCHW, BORDER_CONSTANT) \ | |||||
| INST(ctype, NCHW, BORDER_REPLICATE) \ | |||||
| INST(ctype, NCHW, BORDER_REFLECT) \ | |||||
| INST(ctype, NCHW, BORDER_REFLECT_101) \ | |||||
| INST(ctype, NCHW, BORDER_WRAP) | |||||
| #define FOR_FORMAT_BMODE(ctype) \ | |||||
| INST(ctype, NCHW, BORDER_CONSTANT, INTER_NEAREST) \ | |||||
| INST(ctype, NCHW, BORDER_REPLICATE, INTER_NEAREST) \ | |||||
| INST(ctype, NCHW, BORDER_REFLECT, INTER_NEAREST) \ | |||||
| INST(ctype, NCHW, BORDER_REFLECT_101, INTER_NEAREST) \ | |||||
| INST(ctype, NCHW, BORDER_WRAP, INTER_NEAREST) \ | |||||
| INST(ctype, NCHW, BORDER_CONSTANT, INTER_LINEAR) \ | |||||
| INST(ctype, NCHW, BORDER_REPLICATE, INTER_LINEAR) \ | |||||
| INST(ctype, NCHW, BORDER_REFLECT, INTER_LINEAR) \ | |||||
| INST(ctype, NCHW, BORDER_REFLECT_101, INTER_LINEAR) \ | |||||
| INST(ctype, NCHW, BORDER_WRAP, INTER_LINEAR) | |||||
| FOR_FORMAT_BMODE(float) | FOR_FORMAT_BMODE(float) | ||||
| DNN_INC_FLOAT16(FOR_FORMAT_BMODE(dt_bfloat16)) | DNN_INC_FLOAT16(FOR_FORMAT_BMODE(dt_bfloat16)) | ||||
| @@ -21,17 +21,23 @@ namespace remap { | |||||
| // all these kernels use LINEAR interpolation | // all these kernels use LINEAR interpolation | ||||
| template <typename ctype, const uint32_t format, ::BorderMode bmode> | |||||
| template < | |||||
| typename ctype, const uint32_t format, ::BorderMode bmode, | |||||
| ::InterpolationMode imode> | |||||
| void forward_proxy( | void forward_proxy( | ||||
| const ctype* src, const float* map_xy, ctype* dst, int N, int C, int IH, int IW, | const ctype* src, const float* map_xy, ctype* dst, int N, int C, int IH, int IW, | ||||
| int OH, int OW, float scalar, cudaStream_t stream); | int OH, int OW, float scalar, cudaStream_t stream); | ||||
| template <typename ctype, const uint32_t format, ::BorderMode bmode> | |||||
| template < | |||||
| typename ctype, const uint32_t format, ::BorderMode bmode, | |||||
| ::InterpolationMode imode> | |||||
| void backwarddata_proxy( | void backwarddata_proxy( | ||||
| ctype* grad, const float* map_xy, const ctype* diff, int N, int C, int IH, | ctype* grad, const float* map_xy, const ctype* diff, int N, int C, int IH, | ||||
| int IW, int OH, int OW, cudaStream_t stream); | int IW, int OH, int OW, cudaStream_t stream); | ||||
| template <typename ctype, const uint32_t format, ::BorderMode bmode> | |||||
| template < | |||||
| typename ctype, const uint32_t format, ::BorderMode bmode, | |||||
| ::InterpolationMode imode> | |||||
| void backwardmat_proxy( | void backwardmat_proxy( | ||||
| const ctype* src, const float* map_xy, const ctype* diff, float* grad, int N, | const ctype* src, const float* map_xy, const ctype* diff, float* grad, int N, | ||||
| int C, int IH, int IW, int OH, int OW, float scalar, cudaStream_t stream); | int C, int IH, int IW, int OH, int OW, float scalar, cudaStream_t stream); | ||||
| @@ -30,8 +30,9 @@ void RemapImpl::exec( | |||||
| OW = map_xy.layout.shape[2]; | OW = map_xy.layout.shape[2]; | ||||
| megdnn_assert( | megdnn_assert( | ||||
| param().imode == param::Remap::InterpolationMode::LINEAR, | |||||
| "only support LINEAR interpolationMode"); | |||||
| (param().imode == param::Remap::InterpolationMode::NEAREST) || | |||||
| (param().imode == param::Remap::InterpolationMode::LINEAR), | |||||
| "only support NEAREST and LINEAR interpolationMode"); | |||||
| if (param().format == param::Remap::Format::NCHW) { | if (param().format == param::Remap::Format::NCHW) { | ||||
| N = src.layout.shape[0]; | N = src.layout.shape[0]; | ||||
| @@ -47,13 +48,15 @@ void RemapImpl::exec( | |||||
| megdnn_throw("unsupported format, cuda remap"); | megdnn_throw("unsupported format, cuda remap"); | ||||
| } | } | ||||
| #define cb(dt, _format, bmode) \ | |||||
| #define cb(dt, _format, bmode, inter_mode) \ | |||||
| if (param().format == param::Remap::Format::_format && \ | if (param().format == param::Remap::Format::_format && \ | ||||
| param().border_type == param::Remap::BorderMode::bmode) { \ | |||||
| param().border_type == param::Remap::BorderMode::bmode && \ | |||||
| param().imode == param::Remap::InterpolationMode::inter_mode) { \ | |||||
| using ctype = DTypeTrait<dt>::ctype; \ | using ctype = DTypeTrait<dt>::ctype; \ | ||||
| remap::forward_proxy< \ | remap::forward_proxy< \ | ||||
| ctype, param_enumv::Remap::Format::_format, \ | ctype, param_enumv::Remap::Format::_format, \ | ||||
| ::BorderMode::BORDER_##bmode>( \ | |||||
| ::BorderMode::BORDER_##bmode, \ | |||||
| ::InterpolationMode::INTER_##inter_mode>( \ | |||||
| src.compatible_ptr<ctype>(), map_xy.compatible_ptr<dt_float32>(), \ | src.compatible_ptr<ctype>(), map_xy.compatible_ptr<dt_float32>(), \ | ||||
| dst.compatible_ptr<ctype>(), N, C, IH, IW, OH, OW, param().scalar, \ | dst.compatible_ptr<ctype>(), N, C, IH, IW, OH, OW, param().scalar, \ | ||||
| stream); \ | stream); \ | ||||
| @@ -62,16 +65,26 @@ void RemapImpl::exec( | |||||
| #define support_dtype(dt) \ | #define support_dtype(dt) \ | ||||
| case DTypeTrait<dt>::enumv: { \ | case DTypeTrait<dt>::enumv: { \ | ||||
| cb(dt, NCHW, CONSTANT); \ | |||||
| cb(dt, NCHW, REPLICATE); \ | |||||
| cb(dt, NCHW, REFLECT); \ | |||||
| cb(dt, NCHW, REFLECT_101); \ | |||||
| cb(dt, NCHW, WRAP); \ | |||||
| cb(dt, NHWC, CONSTANT); \ | |||||
| cb(dt, NHWC, REPLICATE); \ | |||||
| cb(dt, NHWC, REFLECT); \ | |||||
| cb(dt, NHWC, REFLECT_101); \ | |||||
| cb(dt, NHWC, WRAP); \ | |||||
| cb(dt, NCHW, CONSTANT, NEAREST); \ | |||||
| cb(dt, NCHW, REPLICATE, NEAREST); \ | |||||
| cb(dt, NCHW, REFLECT, NEAREST); \ | |||||
| cb(dt, NCHW, REFLECT_101, NEAREST); \ | |||||
| cb(dt, NCHW, WRAP, NEAREST); \ | |||||
| cb(dt, NHWC, CONSTANT, NEAREST); \ | |||||
| cb(dt, NHWC, REPLICATE, NEAREST); \ | |||||
| cb(dt, NHWC, REFLECT, NEAREST); \ | |||||
| cb(dt, NHWC, REFLECT_101, NEAREST); \ | |||||
| cb(dt, NHWC, WRAP, NEAREST); \ | |||||
| cb(dt, NCHW, CONSTANT, LINEAR); \ | |||||
| cb(dt, NCHW, REPLICATE, LINEAR); \ | |||||
| cb(dt, NCHW, REFLECT, LINEAR); \ | |||||
| cb(dt, NCHW, REFLECT_101, LINEAR); \ | |||||
| cb(dt, NCHW, WRAP, LINEAR); \ | |||||
| cb(dt, NHWC, CONSTANT, LINEAR); \ | |||||
| cb(dt, NHWC, REPLICATE, LINEAR); \ | |||||
| cb(dt, NHWC, REFLECT, LINEAR); \ | |||||
| cb(dt, NHWC, REFLECT_101, LINEAR); \ | |||||
| cb(dt, NHWC, WRAP, LINEAR); \ | |||||
| megdnn_throw("unsupported border type in remap cuda"); \ | megdnn_throw("unsupported border type in remap cuda"); \ | ||||
| } | } | ||||
| @@ -62,8 +62,23 @@ struct GetSrcData<ctype, format, ::BorderMode::BORDER_CONSTANT> { | |||||
| } | } | ||||
| }; | }; | ||||
| template <typename ctype, ::BorderMode bmode> | |||||
| __global__ void kern_general( | |||||
| __device__ inline float round_half_to_even(float f) { | |||||
| const float round_away_from_zero = round(f); | |||||
| const float diff = round_away_from_zero - f; | |||||
| if ((diff != 0.5f) && (diff != -0.5f)) { | |||||
| return round_away_from_zero; | |||||
| } | |||||
| if (fmod(round_away_from_zero, 2.0f) == 0.0f) { | |||||
| return round_away_from_zero; | |||||
| } | |||||
| return f - diff; | |||||
| } | |||||
| template <typename ctype, const uint32_t format, ::BorderMode bmode> | |||||
| __global__ void kern_general_nearest( | |||||
| const ctype* __restrict sptr, const float* map_xy, ctype* __restrict dst, int C, | const ctype* __restrict sptr, const float* map_xy, ctype* __restrict dst, int C, | ||||
| int IH, int IW, int OH, int OW, float scalar) { | int IH, int IW, int OH, int OW, float scalar) { | ||||
| int ow = blockIdx.x * blockDim.x + threadIdx.x; | int ow = blockIdx.x * blockDim.x + threadIdx.x; | ||||
| @@ -71,37 +86,22 @@ __global__ void kern_general( | |||||
| sptr += blockIdx.z * C * IH * IW; | sptr += blockIdx.z * C * IH * IW; | ||||
| dst += blockIdx.z * C * OH * OW; | dst += blockIdx.z * C * OH * OW; | ||||
| map_xy += blockIdx.z * 2 * OH * OW; | map_xy += blockIdx.z * 2 * OH * OW; | ||||
| RoundingConverter<ctype> round_converter; | |||||
| if (ow < OW && oh < OH) { | if (ow < OW && oh < OH) { | ||||
| float index_col = map_xy[oh * OW * 2 + ow * 2 + 0]; | float index_col = map_xy[oh * OW * 2 + ow * 2 + 0]; | ||||
| float index_row = map_xy[oh * OW * 2 + ow * 2 + 1]; | float index_row = map_xy[oh * OW * 2 + ow * 2 + 1]; | ||||
| int col = static_cast<int>(floor(index_col)); | |||||
| int row = static_cast<int>(floor(index_row)); | |||||
| float v = index_col - col; | |||||
| float u = index_row - row; | |||||
| int col = static_cast<int>(round_half_to_even(index_col)); | |||||
| int row = static_cast<int>(round_half_to_even(index_row)); | |||||
| for (int c = 0; c < C; ++c) { | for (int c = 0; c < C; ++c) { | ||||
| ctype a00 = GetSrcData<ctype, param_enumv::Remap::Format::NCHW, bmode>::get( | |||||
| sptr, row + 0, col + 0, c, IH, IW, C, scalar); | |||||
| ctype a01 = GetSrcData<ctype, param_enumv::Remap::Format::NCHW, bmode>::get( | |||||
| sptr, row + 0, col + 1, c, IH, IW, C, scalar); | |||||
| ctype a10 = GetSrcData<ctype, param_enumv::Remap::Format::NCHW, bmode>::get( | |||||
| sptr, row + 1, col + 0, c, IH, IW, C, scalar); | |||||
| ctype a11 = GetSrcData<ctype, param_enumv::Remap::Format::NCHW, bmode>::get( | |||||
| sptr, row + 1, col + 1, c, IH, IW, C, scalar); | |||||
| /* in remap, we use float as the type of intermediate result */ | |||||
| float result = static_cast<float>(a00) * (1.f - u) * (1.f - v) + | |||||
| static_cast<float>(a01) * (1.f - u) * v + | |||||
| static_cast<float>(a10) * (1.f - v) * u + | |||||
| static_cast<float>(a11) * u * v; | |||||
| dst[get_offset<param_enumv::Remap::Format::NCHW>(oh, ow, c, OH, OW, C)] = | |||||
| round_converter(result); | |||||
| dst[get_offset<format>(oh, ow, c, OH, OW, C)] = | |||||
| GetSrcData<ctype, format, bmode>::get( | |||||
| sptr, row, col, c, IH, IW, C, scalar); | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| template <typename ctype, ::BorderMode bmode> | |||||
| __global__ void kern_general_nhwc( | |||||
| template <typename ctype, const uint32_t format, ::BorderMode bmode> | |||||
| __global__ void kern_general_linear( | |||||
| const ctype* __restrict sptr, const float* map_xy, ctype* __restrict dst, int C, | const ctype* __restrict sptr, const float* map_xy, ctype* __restrict dst, int C, | ||||
| int IH, int IW, int OH, int OW, float scalar) { | int IH, int IW, int OH, int OW, float scalar) { | ||||
| int ow = blockIdx.x * blockDim.x + threadIdx.x; | int ow = blockIdx.x * blockDim.x + threadIdx.x; | ||||
| @@ -119,26 +119,27 @@ __global__ void kern_general_nhwc( | |||||
| float v = index_col - col; | float v = index_col - col; | ||||
| float u = index_row - row; | float u = index_row - row; | ||||
| for (int c = 0; c < C; ++c) { | for (int c = 0; c < C; ++c) { | ||||
| ctype a00 = GetSrcData<ctype, param_enumv::Remap::Format::NHWC, bmode>::get( | |||||
| ctype a00 = GetSrcData<ctype, format, bmode>::get( | |||||
| sptr, row + 0, col + 0, c, IH, IW, C, scalar); | sptr, row + 0, col + 0, c, IH, IW, C, scalar); | ||||
| ctype a01 = GetSrcData<ctype, param_enumv::Remap::Format::NHWC, bmode>::get( | |||||
| ctype a01 = GetSrcData<ctype, format, bmode>::get( | |||||
| sptr, row + 0, col + 1, c, IH, IW, C, scalar); | sptr, row + 0, col + 1, c, IH, IW, C, scalar); | ||||
| ctype a10 = GetSrcData<ctype, param_enumv::Remap::Format::NHWC, bmode>::get( | |||||
| ctype a10 = GetSrcData<ctype, format, bmode>::get( | |||||
| sptr, row + 1, col + 0, c, IH, IW, C, scalar); | sptr, row + 1, col + 0, c, IH, IW, C, scalar); | ||||
| ctype a11 = GetSrcData<ctype, param_enumv::Remap::Format::NHWC, bmode>::get( | |||||
| ctype a11 = GetSrcData<ctype, format, bmode>::get( | |||||
| sptr, row + 1, col + 1, c, IH, IW, C, scalar); | sptr, row + 1, col + 1, c, IH, IW, C, scalar); | ||||
| /* in remap, we use float as the type of intermediate result */ | /* in remap, we use float as the type of intermediate result */ | ||||
| float result = static_cast<float>(a00) * (1.f - u) * (1.f - v) + | float result = static_cast<float>(a00) * (1.f - u) * (1.f - v) + | ||||
| static_cast<float>(a01) * (1.f - u) * v + | static_cast<float>(a01) * (1.f - u) * v + | ||||
| static_cast<float>(a10) * (1.f - v) * u + | static_cast<float>(a10) * (1.f - v) * u + | ||||
| static_cast<float>(a11) * u * v; | static_cast<float>(a11) * u * v; | ||||
| dst[get_offset<param_enumv::Remap::Format::NHWC>(oh, ow, c, OH, OW, C)] = | |||||
| round_converter(result); | |||||
| dst[get_offset<format>(oh, ow, c, OH, OW, C)] = round_converter(result); | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| template <typename ctype, const uint32_t format, ::BorderMode bmode> | |||||
| template < | |||||
| typename ctype, const uint32_t format, ::BorderMode bmode, | |||||
| ::InterpolationMode imode> | |||||
| void dispatch_forward( | void dispatch_forward( | ||||
| const ctype* src, const float* map_xy, ctype* dst, int N, int C, int IH, int IW, | const ctype* src, const float* map_xy, ctype* dst, int N, int C, int IH, int IW, | ||||
| int OH, int OW, float scalar, cudaStream_t stream) { | int OH, int OW, float scalar, cudaStream_t stream) { | ||||
| @@ -150,11 +151,11 @@ void dispatch_forward( | |||||
| dim3 threads(BX, BY); | dim3 threads(BX, BY); | ||||
| dim3 blocks((OW + BX - 1) / BX, (OH + BY - 1) / BY, curr_batch_size); | dim3 blocks((OW + BX - 1) / BX, (OH + BY - 1) / BY, curr_batch_size); | ||||
| if (format == param_enumv::Remap::Format::NCHW) { | |||||
| kern_general<ctype, bmode><<<blocks, threads, 0, stream>>>( | |||||
| if (imode == ::InterpolationMode::INTER_NEAREST) { | |||||
| kern_general_nearest<ctype, format, bmode><<<blocks, threads, 0, stream>>>( | |||||
| src, map_xy, dst, C, IH, IW, OH, OW, scalar); | src, map_xy, dst, C, IH, IW, OH, OW, scalar); | ||||
| } else if (format == param_enumv::Remap::Format::NHWC) { | |||||
| kern_general_nhwc<ctype, bmode><<<blocks, threads, 0, stream>>>( | |||||
| } else if (imode == ::InterpolationMode::INTER_LINEAR) { | |||||
| kern_general_linear<ctype, format, bmode><<<blocks, threads, 0, stream>>>( | |||||
| src, map_xy, dst, C, IH, IW, OH, OW, scalar); | src, map_xy, dst, C, IH, IW, OH, OW, scalar); | ||||
| } | } | ||||
| @@ -171,32 +172,45 @@ namespace megdnn { | |||||
| namespace cuda { | namespace cuda { | ||||
| namespace remap { | namespace remap { | ||||
| template <typename ctype, const uint32_t format, ::BorderMode bmode> | |||||
| template < | |||||
| typename ctype, const uint32_t format, ::BorderMode bmode, | |||||
| ::InterpolationMode imode> | |||||
| void forward_proxy( | void forward_proxy( | ||||
| const ctype* src, const float* map_xy, ctype* dst, int N, int C, int IH, int IW, | const ctype* src, const float* map_xy, ctype* dst, int N, int C, int IH, int IW, | ||||
| int OH, int OW, float scalar, cudaStream_t stream) { | int OH, int OW, float scalar, cudaStream_t stream) { | ||||
| dispatch_forward<ctype, format, bmode>( | |||||
| dispatch_forward<ctype, format, bmode, imode>( | |||||
| src, map_xy, dst, N, C, IH, IW, OH, OW, scalar, stream); | src, map_xy, dst, N, C, IH, IW, OH, OW, scalar, stream); | ||||
| after_kernel_launch(); | after_kernel_launch(); | ||||
| } | } | ||||
| #define INST(ctype, format, bmode) \ | |||||
| template void \ | |||||
| forward_proxy<ctype, param_enumv::Remap::Format::format, ::BorderMode::bmode>( \ | |||||
| #define INST(ctype, format, bmode, imode) \ | |||||
| template void forward_proxy< \ | |||||
| ctype, param_enumv::Remap::Format::format, ::BorderMode::bmode, \ | |||||
| ::InterpolationMode::imode>( \ | |||||
| const ctype*, const float*, ctype*, int, int, int, int, int, int, float, \ | const ctype*, const float*, ctype*, int, int, int, int, int, int, float, \ | ||||
| cudaStream_t); | cudaStream_t); | ||||
| #define FOR_FORMAT_BMODE(ctype) \ | |||||
| INST(ctype, NCHW, BORDER_CONSTANT) \ | |||||
| INST(ctype, NCHW, BORDER_REPLICATE) \ | |||||
| INST(ctype, NCHW, BORDER_REFLECT) \ | |||||
| INST(ctype, NCHW, BORDER_REFLECT_101) \ | |||||
| INST(ctype, NCHW, BORDER_WRAP) \ | |||||
| INST(ctype, NHWC, BORDER_CONSTANT) \ | |||||
| INST(ctype, NHWC, BORDER_REPLICATE) \ | |||||
| INST(ctype, NHWC, BORDER_REFLECT) \ | |||||
| INST(ctype, NHWC, BORDER_REFLECT_101) \ | |||||
| INST(ctype, NHWC, BORDER_WRAP) | |||||
| #define FOR_FORMAT_BMODE(ctype) \ | |||||
| INST(ctype, NCHW, BORDER_CONSTANT, INTER_NEAREST) \ | |||||
| INST(ctype, NCHW, BORDER_REPLICATE, INTER_NEAREST) \ | |||||
| INST(ctype, NCHW, BORDER_REFLECT, INTER_NEAREST) \ | |||||
| INST(ctype, NCHW, BORDER_REFLECT_101, INTER_NEAREST) \ | |||||
| INST(ctype, NCHW, BORDER_WRAP, INTER_NEAREST) \ | |||||
| INST(ctype, NHWC, BORDER_CONSTANT, INTER_NEAREST) \ | |||||
| INST(ctype, NHWC, BORDER_REPLICATE, INTER_NEAREST) \ | |||||
| INST(ctype, NHWC, BORDER_REFLECT, INTER_NEAREST) \ | |||||
| INST(ctype, NHWC, BORDER_REFLECT_101, INTER_NEAREST) \ | |||||
| INST(ctype, NHWC, BORDER_WRAP, INTER_NEAREST) \ | |||||
| INST(ctype, NCHW, BORDER_CONSTANT, INTER_LINEAR) \ | |||||
| INST(ctype, NCHW, BORDER_REPLICATE, INTER_LINEAR) \ | |||||
| INST(ctype, NCHW, BORDER_REFLECT, INTER_LINEAR) \ | |||||
| INST(ctype, NCHW, BORDER_REFLECT_101, INTER_LINEAR) \ | |||||
| INST(ctype, NCHW, BORDER_WRAP, INTER_LINEAR) \ | |||||
| INST(ctype, NHWC, BORDER_CONSTANT, INTER_LINEAR) \ | |||||
| INST(ctype, NHWC, BORDER_REPLICATE, INTER_LINEAR) \ | |||||
| INST(ctype, NHWC, BORDER_REFLECT, INTER_LINEAR) \ | |||||
| INST(ctype, NHWC, BORDER_REFLECT_101, INTER_LINEAR) \ | |||||
| INST(ctype, NHWC, BORDER_WRAP, INTER_LINEAR) | |||||
| FOR_FORMAT_BMODE(float) | FOR_FORMAT_BMODE(float) | ||||
| DNN_INC_FLOAT16(FOR_FORMAT_BMODE(dt_float16)) | DNN_INC_FLOAT16(FOR_FORMAT_BMODE(dt_float16)) | ||||
| @@ -36,6 +36,12 @@ inline int get_offset<param::Remap::Format::NHWC>( | |||||
| return height * w * c + width * c + channel; | return height * w * c + width * c + channel; | ||||
| } | } | ||||
| template <> | |||||
| inline int get_offset<param::Remap::Format::NHWCD4>( | |||||
| int height, int width, int channel, int, int w, int c) { | |||||
| return ((height * c + channel) * w + width) * 4; | |||||
| } | |||||
| template < | template < | ||||
| typename ctype, param::Remap::Format format, | typename ctype, param::Remap::Format format, | ||||
| param::Remap::BorderMode bordertype> | param::Remap::BorderMode bordertype> | ||||
| @@ -80,8 +86,12 @@ void remap_LINEAR( | |||||
| const ctype* src, const float* map_xy, ctype* dst, int N, int C, int IH, int IW, | const ctype* src, const float* map_xy, ctype* dst, int N, int C, int IH, int IW, | ||||
| int OH, int OW, float scalar) { | int OH, int OW, float scalar) { | ||||
| RoundingConverter<ctype> round_converter; | RoundingConverter<ctype> round_converter; | ||||
| for (int n = 0; n < N; | |||||
| ++n, src += C * IH * IW, dst += C * OH * OW, map_xy += OH * OW * 2) { | |||||
| size_t c_scale = 1; | |||||
| if (format == param::Remap::Format::NHWCD4) { | |||||
| c_scale = 4; | |||||
| } | |||||
| for (int n = 0; n < N; ++n, src += c_scale * C * IH * IW, | |||||
| dst += c_scale * C * OH * OW, map_xy += OH * OW * 2) { | |||||
| for (int h = 0; h < OH; ++h) { | for (int h = 0; h < OH; ++h) { | ||||
| for (int w = 0; w < OW; ++w) { | for (int w = 0; w < OW; ++w) { | ||||
| float index_col = map_xy[h * OW * 2 + w * 2 + 0]; | float index_col = map_xy[h * OW * 2 + w * 2 + 0]; | ||||
| @@ -92,18 +102,102 @@ void remap_LINEAR( | |||||
| float u = index_row - row; // alphah | float u = index_row - row; // alphah | ||||
| const float one = 1.f; | const float one = 1.f; | ||||
| for (int c = 0; c < C; ++c) { | for (int c = 0; c < C; ++c) { | ||||
| ctype a00 = GetSrcData<ctype, format, bordertype>::get( | |||||
| src, row + 0, col + 0, c, IH, IW, C, scalar); | |||||
| ctype a01 = GetSrcData<ctype, format, bordertype>::get( | |||||
| src, row + 0, col + 1, c, IH, IW, C, scalar); | |||||
| ctype a10 = GetSrcData<ctype, format, bordertype>::get( | |||||
| src, row + 1, col + 0, c, IH, IW, C, scalar); | |||||
| ctype a11 = GetSrcData<ctype, format, bordertype>::get( | |||||
| src, row + 1, col + 1, c, IH, IW, C, scalar); | |||||
| dst[get_offset<format>(h, w, c, OH, OW, C)] = round_converter( | |||||
| a00 * (one - v) * (one - u) + a01 * (one - u) * v + | |||||
| a10 * (one - v) * u + a11 * u * v); | |||||
| if (format == param::Remap::Format::NHWCD4) { | |||||
| int idx00 = GetSrcData<ctype, format, bordertype>::get_index( | |||||
| row + 0, col + 0, c, IH, IW, C); | |||||
| int idx01 = GetSrcData<ctype, format, bordertype>::get_index( | |||||
| row + 0, col + 1, c, IH, IW, C); | |||||
| int idx10 = GetSrcData<ctype, format, bordertype>::get_index( | |||||
| row + 1, col + 0, c, IH, IW, C); | |||||
| int idx11 = GetSrcData<ctype, format, bordertype>::get_index( | |||||
| row + 1, col + 1, c, IH, IW, C); | |||||
| for (int c_inner = 0; c_inner < 4; ++c_inner) { | |||||
| ctype a00 = (idx00 != -1) ? src[idx00 + c_inner] | |||||
| : round_converter(scalar); | |||||
| ctype a01 = (idx01 != -1) ? src[idx01 + c_inner] | |||||
| : round_converter(scalar); | |||||
| ctype a10 = (idx10 != -1) ? src[idx10 + c_inner] | |||||
| : round_converter(scalar); | |||||
| ctype a11 = (idx11 != -1) ? src[idx11 + c_inner] | |||||
| : round_converter(scalar); | |||||
| dst[get_offset<format>(h, w, c, OH, OW, C) + c_inner] = | |||||
| round_converter( | |||||
| a00 * (one - v) * (one - u) + | |||||
| a01 * (one - u) * v + a10 * (one - v) * u + | |||||
| a11 * u * v); | |||||
| } | |||||
| } else { | |||||
| ctype a00 = GetSrcData<ctype, format, bordertype>::get( | |||||
| src, row + 0, col + 0, c, IH, IW, C, scalar); | |||||
| ctype a01 = GetSrcData<ctype, format, bordertype>::get( | |||||
| src, row + 0, col + 1, c, IH, IW, C, scalar); | |||||
| ctype a10 = GetSrcData<ctype, format, bordertype>::get( | |||||
| src, row + 1, col + 0, c, IH, IW, C, scalar); | |||||
| ctype a11 = GetSrcData<ctype, format, bordertype>::get( | |||||
| src, row + 1, col + 1, c, IH, IW, C, scalar); | |||||
| dst[get_offset<format>(h, w, c, OH, OW, C)] = round_converter( | |||||
| a00 * (one - v) * (one - u) + a01 * (one - u) * v + | |||||
| a10 * (one - v) * u + a11 * u * v); | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| namespace { | |||||
| inline float round_half_to_even(float f) { | |||||
| const float round_away_from_zero = std::round(f); | |||||
| const float diff = round_away_from_zero - f; | |||||
| if ((diff != 0.5f) && (diff != -0.5f)) { | |||||
| return round_away_from_zero; | |||||
| } | |||||
| if (std::fmod(round_away_from_zero, 2.0f) == 0.0f) { | |||||
| return round_away_from_zero; | |||||
| } | |||||
| return f - diff; | |||||
| } | |||||
| } // anonymous namespace | |||||
| template < | |||||
| typename ctype, param::Remap::Format format, | |||||
| param::Remap::BorderMode bordertype> | |||||
| void remap_NEAREST( | |||||
| const ctype* src, const float* map_xy, ctype* dst, int N, int C, int IH, int IW, | |||||
| int OH, int OW, float scalar) { | |||||
| RoundingConverter<ctype> round_converter; | |||||
| size_t c_scale = 1; | |||||
| if (format == param::Remap::Format::NHWCD4) { | |||||
| c_scale = 4; | |||||
| } | |||||
| for (int n = 0; n < N; ++n, src += c_scale * C * IH * IW, | |||||
| dst += c_scale * C * OH * OW, map_xy += OH * OW * 2) { | |||||
| for (int h = 0; h < OH; ++h) { | |||||
| for (int w = 0; w < OW; ++w) { | |||||
| float index_col = map_xy[h * OW * 2 + w * 2 + 0]; | |||||
| float index_row = map_xy[h * OW * 2 + w * 2 + 1]; | |||||
| int col = static_cast<int>(round_half_to_even(index_col)); | |||||
| int row = static_cast<int>(round_half_to_even(index_row)); | |||||
| for (int c = 0; c < C; ++c) { | |||||
| if (format == param::Remap::Format::NHWCD4) { | |||||
| int idx = GetSrcData<ctype, format, bordertype>::get_index( | |||||
| row, col, c, IH, IW, C); | |||||
| for (int c_inner = 0; c_inner < 4; ++c_inner) { | |||||
| dst[get_offset<format>(h, w, c, OH, OW, C) + c_inner] = | |||||
| (idx != -1) ? (src[idx + c_inner]) | |||||
| : round_converter(scalar); | |||||
| } | |||||
| } else { | |||||
| dst[get_offset<format>(h, w, c, OH, OW, C)] = | |||||
| GetSrcData<ctype, format, bordertype>::get( | |||||
| src, row, col, c, IH, IW, C, scalar); | |||||
| } | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| @@ -161,13 +255,40 @@ void remap_LINEAR_backwarddata( | |||||
| } | } | ||||
| } | } | ||||
| template < | |||||
| typename ctype, param::Remap::Format format, | |||||
| param::Remap::BorderMode bordertype> | |||||
| void remap_NEAREST_backwarddata( | |||||
| ctype* grad, const float* map_xy, const ctype* diff, int N, int C, int IH, | |||||
| int IW, int OH, int OW) { | |||||
| std::memset(grad, 0, sizeof(ctype) * N * C * IH * IW); | |||||
| for (int n = 0; n < N; | |||||
| ++n, grad += C * IH * IW, diff += C * OH * OW, map_xy += OH * OW * 2) { | |||||
| for (int h = 0; h < OH; ++h) { | |||||
| for (int w = 0; w < OW; ++w) { | |||||
| float index_col = map_xy[h * OW * 2 + w * 2 + 0]; | |||||
| float index_row = map_xy[h * OW * 2 + w * 2 + 1]; | |||||
| int col = static_cast<int>(round_half_to_even(index_col)); | |||||
| int row = static_cast<int>(round_half_to_even(index_row)); | |||||
| for (int c = 0; c < C; ++c) { | |||||
| ctype hidden = diff[get_offset<format>(h, w, c, OH, OW, C)]; | |||||
| int idx = GetSrcData<ctype, format, bordertype>::get_index( | |||||
| row, col, c, IH, IW, C); | |||||
| if (idx != -1) { | |||||
| grad[idx] += hidden; | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| template < | template < | ||||
| typename ctype, param::Remap::Format format, | typename ctype, param::Remap::Format format, | ||||
| param::Remap::BorderMode bordertype> | param::Remap::BorderMode bordertype> | ||||
| void remap_LINEAR_backwardmat( | void remap_LINEAR_backwardmat( | ||||
| const ctype* src, const float* map_xy, const ctype* diff, float* grad, int N, | const ctype* src, const float* map_xy, const ctype* diff, float* grad, int N, | ||||
| int C, int IH, int IW, int OH, int OW, float scalar) { | int C, int IH, int IW, int OH, int OW, float scalar) { | ||||
| RoundingConverter<ctype> round_converter; | |||||
| std::memset(grad, 0, sizeof(float) * N * 2 * OH * OW); | std::memset(grad, 0, sizeof(float) * N * 2 * OH * OW); | ||||
| for (int n = 0; n < N; ++n, src += C * IH * IW, diff += C * OH * OW, | for (int n = 0; n < N; ++n, src += C * IH * IW, diff += C * OH * OW, | ||||
| map_xy += OH * OW * 2, grad += OH * OW * 2) { | map_xy += OH * OW * 2, grad += OH * OW * 2) { | ||||
| @@ -194,24 +315,38 @@ void remap_LINEAR_backwardmat( | |||||
| int a11 = GetSrcData<ctype, format, bordertype>::get_index( | int a11 = GetSrcData<ctype, format, bordertype>::get_index( | ||||
| row + 1, col + 1, c, IH, IW, C); | row + 1, col + 1, c, IH, IW, C); | ||||
| dv -= ((a00 != -1) ? src[a00] : scalar) * (one - u); | |||||
| dv += ((a01 != -1) ? src[a01] : scalar) * (one - u); | |||||
| dv -= ((a10 != -1) ? src[a10] : scalar) * u; | |||||
| dv += ((a11 != -1) ? src[a11] : scalar) * u; | |||||
| du -= ((a00 != -1) ? src[a00] : scalar) * (one - v); | |||||
| du -= ((a01 != -1) ? src[a01] : scalar) * v; | |||||
| du += ((a10 != -1) ? src[a10] : scalar) * (one - v); | |||||
| du += ((a11 != -1) ? src[a11] : scalar) * v; | |||||
| grad[h * OW * 2 + w * 2 + 0] += round_converter(hidden * dv); | |||||
| grad[h * OW * 2 + w * 2 + 1] += round_converter(hidden * du); | |||||
| dv -= ((a00 != -1) ? static_cast<float>(src[a00]) : scalar) * | |||||
| (one - u); | |||||
| dv += ((a01 != -1) ? static_cast<float>(src[a01]) : scalar) * | |||||
| (one - u); | |||||
| dv -= ((a10 != -1) ? static_cast<float>(src[a10]) : scalar) * u; | |||||
| dv += ((a11 != -1) ? static_cast<float>(src[a11]) : scalar) * u; | |||||
| du -= ((a00 != -1) ? static_cast<float>(src[a00]) : scalar) * | |||||
| (one - v); | |||||
| du -= ((a01 != -1) ? static_cast<float>(src[a01]) : scalar) * v; | |||||
| du += ((a10 != -1) ? static_cast<float>(src[a10]) : scalar) * | |||||
| (one - v); | |||||
| du += ((a11 != -1) ? static_cast<float>(src[a11]) : scalar) * v; | |||||
| grad[h * OW * 2 + w * 2 + 0] += hidden * dv; | |||||
| grad[h * OW * 2 + w * 2 + 1] += hidden * du; | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| template < | |||||
| typename ctype, param::Remap::Format format, | |||||
| param::Remap::BorderMode bordertype> | |||||
| void remap_NEAREST_backwardmat( | |||||
| const ctype*, const float*, const ctype*, float* grad, int N, int, int, int, | |||||
| int OH, int OW, float) { | |||||
| std::memset(grad, 0, sizeof(float) * N * 2 * OH * OW); | |||||
| return; | |||||
| } | |||||
| } // namespace | } // namespace | ||||
| void RemapImpl::exec( | void RemapImpl::exec( | ||||
| @@ -229,6 +364,11 @@ void RemapImpl::exec( | |||||
| C = src.layout.shape[3]; | C = src.layout.shape[3]; | ||||
| IH = src.layout.shape[1]; | IH = src.layout.shape[1]; | ||||
| IW = src.layout.shape[2]; | IW = src.layout.shape[2]; | ||||
| } else if (param().format == param::Remap::Format::NHWCD4) { | |||||
| N = src.layout.shape[0]; | |||||
| C = src.layout.shape[2]; | |||||
| IH = src.layout.shape[1]; | |||||
| IW = src.layout.shape[3]; | |||||
| } else { | } else { | ||||
| megdnn_throw("unsupported format"); | megdnn_throw("unsupported format"); | ||||
| } | } | ||||
| @@ -255,11 +395,31 @@ void RemapImpl::exec( | |||||
| cb(dt, NCHW, REFLECT, LINEAR); \ | cb(dt, NCHW, REFLECT, LINEAR); \ | ||||
| cb(dt, NCHW, REFLECT_101, LINEAR); \ | cb(dt, NCHW, REFLECT_101, LINEAR); \ | ||||
| cb(dt, NCHW, WRAP, LINEAR); \ | cb(dt, NCHW, WRAP, LINEAR); \ | ||||
| cb(dt, NHWCD4, CONSTANT, LINEAR); \ | |||||
| cb(dt, NHWCD4, REPLICATE, LINEAR); \ | |||||
| cb(dt, NHWCD4, REFLECT, LINEAR); \ | |||||
| cb(dt, NHWCD4, REFLECT_101, LINEAR); \ | |||||
| cb(dt, NHWCD4, WRAP, LINEAR); \ | |||||
| cb(dt, NHWC, CONSTANT, LINEAR); \ | cb(dt, NHWC, CONSTANT, LINEAR); \ | ||||
| cb(dt, NHWC, REPLICATE, LINEAR); \ | cb(dt, NHWC, REPLICATE, LINEAR); \ | ||||
| cb(dt, NHWC, REFLECT, LINEAR); \ | cb(dt, NHWC, REFLECT, LINEAR); \ | ||||
| cb(dt, NHWC, REFLECT_101, LINEAR); \ | cb(dt, NHWC, REFLECT_101, LINEAR); \ | ||||
| cb(dt, NHWC, WRAP, LINEAR); \ | cb(dt, NHWC, WRAP, LINEAR); \ | ||||
| cb(dt, NCHW, CONSTANT, NEAREST); \ | |||||
| cb(dt, NCHW, REPLICATE, NEAREST); \ | |||||
| cb(dt, NCHW, REFLECT, NEAREST); \ | |||||
| cb(dt, NCHW, REFLECT_101, NEAREST); \ | |||||
| cb(dt, NCHW, WRAP, NEAREST); \ | |||||
| cb(dt, NHWCD4, CONSTANT, NEAREST); \ | |||||
| cb(dt, NHWCD4, REPLICATE, NEAREST); \ | |||||
| cb(dt, NHWCD4, REFLECT, NEAREST); \ | |||||
| cb(dt, NHWCD4, REFLECT_101, NEAREST); \ | |||||
| cb(dt, NHWCD4, WRAP, NEAREST); \ | |||||
| cb(dt, NHWC, CONSTANT, NEAREST); \ | |||||
| cb(dt, NHWC, REPLICATE, NEAREST); \ | |||||
| cb(dt, NHWC, REFLECT, NEAREST); \ | |||||
| cb(dt, NHWC, REFLECT_101, NEAREST); \ | |||||
| cb(dt, NHWC, WRAP, NEAREST); \ | |||||
| megdnn_throw( \ | megdnn_throw( \ | ||||
| "format, border type or imode is incorrect in remap navie " \ | "format, border type or imode is incorrect in remap navie " \ | ||||
| "with dtype = " #dt); \ | "with dtype = " #dt); \ | ||||
| @@ -313,6 +473,11 @@ void RemapBackwardDataImpl::exec( | |||||
| cb(dt, NCHW, REFLECT, LINEAR); \ | cb(dt, NCHW, REFLECT, LINEAR); \ | ||||
| cb(dt, NCHW, REFLECT_101, LINEAR); \ | cb(dt, NCHW, REFLECT_101, LINEAR); \ | ||||
| cb(dt, NCHW, WRAP, LINEAR); \ | cb(dt, NCHW, WRAP, LINEAR); \ | ||||
| cb(dt, NCHW, CONSTANT, NEAREST); \ | |||||
| cb(dt, NCHW, REPLICATE, NEAREST); \ | |||||
| cb(dt, NCHW, REFLECT, NEAREST); \ | |||||
| cb(dt, NCHW, REFLECT_101, NEAREST); \ | |||||
| cb(dt, NCHW, WRAP, NEAREST); \ | |||||
| megdnn_throw( \ | megdnn_throw( \ | ||||
| "format, border type or imode is incorrect in remap navie " \ | "format, border type or imode is incorrect in remap navie " \ | ||||
| "with dtype = " #dt); \ | "with dtype = " #dt); \ | ||||
| @@ -365,6 +530,11 @@ void RemapBackwardMatImpl::exec( | |||||
| cb(dt, NCHW, REFLECT, LINEAR); \ | cb(dt, NCHW, REFLECT, LINEAR); \ | ||||
| cb(dt, NCHW, REFLECT_101, LINEAR); \ | cb(dt, NCHW, REFLECT_101, LINEAR); \ | ||||
| cb(dt, NCHW, WRAP, LINEAR); \ | cb(dt, NCHW, WRAP, LINEAR); \ | ||||
| cb(dt, NCHW, CONSTANT, NEAREST); \ | |||||
| cb(dt, NCHW, REPLICATE, NEAREST); \ | |||||
| cb(dt, NCHW, REFLECT, NEAREST); \ | |||||
| cb(dt, NCHW, REFLECT_101, NEAREST); \ | |||||
| cb(dt, NCHW, WRAP, NEAREST); \ | |||||
| megdnn_throw( \ | megdnn_throw( \ | ||||
| "format, border type or imode is incorrect in remap navie " \ | "format, border type or imode is incorrect in remap navie " \ | ||||
| "with dtype = " #dt); \ | "with dtype = " #dt); \ | ||||
| @@ -34,53 +34,91 @@ static inline std::vector<TestArg> get_nchw_args() { | |||||
| param::Remap param; | param::Remap param; | ||||
| std::vector<param::Remap::Format> format_vec = {param::Remap::Format::NCHW}; | std::vector<param::Remap::Format> format_vec = {param::Remap::Format::NCHW}; | ||||
| std::vector<param::Remap::InterpolationMode> interp_mode_vec = { | |||||
| param::Remap::InterpolationMode::NEAREST, | |||||
| param::Remap::InterpolationMode::LINEAR}; | |||||
| std::vector<param::Remap::BorderMode> border_mode_vec = { | std::vector<param::Remap::BorderMode> border_mode_vec = { | ||||
| param::Remap::BorderMode::CONSTANT, param::Remap::BorderMode::REFLECT_101, | param::Remap::BorderMode::CONSTANT, param::Remap::BorderMode::REFLECT_101, | ||||
| param::Remap::BorderMode::REFLECT, param::Remap::BorderMode::WRAP, | param::Remap::BorderMode::REFLECT, param::Remap::BorderMode::WRAP, | ||||
| param::Remap::BorderMode::REPLICATE}; | param::Remap::BorderMode::REPLICATE}; | ||||
| // current do not test this. | // current do not test this. | ||||
| std::vector<float> scalar; | std::vector<float> scalar; | ||||
| for (auto fmt : format_vec) { | for (auto fmt : format_vec) { | ||||
| for (auto border_type : border_mode_vec) { | |||||
| param.format = fmt; | |||||
| param.border_type = border_type; | |||||
| args.emplace_back( | |||||
| param, TensorShape{70000, 1, 2, 2}, TensorShape{70000, 2, 2, 2}, | |||||
| TensorShape{70000, 1, 2, 2}); | |||||
| args.emplace_back( | |||||
| param, TensorShape{1, 1, 2, 2}, TensorShape{1, 2, 2, 2}, | |||||
| TensorShape{1, 1, 2, 2}); | |||||
| args.emplace_back( | |||||
| param, TensorShape{1, 3, 2, 2}, TensorShape{1, 2, 2, 2}, | |||||
| TensorShape{1, 3, 2, 2}); | |||||
| args.emplace_back( | |||||
| param, TensorShape{1, 10, 100, 100}, TensorShape{1, 100, 100, 2}, | |||||
| TensorShape{1, 10, 100, 100}); | |||||
| args.emplace_back( | |||||
| param, TensorShape{2, 4, 100, 200}, TensorShape{2, 100, 200, 2}, | |||||
| TensorShape{2, 4, 100, 200}); | |||||
| args.emplace_back( | |||||
| param, TensorShape{2, 4, 100, 200}, TensorShape{2, 20, 30, 2}, | |||||
| TensorShape{2, 4, 20, 30}); | |||||
| args.emplace_back( | |||||
| param, TensorShape{2, 4, 10, 10}, TensorShape{2, 20, 30, 2}, | |||||
| TensorShape{2, 4, 20, 30}); | |||||
| for (auto interp_mode : interp_mode_vec) { | |||||
| for (auto border_type : border_mode_vec) { | |||||
| param.format = fmt; | |||||
| param.imode = interp_mode; | |||||
| param.border_type = border_type; | |||||
| args.emplace_back( | |||||
| param, TensorShape{70000, 1, 2, 2}, TensorShape{70000, 2, 2, 2}, | |||||
| TensorShape{70000, 1, 2, 2}); | |||||
| args.emplace_back( | |||||
| param, TensorShape{1, 1, 2, 2}, TensorShape{1, 2, 2, 2}, | |||||
| TensorShape{1, 1, 2, 2}); | |||||
| args.emplace_back( | |||||
| param, TensorShape{1, 3, 2, 2}, TensorShape{1, 2, 2, 2}, | |||||
| TensorShape{1, 3, 2, 2}); | |||||
| args.emplace_back( | |||||
| param, TensorShape{1, 10, 100, 100}, | |||||
| TensorShape{1, 100, 100, 2}, TensorShape{1, 10, 100, 100}); | |||||
| args.emplace_back( | |||||
| param, TensorShape{2, 4, 100, 200}, TensorShape{2, 100, 200, 2}, | |||||
| TensorShape{2, 4, 100, 200}); | |||||
| args.emplace_back( | |||||
| param, TensorShape{2, 4, 100, 200}, TensorShape{2, 20, 30, 2}, | |||||
| TensorShape{2, 4, 20, 30}); | |||||
| args.emplace_back( | |||||
| param, TensorShape{2, 4, 10, 10}, TensorShape{2, 20, 30, 2}, | |||||
| TensorShape{2, 4, 20, 30}); | |||||
| } | |||||
| } | } | ||||
| } | } | ||||
| return args; | return args; | ||||
| } | } | ||||
| static inline std::vector<TestArg> get_nhwcd4_args() { | |||||
| std::vector<TestArg> args; | |||||
| param::Remap param; | |||||
| param.format = param::Remap::Format::NHWCD4; | |||||
| param.imode = param::Remap::InterpolationMode::LINEAR; | |||||
| param.border_type = param::Remap::BorderMode::CONSTANT; | |||||
| // FIXME: when fractional part of bval is not zero, naive and opencl bankend may | |||||
| // have different rounding result | |||||
| param.scalar = 77; | |||||
| args.emplace_back( | |||||
| param, TensorShape{2, 2, 1, 3, 4}, TensorShape{2, 4, 6, 2}, | |||||
| TensorShape{2, 4, 1, 6, 4}); | |||||
| args.emplace_back( | |||||
| param, TensorShape{2, 4, 1, 6, 4}, TensorShape{2, 2, 3, 2}, | |||||
| TensorShape{2, 2, 1, 3, 4}); | |||||
| param.imode = param::Remap::InterpolationMode::NEAREST; | |||||
| args.emplace_back( | |||||
| param, TensorShape{2, 2, 1, 3, 4}, TensorShape{2, 4, 6, 2}, | |||||
| TensorShape{2, 4, 1, 6, 4}); | |||||
| args.emplace_back( | |||||
| param, TensorShape{2, 4, 1, 6, 4}, TensorShape{2, 2, 3, 2}, | |||||
| TensorShape{2, 2, 1, 3, 4}); | |||||
| return args; | |||||
| } | |||||
| static inline std::vector<TestArg> get_nhwc_args() { | static inline std::vector<TestArg> get_nhwc_args() { | ||||
| std::vector<TestArg> args; | std::vector<TestArg> args; | ||||
| param::Remap param; | param::Remap param; | ||||
| std::vector<param::Remap::Format> format_vec = {param::Remap::Format::NHWC}; | std::vector<param::Remap::Format> format_vec = {param::Remap::Format::NHWC}; | ||||
| std::vector<param::Remap::InterpolationMode> interp_mode_vec = { | |||||
| param::Remap::InterpolationMode::NEAREST, | |||||
| param::Remap::InterpolationMode::LINEAR}; | |||||
| std::vector<param::Remap::BorderMode> border_mode_vec = { | std::vector<param::Remap::BorderMode> border_mode_vec = { | ||||
| param::Remap::BorderMode::CONSTANT, param::Remap::BorderMode::REFLECT_101, | param::Remap::BorderMode::CONSTANT, param::Remap::BorderMode::REFLECT_101, | ||||
| param::Remap::BorderMode::REFLECT, param::Remap::BorderMode::WRAP, | param::Remap::BorderMode::REFLECT, param::Remap::BorderMode::WRAP, | ||||
| @@ -88,41 +126,44 @@ static inline std::vector<TestArg> get_nhwc_args() { | |||||
| // current do not test this. | // current do not test this. | ||||
| std::vector<float> scalar; | std::vector<float> scalar; | ||||
| for (auto fmt : format_vec) { | for (auto fmt : format_vec) { | ||||
| for (auto border_type : border_mode_vec) { | |||||
| param.format = fmt; | |||||
| param.border_type = border_type; | |||||
| param.scalar = 12.f; | |||||
| args.emplace_back( | |||||
| param, TensorShape{70000, 2, 2, 1}, TensorShape{70000, 2, 2, 2}, | |||||
| TensorShape{70000, 2, 2, 1}); | |||||
| args.emplace_back( | |||||
| param, TensorShape{1, 2, 2, 1}, TensorShape{1, 2, 2, 2}, | |||||
| TensorShape{1, 2, 2, 1}); | |||||
| args.emplace_back( | |||||
| param, TensorShape{1, 2, 2, 3}, TensorShape{1, 2, 2, 2}, | |||||
| TensorShape{1, 2, 2, 3}); | |||||
| args.emplace_back( | |||||
| param, TensorShape{1, 2, 2, 66}, TensorShape{1, 2, 2, 2}, | |||||
| TensorShape{1, 2, 2, 66}); | |||||
| args.emplace_back( | |||||
| param, TensorShape{1, 100, 100, 10}, TensorShape{1, 100, 100, 2}, | |||||
| TensorShape{1, 100, 100, 10}); | |||||
| args.emplace_back( | |||||
| param, TensorShape{2, 100, 200, 4}, TensorShape{2, 100, 200, 2}, | |||||
| TensorShape{2, 100, 200, 4}); | |||||
| args.emplace_back( | |||||
| param, TensorShape{2, 100, 200, 4}, TensorShape{2, 20, 30, 2}, | |||||
| TensorShape{2, 20, 30, 4}); | |||||
| args.emplace_back( | |||||
| param, TensorShape{2, 10, 10, 4}, TensorShape{2, 20, 30, 2}, | |||||
| TensorShape{2, 20, 30, 4}); | |||||
| for (auto interp_mode : interp_mode_vec) { | |||||
| for (auto border_type : border_mode_vec) { | |||||
| param.format = fmt; | |||||
| param.imode = interp_mode; | |||||
| param.border_type = border_type; | |||||
| param.scalar = 12.f; | |||||
| args.emplace_back( | |||||
| param, TensorShape{70000, 2, 2, 1}, TensorShape{70000, 2, 2, 2}, | |||||
| TensorShape{70000, 2, 2, 1}); | |||||
| args.emplace_back( | |||||
| param, TensorShape{1, 2, 2, 1}, TensorShape{1, 2, 2, 2}, | |||||
| TensorShape{1, 2, 2, 1}); | |||||
| args.emplace_back( | |||||
| param, TensorShape{1, 2, 2, 3}, TensorShape{1, 2, 2, 2}, | |||||
| TensorShape{1, 2, 2, 3}); | |||||
| args.emplace_back( | |||||
| param, TensorShape{1, 2, 2, 66}, TensorShape{1, 2, 2, 2}, | |||||
| TensorShape{1, 2, 2, 66}); | |||||
| args.emplace_back( | |||||
| param, TensorShape{1, 100, 100, 10}, | |||||
| TensorShape{1, 100, 100, 2}, TensorShape{1, 100, 100, 10}); | |||||
| args.emplace_back( | |||||
| param, TensorShape{2, 100, 200, 4}, TensorShape{2, 100, 200, 2}, | |||||
| TensorShape{2, 100, 200, 4}); | |||||
| args.emplace_back( | |||||
| param, TensorShape{2, 100, 200, 4}, TensorShape{2, 20, 30, 2}, | |||||
| TensorShape{2, 20, 30, 4}); | |||||
| args.emplace_back( | |||||
| param, TensorShape{2, 10, 10, 4}, TensorShape{2, 20, 30, 2}, | |||||
| TensorShape{2, 20, 30, 4}); | |||||
| } | |||||
| } | } | ||||
| } | } | ||||
| return args; | return args; | ||||
| @@ -58,6 +58,11 @@ static void set_nchw_args(std::vector<TestArg>& args) { | |||||
| args.emplace_back(param, TensorShape{2, 2, 3, 4}, TensorShape{2, 2, 6, 8}); | args.emplace_back(param, TensorShape{2, 2, 3, 4}, TensorShape{2, 2, 6, 8}); | ||||
| args.emplace_back(param, TensorShape{1, 2, 2, 2}, TensorShape{1, 2, 4, 3}); | args.emplace_back(param, TensorShape{1, 2, 2, 2}, TensorShape{1, 2, 4, 3}); | ||||
| args.emplace_back(param, TensorShape{1, 2, 6, 8}, TensorShape{1, 2, 3, 4}); | args.emplace_back(param, TensorShape{1, 2, 6, 8}, TensorShape{1, 2, 3, 4}); | ||||
| param.imode = param::Resize::InterpolationMode::NEAREST; | |||||
| args.emplace_back(param, TensorShape{2, 2, 3, 4}, TensorShape{2, 2, 6, 8}); | |||||
| args.emplace_back(param, TensorShape{1, 2, 2, 2}, TensorShape{1, 2, 4, 3}); | |||||
| args.emplace_back(param, TensorShape{1, 2, 6, 8}, TensorShape{1, 2, 3, 4}); | |||||
| } | } | ||||
| static inline std::vector<TestArg> get_args(IMode imode = IMode::INTER_LINEAR) { | static inline std::vector<TestArg> get_args(IMode imode = IMode::INTER_LINEAR) { | ||||
| @@ -75,6 +80,25 @@ static inline std::vector<TestArg> get_args(IMode imode = IMode::INTER_LINEAR) { | |||||
| return args; | return args; | ||||
| } | } | ||||
| static inline std::vector<TestArg> get_nhwc_args() { | |||||
| std::vector<TestArg> args; | |||||
| param::Resize param; | |||||
| param.format = param::Resize::Format::NHWC; | |||||
| param.imode = param::Resize::InterpolationMode::LINEAR; | |||||
| args.emplace_back(param, TensorShape{2, 3, 4, 2}, TensorShape{2, 6, 8, 2}); | |||||
| args.emplace_back(param, TensorShape{1, 2, 2, 2}, TensorShape{1, 4, 3, 2}); | |||||
| args.emplace_back(param, TensorShape{1, 6, 8, 2}, TensorShape{1, 3, 4, 2}); | |||||
| param.imode = param::Resize::InterpolationMode::NEAREST; | |||||
| args.emplace_back(param, TensorShape{2, 3, 4, 2}, TensorShape{2, 6, 8, 2}); | |||||
| args.emplace_back(param, TensorShape{1, 2, 2, 2}, TensorShape{1, 4, 3, 2}); | |||||
| args.emplace_back(param, TensorShape{1, 6, 8, 2}, TensorShape{1, 3, 4, 2}); | |||||
| return args; | |||||
| } | |||||
| static inline std::vector<TestArg> get_nhwcd4_args() { | static inline std::vector<TestArg> get_nhwcd4_args() { | ||||
| std::vector<TestArg> args; | std::vector<TestArg> args; | ||||
| @@ -83,6 +107,9 @@ static inline std::vector<TestArg> get_nhwcd4_args() { | |||||
| param.imode = param::Resize::InterpolationMode::LINEAR; | param.imode = param::Resize::InterpolationMode::LINEAR; | ||||
| args.emplace_back(param, TensorShape{2, 2, 1, 3, 4}, TensorShape{2, 4, 1, 6, 4}); | args.emplace_back(param, TensorShape{2, 2, 1, 3, 4}, TensorShape{2, 4, 1, 6, 4}); | ||||
| args.emplace_back(param, TensorShape{2, 4, 1, 6, 4}, TensorShape{2, 2, 1, 3, 4}); | args.emplace_back(param, TensorShape{2, 4, 1, 6, 4}, TensorShape{2, 2, 1, 3, 4}); | ||||
| param.imode = param::Resize::InterpolationMode::NEAREST; | |||||
| args.emplace_back(param, TensorShape{2, 2, 1, 3, 4}, TensorShape{2, 4, 1, 6, 4}); | |||||
| args.emplace_back(param, TensorShape{2, 4, 1, 6, 4}, TensorShape{2, 2, 1, 3, 4}); | |||||
| return args; | return args; | ||||
| } | } | ||||
| @@ -351,7 +351,7 @@ def remap( | |||||
| "reflect_101", "wrap". | "reflect_101", "wrap". | ||||
| scalar: value used in case of a constant border. Default: 0 | scalar: value used in case of a constant border. Default: 0 | ||||
| interp_mode: interpolation methods. | interp_mode: interpolation methods. | ||||
| Default: "linear". Currently only support "linear" mode. | |||||
| Default: "linear". Currently also support "nearest" mode. | |||||
| Returns: | Returns: | ||||
| output tensor. | output tensor. | ||||