| @@ -270,6 +270,41 @@ protected: | |||
| }; | |||
| using Remap = RemapForward; | |||
| class RemapBackwardData : public RemapBase { | |||
| DEF_OPR_IMPL(RemapBackwardData, RemapBase, 2, 1); | |||
| public: | |||
| virtual void exec(_megdnn_tensor_in map_xy, _megdnn_tensor_in diff, | |||
| _megdnn_tensor_out grad, _megdnn_workspace workspace) = 0; | |||
| virtual size_t get_workspace_in_bytes(const TensorLayout& map_xy, | |||
| const TensorLayout& diff, | |||
| const TensorLayout& grad) = 0; | |||
| protected: | |||
| void check_exec(const TensorLayout& map_xy, const TensorLayout& diff, | |||
| const TensorLayout& grad, size_t workspace_in_bytes); | |||
| }; | |||
| class RemapBackwardMat : public RemapBase { | |||
| DEF_OPR_IMPL(RemapBackwardMat, RemapBase, 3, 1); | |||
| public: | |||
| virtual void exec(_megdnn_tensor_in src, _megdnn_tensor_in map_xy, | |||
| _megdnn_tensor_in diff, _megdnn_tensor_out grad, | |||
| _megdnn_workspace workspace) = 0; | |||
| virtual size_t get_workspace_in_bytes(const TensorLayout& src, | |||
| const TensorLayout& map_xy, | |||
| const TensorLayout& diff, | |||
| const TensorLayout& grad) = 0; | |||
| protected: | |||
| void check_exec(const TensorLayout& src, const TensorLayout& map_xy, | |||
| const TensorLayout& diff, const TensorLayout& grad, | |||
| size_t workspace_in_bytes); | |||
| }; | |||
| class SeparableFilterBase : public OperatorBase { | |||
| DEF_OPR_IMPL_CTOR(SeparableFilterBase, OperatorBase); | |||
| DEF_OPR_PARAM(SeparableFilter); | |||
| @@ -197,6 +197,8 @@ private: | |||
| cb(ROIAlignBackward) \ | |||
| cb(BatchConvBiasForward) \ | |||
| cb(Remap) \ | |||
| cb(RemapBackwardData) \ | |||
| cb(RemapBackwardMat) \ | |||
| /*! | |||
| * \brief specialize HandleImpl::create_operator for a single opr type; | |||
| @@ -50,6 +50,7 @@ void RemapBase::check_layout_fwd(const TensorLayout& src, | |||
| megdnn_assert(dst.shape[0] == src.shape[0], "%s", errmsg().c_str()); | |||
| megdnn_assert(map_xy.shape[3] == 2); | |||
| megdnn_assert(map_xy.shape[0] == src.shape[0]); | |||
| megdnn_assert_contiguous(src); | |||
| // map_xy only support floa32 type | |||
| // map_xy always in NHWC format | |||
| @@ -85,6 +86,34 @@ void Remap::check_exec(const TensorLayout& src, const TensorLayout& map_xy, | |||
| megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes); | |||
| } | |||
| void RemapBackwardData::check_exec(const TensorLayout& map_xy, | |||
| const TensorLayout& diff, | |||
| const TensorLayout& grad, | |||
| size_t workspace_in_bytes) { | |||
| check_layout_fwd(grad, map_xy, diff); | |||
| megdnn_assert(grad.dtype == dtype::Float32() MEGDNN_INC_FLOAT16( | |||
| || grad.dtype == dtype::BFloat16()), | |||
| "Backward Remap only supports Float32/BFloat16."); | |||
| auto required_workspace_in_bytes = | |||
| get_workspace_in_bytes(map_xy, diff, grad); | |||
| megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes); | |||
| } | |||
| void RemapBackwardMat::check_exec(const TensorLayout& src, | |||
| const TensorLayout& map_xy, | |||
| const TensorLayout& diff, | |||
| const TensorLayout& grad, | |||
| size_t workspace_in_bytes) { | |||
| check_layout_fwd(src, map_xy, diff); | |||
| megdnn_assert_eq_layout(map_xy, grad); | |||
| megdnn_assert(grad.dtype == dtype::Float32() MEGDNN_INC_FLOAT16( | |||
| || grad.dtype == dtype::BFloat16()), | |||
| "Backward Remap only supports Float32/BFloat16."); | |||
| auto required_workspace_in_bytes = | |||
| get_workspace_in_bytes(src, map_xy, diff, grad); | |||
| megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes); | |||
| } | |||
| } // namespace megdnn | |||
| // vim: syntax=cpp.doxygen | |||
| @@ -0,0 +1,71 @@ | |||
| /** | |||
| * \file dnn/src/cuda/remap/backward_data.cpp | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
| * implied. | |||
| */ | |||
| #include "src/cuda/remap/common.h" | |||
| #include "src/cuda/remap/opr_impl.h" | |||
| #include "src/cuda/utils.h" | |||
| using namespace megdnn; | |||
| using namespace cuda; | |||
| void RemapBackwardDataImpl::exec(_megdnn_tensor_in map_xy, | |||
| _megdnn_tensor_in diff, | |||
| _megdnn_tensor_out grad, | |||
| _megdnn_workspace workspace) { | |||
| check_exec(map_xy.layout, diff.layout, grad.layout, workspace.size); | |||
| megdnn_assert(param().imode == param::Remap::InterpolationMode::LINEAR, | |||
| "only support LINEAR interpolationMode"); | |||
| megdnn_assert(param().format == param::Remap::Format::NCHW, | |||
| "only support NCHW format for remap backward"); | |||
| auto stream = cuda_stream(this->handle()); | |||
| int N, C, IH, IW, OH, OW; | |||
| N = grad.layout.shape[0]; | |||
| C = grad.layout.shape[1]; | |||
| IH = grad.layout.shape[2]; | |||
| IW = grad.layout.shape[3]; | |||
| OH = map_xy.layout.shape[1]; | |||
| OW = map_xy.layout.shape[2]; | |||
| #define cb(dt, _format, bmode) \ | |||
| if (param().format == param::Remap::Format::_format && \ | |||
| param().border_type == param::Remap::BorderMode::bmode) { \ | |||
| using ctype = DTypeTrait<dt>::ctype; \ | |||
| remap::backwarddata_proxy<ctype, param_enumv::Remap::Format::_format, \ | |||
| ::BorderMode::BORDER_##bmode>( \ | |||
| grad.compatible_ptr<ctype>(), \ | |||
| map_xy.compatible_ptr<dt_float32>(), \ | |||
| diff.compatible_ptr<ctype>(), N, C, IH, IW, OH, OW, stream); \ | |||
| break; \ | |||
| } | |||
| #define support_dtype(dt) \ | |||
| case DTypeTrait<dt>::enumv: { \ | |||
| cb(dt, NCHW, CONSTANT); \ | |||
| cb(dt, NCHW, REPLICATE); \ | |||
| cb(dt, NCHW, REFLECT); \ | |||
| cb(dt, NCHW, REFLECT_101); \ | |||
| cb(dt, NCHW, WRAP); \ | |||
| megdnn_throw("unsupported border type in remap cuda"); \ | |||
| } | |||
| switch (grad.layout.dtype.enumv()) { | |||
| support_dtype(dtype::Float32); | |||
| support_dtype(dtype::BFloat16); | |||
| default: | |||
| megdnn_throw("unsupported dtype in remap backward cuda\n"); | |||
| } | |||
| #undef support_dtype | |||
| #undef cb | |||
| } | |||
| // vim: syntax=cpp.doxygen | |||
| @@ -0,0 +1,169 @@ | |||
| /** | |||
| * \file dnn/src/cuda/remap/backward_data.cu | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
| * implied. | |||
| */ | |||
| #include <cuda_runtime.h> | |||
| #include "src/common/rounding_converter.cuh" | |||
| #include "src/cuda/cv/kernel_common.cuh" | |||
| #include "src/cuda/remap/common.h" | |||
| #include "src/cuda/utils.cuh" | |||
| using namespace megdnn; | |||
| using namespace cuda; | |||
| using namespace remap; | |||
| using namespace rounding; | |||
| namespace { | |||
| template <const uint32_t format> | |||
| __device__ inline int get_offset(int height, int width, int channel, int h, | |||
| int w, int c); | |||
| template <> | |||
| __device__ inline int get_offset<param_enumv::Remap::Format::NCHW>( | |||
| int height, int width, int channel, int h, int w, int c) { | |||
| return channel * h * w + height * w + width; | |||
| } | |||
| template <typename ctype, const uint32_t format, ::BorderMode bmode> | |||
| struct GetSrcData { | |||
| __device__ static inline int get_index(int height, int width, int channel, | |||
| int h, int w, int c) { | |||
| height = megcv::border_interpolate<bmode>(height, h); | |||
| width = megcv::border_interpolate<bmode>(width, w); | |||
| return get_offset<format>(height, width, channel, h, w, c); | |||
| } | |||
| }; | |||
| template <typename ctype, const uint32_t format> | |||
| struct GetSrcData<ctype, format, ::BorderMode::BORDER_CONSTANT> { | |||
| __device__ static inline int get_index(int height, int width, int channel, | |||
| int h, int w, int c) { | |||
| return (height >= 0 && height < h && width >= 0 && width < w) | |||
| ? get_offset<format>(height, width, channel, h, w, c) | |||
| : -1; | |||
| } | |||
| }; | |||
| template <typename ctype, const uint32_t format, ::BorderMode bmode> | |||
| __global__ void kern_general(ctype* __restrict grad, const float* map_xy, | |||
| const ctype* diff, int C, int IH, int IW, int OH, | |||
| int OW) { | |||
| int ow = blockIdx.x * blockDim.x + threadIdx.x; | |||
| int oh = blockIdx.y * blockDim.y + threadIdx.y; | |||
| grad += blockIdx.z * C * IH * IW; | |||
| diff += blockIdx.z * C * OH * OW; | |||
| map_xy += blockIdx.z * 2 * OH * OW; | |||
| RoundingConverter<ctype> round_converter; | |||
| if (ow < OW && oh < OH) { | |||
| float index_col = map_xy[oh * OW * 2 + ow * 2 + 0]; | |||
| float index_row = map_xy[oh * OW * 2 + ow * 2 + 1]; | |||
| int col = static_cast<int>(floor(index_col)); | |||
| int row = static_cast<int>(floor(index_row)); | |||
| float v = index_col - col; // alphah | |||
| float u = index_row - row; // alphaw | |||
| const float one = 1.f; | |||
| for (int c = 0; c < C; ++c) { | |||
| float hidden = static_cast<float>( | |||
| diff[get_offset<format>(oh, ow, c, OH, OW, C)]); | |||
| int a00 = GetSrcData<ctype, format, bmode>::get_index( | |||
| row + 0, col + 0, c, IH, IW, C); | |||
| if (a00 != -1) { | |||
| atomic_add(grad + a00, | |||
| round_converter((one - u) * (one - v) * hidden)); | |||
| } | |||
| int a01 = GetSrcData<ctype, format, bmode>::get_index( | |||
| row + 0, col + 1, c, IH, IW, C); | |||
| if (a01 != -1) { | |||
| atomic_add(grad + a01, round_converter((one - u) * v * hidden)); | |||
| } | |||
| int a10 = GetSrcData<ctype, format, bmode>::get_index( | |||
| row + 1, col + 0, c, IH, IW, C); | |||
| if (a10 != -1) { | |||
| atomic_add(grad + a10, round_converter(u * (one - v) * hidden)); | |||
| } | |||
| int a11 = GetSrcData<ctype, param_enumv::Remap::Format::NCHW, | |||
| bmode>::get_index(row + 1, col + 1, c, IH, IW, | |||
| C); | |||
| if (a11 != -1) { | |||
| atomic_add(grad + a11, round_converter(u * v * hidden)); | |||
| } | |||
| } | |||
| } | |||
| } | |||
| template <typename ctype, const uint32_t format, ::BorderMode bmode> | |||
| void dispatch_backwarddata(ctype* grad, const float* map_xy, const ctype* diff, | |||
| int N, int C, int IH, int IW, int OH, int OW, | |||
| cudaStream_t stream) { | |||
| const int BX = 32, BY = 16; | |||
| const int max_batch_size = 65535; | |||
| while (N) { | |||
| size_t curr_batch_size = N < max_batch_size ? N : max_batch_size; | |||
| dim3 threads(BX, BY); | |||
| dim3 blocks((OW + BX - 1) / BX, (OH + BY - 1) / BY, curr_batch_size); | |||
| cuda_check(cudaMemsetAsync( | |||
| grad, 0, sizeof(ctype) * curr_batch_size * C * IH * IW, | |||
| stream)); | |||
| kern_general<ctype, format, bmode><<<blocks, threads, 0, stream>>>( | |||
| grad, map_xy, diff, C, IH, IW, OH, OW); | |||
| N -= curr_batch_size; | |||
| grad += curr_batch_size * C * IH * IW; | |||
| diff += curr_batch_size * C * OH * OW; | |||
| map_xy += curr_batch_size * 2 * OH * OW; | |||
| } | |||
| } | |||
| } // anonymous namespace | |||
| namespace megdnn { | |||
| namespace cuda { | |||
| namespace remap { | |||
| template <typename ctype, const uint32_t format, ::BorderMode bmode> | |||
| void backwarddata_proxy(ctype* grad, const float* map_xy, const ctype* diff, | |||
| int N, int C, int IH, int IW, int OH, int OW, | |||
| cudaStream_t stream) { | |||
| dispatch_backwarddata<ctype, format, bmode>(grad, map_xy, diff, N, C, IH, | |||
| IW, OH, OW, stream); | |||
| after_kernel_launch(); | |||
| } | |||
| #define INST(ctype, format, bmode) \ | |||
| template void backwarddata_proxy< \ | |||
| ctype, param_enumv::Remap::Format::format, ::BorderMode::bmode>( \ | |||
| ctype*, const float*, const ctype*, int, int, int, int, int, int, \ | |||
| cudaStream_t); | |||
| #define FOR_FORMAT_BMODE(ctype) \ | |||
| INST(ctype, NCHW, BORDER_CONSTANT) \ | |||
| INST(ctype, NCHW, BORDER_REPLICATE) \ | |||
| INST(ctype, NCHW, BORDER_REFLECT) \ | |||
| INST(ctype, NCHW, BORDER_REFLECT_101) \ | |||
| INST(ctype, NCHW, BORDER_WRAP) | |||
| FOR_FORMAT_BMODE(float) | |||
| MEGDNN_INC_FLOAT16(FOR_FORMAT_BMODE(dt_bfloat16)) | |||
| #undef FOR_FORMAT_BMODE | |||
| #undef INST | |||
| } // namespace remap | |||
| } // namespace cuda | |||
| } // namespace megdnn | |||
| // vim: syntax=cpp.doxygen | |||
| @@ -0,0 +1,73 @@ | |||
| /** | |||
| * \file dnn/src/cuda/remap/backward_mat.cpp | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
| * implied. | |||
| */ | |||
| #include "src/cuda/remap/common.h" | |||
| #include "src/cuda/remap/opr_impl.h" | |||
| #include "src/cuda/utils.h" | |||
| using namespace megdnn; | |||
| using namespace cuda; | |||
| void RemapBackwardMatImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_in map_xy, | |||
| _megdnn_tensor_in diff, _megdnn_tensor_out grad, | |||
| _megdnn_workspace workspace) { | |||
| check_exec(src.layout, map_xy.layout, diff.layout, grad.layout, | |||
| workspace.size); | |||
| megdnn_assert(param().imode == param::Remap::InterpolationMode::LINEAR, | |||
| "only support LINEAR interpolationMode"); | |||
| megdnn_assert(param().format == param::Remap::Format::NCHW, | |||
| "only support NCHW format for remap backward"); | |||
| auto stream = cuda_stream(this->handle()); | |||
| int N, C, IH, IW, OH, OW; | |||
| N = src.layout.shape[0]; | |||
| C = src.layout.shape[1]; | |||
| IH = src.layout.shape[2]; | |||
| IW = src.layout.shape[3]; | |||
| OH = map_xy.layout.shape[1]; | |||
| OW = map_xy.layout.shape[2]; | |||
| #define cb(dt, _format, bmode) \ | |||
| if (param().format == param::Remap::Format::_format && \ | |||
| param().border_type == param::Remap::BorderMode::bmode) { \ | |||
| using ctype = DTypeTrait<dt>::ctype; \ | |||
| remap::backwardmat_proxy<ctype, param_enumv::Remap::Format::_format, \ | |||
| ::BorderMode::BORDER_##bmode>( \ | |||
| src.compatible_ptr<ctype>(), \ | |||
| map_xy.compatible_ptr<dt_float32>(), \ | |||
| diff.compatible_ptr<ctype>(), \ | |||
| grad.compatible_ptr<dt_float32>(), N, C, IH, IW, OH, OW, \ | |||
| param().scalar, stream); \ | |||
| break; \ | |||
| } | |||
| #define support_dtype(dt) \ | |||
| case DTypeTrait<dt>::enumv: { \ | |||
| cb(dt, NCHW, CONSTANT); \ | |||
| cb(dt, NCHW, REPLICATE); \ | |||
| cb(dt, NCHW, REFLECT); \ | |||
| cb(dt, NCHW, REFLECT_101); \ | |||
| cb(dt, NCHW, WRAP); \ | |||
| megdnn_throw("unsupported border type in remap cuda"); \ | |||
| } | |||
| switch (src.layout.dtype.enumv()) { | |||
| support_dtype(dtype::Float32); | |||
| support_dtype(dtype::BFloat16); | |||
| default: | |||
| megdnn_throw("unsupported dtype in remap backward cuda\n"); | |||
| } | |||
| #undef support_dtype | |||
| #undef cb | |||
| } | |||
| // vim: syntax=cpp.doxygen | |||
| @@ -0,0 +1,170 @@ | |||
| /** | |||
| * \file dnn/src/cuda/remap/backward_mat.cu | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
| * implied. | |||
| */ | |||
| #include <cuda_runtime.h> | |||
| #include "src/common/rounding_converter.cuh" | |||
| #include "src/cuda/cv/kernel_common.cuh" | |||
| #include "src/cuda/remap/common.h" | |||
| #include "src/cuda/utils.cuh" | |||
| using namespace megdnn; | |||
| using namespace cuda; | |||
| using namespace remap; | |||
| using namespace rounding; | |||
| namespace { | |||
| template <const uint32_t format> | |||
| __device__ inline int get_offset(int height, int width, int channel, int h, | |||
| int w, int c); | |||
| template <> | |||
| __device__ inline int get_offset<param_enumv::Remap::Format::NCHW>( | |||
| int height, int width, int channel, int h, int w, int c) { | |||
| return channel * h * w + height * w + width; | |||
| } | |||
| template <typename ctype, const uint32_t format, ::BorderMode bmode> | |||
| struct GetSrcData { | |||
| __device__ static inline int get_index(int height, int width, int channel, | |||
| int h, int w, int c) { | |||
| height = megcv::border_interpolate<bmode>(height, h); | |||
| width = megcv::border_interpolate<bmode>(width, w); | |||
| return get_offset<format>(height, width, channel, h, w, c); | |||
| } | |||
| }; | |||
| template <typename ctype, const uint32_t format> | |||
| struct GetSrcData<ctype, format, ::BorderMode::BORDER_CONSTANT> { | |||
| __device__ static inline int get_index(int height, int width, int channel, | |||
| int h, int w, int c) { | |||
| return (height >= 0 && height < h && width >= 0 && width < w) | |||
| ? get_offset<format>(height, width, channel, h, w, c) | |||
| : -1; | |||
| } | |||
| }; | |||
| template <typename ctype, const uint32_t format, ::BorderMode bmode> | |||
| __global__ void kern_general(const ctype* src, const float* map_xy, | |||
| const ctype* diff, float* __restrict grad, int C, | |||
| int IH, int IW, int OH, int OW, float scalar) { | |||
| int ow = blockIdx.x * blockDim.x + threadIdx.x; | |||
| int oh = blockIdx.y * blockDim.y + threadIdx.y; | |||
| src += blockIdx.z * C * IH * IW; | |||
| diff += blockIdx.z * C * OH * OW; | |||
| map_xy += blockIdx.z * 2 * OH * OW; | |||
| grad += blockIdx.z * 2 * OH * OW; | |||
| RoundingConverter<ctype> round_converter; | |||
| if (ow < OW && oh < OH) { | |||
| float index_col = map_xy[oh * OW * 2 + ow * 2 + 0]; | |||
| float index_row = map_xy[oh * OW * 2 + ow * 2 + 1]; | |||
| int col = static_cast<int>(floor(index_col)); | |||
| int row = static_cast<int>(floor(index_row)); | |||
| float v = index_col - col; // alphaw | |||
| float u = index_row - row; // alphah | |||
| const float one = 1.f; | |||
| for (int c = 0; c < C; ++c) { | |||
| float hidden = static_cast<float>( | |||
| diff[get_offset<format>( | |||
| oh, ow, c, OH, OW, C)]); | |||
| float du = 0.f, dv = 0.f; | |||
| int a00 = GetSrcData<ctype, format, bmode>::get_index( | |||
| row + 0, col + 0, c, IH, IW, C); | |||
| int a01 = GetSrcData<ctype, format, bmode>::get_index( | |||
| row + 0, col + 1, c, IH, IW, C); | |||
| int a10 = GetSrcData<ctype, format, bmode>::get_index( | |||
| row + 1, col + 0, c, IH, IW, C); | |||
| int a11 = GetSrcData<ctype, format, bmode>::get_index( | |||
| row + 1, col + 1, c, IH, IW, C); | |||
| dv -= ((a00 != -1) ? src[a00] : scalar) * (one - u); | |||
| dv += ((a01 != -1) ? src[a01] : scalar) * (one - u); | |||
| dv -= ((a10 != -1) ? src[a10] : scalar) * u; | |||
| dv += ((a11 != -1) ? src[a11] : scalar) * u; | |||
| du -= ((a00 != -1) ? src[a00] : scalar) * (one - v); | |||
| du -= ((a01 != -1) ? src[a01] : scalar) * v; | |||
| du += ((a10 != -1) ? src[a10] : scalar) * (one - v); | |||
| du += ((a11 != -1) ? src[a11] : scalar) * v; | |||
| grad[oh * OW * 2 + ow * 2 + 0] += round_converter(hidden * dv); | |||
| grad[oh * OW * 2 + ow * 2 + 1] += round_converter(hidden * du); | |||
| } | |||
| } | |||
| } | |||
| template <typename ctype, const uint32_t format, ::BorderMode bmode> | |||
| void dispatch_backwardmat(const ctype* src, const float* map_xy, | |||
| const ctype* diff, float* grad, int N, int C, int IH, | |||
| int IW, int OH, int OW, float scalar, | |||
| cudaStream_t stream) { | |||
| const int BX = 32, BY = 16; | |||
| const int max_batch_size = 65535; | |||
| while (N) { | |||
| size_t curr_batch_size = N < max_batch_size ? N : max_batch_size; | |||
| dim3 threads(BX, BY); | |||
| dim3 blocks((OW + BX - 1) / BX, (OH + BY - 1) / BY, curr_batch_size); | |||
| cuda_check(cudaMemsetAsync( | |||
| grad, 0, sizeof(float) * curr_batch_size * OH * OW * 2, | |||
| stream)); | |||
| kern_general<ctype, format, bmode><<<blocks, threads, 0, stream>>>( | |||
| src, map_xy, diff, grad, C, IH, IW, OH, OW, scalar); | |||
| N -= curr_batch_size; | |||
| src += curr_batch_size * C * IH * IW; | |||
| diff += curr_batch_size * C * OH * OW; | |||
| map_xy += curr_batch_size * 2 * OH * OW; | |||
| grad += curr_batch_size * 2 * OH * OW; | |||
| } | |||
| } | |||
| } // anonymous namespace | |||
| namespace megdnn { | |||
| namespace cuda { | |||
| namespace remap { | |||
| template <typename ctype, const uint32_t format, ::BorderMode bmode> | |||
| void backwardmat_proxy(const ctype* src, const float* map_xy, const ctype* diff, | |||
| float* grad, int N, int C, int IH, int IW, int OH, | |||
| int OW, float scalar, cudaStream_t stream) { | |||
| dispatch_backwardmat<ctype, format, bmode>(src, map_xy, diff, grad, N, C, | |||
| IH, IW, OH, OW, scalar, stream); | |||
| after_kernel_launch(); | |||
| } | |||
| #define INST(ctype, format, bmode) \ | |||
| template void backwardmat_proxy<ctype, param_enumv::Remap::Format::format, \ | |||
| ::BorderMode::bmode>( \ | |||
| const ctype*, const float*, const ctype*, float*, int, int, int, \ | |||
| int, int, int, float, cudaStream_t); | |||
| #define FOR_FORMAT_BMODE(ctype) \ | |||
| INST(ctype, NCHW, BORDER_CONSTANT) \ | |||
| INST(ctype, NCHW, BORDER_REPLICATE) \ | |||
| INST(ctype, NCHW, BORDER_REFLECT) \ | |||
| INST(ctype, NCHW, BORDER_REFLECT_101) \ | |||
| INST(ctype, NCHW, BORDER_WRAP) | |||
| FOR_FORMAT_BMODE(float) | |||
| MEGDNN_INC_FLOAT16(FOR_FORMAT_BMODE(dt_bfloat16)) | |||
| #undef FOR_FORMAT_BMODE | |||
| #undef INST | |||
| } // namespace remap | |||
| } // namespace cuda | |||
| } // namespace megdnn | |||
| // vim: syntax=cpp.doxygen | |||
| @@ -24,7 +24,17 @@ namespace remap { | |||
| template <typename ctype, const uint32_t format, ::BorderMode bmode> | |||
| void forward_proxy(const ctype* src, const float* map_xy, ctype* dst, int N, | |||
| int C, int IH, int IW, int OH, int OW, float scalar, | |||
| int S_IN, int S_IC, int S_IH, int S_IW, cudaStream_t stream); | |||
| cudaStream_t stream); | |||
| template <typename ctype, const uint32_t format, ::BorderMode bmode> | |||
| void backwarddata_proxy(ctype* grad, const float* map_xy, const ctype* diff, | |||
| int N, int C, int IH, int IW, int OH, int OW, | |||
| cudaStream_t stream); | |||
| template <typename ctype, const uint32_t format, ::BorderMode bmode> | |||
| void backwardmat_proxy(const ctype* src, const float* map_xy, const ctype* diff, | |||
| float* grad, int N, int C, int IH, int IW, int OH, | |||
| int OW, float scalar, cudaStream_t stream); | |||
| } // namespace remap | |||
| } // namespace cuda | |||
| @@ -22,9 +22,10 @@ using namespace cuda; | |||
| void RemapImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_out map_xy, | |||
| _megdnn_tensor_in dst, _megdnn_workspace workspace) { | |||
| check_exec(src.layout, map_xy.layout, dst.layout, workspace.size); | |||
| megdnn_assert(map_xy.layout.dtype.enumv() == | |||
| DTypeTrait<dtype::Float32>::enumv); | |||
| auto stream = cuda_stream(this->handle()); | |||
| int N, C, IH, IW, OH, OW; | |||
| ptrdiff_t S_IN = 0, S_IC = 0, S_IH = 0, S_IW = 0; | |||
| OH = map_xy.layout.shape[1]; | |||
| OW = map_xy.layout.shape[2]; | |||
| @@ -36,10 +37,6 @@ void RemapImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_out map_xy, | |||
| C = src.layout.shape[1]; | |||
| IH = src.layout.shape[2]; | |||
| IW = src.layout.shape[3]; | |||
| S_IN = src.layout.stride[0]; | |||
| S_IC = src.layout.stride[1]; | |||
| S_IH = src.layout.stride[2]; | |||
| S_IW = src.layout.stride[3]; | |||
| } else if (param().format == param::Remap::Format::NHWC) { | |||
| N = src.layout.shape[0]; | |||
| C = src.layout.shape[3]; | |||
| @@ -58,7 +55,7 @@ void RemapImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_out map_xy, | |||
| src.compatible_ptr<ctype>(), \ | |||
| map_xy.compatible_ptr<dt_float32>(), \ | |||
| dst.compatible_ptr<ctype>(), N, C, IH, IW, OH, OW, \ | |||
| param().scalar, S_IN, S_IC, S_IH, S_IW, stream); \ | |||
| param().scalar, stream); \ | |||
| break; \ | |||
| } | |||
| @@ -78,15 +75,16 @@ void RemapImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_out map_xy, | |||
| } | |||
| switch (src.layout.dtype.enumv()) { | |||
| support_dtype(dtype::Float32) | |||
| MEGDNN_INC_FLOAT16(support_dtype(dtype::Float16)) | |||
| support_dtype(dtype::Int8) | |||
| support_dtype(dtype::Uint8) | |||
| support_dtype(dtype::Float32); | |||
| MEGDNN_INC_FLOAT16(support_dtype(dtype::Float16)); | |||
| MEGDNN_INC_FLOAT16(support_dtype(dtype::BFloat16)); | |||
| support_dtype(dtype::Int8); | |||
| support_dtype(dtype::Uint8); | |||
| default: | |||
| megdnn_throw("unsupported dtype in remap cuda"); | |||
| } | |||
| #undef supported_dtype | |||
| #undef support_dtype | |||
| #undef cb | |||
| } | |||
| @@ -23,17 +23,6 @@ using namespace rounding; | |||
| namespace { | |||
| template <typename ctype> | |||
| struct DirectSrcVisitor { | |||
| const ctype* ptr; | |||
| __device__ __forceinline__ const ctype* get(int batch, int im_size) { | |||
| return ptr + batch * im_size; | |||
| } | |||
| void move_batch(size_t batch, size_t im_size) { ptr += batch * im_size; } | |||
| }; | |||
| template <const uint32_t format> | |||
| __device__ inline int get_offset(int height, int width, int channel, int h, | |||
| int w, int c); | |||
| @@ -74,14 +63,13 @@ struct GetSrcData<ctype, format, ::BorderMode::BORDER_CONSTANT> { | |||
| } | |||
| }; | |||
| template <typename ctype, typename SrcVisitor, ::BorderMode bmode> | |||
| __global__ void kern_general(SrcVisitor src, const float* map_xy, | |||
| template <typename ctype, ::BorderMode bmode> | |||
| __global__ void kern_general(const ctype* __restrict sptr, const float* map_xy, | |||
| ctype* __restrict dst, int C, int IH, int IW, | |||
| int OH, int OW, int S_IN, int S_IC, int S_IH, | |||
| int S_IW, float scalar) { | |||
| int OH, int OW, float scalar) { | |||
| int ow = blockIdx.x * blockDim.x + threadIdx.x; | |||
| int oh = blockIdx.y * blockDim.y + threadIdx.y; | |||
| const ctype* __restrict sptr = src.get(blockIdx.z, S_IN); | |||
| sptr += blockIdx.z * C * IH * IW; | |||
| dst += blockIdx.z * C * OH * OW; | |||
| map_xy += blockIdx.z * 2 * OH * OW; | |||
| RoundingConverter<ctype> round_converter; | |||
| @@ -89,8 +77,8 @@ __global__ void kern_general(SrcVisitor src, const float* map_xy, | |||
| if (ow < OW && oh < OH) { | |||
| float index_col = map_xy[oh * OW * 2 + ow * 2 + 0]; | |||
| float index_row = map_xy[oh * OW * 2 + ow * 2 + 1]; | |||
| int col = (int)floor(index_col); | |||
| int row = (int)floor(index_row); | |||
| int col = static_cast<int>(floor(index_col)); | |||
| int row = static_cast<int>(floor(index_row)); | |||
| float v = index_col - col; | |||
| float u = index_row - row; | |||
| for (int c = 0; c < C; ++c) { | |||
| @@ -106,22 +94,25 @@ __global__ void kern_general(SrcVisitor src, const float* map_xy, | |||
| ctype a11 = GetSrcData<ctype, param_enumv::Remap::Format::NCHW, | |||
| bmode>::get(sptr, row + 1, col + 1, c, IH, | |||
| IW, C, scalar); | |||
| dst[get_offset<param_enumv::Remap::Format::NCHW>(oh, ow, c, OH, OW, | |||
| C)] = | |||
| round_converter(a00 * (1.f - u) * (1.f - v) + | |||
| a01 * (1.f - u) * v + a10 * (1.f - v) * u + | |||
| a11 * u * v); | |||
| /* in remap, we use float as the type of intermediate result */ | |||
| float result = static_cast<float>(a00) * (1.f - u) * (1.f - v) + | |||
| static_cast<float>(a01) * (1.f - u) * v + | |||
| static_cast<float>(a10) * (1.f - v) * u + | |||
| static_cast<float>(a11) * u * v; | |||
| dst[get_offset<param_enumv::Remap::Format::NCHW>( | |||
| oh, ow, c, OH, OW, C)] = round_converter(result); | |||
| } | |||
| } | |||
| } | |||
| template <typename ctype, typename SrcVisitor, ::BorderMode bmode> | |||
| __global__ void kern_general_nhwc(SrcVisitor src, const float* map_xy, | |||
| ctype* __restrict dst, int C, int IH, int IW, | |||
| int OH, int OW, float scalar) { | |||
| template <typename ctype, ::BorderMode bmode> | |||
| __global__ void kern_general_nhwc(const ctype* __restrict sptr, | |||
| const float* map_xy, ctype* __restrict dst, | |||
| int C, int IH, int IW, int OH, int OW, | |||
| float scalar) { | |||
| int ow = blockIdx.x * blockDim.x + threadIdx.x; | |||
| int oh = blockIdx.y * blockDim.y + threadIdx.y; | |||
| const ctype* __restrict sptr = src.get(blockIdx.z, C * IH * IW); | |||
| sptr += blockIdx.z * C * IH * IW; | |||
| dst += blockIdx.z * C * OH * OW; | |||
| map_xy += blockIdx.z * 2 * OH * OW; | |||
| RoundingConverter<ctype> round_converter; | |||
| @@ -129,8 +120,8 @@ __global__ void kern_general_nhwc(SrcVisitor src, const float* map_xy, | |||
| if (ow < OW && oh < OH) { | |||
| float index_col = map_xy[oh * OW * 2 + ow * 2 + 0]; | |||
| float index_row = map_xy[oh * OW * 2 + ow * 2 + 1]; | |||
| int col = (int)floor(index_col); | |||
| int row = (int)floor(index_row); | |||
| int col = static_cast<int>(floor(index_col)); | |||
| int row = static_cast<int>(floor(index_row)); | |||
| float v = index_col - col; | |||
| float u = index_row - row; | |||
| for (int c = 0; c < C; ++c) { | |||
| @@ -146,21 +137,21 @@ __global__ void kern_general_nhwc(SrcVisitor src, const float* map_xy, | |||
| ctype a11 = GetSrcData<ctype, param_enumv::Remap::Format::NHWC, | |||
| bmode>::get(sptr, row + 1, col + 1, c, IH, | |||
| IW, C, scalar); | |||
| dst[get_offset<param_enumv::Remap::Format::NHWC>(oh, ow, c, OH, OW, | |||
| C)] = | |||
| round_converter(a00 * (1.f - u) * (1.f - v) + | |||
| a01 * (1.f - u) * v + a10 * (1.f - v) * u + | |||
| a11 * u * v); | |||
| /* in remap, we use float as the type of intermediate result */ | |||
| float result = static_cast<float>(a00) * (1.f - u) * (1.f - v) + | |||
| static_cast<float>(a01) * (1.f - u) * v + | |||
| static_cast<float>(a10) * (1.f - v) * u + | |||
| static_cast<float>(a11) * u * v; | |||
| dst[get_offset<param_enumv::Remap::Format::NHWC>( | |||
| oh, ow, c, OH, OW, C)] = round_converter(result); | |||
| } | |||
| } | |||
| } | |||
| template <typename ctype, typename SrcVisitor, const uint32_t format, | |||
| ::BorderMode bmode> | |||
| void dispatch_with_visitor(SrcVisitor src, const float* map_xy, ctype* dst, | |||
| int N, int C, int IH, int IW, int OH, int OW, | |||
| float scalar, int S_IN, int S_IC, int S_IH, int S_IW, | |||
| cudaStream_t stream) { | |||
| template <typename ctype, const uint32_t format, ::BorderMode bmode> | |||
| void dispatch_forward(const ctype* src, const float* map_xy, ctype* dst, int N, | |||
| int C, int IH, int IW, int OH, int OW, float scalar, | |||
| cudaStream_t stream) { | |||
| const int BX = 32, BY = 16; | |||
| const int max_batch_size = 65535; | |||
| @@ -170,19 +161,17 @@ void dispatch_with_visitor(SrcVisitor src, const float* map_xy, ctype* dst, | |||
| dim3 blocks((OW + BX - 1) / BX, (OH + BY - 1) / BY, curr_batch_size); | |||
| if (format == param_enumv::Remap::Format::NCHW) { | |||
| kern_general<ctype, SrcVisitor, bmode> | |||
| <<<blocks, threads, 0, stream>>>(src, map_xy, dst, C, IH, | |||
| IW, OH, OW, S_IN, S_IC, | |||
| S_IH, S_IW, scalar); | |||
| kern_general<ctype, bmode><<<blocks, threads, 0, stream>>>( | |||
| src, map_xy, dst, C, IH, IW, OH, OW, scalar); | |||
| } else if (format == param_enumv::Remap::Format::NHWC) { | |||
| kern_general_nhwc<ctype, SrcVisitor, bmode> | |||
| <<<blocks, threads, 0, stream>>>(src, map_xy, dst, C, IH, | |||
| IW, OH, OW, scalar); | |||
| kern_general_nhwc<ctype, bmode><<<blocks, threads, 0, stream>>>( | |||
| src, map_xy, dst, C, IH, IW, OH, OW, scalar); | |||
| } | |||
| N -= curr_batch_size; | |||
| src.move_batch(curr_batch_size, C * IH * IW); | |||
| src += curr_batch_size * C * IH * IW; | |||
| dst += curr_batch_size * C * OH * OW; | |||
| map_xy += curr_batch_size * OH * OW * 2; | |||
| } | |||
| } | |||
| @@ -195,22 +184,17 @@ namespace remap { | |||
| template <typename ctype, const uint32_t format, ::BorderMode bmode> | |||
| void forward_proxy(const ctype* src, const float* map_xy, ctype* dst, int N, | |||
| int C, int IH, int IW, int OH, int OW, float scalar, | |||
| int S_IN, int S_IC, int S_IH, int S_IW, | |||
| cudaStream_t stream) { | |||
| DirectSrcVisitor<ctype> visitor; | |||
| visitor.ptr = src; | |||
| using SrcVisitor = DirectSrcVisitor<ctype>; | |||
| dispatch_with_visitor<ctype, SrcVisitor, format, bmode>( | |||
| visitor, map_xy, dst, N, C, IH, IW, OH, OW, scalar, S_IN, S_IC, | |||
| S_IH, S_IW, stream); | |||
| dispatch_forward<ctype, format, bmode>(src, map_xy, dst, N, C, IH, IW, OH, | |||
| OW, scalar, stream); | |||
| after_kernel_launch(); | |||
| } | |||
| #define INST(ctype, format, bmode) \ | |||
| template void forward_proxy<ctype, param_enumv::Remap::Format::format, \ | |||
| ::BorderMode::bmode>( \ | |||
| const ctype* src, const float*, ctype*, int, int, int, int, int, \ | |||
| int, float, int, int, int, int, cudaStream_t); | |||
| #define INST(ctype, format, bmode) \ | |||
| template void forward_proxy<ctype, param_enumv::Remap::Format::format, \ | |||
| ::BorderMode::bmode>( \ | |||
| const ctype*, const float*, ctype*, int, int, int, int, int, int, \ | |||
| float, cudaStream_t); | |||
| #define FOR_FORMAT_BMODE(ctype) \ | |||
| INST(ctype, NCHW, BORDER_CONSTANT) \ | |||
| @@ -226,11 +210,13 @@ void forward_proxy(const ctype* src, const float* map_xy, ctype* dst, int N, | |||
| FOR_FORMAT_BMODE(float) | |||
| MEGDNN_INC_FLOAT16(FOR_FORMAT_BMODE(dt_float16)) | |||
| MEGDNN_INC_FLOAT16(FOR_FORMAT_BMODE(dt_bfloat16)) | |||
| FOR_FORMAT_BMODE(int8_t) | |||
| FOR_FORMAT_BMODE(uint8_t) | |||
| #undef FOR_BMODE | |||
| #undef FOR_FORMAT_BMODE | |||
| #undef INST | |||
| } // namespace remap | |||
| } // namespace cuda | |||
| } // namespace megdnn | |||
| @@ -15,13 +15,41 @@ | |||
| namespace megdnn { | |||
| namespace cuda { | |||
| class RemapImpl final : public Remap { | |||
| public: | |||
| using Remap::Remap; | |||
| void exec(_megdnn_tensor_in, _megdnn_tensor_in, _megdnn_tensor_out, | |||
| _megdnn_workspace) override; | |||
| void exec(_megdnn_tensor_in src, _megdnn_tensor_in map_xy, | |||
| _megdnn_tensor_out dst, _megdnn_workspace workspace) override; | |||
| size_t get_workspace_in_bytes(const TensorLayout&, const TensorLayout&, | |||
| const TensorLayout&) override { | |||
| size_t get_workspace_in_bytes(const TensorLayout& src, | |||
| const TensorLayout& map_xy, | |||
| const TensorLayout& dst) override { | |||
| return 0; | |||
| } | |||
| }; | |||
| class RemapBackwardDataImpl final : public RemapBackwardData { | |||
| public: | |||
| using RemapBackwardData::RemapBackwardData; | |||
| void exec(_megdnn_tensor_in map_xy, _megdnn_tensor_in diff, | |||
| _megdnn_tensor_out grad, _megdnn_workspace workspace) override; | |||
| size_t get_workspace_in_bytes(const TensorLayout& map_xy, | |||
| const TensorLayout& diff, | |||
| const TensorLayout& grad) override { | |||
| return 0; | |||
| } | |||
| }; | |||
| class RemapBackwardMatImpl final : public RemapBackwardMat { | |||
| public: | |||
| using RemapBackwardMat::RemapBackwardMat; | |||
| void exec(_megdnn_tensor_in src, _megdnn_tensor_in map_xy, | |||
| _megdnn_tensor_in diff, _megdnn_tensor_out grad, | |||
| _megdnn_workspace workspace) override; | |||
| size_t get_workspace_in_bytes(const TensorLayout& src, | |||
| const TensorLayout& map_xy, | |||
| const TensorLayout& diff, | |||
| const TensorLayout& grad) override { | |||
| return 0; | |||
| } | |||
| }; | |||
| @@ -12,11 +12,13 @@ | |||
| #include "src/naive/remap/opr_impl.h" | |||
| #include "src/common/cv/helper.h" | |||
| #include "src/common/rounding_converter.cuh" | |||
| #include "src/common/utils.h" | |||
| #include "src/naive/handle.h" | |||
| using namespace megdnn; | |||
| using namespace naive; | |||
| using namespace rounding; | |||
| namespace { | |||
| template <param::Remap::Format format> | |||
| @@ -36,35 +38,46 @@ inline int get_offset<param::Remap::Format::NHWC>(int height, int width, | |||
| return height * w * c + width * c + channel; | |||
| } | |||
| template <typename DataType, param::Remap::Format format, | |||
| template <typename ctype, param::Remap::Format format, | |||
| param::Remap::BorderMode bordertype> | |||
| struct GetSrcData { | |||
| static inline DataType get(const DataType* src, int height, int width, | |||
| int channel, int h, int w, int c, float, | |||
| std::function<DataType(float)>) { | |||
| static inline ctype get(const ctype* src, int height, int width, | |||
| int channel, int h, int w, int c, float) { | |||
| height = megcv::border_interpolate<bordertype>(height, h); | |||
| width = megcv::border_interpolate<bordertype>(width, w); | |||
| return src[get_offset<format>(height, width, channel, h, w, c)]; | |||
| } | |||
| static inline int get_index(int height, int width, int channel, int h, | |||
| int w, int c) { | |||
| height = megcv::border_interpolate<bordertype>(height, h); | |||
| width = megcv::border_interpolate<bordertype>(width, w); | |||
| return get_offset<format>(height, width, channel, h, w, c); | |||
| } | |||
| }; | |||
| template <typename DataType, param::Remap::Format format> | |||
| struct GetSrcData<DataType, format, param::Remap::BorderMode::CONSTANT> { | |||
| static inline DataType get(const DataType* src, int height, int width, | |||
| int channel, int h, int w, int c, float scalar, | |||
| std::function<DataType(float)> round) { | |||
| template <typename ctype, param::Remap::Format format> | |||
| struct GetSrcData<ctype, format, param::Remap::BorderMode::CONSTANT> { | |||
| static inline ctype get(const ctype* src, int height, int width, | |||
| int channel, int h, int w, int c, float scalar) { | |||
| RoundingConverter<ctype> round; | |||
| return (height >= 0 && height < h && width >= 0 && width < w) | |||
| ? src[get_offset<format>(height, width, channel, h, w, | |||
| c)] | |||
| : static_cast<DataType>(round(scalar)); | |||
| : round(scalar); | |||
| } | |||
| static inline int get_index(int height, int width, int channel, int h, | |||
| int w, int c) { | |||
| return (height >= 0 && height < h && width >= 0 && width < w) | |||
| ? get_offset<format>(height, width, channel, h, w, c) | |||
| : -1; | |||
| } | |||
| }; | |||
| template <typename DataType, param::Remap::Format format, | |||
| template <typename ctype, param::Remap::Format format, | |||
| param::Remap::BorderMode bordertype> | |||
| void remap_LINEAR(const DataType* src, const float* map_xy, DataType* dst, | |||
| int N, int C, int IH, int IW, int OH, int OW, float scalar, | |||
| std::function<DataType(float)> round) { | |||
| void remap_LINEAR(const ctype* src, const float* map_xy, ctype* dst, int N, | |||
| int C, int IH, int IW, int OH, int OW, float scalar) { | |||
| RoundingConverter<ctype> round_converter; | |||
| for (int n = 0; n < N; | |||
| ++n, src += C * IH * IW, dst += C * OH * OW, map_xy += OH * OW * 2) { | |||
| for (int h = 0; h < OH; ++h) { | |||
| @@ -73,47 +86,131 @@ void remap_LINEAR(const DataType* src, const float* map_xy, DataType* dst, | |||
| float index_row = map_xy[h * OW * 2 + w * 2 + 1]; | |||
| int col = static_cast<int>(floor(index_col)); | |||
| int row = static_cast<int>(floor(index_row)); | |||
| float v = index_col - col; | |||
| float u = index_row - row; | |||
| float one = 1.f; | |||
| float v = index_col - col; // alphaw | |||
| float u = index_row - row; // alphah | |||
| const float one = 1.f; | |||
| for (int c = 0; c < C; ++c) { | |||
| DataType a00 = | |||
| GetSrcData<DataType, format, bordertype>::get( | |||
| src, row + 0, col + 0, c, IH, IW, C, scalar, | |||
| round); | |||
| DataType a01 = | |||
| GetSrcData<DataType, format, bordertype>::get( | |||
| src, row + 0, col + 1, c, IH, IW, C, scalar, | |||
| round); | |||
| DataType a10 = | |||
| GetSrcData<DataType, format, bordertype>::get( | |||
| src, row + 1, col + 0, c, IH, IW, C, scalar, | |||
| round); | |||
| DataType a11 = | |||
| GetSrcData<DataType, format, bordertype>::get( | |||
| src, row + 1, col + 1, c, IH, IW, C, scalar, | |||
| round); | |||
| ctype a00 = GetSrcData<ctype, format, bordertype>::get( | |||
| src, row + 0, col + 0, c, IH, IW, C, scalar); | |||
| ctype a01 = GetSrcData<ctype, format, bordertype>::get( | |||
| src, row + 0, col + 1, c, IH, IW, C, scalar); | |||
| ctype a10 = GetSrcData<ctype, format, bordertype>::get( | |||
| src, row + 1, col + 0, c, IH, IW, C, scalar); | |||
| ctype a11 = GetSrcData<ctype, format, bordertype>::get( | |||
| src, row + 1, col + 1, c, IH, IW, C, scalar); | |||
| dst[get_offset<format>(h, w, c, OH, OW, C)] = | |||
| static_cast<DataType>( | |||
| round(a00 * (one - u) * (one - v) + | |||
| a01 * (one - u) * v + | |||
| a10 * (one - v) * u + a11 * u * v)); | |||
| round_converter(a00 * (one - v) * (one - u) + | |||
| a01 * (one - u) * v + | |||
| a10 * (one - v) * u + a11 * u * v); | |||
| } | |||
| } | |||
| } | |||
| } | |||
| } | |||
| template <typename DataType, DTypeCategory cat> | |||
| struct Round { | |||
| static inline DataType round(float x) { return std::round(x); } | |||
| }; | |||
| template <typename ctype, param::Remap::Format format, | |||
| param::Remap::BorderMode bordertype> | |||
| void remap_LINEAR_backwarddata(ctype* grad, const float* map_xy, | |||
| const ctype* diff, int N, int C, int IH, int IW, | |||
| int OH, int OW) { | |||
| RoundingConverter<ctype> round_converter; | |||
| std::memset(grad, 0, sizeof(ctype) * N * C * IH * IW); | |||
| for (int n = 0; n < N; | |||
| ++n, grad += C * IH * IW, diff += C * OH * OW, map_xy += OH * OW * 2) { | |||
| for (int h = 0; h < OH; ++h) { | |||
| for (int w = 0; w < OW; ++w) { | |||
| float index_col = map_xy[h * OW * 2 + w * 2 + 0]; | |||
| float index_row = map_xy[h * OW * 2 + w * 2 + 1]; | |||
| int col = static_cast<int>(floor(index_col)); | |||
| int row = static_cast<int>(floor(index_row)); | |||
| float v = index_col - col; // alphaw | |||
| float u = index_row - row; // alphah | |||
| const float one = 1.f; | |||
| for (int c = 0; c < C; ++c) { | |||
| ctype hidden = diff[get_offset<format>(h, w, c, OH, OW, C)]; | |||
| template <typename DataType> | |||
| struct Round<DataType, DTypeCategory::FLOAT> { | |||
| static inline DataType round(float x) { return static_cast<DataType>(x); } | |||
| }; | |||
| int a00 = GetSrcData<ctype, format, bordertype>::get_index( | |||
| row + 0, col + 0, c, IH, IW, C); | |||
| if (a00 != -1) { | |||
| grad[a00] += | |||
| round_converter((one - v) * (one - u) * hidden); | |||
| } | |||
| int a01 = GetSrcData<ctype, format, bordertype>::get_index( | |||
| row + 0, col + 1, c, IH, IW, C); | |||
| if (a01 != -1) { | |||
| grad[a01] += round_converter((one - u) * v * hidden); | |||
| } | |||
| int a10 = GetSrcData<ctype, format, bordertype>::get_index( | |||
| row + 1, col + 0, c, IH, IW, C); | |||
| if (a10 != -1) { | |||
| grad[a10] += round_converter(u * (one - v) * hidden); | |||
| } | |||
| int a11 = GetSrcData<ctype, format, bordertype>::get_index( | |||
| row + 1, col + 1, c, IH, IW, C); | |||
| if (a11 != -1) { | |||
| grad[a11] += round_converter(v * u * hidden); | |||
| } | |||
| } | |||
| } | |||
| } | |||
| } | |||
| } | |||
| template <typename ctype, param::Remap::Format format, | |||
| param::Remap::BorderMode bordertype> | |||
| void remap_LINEAR_backwardmat(const ctype* src, const float* map_xy, | |||
| const ctype* diff, float* grad, int N, int C, | |||
| int IH, int IW, int OH, int OW, float scalar) { | |||
| RoundingConverter<ctype> round_converter; | |||
| std::memset(grad, 0, sizeof(float) * N * 2 * OH * OW); | |||
| for (int n = 0; n < N; ++n, src += C * IH * IW, diff += C * OH * OW, | |||
| map_xy += OH * OW * 2, grad += OH * OW * 2) { | |||
| for (int h = 0; h < OH; ++h) { | |||
| for (int w = 0; w < OW; ++w) { | |||
| float index_col = map_xy[h * OW * 2 + w * 2 + 0]; | |||
| float index_row = map_xy[h * OW * 2 + w * 2 + 1]; | |||
| int col = static_cast<int>(floor(index_col)); | |||
| int row = static_cast<int>(floor(index_row)); | |||
| float v = index_col - col; // alphaw | |||
| float u = index_row - row; // alphah | |||
| const float one = 1.f; | |||
| for (int c = 0; c < C; ++c) { | |||
| float hidden = static_cast<float>( | |||
| diff[get_offset<format>(h, w, c, OH, OW, C)]); | |||
| float du = 0.f, dv = 0.f; | |||
| int a00 = GetSrcData<ctype, format, bordertype>::get_index( | |||
| row + 0, col + 0, c, IH, IW, C); | |||
| int a01 = GetSrcData<ctype, format, bordertype>::get_index( | |||
| row + 0, col + 1, c, IH, IW, C); | |||
| int a10 = GetSrcData<ctype, format, bordertype>::get_index( | |||
| row + 1, col + 0, c, IH, IW, C); | |||
| int a11 = GetSrcData<ctype, format, bordertype>::get_index( | |||
| row + 1, col + 1, c, IH, IW, C); | |||
| dv -= ((a00 != -1) ? src[a00] : scalar) * (one - u); | |||
| dv += ((a01 != -1) ? src[a01] : scalar) * (one - u); | |||
| dv -= ((a10 != -1) ? src[a10] : scalar) * u; | |||
| dv += ((a11 != -1) ? src[a11] : scalar) * u; | |||
| du -= ((a00 != -1) ? src[a00] : scalar) * (one - v); | |||
| du -= ((a01 != -1) ? src[a01] : scalar) * v; | |||
| du += ((a10 != -1) ? src[a10] : scalar) * (one - v); | |||
| du += ((a11 != -1) ? src[a11] : scalar) * v; | |||
| grad[h * OW * 2 + w * 2 + 0] += | |||
| round_converter(hidden * dv); | |||
| grad[h * OW * 2 + w * 2 + 1] += | |||
| round_converter(hidden * du); | |||
| } | |||
| } | |||
| } | |||
| } | |||
| } | |||
| } // namespace | |||
| @@ -148,8 +245,7 @@ void RemapImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_in map_xy, | |||
| src.compatible_ptr<ctype>(), \ | |||
| map_xy.compatible_ptr<dt_float32>(), \ | |||
| dst.compatible_ptr<ctype>(), N, C, IH, IW, OH, OW, \ | |||
| param().scalar, \ | |||
| Round<ctype, DTypeTrait<dt>::category>::round))); \ | |||
| param().scalar))); \ | |||
| break; \ | |||
| } | |||
| @@ -172,6 +268,7 @@ void RemapImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_in map_xy, | |||
| support_dtype(dtype::Float32); | |||
| MEGDNN_INC_FLOAT16(support_dtype(dtype::Float16)); | |||
| MEGDNN_INC_FLOAT16(support_dtype(dtype::BFloat16)); | |||
| support_dtype(dtype::Int8); | |||
| support_dtype(dtype::Uint8); | |||
| #undef cb | |||
| @@ -181,3 +278,109 @@ void RemapImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_in map_xy, | |||
| megdnn_throw("unsupported dtype in remap naive\n"); | |||
| } | |||
| } | |||
| void RemapBackwardDataImpl::exec(_megdnn_tensor_in map_xy, | |||
| _megdnn_tensor_in diff, | |||
| _megdnn_tensor_out grad, | |||
| _megdnn_workspace workspace) { | |||
| check_exec(map_xy.layout, diff.layout, grad.layout, workspace.size); | |||
| megdnn_assert(param().format == param::Remap::Format::NCHW, | |||
| "only support NCHW format for remap backward"); | |||
| int N, C, IH, IW, OH, OW; | |||
| N = grad.layout.shape[0]; | |||
| C = grad.layout.shape[1]; | |||
| IH = grad.layout.shape[2]; | |||
| IW = grad.layout.shape[3]; | |||
| OH = map_xy.layout.shape[1]; | |||
| OW = map_xy.layout.shape[2]; | |||
| switch (diff.layout.dtype.enumv()) { | |||
| #define cb(dt, fmt, border, interpolation) \ | |||
| if (param().format == param::Remap::Format::fmt && \ | |||
| param().border_type == param::Remap::BorderMode::border && \ | |||
| param().imode == param::Remap::InterpolationMode::interpolation) { \ | |||
| using ctype = DTypeTrait<dt>::ctype; \ | |||
| MEGDNN_DISPATCH_CPU_KERN_OPR((remap_##interpolation##_backwarddata< \ | |||
| ctype, param::Remap::Format::fmt, \ | |||
| param::Remap::BorderMode::border>( \ | |||
| grad.compatible_ptr<ctype>(), \ | |||
| map_xy.compatible_ptr<dt_float32>(), \ | |||
| diff.compatible_ptr<ctype>(), N, C, IH, IW, OH, OW))); \ | |||
| break; \ | |||
| } | |||
| #define support_dtype(dt) \ | |||
| case DTypeTrait<dt>::enumv: { \ | |||
| cb(dt, NCHW, CONSTANT, LINEAR); \ | |||
| cb(dt, NCHW, REPLICATE, LINEAR); \ | |||
| cb(dt, NCHW, REFLECT, LINEAR); \ | |||
| cb(dt, NCHW, REFLECT_101, LINEAR); \ | |||
| cb(dt, NCHW, WRAP, LINEAR); \ | |||
| megdnn_throw( \ | |||
| "format, border type or imode is incorrect in remap navie " \ | |||
| "with dtype = " #dt); \ | |||
| } | |||
| support_dtype(dtype::Float32); | |||
| MEGDNN_INC_FLOAT16(support_dtype(dtype::BFloat16)); | |||
| #undef cb | |||
| #undef support_dtype | |||
| default: | |||
| megdnn_throw("unsupported dtype in remap backward naive\n"); | |||
| } | |||
| } | |||
| void RemapBackwardMatImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_in map_xy, | |||
| _megdnn_tensor_in diff, _megdnn_tensor_out grad, | |||
| _megdnn_workspace workspace) { | |||
| check_exec(src.layout, map_xy.layout, diff.layout, grad.layout, | |||
| workspace.size); | |||
| megdnn_assert(param().format == param::Remap::Format::NCHW, | |||
| "only support NCHW format for remap backward"); | |||
| int N, C, IH, IW, OH, OW; | |||
| N = src.layout.shape[0]; | |||
| C = src.layout.shape[1]; | |||
| IH = src.layout.shape[2]; | |||
| IW = src.layout.shape[3]; | |||
| OH = map_xy.layout.shape[1]; | |||
| OW = map_xy.layout.shape[2]; | |||
| switch (src.layout.dtype.enumv()) { | |||
| #define cb(dt, fmt, border, interpolation) \ | |||
| if (param().format == param::Remap::Format::fmt && \ | |||
| param().border_type == param::Remap::BorderMode::border && \ | |||
| param().imode == param::Remap::InterpolationMode::interpolation) { \ | |||
| using ctype = DTypeTrait<dt>::ctype; \ | |||
| MEGDNN_DISPATCH_CPU_KERN_OPR((remap_##interpolation##_backwardmat< \ | |||
| ctype, param::Remap::Format::fmt, \ | |||
| param::Remap::BorderMode::border>( \ | |||
| src.compatible_ptr<ctype>(), \ | |||
| map_xy.compatible_ptr<dt_float32>(), \ | |||
| diff.compatible_ptr<ctype>(), \ | |||
| grad.compatible_ptr<dt_float32>(), N, C, IH, IW, OH, OW, \ | |||
| param().scalar))); \ | |||
| break; \ | |||
| } | |||
| #define support_dtype(dt) \ | |||
| case DTypeTrait<dt>::enumv: { \ | |||
| cb(dt, NCHW, CONSTANT, LINEAR); \ | |||
| cb(dt, NCHW, REPLICATE, LINEAR); \ | |||
| cb(dt, NCHW, REFLECT, LINEAR); \ | |||
| cb(dt, NCHW, REFLECT_101, LINEAR); \ | |||
| cb(dt, NCHW, WRAP, LINEAR); \ | |||
| megdnn_throw( \ | |||
| "format, border type or imode is incorrect in remap navie " \ | |||
| "with dtype = " #dt); \ | |||
| } | |||
| support_dtype(dtype::Float32); | |||
| MEGDNN_INC_FLOAT16(support_dtype(dtype::BFloat16)); | |||
| #undef cb | |||
| #undef support_dtype | |||
| default: | |||
| megdnn_throw("unsupported dtype in remap backward naive\n"); | |||
| } | |||
| } | |||
| // vim: syntax=cpp.doxygen | |||
| @@ -23,6 +23,33 @@ class RemapImpl final : public Remap { | |||
| return 0; | |||
| } | |||
| }; | |||
| class RemapBackwardDataImpl final : public RemapBackwardData { | |||
| public: | |||
| using RemapBackwardData::RemapBackwardData; | |||
| void exec(_megdnn_tensor_in map_xy, _megdnn_tensor_in diff, | |||
| _megdnn_tensor_out grad, _megdnn_workspace workspace) override; | |||
| size_t get_workspace_in_bytes(const TensorLayout&, | |||
| const TensorLayout&, | |||
| const TensorLayout&) override { | |||
| return 0; | |||
| } | |||
| }; | |||
| class RemapBackwardMatImpl final : public RemapBackwardMat { | |||
| public: | |||
| using RemapBackwardMat::RemapBackwardMat; | |||
| void exec(_megdnn_tensor_in src, _megdnn_tensor_in map_xy, | |||
| _megdnn_tensor_in diff, _megdnn_tensor_out grad, | |||
| _megdnn_workspace workspace) override; | |||
| size_t get_workspace_in_bytes(const TensorLayout&, | |||
| const TensorLayout&, | |||
| const TensorLayout&, | |||
| const TensorLayout&) override { | |||
| return 0; | |||
| } | |||
| }; | |||
| } // namespace naive | |||
| } // namespace megdnn | |||
| @@ -106,6 +106,8 @@ DEF(DeformablePSROIPoolingForward, 5, true, true); | |||
| DEF(DeformablePSROIPoolingBackward, 7, true, false); | |||
| DEF(BatchConvBiasForward, 5, true, true); | |||
| DEF(Remap, 3, true, true); | |||
| DEF(RemapBackwardData, 3, true, false); | |||
| DEF(RemapBackwardMat, 4, true, false); | |||
| } // namespace test | |||
| } // namespace megdnn | |||
| @@ -46,6 +46,9 @@ static inline std::vector<TestArg> get_nchw_args() { | |||
| for (auto border_type : border_mode_vec) { | |||
| param.format = fmt; | |||
| param.border_type = border_type; | |||
| args.emplace_back(param, TensorShape{70000, 1, 2, 2}, | |||
| TensorShape{70000, 2, 2, 2}, TensorShape{70000, 1, 2, 2}); | |||
| args.emplace_back(param, TensorShape{1, 1, 2, 2}, | |||
| TensorShape{1, 2, 2, 2}, TensorShape{1, 1, 2, 2}); | |||
| @@ -90,6 +93,9 @@ static inline std::vector<TestArg> get_nhwc_args() { | |||
| param.format = fmt; | |||
| param.border_type = border_type; | |||
| param.scalar = 12.f; | |||
| args.emplace_back(param, TensorShape{70000, 2, 2, 1}, | |||
| TensorShape{70000, 2, 2, 2}, TensorShape{70000, 2, 2, 1}); | |||
| args.emplace_back(param, TensorShape{1, 2, 2, 1}, | |||
| TensorShape{1, 2, 2, 2}, TensorShape{1, 2, 2, 1}); | |||
| @@ -40,6 +40,22 @@ TEST_F(CUDA, REMAP_NCHW_FLOAT) { | |||
| cb(dtype::Float32(), float_rng); | |||
| cb(dtype::Float16(), float_rng); | |||
| #undef cb | |||
| #define cb(data_type, data_rng) \ | |||
| for (auto arg : args) { \ | |||
| UniformFloatRNG map_rng( \ | |||
| -2, std::max(arg.map_xy.shape[2], arg.map_xy.shape[1]) + 2); \ | |||
| checker.set_dtype(0, data_type) \ | |||
| .set_dtype(1, dtype::Float32()) \ | |||
| .set_dtype(2, data_type) \ | |||
| .set_rng(0, &data_rng) \ | |||
| .set_rng(1, &map_rng) \ | |||
| .set_rng(2, &data_rng) \ | |||
| .set_param(arg.param) \ | |||
| .set_epsilon(1e-2) \ | |||
| .execs({arg.src, arg.map_xy, arg.dst}); \ | |||
| } | |||
| cb(dtype::BFloat16(), float_rng); | |||
| #undef cb | |||
| } | |||
| TEST_F(CUDA, REMAP_NCHW_INT) { | |||
| @@ -87,6 +103,22 @@ TEST_F(CUDA, REMAP_NHWC_FLOAT) { | |||
| cb(dtype::Float32(), float_rng); | |||
| cb(dtype::Float16(), float_rng); | |||
| #undef cb | |||
| #define cb(data_type, data_rng) \ | |||
| for (auto arg : args) { \ | |||
| UniformFloatRNG map_rng( \ | |||
| -2, std::max(arg.map_xy.shape[2], arg.map_xy.shape[1]) + 2); \ | |||
| checker.set_dtype(0, data_type) \ | |||
| .set_dtype(1, dtype::Float32()) \ | |||
| .set_dtype(2, data_type) \ | |||
| .set_rng(0, &data_rng) \ | |||
| .set_rng(1, &map_rng) \ | |||
| .set_rng(2, &data_rng) \ | |||
| .set_param(arg.param) \ | |||
| .set_epsilon(1e-2) \ | |||
| .execs({arg.src, arg.map_xy, arg.dst}); \ | |||
| } | |||
| cb(dtype::BFloat16(), float_rng); | |||
| #undef cb | |||
| } | |||
| TEST_F(CUDA, REMAP_NHWC_INT) { | |||
| @@ -114,6 +146,85 @@ TEST_F(CUDA, REMAP_NHWC_INT) { | |||
| #undef cb | |||
| } | |||
| TEST_F(CUDA, REMAP_BACKWARD_DATA) { | |||
| Checker<RemapBackwardData> checker(handle_cuda()); | |||
| std::vector<TestArg> args = get_nchw_args(); | |||
| UniformFloatRNG float_rng(0, 255); | |||
| #define cb(data_type, data_rng) \ | |||
| for (auto arg : args) { \ | |||
| UniformFloatRNG map_rng( \ | |||
| -2, std::max(arg.map_xy.shape[2], arg.map_xy.shape[1]) + 2); \ | |||
| checker.set_dtype(1, data_type) \ | |||
| .set_dtype(0, dtype::Float32()) \ | |||
| .set_dtype(2, data_type) \ | |||
| .set_rng(1, &data_rng) \ | |||
| .set_rng(0, &map_rng) \ | |||
| .set_rng(2, &data_rng) \ | |||
| .set_param(arg.param) \ | |||
| .execs({arg.map_xy, arg.dst, arg.src}); \ | |||
| } | |||
| cb(dtype::Float32(), float_rng); | |||
| #undef cb | |||
| #define cb(data_type, data_rng) \ | |||
| for (auto arg : args) { \ | |||
| UniformFloatRNG map_rng( \ | |||
| -2, std::max(arg.map_xy.shape[2], arg.map_xy.shape[1]) + 2); \ | |||
| checker.set_dtype(1, data_type) \ | |||
| .set_dtype(0, dtype::Float32()) \ | |||
| .set_dtype(2, data_type) \ | |||
| .set_rng(1, &data_rng) \ | |||
| .set_rng(0, &map_rng) \ | |||
| .set_rng(2, &data_rng) \ | |||
| .set_param(arg.param) \ | |||
| .set_epsilon(1e-1) \ | |||
| .execs({arg.map_xy, arg.dst, arg.src}); \ | |||
| } | |||
| cb(dtype::BFloat16(), float_rng); | |||
| #undef cb | |||
| } | |||
| TEST_F(CUDA, REMAP_BACKWARD_MAT) { | |||
| Checker<RemapBackwardMat> checker(handle_cuda()); | |||
| std::vector<TestArg> args = get_nchw_args(); | |||
| UniformFloatRNG float_rng(0, 255); | |||
| #define cb(data_type, data_rng) \ | |||
| for (auto arg : args) { \ | |||
| UniformFloatRNG map_rng( \ | |||
| -2, std::max(arg.map_xy.shape[2], arg.map_xy.shape[1]) + 2); \ | |||
| checker.set_dtype(0, data_type) \ | |||
| .set_dtype(1, dtype::Float32()) \ | |||
| .set_dtype(2, data_type) \ | |||
| .set_dtype(3, dtype::Float32()) \ | |||
| .set_rng(0, &data_rng) \ | |||
| .set_rng(1, &map_rng) \ | |||
| .set_rng(2, &data_rng) \ | |||
| .set_rng(3, &map_rng) \ | |||
| .set_param(arg.param) \ | |||
| .set_epsilon(2e-2) \ | |||
| .execs({arg.src, arg.map_xy, arg.dst, arg.map_xy}); \ | |||
| } | |||
| cb(dtype::Float32(), float_rng); | |||
| #undef cb | |||
| #define cb(data_type, data_rng) \ | |||
| for (auto arg : args) { \ | |||
| UniformFloatRNG map_rng( \ | |||
| -2, std::max(arg.map_xy.shape[2], arg.map_xy.shape[1]) + 2); \ | |||
| checker.set_dtype(0, data_type) \ | |||
| .set_dtype(1, dtype::Float32()) \ | |||
| .set_dtype(2, data_type) \ | |||
| .set_dtype(3, dtype::Float32()) \ | |||
| .set_rng(0, &data_rng) \ | |||
| .set_rng(1, &map_rng) \ | |||
| .set_rng(2, &data_rng) \ | |||
| .set_rng(3, &map_rng) \ | |||
| .set_param(arg.param) \ | |||
| .set_epsilon(1e-1) \ | |||
| .execs({arg.src, arg.map_xy, arg.dst, arg.map_xy}); \ | |||
| } | |||
| cb(dtype::BFloat16(), float_rng); | |||
| #undef cb | |||
| } | |||
| #if MEGDNN_WITH_BENCHMARK | |||
| TEST_F(CUDA, BENCHMARK_REMAP) { | |||
| @@ -144,13 +255,31 @@ TEST_F(CUDA, BENCHMARK_REMAP) { | |||
| .execs(shapes); | |||
| auto t2 = benchmarker_cuda.set_display(false).set_param(param).execs( | |||
| shapes); | |||
| int size = 0; | |||
| if (dtype == dtype::Float32{}) { | |||
| size = sizeof(float); | |||
| printf("float32: "); | |||
| } else if (dtype == dtype::Float16{}) { | |||
| size = sizeof(dt_float16); | |||
| printf("float16: "); | |||
| } else if (dtype == dtype::Int8{}) { | |||
| size = sizeof(dt_int8); | |||
| printf("int8: "); | |||
| } else if (dtype == dtype::Uint8{}) { | |||
| size = sizeof(dt_uint8); | |||
| printf("uint8: "); | |||
| } | |||
| const TensorShape map_xy = shapes[1]; | |||
| const TensorShape dst_layout = shapes[2]; | |||
| float calc_amount = dst_layout.total_nr_elems(); | |||
| printf("naive={%.3fms, %.3fMflops}, " | |||
| "cuda={%.3fms, %.3fMflops}\n", | |||
| t1 / RUN, calc_amount / (t1 / RUN * 1000), t2, | |||
| calc_amount / (t2 * 1000)); | |||
| float calc_amount = (dst_layout.total_nr_elems() * (4.f + 1.f) * size + | |||
| map_xy.total_nr_elems() * sizeof(float)) / | |||
| (1024 * 1024 * 1024); | |||
| printf("naive={%.3fms, %.3fGBPS}, " | |||
| "cuda={%.3fms, %.3fGBPS}\n", | |||
| t1 / RUN, calc_amount / (t1 / RUN) * 1e3, t2, | |||
| calc_amount / t2 * 1e3); | |||
| }; | |||
| Param param; | |||
| param.imode = param::Remap::InterpolationMode::LINEAR; | |||
| @@ -84,6 +84,7 @@ from .nn import ( | |||
| max_pool2d, | |||
| one_hot, | |||
| prelu, | |||
| remap, | |||
| roi_align, | |||
| roi_pooling, | |||
| softmax, | |||
| @@ -705,6 +705,61 @@ def warp_perspective( | |||
| ) | |||
| @wrap_io_tensor | |||
| def remap( | |||
| inp: Tensor, | |||
| map_xy: Tensor, | |||
| border_mode: str = "REPLICATE", | |||
| scalar: float = 0.0, | |||
| interp_mode: str = "LINEAR", | |||
| ) -> Tensor: | |||
| r""" | |||
| Applies remap transformation to batched 2D images. | |||
| The input images are transformed to the output images by the tensor map_xy. | |||
| The output's H and W are same as map_xy's H and W. | |||
| :param inp: input image | |||
| :param map_xy: (batch, oh, ow, 2) transformation matrix | |||
| :param border_mode: pixel extrapolation method. Default: ``"REPLICATE"`` | |||
| :param scalar: value used in case of a constant border. Default: ``0`` | |||
| :param interp_mode: interpolation methods. Default: ``"LINEAR"`` | |||
| Examples: | |||
| .. testcode:: | |||
| import numpy as np | |||
| from megengine import tensor | |||
| import megengine.functional as F | |||
| inp_shape = (1, 1, 4, 4) | |||
| inp = tensor(np.arange(16, dtype=np.float32).reshape(inp_shape)) | |||
| map_xy_shape = (1, 2, 2, 2) | |||
| map_xy = tensor(np.array([[[1., 0.],[0., 1.]], | |||
| [[0., 1.],[0., 1.]]], | |||
| dtype=np.float32).reshape(map_xy_shape)) | |||
| out = F.remap(inp, map_xy) | |||
| print(out.numpy()) | |||
| Outputs: | |||
| .. testoutput:: | |||
| [[[[1. 4.] | |||
| [4. 4.]]]] | |||
| """ | |||
| return mgb.opr.remap( | |||
| inp, | |||
| map_xy, | |||
| border_type=border_mode, | |||
| scalar=scalar, | |||
| imode=interp_mode, | |||
| format="NCHW", | |||
| ) | |||
| @wrap_io_tensor | |||
| def eye( | |||
| n: int, | |||
| @@ -443,4 +443,29 @@ void RemapForward::init_output_dtype() { | |||
| output(0)->dtype(input(0)->dtype()); | |||
| } | |||
| #ifdef MGB_ENABLE_GRAD | |||
| MGB_IMPL_OPR_GRAD(RemapForward) { | |||
| mgb_assert(opr.input().size() == 2); | |||
| if (wrt_idx == 0) { | |||
| SymbolVar grad = | |||
| RemapBackwardData::make(opr.input(1), out_grad[0], | |||
| opr.input(0), opr.param()); | |||
| return grad.node(); | |||
| } else if (wrt_idx == 1) { | |||
| SymbolVar grad = | |||
| RemapBackwardMat::make(opr.input(0), opr.input(1), | |||
| out_grad[0], opr.param()); | |||
| return grad.node(); | |||
| } else | |||
| return InvalidGrad::make(opr, wrt_idx); | |||
| } | |||
| #endif | |||
| /* ====================== RemapBackward ====================== */ | |||
| MGB_DYN_TYPE_OBJ_FINAL_IMPL(RemapBackwardData); | |||
| MEGDNN_OPR_INIT3(RemapBackwardData, "remap_bwd_data", 2, false); | |||
| MGB_DYN_TYPE_OBJ_FINAL_IMPL(RemapBackwardMat); | |||
| MEGDNN_OPR_INIT3(RemapBackwardMat, "remap_bwd_mat", 1, true); | |||
| // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} | |||
| @@ -97,6 +97,8 @@ namespace opr { | |||
| MGB_SEREG_OPR(ResizeBackward, 2); | |||
| MGB_SEREG_OPR(Remap, 2); | |||
| MGB_SEREG_OPR(RemapBackwardData, 3); | |||
| MGB_SEREG_OPR(RemapBackwardMat, 3); | |||
| //! current warp affine version | |||
| using WarpAffineV1 = opr::WarpAffine; | |||
| @@ -74,7 +74,7 @@ size_t get_workspace_size_bytes( | |||
| const TensorShapeArray& output_shapes) const override; | |||
| void record_execute_deps(ExecDependencyArray& deps) override; | |||
| }; // namespace opr | |||
| }; | |||
| using WarpPerspective = WarpPerspectiveForward; | |||
| MGB_DEFINE_OPR_CLASS( | |||
| @@ -98,7 +98,7 @@ static SymbolVar make(SymbolVar mat, SymbolVar mat_idx, SymbolVar out_diff, | |||
| const OperatorNodeConfig& config = {}); | |||
| void scn_do_execute() override; | |||
| }; // namespace mgb | |||
| }; | |||
| MGB_DEFINE_OPR_CLASS( | |||
| WarpPerspectiveBackwardMat, | |||
| @@ -119,8 +119,7 @@ static SymbolVar make(SymbolVar src, SymbolVar mat, SymbolVar mat_idx, | |||
| const OperatorNodeConfig& config = {}); | |||
| void scn_do_execute() override; | |||
| } | |||
| ; | |||
| }; | |||
| /* ============================= shape infer ============================== */ | |||
| //! param: src, dst | |||
| @@ -164,8 +163,7 @@ size_t get_workspace_size_bytes( | |||
| const TensorShapeArray& input_shapes, | |||
| const TensorShapeArray& output_shapes) const override; | |||
| void record_execute_deps(ExecDependencyArray& deps) override; | |||
| } | |||
| ; | |||
| }; | |||
| using Resize = ResizeForward; | |||
| MGB_DEFINE_OPR_CLASS(ResizeBackward, | |||
| @@ -177,8 +175,7 @@ ResizeBackward(VarNode* out_diff, VarNode* in_for_shape, const Param& param, | |||
| static SymbolVar make(SymbolVar out_diff, SymbolVar in_for_shape, | |||
| const Param& param = {}, | |||
| const OperatorNodeConfig& config = {}); | |||
| } | |||
| ; | |||
| }; | |||
| MGB_DEFINE_OPR_CLASS(RemapForward, | |||
| intl::MegDNNOprWrapperFwd<megdnn::RemapForward>) // { | |||
| @@ -192,10 +189,31 @@ static SymbolVar make(SymbolVar in_tensor, SymbolVar map, | |||
| private: | |||
| void init_output_dtype() override; | |||
| } | |||
| ; | |||
| }; | |||
| using Remap = RemapForward; | |||
| MGB_DEFINE_OPR_CLASS(RemapBackwardData, | |||
| intl::MegDNNOprWrapperBwd<megdnn::RemapBackwardData>) // { | |||
| public: | |||
| RemapBackwardData(VarNode *map, VarNode *out_diff, | |||
| VarNode *in_for_shape, const Param ¶m, | |||
| const OperatorNodeConfig &config); | |||
| static SymbolVar make(SymbolVar map, SymbolVar out_diff, | |||
| SymbolVar in_for_shape, const Param ¶m = {}, | |||
| const OperatorNodeConfig &config = {}); | |||
| }; | |||
| MGB_DEFINE_OPR_CLASS(RemapBackwardMat, | |||
| intl::MegDNNOprWrapperBwd<megdnn::RemapBackwardMat>) // { | |||
| public: | |||
| RemapBackwardMat(VarNode *src, VarNode *map, VarNode *out_diff, | |||
| const Param ¶m, const OperatorNodeConfig &config); | |||
| static SymbolVar make(SymbolVar src, SymbolVar map, SymbolVar out_diff, | |||
| const Param ¶m = {}, const OperatorNodeConfig &config = {}); | |||
| }; | |||
| /*! | |||
| * \brief apply affine transformation to batched 2D images | |||
| * | |||
| @@ -238,8 +256,7 @@ size_t get_workspace_size_bytes( | |||
| const TensorShapeArray& input_shapes, | |||
| const TensorShapeArray& output_shapes) const override; | |||
| void record_execute_deps(ExecDependencyArray& deps) override; | |||
| } | |||
| ; | |||
| }; | |||
| using WarpAffine = WarpAffineForward; | |||
| } // opr | |||
| @@ -640,11 +640,11 @@ TEST(TestOprImgproc, WarpAffineForward) { | |||
| } | |||
| TEST(TestOprImgproc, Remap_NCHW) { | |||
| constexpr size_t N = 2, C = 8; | |||
| constexpr size_t N = 2, C = 8, OH = 10, OW = 10; | |||
| opr::Remap::Param param; | |||
| using Checker = AutoOprChecker<2, 1>; | |||
| TensorShape out_shp{N, C, 10, 10}; | |||
| TensorShape out_shp{N, C, OH, OW}; | |||
| param.format = opr::Remap::Param::Format::NCHW; | |||
| auto make_graph = [&](const Checker::SymInpArray &inputs) -> | |||
| Checker::SymOutArray { | |||
| @@ -657,12 +657,34 @@ TEST(TestOprImgproc, Remap_NCHW) { | |||
| opr->exec(inp[0]->as_megdnn(), inp[1]->as_megdnn(), dest[0].as_megdnn(), {}); | |||
| }; | |||
| std::mt19937 rng(next_rand_seed()); | |||
| auto rand_real = [&](double lo, double hi) { | |||
| auto real = rng() / (std::mt19937::max() + 1.0) * (hi - lo) + lo; | |||
| if(std::abs(std::round(real) - real) <= 1e-2) | |||
| return real + 1e-1; | |||
| return real; | |||
| }; | |||
| auto rand_real2 = [&](double range) { | |||
| return rand_real(-range, range); | |||
| }; | |||
| auto gen_mat = [&](HostTensorND& mat) { | |||
| auto ptr = mat.ptr<float>(); | |||
| for (size_t i = 0; i < N; ++ i) { | |||
| for(size_t j = 0; j < OH * OW * 2; j++) { | |||
| //! undifferentiable when map is an integer | |||
| ptr[j] = static_cast<float>(rand_real2(20)); | |||
| } | |||
| ptr += OH * OW * 2; | |||
| } | |||
| mgb_assert(ptr == mat.ptr<float>() + mat.shape().total_nr_elems()); | |||
| }; | |||
| Checker::RunOptions opt; | |||
| Checker(make_graph, fwd, CompNode::load("cpu1")) | |||
| .disable_grad_check() | |||
| .run({TensorShape{N, C, 3, 20}, TensorShape{N, 10, 10, 2}}, opt) | |||
| .run({TensorShape{N, C, 6, 5}, TensorShape{N, 10, 10, 2}}, opt) | |||
| .run({TensorShape{N, C, 20, 20}, TensorShape{N, 10, 10, 2}}, opt); | |||
| .set_input_generator(1, gen_mat) | |||
| .run({TensorShape{N, C, 3, 20}, TensorShape{N, OH, OW, 2}}, opt) | |||
| .run({TensorShape{N, C, 6, 5}, TensorShape{N, OH, OW, 2}}, opt) | |||
| .run({TensorShape{N, C, 20, 20}, TensorShape{N, OH, OW, 2}}, opt); | |||
| } | |||
| TEST(TestOprImgproc, Remap_NHWC) { | |||
| @@ -690,4 +712,5 @@ TEST(TestOprImgproc, Remap_NHWC) { | |||
| .run({TensorShape{N, 6, 5, C}, TensorShape{N, 10, 10, 2}}, opt) | |||
| .run({TensorShape{N, 20, 20, C}, TensorShape{N, 10, 10, 2}}, opt); | |||
| } | |||
| // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} | |||